p2p: integrate p2p/discover

Overview of changes:

- ClientIdentity has been removed, use discover.NodeID
- Server now requires a private key to be set (instead of public key)
- Server performs the encryption handshake before launching Peer
- Dial logic takes peers from discover table
- Encryption handshake code has been cleaned up a bit
- baseProtocol is gone because we don't exchange peers anymore
- Some parts of baseProtocol have moved into Peer instead
This commit is contained in:
Felix Lange 2015-02-05 03:07:58 +01:00
parent 739066ec56
commit 5bdc115943
15 changed files with 1080 additions and 1683 deletions

View File

@ -1,68 +0,0 @@
package p2p
import (
"fmt"
"runtime"
)
// ClientIdentity represents the identity of a peer.
type ClientIdentity interface {
String() string // human readable identity
Pubkey() []byte // 512-bit public key
}
type SimpleClientIdentity struct {
clientIdentifier string
version string
customIdentifier string
os string
implementation string
privkey []byte
pubkey []byte
}
func NewSimpleClientIdentity(clientIdentifier string, version string, customIdentifier string, pubkey []byte) *SimpleClientIdentity {
clientIdentity := &SimpleClientIdentity{
clientIdentifier: clientIdentifier,
version: version,
customIdentifier: customIdentifier,
os: runtime.GOOS,
implementation: runtime.Version(),
pubkey: pubkey,
}
return clientIdentity
}
func (c *SimpleClientIdentity) init() {
}
func (c *SimpleClientIdentity) String() string {
var id string
if len(c.customIdentifier) > 0 {
id = "/" + c.customIdentifier
}
return fmt.Sprintf("%s/v%s%s/%s/%s",
c.clientIdentifier,
c.version,
id,
c.os,
c.implementation)
}
func (c *SimpleClientIdentity) Privkey() []byte {
return c.privkey
}
func (c *SimpleClientIdentity) Pubkey() []byte {
return c.pubkey
}
func (c *SimpleClientIdentity) SetCustomIdentifier(customIdentifier string) {
c.customIdentifier = customIdentifier
}
func (c *SimpleClientIdentity) GetCustomIdentifier() string {
return c.customIdentifier
}

View File

@ -1,35 +0,0 @@
package p2p
import (
"bytes"
"fmt"
"runtime"
"testing"
)
func TestClientIdentity(t *testing.T) {
clientIdentity := NewSimpleClientIdentity("Ethereum(G)", "0.5.16", "test", []byte("pubkey"))
key := clientIdentity.Pubkey()
if !bytes.Equal(key, []byte("pubkey")) {
t.Errorf("Expected Pubkey to be %x, got %x", key, []byte("pubkey"))
}
clientString := clientIdentity.String()
expected := fmt.Sprintf("Ethereum(G)/v0.5.16/test/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
customIdentifier := clientIdentity.GetCustomIdentifier()
if customIdentifier != "test" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test', got %v", customIdentifier)
}
clientIdentity.SetCustomIdentifier("test2")
customIdentifier = clientIdentity.GetCustomIdentifier()
if customIdentifier != "test2" {
t.Errorf("Expected clientIdentity.GetCustomIdentifier() to be 'test2', got %v", customIdentifier)
}
clientString = clientIdentity.String()
expected = fmt.Sprintf("Ethereum(G)/v0.5.16/test2/%s/%s", runtime.GOOS, runtime.Version())
if clientString != expected {
t.Errorf("Expected clientIdentity to be %v, got %v", expected, clientString)
}
}

View File

@ -10,28 +10,25 @@ import (
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/crypto/secp256k1" "github.com/ethereum/go-ethereum/crypto/secp256k1"
ethlogger "github.com/ethereum/go-ethereum/logger" ethlogger "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/obscuren/ecies" "github.com/obscuren/ecies"
) )
var clogger = ethlogger.NewLogger("CRYPTOID") var clogger = ethlogger.NewLogger("CRYPTOID")
const ( const (
sskLen int = 16 // ecies.MaxSharedKeyLength(pubKey) / 2 sskLen = 16 // ecies.MaxSharedKeyLength(pubKey) / 2
sigLen int = 65 // elliptic S256 sigLen = 65 // elliptic S256
pubLen int = 64 // 512 bit pubkey in uncompressed representation without format byte pubLen = 64 // 512 bit pubkey in uncompressed representation without format byte
shaLen int = 32 // hash length (for nonce etc) shaLen = 32 // hash length (for nonce etc)
msgLen int = 194 // sigLen + shaLen + pubLen + shaLen + 1 = 194
resLen int = 97 // pubLen + shaLen + 1
iHSLen int = 307 // size of the final ECIES payload sent as initiator's handshake
rHSLen int = 210 // size of the final ECIES payload sent as receiver's handshake
)
// secretRW implements a message read writer with encryption and authentication authMsgLen = sigLen + shaLen + pubLen + shaLen + 1
// it is initialised by cryptoId.Run() after a successful crypto handshake authRespLen = pubLen + shaLen + 1
// aesSecret, macSecret, egressMac, ingress
type secretRW struct { eciesBytes = 65 + 16 + 32
aesSecret, macSecret, egressMac, ingressMac []byte iHSLen = authMsgLen + eciesBytes // size of the final ECIES payload sent as initiator's handshake
} rHSLen = authRespLen + eciesBytes // size of the final ECIES payload sent as receiver's handshake
)
type hexkey []byte type hexkey []byte
@ -39,150 +36,73 @@ func (self hexkey) String() string {
return fmt.Sprintf("(%d) %x", len(self), []byte(self)) return fmt.Sprintf("(%d) %x", len(self), []byte(self))
} }
var nonceF = func(b []byte) (n int, err error) { func encHandshake(conn io.ReadWriter, prv *ecdsa.PrivateKey, dial *discover.Node) (
return rand.Read(b) remoteID discover.NodeID,
} sessionToken []byte,
err error,
var step = 0 ) {
var detnonceF = func(b []byte) (n int, err error) { if dial == nil {
step++ var remotePubkey []byte
copy(b, crypto.Sha3([]byte("privacy"+string(step)))) sessionToken, remotePubkey, err = inboundEncHandshake(conn, prv, nil)
fmt.Printf("detkey %v: %v\n", step, hexkey(b)) copy(remoteID[:], remotePubkey)
return } else {
} remoteID = dial.ID
sessionToken, err = outboundEncHandshake(conn, prv, remoteID[:], nil)
var keyF = func() (priv *ecdsa.PrivateKey, err error) {
priv, err = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
if err != nil {
return
} }
return return remoteID, sessionToken, err
} }
var detkeyF = func() (priv *ecdsa.PrivateKey, err error) { // outboundEncHandshake negotiates a session token on conn.
s := make([]byte, 32) // it should be called on the dialing side of the connection.
detnonceF(s) //
priv = crypto.ToECDSA(s) // privateKey is the local client's private key
return // remotePublicKey is the remote peer's node ID
} // sessionToken is the token from a previous session with this node.
func outboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePublicKey []byte, sessionToken []byte) (
/* newSessionToken []byte,
NewSecureSession(connection, privateKey, remotePublicKey, sessionToken, initiator) is called when the peer connection starts to set up a secure session by performing a crypto handshake. err error,
) {
connection is (a buffered) network connection. auth, initNonce, randomPrivKey, err := authMsg(prvKey, remotePublicKey, sessionToken)
if err != nil {
privateKey is the local client's private key (*ecdsa.PrivateKey) return nil, err
remotePublicKey is the remote peer's node Id ([]byte)
sessionToken is the token from the previous session with this same peer. Nil if no token is found.
initiator is a boolean flag. True if the node is the initiator of the connection (ie., remote is an outbound peer reached by dialing out). False if the connection was established by accepting a call from the remote peer via a listener.
It returns a secretRW which implements the MsgReadWriter interface.
*/
func NewSecureSession(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, remotePubKeyS []byte, sessionToken []byte, initiator bool) (token []byte, rw *secretRW, err error) {
var auth, initNonce, recNonce []byte
var read int
var randomPrivKey *ecdsa.PrivateKey
var remoteRandomPubKey *ecdsa.PublicKey
clogger.Debugf("attempting session with %v", hexkey(remotePubKeyS))
if initiator {
if auth, initNonce, randomPrivKey, _, err = startHandshake(prvKey, remotePubKeyS, sessionToken); err != nil {
return
} }
if sessionToken != nil { if sessionToken != nil {
clogger.Debugf("session-token: %v", hexkey(sessionToken)) clogger.Debugf("session-token: %v", hexkey(sessionToken))
} }
clogger.Debugf("initiator-nonce: %v", hexkey(initNonce)) clogger.Debugf("initiator-nonce: %v", hexkey(initNonce))
clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey))) clogger.Debugf("initiator-random-private-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
randomPublicKeyS, _ := ExportPublicKey(&randomPrivKey.PublicKey) randomPublicKeyS, _ := exportPublicKey(&randomPrivKey.PublicKey)
clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS)) clogger.Debugf("initiator-random-public-key: %v", hexkey(randomPublicKeyS))
if _, err = conn.Write(auth); err != nil { if _, err = conn.Write(auth); err != nil {
return return nil, err
} }
clogger.Debugf("initiator handshake (sent to %v):\n%v", hexkey(remotePubKeyS), hexkey(auth)) clogger.Debugf("initiator handshake: %v", hexkey(auth))
var response []byte = make([]byte, rHSLen)
if read, err = conn.Read(response); err != nil || read == 0 { response := make([]byte, rHSLen)
return if _, err = io.ReadFull(conn, response); err != nil {
return nil, err
} }
if read != rHSLen { recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prvKey)
err = fmt.Errorf("remote receiver's handshake has invalid length. expect %v, got %v", rHSLen, read) if err != nil {
return return nil, err
}
// write out auth message
// wait for response, then call complete
if recNonce, remoteRandomPubKey, _, err = completeHandshake(response, prvKey); err != nil {
return
} }
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce)) clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
remoteRandomPubKeyS, _ := ExportPublicKey(remoteRandomPubKey) remoteRandomPubKeyS, _ := exportPublicKey(remoteRandomPubKey)
clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS)) clogger.Debugf("receiver-random-public-key: %v", hexkey(remoteRandomPubKeyS))
return newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
} else {
auth = make([]byte, iHSLen)
clogger.Debugf("waiting for initiator handshake (from %v)", hexkey(remotePubKeyS))
if read, err = conn.Read(auth); err != nil {
return
}
if read != iHSLen {
err = fmt.Errorf("remote initiator's handshake has invalid length. expect %v, got %v", iHSLen, read)
return
}
clogger.Debugf("received initiator handshake (from %v):\n%v", hexkey(remotePubKeyS), hexkey(auth))
// we are listening connection. we are responders in the handshake.
// Extract info from the authentication. The initiator starts by sending us a handshake that we need to respond to.
// so we read auth message first, then respond
var response []byte
if response, recNonce, initNonce, randomPrivKey, remoteRandomPubKey, err = respondToHandshake(auth, prvKey, remotePubKeyS, sessionToken); err != nil {
return
}
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
if _, err = conn.Write(response); err != nil {
return
}
clogger.Debugf("receiver handshake (sent to %v):\n%v", hexkey(remotePubKeyS), hexkey(response))
}
return newSession(initiator, initNonce, recNonce, auth, randomPrivKey, remoteRandomPubKey)
} }
/* // authMsg creates the initiator handshake.
ImportPublicKey creates a 512 bit *ecsda.PublicKey from a byte slice. It accepts the simple 64 byte uncompressed format or the 65 byte format given by calling elliptic.Marshal on the EC point represented by the key. Any other length will result in an invalid public key error. func authMsg(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (
*/ auth, initNonce []byte,
func ImportPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) { randomPrvKey *ecdsa.PrivateKey,
var pubKey65 []byte err error,
switch len(pubKey) { ) {
case 64:
pubKey65 = append([]byte{0x04}, pubKey...)
case 65:
pubKey65 = pubKey
default:
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
return crypto.ToECDSAPub(pubKey65), nil
}
/*
ExportPublicKey exports a *ecdsa.PublicKey into a byte slice using a simple 64-byte format. and is used for simple serialisation in network communication
*/
func ExportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
if pubKeyEC == nil {
return nil, fmt.Errorf("no ECDSA public key given")
}
return crypto.FromECDSAPub(pubKeyEC)[1:], nil
}
/* startHandshake is called by if the node is the initiator of the connection.
The caller provides the public key of the peer as conjuctured from lookup based on IP:port, given as user input or proven by signatures. The caller must have access to persistant information about the peers, and pass the previous session token as an argument to cryptoId.
The first return value is the auth message that is to be sent out to the remote receiver.
*/
func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (auth []byte, initNonce []byte, randomPrvKey *ecdsa.PrivateKey, remotePubKey *ecdsa.PublicKey, err error) {
// session init, common to both parties // session init, common to both parties
if remotePubKey, err = ImportPublicKey(remotePubKeyS); err != nil { remotePubKey, err := importPublicKey(remotePubKeyS)
if err != nil {
return return
} }
@ -203,20 +123,18 @@ func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte
//E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0) //E(remote-pubk, S(ecdhe-random, ecdh-shared-secret^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x0)
// E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1) // E(remote-pubk, S(ecdhe-random, token^nonce) || H(ecdhe-random-pubk) || pubk || nonce || 0x1)
// allocate msgLen long message, // allocate msgLen long message,
var msg []byte = make([]byte, msgLen) var msg []byte = make([]byte, authMsgLen)
initNonce = msg[msgLen-shaLen-1 : msgLen-1] initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
fmt.Printf("init-nonce: ") if _, err = rand.Read(initNonce); err != nil {
if _, err = nonceF(initNonce); err != nil {
return return
} }
// create known message // create known message
// ecdh-shared-secret^nonce for new peers // ecdh-shared-secret^nonce for new peers
// token^nonce for old peers // token^nonce for old peers
var sharedSecret = Xor(sessionToken, initNonce) var sharedSecret = xor(sessionToken, initNonce)
// generate random keypair to use for signing // generate random keypair to use for signing
fmt.Printf("init-random-ecdhe-private-key: ") if randomPrvKey, err = crypto.GenerateKey(); err != nil {
if randomPrvKey, err = keyF(); err != nil {
return return
} }
// sign shared secret (message known to both parties): shared-secret // sign shared secret (message known to both parties): shared-secret
@ -232,11 +150,11 @@ func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte
copy(msg, signature) // copy signed-shared-secret copy(msg, signature) // copy signed-shared-secret
// H(ecdhe-random-pubk) // H(ecdhe-random-pubk)
var randomPubKey64 []byte var randomPubKey64 []byte
if randomPubKey64, err = ExportPublicKey(&randomPrvKey.PublicKey); err != nil { if randomPubKey64, err = exportPublicKey(&randomPrvKey.PublicKey); err != nil {
return return
} }
var pubKey64 []byte var pubKey64 []byte
if pubKey64, err = ExportPublicKey(&prvKey.PublicKey); err != nil { if pubKey64, err = exportPublicKey(&prvKey.PublicKey); err != nil {
return return
} }
copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64)) copy(msg[sigLen:sigLen+shaLen], crypto.Sha3(randomPubKey64))
@ -244,36 +162,98 @@ func startHandshake(prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte
copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64) copy(msg[sigLen+shaLen:sigLen+shaLen+pubLen], pubKey64)
// nonce is already in the slice // nonce is already in the slice
// stick tokenFlag byte to the end // stick tokenFlag byte to the end
msg[msgLen-1] = tokenFlag msg[authMsgLen-1] = tokenFlag
// encrypt using remote-pubk // encrypt using remote-pubk
// auth = eciesEncrypt(remote-pubk, msg) // auth = eciesEncrypt(remote-pubk, msg)
if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil { if auth, err = crypto.Encrypt(remotePubKey, msg); err != nil {
return return
} }
return return
} }
/* // completeHandshake is called when the initiator receives an
respondToHandshake is called by peer if it accepted (but not initiated) the connection from the remote. It is passed the initiator handshake received, the public key and session token belonging to the remote initiator. // authentication response (aka receiver handshake). It completes the
// handshake by reading off parameters the remote peer provides needed
The first return value is the authentication response (aka receiver handshake) that is to be sent to the remote initiator. // to set up the secure session.
*/ func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (
func respondToHandshake(auth []byte, prvKey *ecdsa.PrivateKey, remotePubKeyS, sessionToken []byte) (authResp []byte, respNonce []byte, initNonce []byte, randomPrivKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey, err error) { respNonce []byte,
remoteRandomPubKey *ecdsa.PublicKey,
tokenFlag bool,
err error,
) {
var msg []byte var msg []byte
var remotePubKey *ecdsa.PublicKey
if remotePubKey, err = ImportPublicKey(remotePubKeyS); err != nil {
return
}
// they prove that msg is meant for me, // they prove that msg is meant for me,
// I prove I possess private key if i can read it // I prove I possess private key if i can read it
if msg, err = crypto.Decrypt(prvKey, auth); err != nil { if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
return return
} }
respNonce = msg[pubLen : pubLen+shaLen]
var remoteRandomPubKeyS = msg[:pubLen]
if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
return
}
if msg[authRespLen-1] == 0x01 {
tokenFlag = true
}
return
}
// inboundEncHandshake negotiates a session token on conn.
// it should be called on the listening side of the connection.
//
// privateKey is the local client's private key
// sessionToken is the token from a previous session with this node.
func inboundEncHandshake(conn io.ReadWriter, prvKey *ecdsa.PrivateKey, sessionToken []byte) (
token, remotePubKey []byte,
err error,
) {
// we are listening connection. we are responders in the
// handshake. Extract info from the authentication. The initiator
// starts by sending us a handshake that we need to respond to. so
// we read auth message first, then respond.
auth := make([]byte, iHSLen)
if _, err := io.ReadFull(conn, auth); err != nil {
return nil, nil, err
}
response, recNonce, initNonce, remotePubKey, randomPrivKey, remoteRandomPubKey, err := authResp(auth, sessionToken, prvKey)
if err != nil {
return nil, nil, err
}
clogger.Debugf("receiver-nonce: %v", hexkey(recNonce))
clogger.Debugf("receiver-random-priv-key: %v", hexkey(crypto.FromECDSA(randomPrivKey)))
if _, err = conn.Write(response); err != nil {
return nil, nil, err
}
clogger.Debugf("receiver handshake:\n%v", hexkey(response))
token, err = newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
return token, remotePubKey, err
}
// authResp is called by peer if it accepted (but not
// initiated) the connection from the remote. It is passed the initiator
// handshake received and the session token belonging to the
// remote initiator.
//
// The first return value is the authentication response (aka receiver
// handshake) that is to be sent to the remote initiator.
func authResp(auth, sessionToken []byte, prvKey *ecdsa.PrivateKey) (
authResp, respNonce, initNonce, remotePubKeyS []byte,
randomPrivKey *ecdsa.PrivateKey,
remoteRandomPubKey *ecdsa.PublicKey,
err error,
) {
// they prove that msg is meant for me,
// I prove I possess private key if i can read it
msg, err := crypto.Decrypt(prvKey, auth)
if err != nil {
return
}
remotePubKeyS = msg[sigLen+shaLen : sigLen+shaLen+pubLen]
remotePubKey, _ := importPublicKey(remotePubKeyS)
var tokenFlag byte var tokenFlag byte
if sessionToken == nil { if sessionToken == nil {
// no session token found means we need to generate shared secret. // no session token found means we need to generate shared secret.
@ -289,42 +269,42 @@ func respondToHandshake(auth []byte, prvKey *ecdsa.PrivateKey, remotePubKeyS, se
} }
// the initiator nonce is read off the end of the message // the initiator nonce is read off the end of the message
initNonce = msg[msgLen-shaLen-1 : msgLen-1] initNonce = msg[authMsgLen-shaLen-1 : authMsgLen-1]
// I prove that i own prv key (to derive shared secret, and read nonce off encrypted msg) and that I own shared secret // I prove that i own prv key (to derive shared secret, and read
// they prove they own the private key belonging to ecdhe-random-pubk // nonce off encrypted msg) and that I own shared secret they
// we can now reconstruct the signed message and recover the peers pubkey // prove they own the private key belonging to ecdhe-random-pubk
var signedMsg = Xor(sessionToken, initNonce) // we can now reconstruct the signed message and recover the peers
// pubkey
var signedMsg = xor(sessionToken, initNonce)
var remoteRandomPubKeyS []byte var remoteRandomPubKeyS []byte
if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil { if remoteRandomPubKeyS, err = secp256k1.RecoverPubkey(signedMsg, msg[:sigLen]); err != nil {
return return
} }
// convert to ECDSA standard // convert to ECDSA standard
if remoteRandomPubKey, err = ImportPublicKey(remoteRandomPubKeyS); err != nil { if remoteRandomPubKey, err = importPublicKey(remoteRandomPubKeyS); err != nil {
return return
} }
// now we find ourselves a long task too, fill it random // now we find ourselves a long task too, fill it random
var resp = make([]byte, resLen) var resp = make([]byte, authRespLen)
// generate shaLen long nonce // generate shaLen long nonce
respNonce = resp[pubLen : pubLen+shaLen] respNonce = resp[pubLen : pubLen+shaLen]
fmt.Printf("rec-nonce: ") if _, err = rand.Read(respNonce); err != nil {
if _, err = nonceF(respNonce); err != nil {
return return
} }
// generate random keypair for session // generate random keypair for session
fmt.Printf("rec-random-ecdhe-private-key: ") if randomPrivKey, err = crypto.GenerateKey(); err != nil {
if randomPrivKey, err = keyF(); err != nil {
return return
} }
// responder auth message // responder auth message
// E(remote-pubk, ecdhe-random-pubk || nonce || 0x0) // E(remote-pubk, ecdhe-random-pubk || nonce || 0x0)
var randomPubKeyS []byte var randomPubKeyS []byte
if randomPubKeyS, err = ExportPublicKey(&randomPrivKey.PublicKey); err != nil { if randomPubKeyS, err = exportPublicKey(&randomPrivKey.PublicKey); err != nil {
return return
} }
copy(resp[:pubLen], randomPubKeyS) copy(resp[:pubLen], randomPubKeyS)
// nonce is already in the slice // nonce is already in the slice
resp[resLen-1] = tokenFlag resp[authRespLen-1] = tokenFlag
// encrypt using remote-pubk // encrypt using remote-pubk
// auth = eciesEncrypt(remote-pubk, msg) // auth = eciesEncrypt(remote-pubk, msg)
@ -335,70 +315,49 @@ func respondToHandshake(auth []byte, prvKey *ecdsa.PrivateKey, remotePubKeyS, se
return return
} }
/* // newSession is called after the handshake is completed. The
completeHandshake is called when the initiator receives an authentication response (aka receiver handshake). It completes the handshake by reading off parameters the remote peer provides needed to set up the secure session // arguments are values negotiated in the handshake. The return value
*/ // is a new session Token to be remembered for the next time we
func completeHandshake(auth []byte, prvKey *ecdsa.PrivateKey) (respNonce []byte, remoteRandomPubKey *ecdsa.PublicKey, tokenFlag bool, err error) { // connect with this peer.
var msg []byte func newSession(initNonce, respNonce []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) ([]byte, error) {
// they prove that msg is meant for me,
// I prove I possess private key if i can read it
if msg, err = crypto.Decrypt(prvKey, auth); err != nil {
return
}
respNonce = msg[pubLen : pubLen+shaLen]
var remoteRandomPubKeyS = msg[:pubLen]
if remoteRandomPubKey, err = ImportPublicKey(remoteRandomPubKeyS); err != nil {
return
}
if msg[resLen-1] == 0x01 {
tokenFlag = true
}
return
}
/*
newSession is called after the handshake is completed. The arguments are values negotiated in the handshake and the return value is a new session : a new session Token to be remembered for the next time we connect with this peer. And a MsgReadWriter that implements an encrypted and authenticated connection with key material obtained from the crypto handshake key exchange
*/
func newSession(initiator bool, initNonce, respNonce, auth []byte, privKey *ecdsa.PrivateKey, remoteRandomPubKey *ecdsa.PublicKey) (sessionToken []byte, rw *secretRW, err error) {
// 3) Now we can trust ecdhe-random-pubk to derive new keys // 3) Now we can trust ecdhe-random-pubk to derive new keys
//ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk) //ecdhe-shared-secret = ecdh.agree(ecdhe-random, remote-ecdhe-random-pubk)
var dhSharedSecret []byte
pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey) pubKey := ecies.ImportECDSAPublic(remoteRandomPubKey)
if dhSharedSecret, err = ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen); err != nil { dhSharedSecret, err := ecies.ImportECDSA(privKey).GenerateShared(pubKey, sskLen, sskLen)
return if err != nil {
return nil, err
} }
var sharedSecret = crypto.Sha3(append(dhSharedSecret, crypto.Sha3(append(respNonce, initNonce...))...)) sharedSecret := crypto.Sha3(dhSharedSecret, crypto.Sha3(respNonce, initNonce))
sessionToken = crypto.Sha3(sharedSecret) sessionToken := crypto.Sha3(sharedSecret)
var aesSecret = crypto.Sha3(append(dhSharedSecret, sharedSecret...)) return sessionToken, nil
var macSecret = crypto.Sha3(append(dhSharedSecret, aesSecret...))
var egressMac, ingressMac []byte
if initiator {
egressMac = Xor(macSecret, respNonce)
ingressMac = Xor(macSecret, initNonce)
} else {
egressMac = Xor(macSecret, initNonce)
ingressMac = Xor(macSecret, respNonce)
}
rw = &secretRW{
aesSecret: aesSecret,
macSecret: macSecret,
egressMac: egressMac,
ingressMac: ingressMac,
}
clogger.Debugf("aes-secret: %v", hexkey(aesSecret))
clogger.Debugf("mac-secret: %v", hexkey(macSecret))
clogger.Debugf("egress-mac: %v", hexkey(egressMac))
clogger.Debugf("ingress-mac: %v", hexkey(ingressMac))
return
} }
// TODO: optimisation // importPublicKey unmarshals 512 bit public keys.
// should use cipher.xorBytes from crypto/cipher/xor.go for fast xor func importPublicKey(pubKey []byte) (pubKeyEC *ecdsa.PublicKey, err error) {
func Xor(one, other []byte) (xor []byte) { var pubKey65 []byte
switch len(pubKey) {
case 64:
// add 'uncompressed key' flag
pubKey65 = append([]byte{0x04}, pubKey...)
case 65:
pubKey65 = pubKey
default:
return nil, fmt.Errorf("invalid public key length %v (expect 64/65)", len(pubKey))
}
return crypto.ToECDSAPub(pubKey65), nil
}
func exportPublicKey(pubKeyEC *ecdsa.PublicKey) (pubKey []byte, err error) {
if pubKeyEC == nil {
return nil, fmt.Errorf("no ECDSA public key given")
}
return crypto.FromECDSAPub(pubKeyEC)[1:], nil
}
func xor(one, other []byte) (xor []byte) {
xor = make([]byte, len(one)) xor = make([]byte, len(one))
for i := 0; i < len(one); i++ { for i := 0; i < len(one); i++ {
xor[i] = one[i] ^ other[i] xor[i] = one[i] ^ other[i]
} }
return return xor
} }

View File

@ -3,10 +3,9 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa" "crypto/ecdsa"
"fmt" "crypto/rand"
"net" "net"
"testing" "testing"
"time"
"github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/crypto"
"github.com/obscuren/ecies" "github.com/obscuren/ecies"
@ -16,7 +15,7 @@ func TestPublicKeyEncoding(t *testing.T) {
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader) prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey pub0 := &prv0.PublicKey
pub0s := crypto.FromECDSAPub(pub0) pub0s := crypto.FromECDSAPub(pub0)
pub1, err := ImportPublicKey(pub0s) pub1, err := importPublicKey(pub0s)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
@ -24,18 +23,18 @@ func TestPublicKeyEncoding(t *testing.T) {
if eciesPub1 == nil { if eciesPub1 == nil {
t.Errorf("invalid ecdsa public key") t.Errorf("invalid ecdsa public key")
} }
pub1s, err := ExportPublicKey(pub1) pub1s, err := exportPublicKey(pub1)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
if len(pub1s) != 64 { if len(pub1s) != 64 {
t.Errorf("wrong length expect 64, got", len(pub1s)) t.Errorf("wrong length expect 64, got", len(pub1s))
} }
pub2, err := ImportPublicKey(pub1s) pub2, err := importPublicKey(pub1s)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
pub2s, err := ExportPublicKey(pub2) pub2s, err := exportPublicKey(pub2)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
@ -69,95 +68,53 @@ func TestSharedSecret(t *testing.T) {
} }
func TestCryptoHandshake(t *testing.T) { func TestCryptoHandshake(t *testing.T) {
testCryptoHandshakeWithGen(false, t) testCryptoHandshake(newkey(), newkey(), nil, t)
} }
func TestTokenCryptoHandshake(t *testing.T) { func TestCryptoHandshakeWithToken(t *testing.T) {
testCryptoHandshakeWithGen(true, t) sessionToken := make([]byte, shaLen)
} rand.Read(sessionToken)
testCryptoHandshake(newkey(), newkey(), sessionToken, t)
func TestDetCryptoHandshake(t *testing.T) {
defer testlog(t).detach()
tmpkeyF := keyF
keyF = detkeyF
tmpnonceF := nonceF
nonceF = detnonceF
testCryptoHandshakeWithGen(false, t)
keyF = tmpkeyF
nonceF = tmpnonceF
}
func TestDetTokenCryptoHandshake(t *testing.T) {
defer testlog(t).detach()
tmpkeyF := keyF
keyF = detkeyF
tmpnonceF := nonceF
nonceF = detnonceF
testCryptoHandshakeWithGen(true, t)
keyF = tmpkeyF
nonceF = tmpnonceF
}
func testCryptoHandshakeWithGen(token bool, t *testing.T) {
fmt.Printf("init-private-key: ")
prv0, err := keyF()
if err != nil {
t.Errorf("%v", err)
return
}
fmt.Printf("rec-private-key: ")
prv1, err := keyF()
if err != nil {
t.Errorf("%v", err)
return
}
var nonce []byte
if token {
fmt.Printf("session-token: ")
nonce = make([]byte, shaLen)
nonceF(nonce)
}
testCryptoHandshake(prv0, prv1, nonce, t)
} }
func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) { func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *testing.T) {
var err error var err error
pub0 := &prv0.PublicKey // pub0 := &prv0.PublicKey
pub1 := &prv1.PublicKey pub1 := &prv1.PublicKey
pub0s := crypto.FromECDSAPub(pub0) // pub0s := crypto.FromECDSAPub(pub0)
pub1s := crypto.FromECDSAPub(pub1) pub1s := crypto.FromECDSAPub(pub1)
// simulate handshake by feeding output to input // simulate handshake by feeding output to input
// initiator sends handshake 'auth' // initiator sends handshake 'auth'
auth, initNonce, randomPrivKey, _, err := startHandshake(prv0, pub1s, sessionToken) auth, initNonce, randomPrivKey, err := authMsg(prv0, pub1s, sessionToken)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
fmt.Printf("-> %v\n", hexkey(auth)) t.Logf("-> %v", hexkey(auth))
// receiver reads auth and responds with response // receiver reads auth and responds with response
response, remoteRecNonce, remoteInitNonce, remoteRandomPrivKey, remoteInitRandomPubKey, err := respondToHandshake(auth, prv1, pub0s, sessionToken) response, remoteRecNonce, remoteInitNonce, _, remoteRandomPrivKey, remoteInitRandomPubKey, err := authResp(auth, sessionToken, prv1)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("%v", err)
} }
fmt.Printf("<- %v\n", hexkey(response)) t.Logf("<- %v\n", hexkey(response))
// initiator reads receiver's response and the key exchange completes // initiator reads receiver's response and the key exchange completes
recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0) recNonce, remoteRandomPubKey, _, err := completeHandshake(response, prv0)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("completeHandshake error: %v", err)
} }
// now both parties should have the same session parameters // now both parties should have the same session parameters
initSessionToken, initSecretRW, err := newSession(true, initNonce, recNonce, auth, randomPrivKey, remoteRandomPubKey) initSessionToken, err := newSession(initNonce, recNonce, randomPrivKey, remoteRandomPubKey)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("newSession error: %v", err)
} }
recSessionToken, recSecretRW, err := newSession(false, remoteInitNonce, remoteRecNonce, auth, remoteRandomPrivKey, remoteInitRandomPubKey) recSessionToken, err := newSession(remoteInitNonce, remoteRecNonce, remoteRandomPrivKey, remoteInitRandomPubKey)
if err != nil { if err != nil {
t.Errorf("%v", err) t.Errorf("newSession error: %v", err)
} }
// fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response) // fmt.Printf("\nauth (%v) %x\n\nresp (%v) %x\n\n", len(auth), auth, len(response), response)
@ -173,76 +130,38 @@ func testCryptoHandshake(prv0, prv1 *ecdsa.PrivateKey, sessionToken []byte, t *t
if !bytes.Equal(initSessionToken, recSessionToken) { if !bytes.Equal(initSessionToken, recSessionToken) {
t.Errorf("session tokens do not match") t.Errorf("session tokens do not match")
} }
// aesSecret, macSecret, egressMac, ingressMac
if !bytes.Equal(initSecretRW.aesSecret, recSecretRW.aesSecret) {
t.Errorf("AES secrets do not match")
}
if !bytes.Equal(initSecretRW.macSecret, recSecretRW.macSecret) {
t.Errorf("macSecrets do not match")
}
if !bytes.Equal(initSecretRW.egressMac, recSecretRW.ingressMac) {
t.Errorf("initiator's egressMac do not match receiver's ingressMac")
}
if !bytes.Equal(initSecretRW.ingressMac, recSecretRW.egressMac) {
t.Errorf("initiator's inressMac do not match receiver's egressMac")
}
} }
func TestPeersHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
var err error
// var sessionToken []byte prv0, _ := crypto.GenerateKey()
prv0, _ := crypto.GenerateKey() // = ecdsa.GenerateKey(crypto.S256(), rand.Reader)
pub0 := &prv0.PublicKey
prv1, _ := crypto.GenerateKey() prv1, _ := crypto.GenerateKey()
pub1 := &prv1.PublicKey pub0s, _ := exportPublicKey(&prv0.PublicKey)
pub1s, _ := exportPublicKey(&prv1.PublicKey)
rw0, rw1 := net.Pipe()
tokens := make(chan []byte)
prv0s := crypto.FromECDSA(prv0)
pub0s := crypto.FromECDSAPub(pub0)
prv1s := crypto.FromECDSA(prv1)
pub1s := crypto.FromECDSAPub(pub1)
conn1, conn2 := net.Pipe()
initiator := newPeer(conn1, []Protocol{}, nil)
receiver := newPeer(conn2, []Protocol{}, nil)
initiator.dialAddr = &peerAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: pub1s[1:]}
initiator.privateKey = prv0s
// this is cheating. identity of initiator/dialler not available to listener/receiver
// its public key should be looked up based on IP address
receiver.identity = &peerId{nil, pub0s}
receiver.privateKey = prv1s
initiator.pubkeyHook = func(*peerAddr) error { return nil }
receiver.pubkeyHook = func(*peerAddr) error { return nil }
initiator.cryptoHandshake = true
receiver.cryptoHandshake = true
errc0 := make(chan error, 1)
errc1 := make(chan error, 1)
go func() { go func() {
_, err := initiator.loop() token, err := outboundEncHandshake(rw0, prv0, pub1s, nil)
errc0 <- err if err != nil {
t.Errorf("outbound side error: %v", err)
}
tokens <- token
}() }()
go func() { go func() {
_, err := receiver.loop() token, remotePubkey, err := inboundEncHandshake(rw1, prv1, nil)
errc1 <- err if err != nil {
t.Errorf("inbound side error: %v", err)
}
if !bytes.Equal(remotePubkey, pub0s) {
t.Errorf("inbound side returned wrong remote pubkey\n got: %x\n want: %x", remotePubkey, pub0s)
}
tokens <- token
}() }()
ready := make(chan bool)
go func() { t1, t2 := <-tokens, <-tokens
<-initiator.cryptoReady if !bytes.Equal(t1, t2) {
<-receiver.cryptoReady t.Error("session token mismatch")
close(ready)
}()
timeout := time.After(10 * time.Second)
select {
case <-ready:
case <-timeout:
t.Errorf("crypto handshake hanging for too long")
case err = <-errc0:
t.Errorf("peer 0 quit with error: %v", err)
case err = <-errc1:
t.Errorf("peer 1 quit with error: %v", err)
} }
} }

View File

@ -1,6 +1,7 @@
package p2p package p2p
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
@ -8,7 +9,10 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"net"
"sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/ethereum/go-ethereum/ethutil" "github.com/ethereum/go-ethereum/ethutil"
"github.com/ethereum/go-ethereum/rlp" "github.com/ethereum/go-ethereum/rlp"
@ -74,11 +78,14 @@ type MsgWriter interface {
// WriteMsg sends a message. It will block until the message's // WriteMsg sends a message. It will block until the message's
// Payload has been consumed by the other end. // Payload has been consumed by the other end.
// //
// Note that messages can be sent only once. // Note that messages can be sent only once because their
// payload reader is drained.
WriteMsg(Msg) error WriteMsg(Msg) error
} }
// MsgReadWriter provides reading and writing of encoded messages. // MsgReadWriter provides reading and writing of encoded messages.
// Implementations should ensure that ReadMsg and WriteMsg can be
// called simultaneously from multiple goroutines.
type MsgReadWriter interface { type MsgReadWriter interface {
MsgReader MsgReader
MsgWriter MsgWriter
@ -90,8 +97,45 @@ func EncodeMsg(w MsgWriter, code uint64, data ...interface{}) error {
return w.WriteMsg(NewMsg(code, data...)) return w.WriteMsg(NewMsg(code, data...))
} }
// frameRW is a MsgReadWriter that reads and writes devp2p message frames.
// As required by the interface, ReadMsg and WriteMsg can be called from
// multiple goroutines.
type frameRW struct {
net.Conn // make Conn methods available. be careful.
bufconn *bufio.ReadWriter
// this channel is used to 'lend' bufconn to a caller of ReadMsg
// until the message payload has been consumed. the channel
// receives a value when EOF is reached on the payload, unblocking
// a pending call to ReadMsg.
rsync chan struct{}
// this mutex guards writes to bufconn.
writeMu sync.Mutex
}
func newFrameRW(conn net.Conn, timeout time.Duration) *frameRW {
rsync := make(chan struct{}, 1)
rsync <- struct{}{}
return &frameRW{
Conn: conn,
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
rsync: rsync,
}
}
var magicToken = []byte{34, 64, 8, 145} var magicToken = []byte{34, 64, 8, 145}
func (rw *frameRW) WriteMsg(msg Msg) error {
rw.writeMu.Lock()
defer rw.writeMu.Unlock()
rw.SetWriteDeadline(time.Now().Add(msgWriteTimeout))
if err := writeMsg(rw.bufconn, msg); err != nil {
return err
}
return rw.bufconn.Flush()
}
func writeMsg(w io.Writer, msg Msg) error { func writeMsg(w io.Writer, msg Msg) error {
// TODO: handle case when Size + len(code) + len(listhdr) overflows uint32 // TODO: handle case when Size + len(code) + len(listhdr) overflows uint32
code := ethutil.Encode(uint32(msg.Code)) code := ethutil.Encode(uint32(msg.Code))
@ -120,12 +164,16 @@ func makeListHeader(length uint32) []byte {
return append([]byte{lenb}, enc...) return append([]byte{lenb}, enc...)
} }
// readMsg reads a message header from r. func (rw *frameRW) ReadMsg() (msg Msg, err error) {
// It takes an rlp.ByteReader to ensure that the decoding doesn't buffer. <-rw.rsync // wait until bufconn is ours
func readMsg(r rlp.ByteReader) (msg Msg, err error) {
// this read timeout applies also to the payload.
// TODO: proper read timeout
rw.SetReadDeadline(time.Now().Add(msgReadTimeout))
// read magic and payload size // read magic and payload size
start := make([]byte, 8) start := make([]byte, 8)
if _, err = io.ReadFull(r, start); err != nil { if _, err = io.ReadFull(rw.bufconn, start); err != nil {
return msg, newPeerError(errRead, "%v", err) return msg, newPeerError(errRead, "%v", err)
} }
if !bytes.HasPrefix(start, magicToken) { if !bytes.HasPrefix(start, magicToken) {
@ -134,17 +182,33 @@ func readMsg(r rlp.ByteReader) (msg Msg, err error) {
size := binary.BigEndian.Uint32(start[4:]) size := binary.BigEndian.Uint32(start[4:])
// decode start of RLP message to get the message code // decode start of RLP message to get the message code
posr := &postrack{r, 0} posr := &postrack{rw.bufconn, 0}
s := rlp.NewStream(posr) s := rlp.NewStream(posr)
if _, err := s.List(); err != nil { if _, err := s.List(); err != nil {
return msg, err return msg, err
} }
code, err := s.Uint() msg.Code, err = s.Uint()
if err != nil { if err != nil {
return msg, err return msg, err
} }
payloadsize := size - posr.p msg.Size = size - posr.p
return Msg{code, payloadsize, io.LimitReader(r, int64(payloadsize))}, nil
if msg.Size <= wholePayloadSize {
// msg is small, read all of it and move on to the next message.
pbuf := make([]byte, msg.Size)
if _, err := io.ReadFull(rw.bufconn, pbuf); err != nil {
return msg, err
}
rw.rsync <- struct{}{} // bufconn is available again
msg.Payload = bytes.NewReader(pbuf)
} else {
// lend bufconn to the caller until it has
// consumed the payload. eofSignal will send a value
// on rw.rsync when EOF is reached.
pr := &eofSignal{rw.bufconn, msg.Size, rw.rsync}
msg.Payload = pr
}
return msg, nil
} }
// postrack wraps an rlp.ByteReader with a position counter. // postrack wraps an rlp.ByteReader with a position counter.
@ -167,6 +231,39 @@ func (r *postrack) ReadByte() (byte, error) {
return b, err return b, err
} }
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
type eofSignal struct {
wrapped io.Reader
count uint32 // number of bytes left
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
if r.count == 0 {
if r.eof != nil {
r.eof <- struct{}{}
r.eof = nil
}
return 0, io.EOF
}
max := len(buf)
if int(r.count) < len(buf) {
max = int(r.count)
}
n, err := r.wrapped.Read(buf[:max])
r.count -= uint32(n)
if (err != nil || r.count == 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}
// MsgPipe creates a message pipe. Reads on one end are matched // MsgPipe creates a message pipe. Reads on one end are matched
// with writes on the other. The pipe is full-duplex, both ends // with writes on the other. The pipe is full-duplex, both ends
// implement MsgReadWriter. // implement MsgReadWriter.
@ -198,7 +295,7 @@ type MsgPipeRW struct {
func (p *MsgPipeRW) WriteMsg(msg Msg) error { func (p *MsgPipeRW) WriteMsg(msg Msg) error {
if atomic.LoadInt32(p.closed) == 0 { if atomic.LoadInt32(p.closed) == 0 {
consumed := make(chan struct{}, 1) consumed := make(chan struct{}, 1)
msg.Payload = &eofSignal{msg.Payload, int64(msg.Size), consumed} msg.Payload = &eofSignal{msg.Payload, msg.Size, consumed}
select { select {
case p.w <- msg: case p.w <- msg:
if msg.Size > 0 { if msg.Size > 0 {

View File

@ -3,12 +3,11 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"runtime" "runtime"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/ethutil"
) )
func TestNewMsg(t *testing.T) { func TestNewMsg(t *testing.T) {
@ -26,51 +25,51 @@ func TestNewMsg(t *testing.T) {
} }
} }
func TestEncodeDecodeMsg(t *testing.T) { // func TestEncodeDecodeMsg(t *testing.T) {
msg := NewMsg(3, 1, "000") // msg := NewMsg(3, 1, "000")
buf := new(bytes.Buffer) // buf := new(bytes.Buffer)
if err := writeMsg(buf, msg); err != nil { // if err := writeMsg(buf, msg); err != nil {
t.Fatalf("encodeMsg error: %v", err) // t.Fatalf("encodeMsg error: %v", err)
} // }
// t.Logf("encoded: %x", buf.Bytes()) // // t.Logf("encoded: %x", buf.Bytes())
decmsg, err := readMsg(buf) // decmsg, err := readMsg(buf)
if err != nil { // if err != nil {
t.Fatalf("readMsg error: %v", err) // t.Fatalf("readMsg error: %v", err)
} // }
if decmsg.Code != 3 { // if decmsg.Code != 3 {
t.Errorf("incorrect code %d, want %d", decmsg.Code, 3) // t.Errorf("incorrect code %d, want %d", decmsg.Code, 3)
} // }
if decmsg.Size != 5 { // if decmsg.Size != 5 {
t.Errorf("incorrect size %d, want %d", decmsg.Size, 5) // t.Errorf("incorrect size %d, want %d", decmsg.Size, 5)
} // }
var data struct { // var data struct {
I uint // I uint
S string // S string
} // }
if err := decmsg.Decode(&data); err != nil { // if err := decmsg.Decode(&data); err != nil {
t.Fatalf("Decode error: %v", err) // t.Fatalf("Decode error: %v", err)
} // }
if data.I != 1 { // if data.I != 1 {
t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1) // t.Errorf("incorrect data.I: got %v, expected %d", data.I, 1)
} // }
if data.S != "000" { // if data.S != "000" {
t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000") // t.Errorf("incorrect data.S: got %q, expected %q", data.S, "000")
} // }
} // }
func TestDecodeRealMsg(t *testing.T) { // func TestDecodeRealMsg(t *testing.T) {
data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb") // data := ethutil.Hex2Bytes("2240089100000080f87e8002b5457468657265756d282b2b292f5065657220536572766572204f6e652f76302e372e382f52656c656173652f4c696e75782f672b2bc082765fb84086dd80b7aefd6a6d2e3b93f4f300a86bfb6ef7bdc97cb03f793db6bb")
msg, err := readMsg(bytes.NewReader(data)) // msg, err := readMsg(bytes.NewReader(data))
if err != nil { // if err != nil {
t.Fatalf("unexpected error: %v", err) // t.Fatalf("unexpected error: %v", err)
} // }
if msg.Code != 0 { // if msg.Code != 0 {
t.Errorf("incorrect code %d, want %d", msg.Code, 0) // t.Errorf("incorrect code %d, want %d", msg.Code, 0)
} // }
} // }
func ExampleMsgPipe() { func ExampleMsgPipe() {
rw1, rw2 := MsgPipe() rw1, rw2 := MsgPipe()
@ -131,3 +130,58 @@ func TestMsgPipeConcurrentClose(t *testing.T) {
go rw1.Close() go rw1.Close()
} }
} }
func TestEOFSignal(t *testing.T) {
rb := make([]byte, 10)
// empty reader
eof := make(chan struct{}, 1)
sig := &eofSignal{new(bytes.Buffer), 0, eof}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
}
}

View File

@ -1,10 +1,6 @@
package p2p package p2p
import ( import (
"bufio"
"bytes"
"crypto/ecdsa"
"crypto/rand"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -13,179 +9,118 @@ import (
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/event"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
) )
// peerAddr is the structure of a peer list element. const (
// It is also a valid net.Addr. // maximum amount of time allowed for reading a message
type peerAddr struct { msgReadTimeout = 5 * time.Second
IP net.IP // maximum amount of time allowed for writing a message
Port uint64 msgWriteTimeout = 5 * time.Second
Pubkey []byte // optional // messages smaller than this many bytes will be read at
// once before passing them to a protocol.
wholePayloadSize = 64 * 1024
disconnectGracePeriod = 2 * time.Second
)
const (
baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the RLP structure of the protocol handshake.
type handshake struct {
Version uint64
Name string
Caps []Cap
ListenPort uint64
NodeID discover.NodeID
} }
func newPeerAddr(addr net.Addr, pubkey []byte) *peerAddr { // Peer represents a connected remote node.
n := addr.Network()
if n != "tcp" && n != "tcp4" && n != "tcp6" {
// for testing with non-TCP
return &peerAddr{net.ParseIP("127.0.0.1"), 30303, pubkey}
}
ta := addr.(*net.TCPAddr)
return &peerAddr{ta.IP, uint64(ta.Port), pubkey}
}
func (d peerAddr) Network() string {
if d.IP.To4() != nil {
return "tcp4"
} else {
return "tcp6"
}
}
func (d peerAddr) String() string {
return fmt.Sprintf("%v:%d", d.IP, d.Port)
}
func (d *peerAddr) RlpData() interface{} {
return []interface{}{string(d.IP), d.Port, d.Pubkey}
}
// Peer represents a remote peer.
type Peer struct { type Peer struct {
// Peers have all the log methods. // Peers have all the log methods.
// Use them to display messages related to the peer. // Use them to display messages related to the peer.
*logger.Logger *logger.Logger
infolock sync.Mutex infoMu sync.Mutex
identity ClientIdentity name string
caps []Cap caps []Cap
listenAddr *peerAddr // what remote peer is listening on
dialAddr *peerAddr // non-nil if dialing
// The mutex protects the connection ourID, remoteID *discover.NodeID
// so only one protocol can write at a time. ourName string
writeMu sync.Mutex
conn net.Conn rw *frameRW
bufconn *bufio.ReadWriter
// These fields maintain the running protocols. // These fields maintain the running protocols.
protocols []Protocol protocols []Protocol
runBaseProtocol bool // for testing
cryptoHandshake bool // for testing
cryptoReady chan struct{}
privateKey []byte
runlock sync.RWMutex // protects running runlock sync.RWMutex // protects running
running map[string]*proto running map[string]*proto
protocolHandshakeEnabled bool
protoWG sync.WaitGroup protoWG sync.WaitGroup
protoErr chan error protoErr chan error
closed chan struct{} closed chan struct{}
disc chan DiscReason disc chan DiscReason
activity event.TypeMux // for activity events
slot int // index into Server peer list
// These fields are kept so base protocol can access them.
// TODO: this should be one or more interfaces
ourID ClientIdentity // client id of the Server
ourListenAddr *peerAddr // listen addr of Server, nil if not listening
newPeerAddr chan<- *peerAddr // tell server about received peers
otherPeers func() []*Peer // should return the list of all peers
pubkeyHook func(*peerAddr) error // called at end of handshake to validate pubkey
} }
// NewPeer returns a peer for testing purposes. // NewPeer returns a peer for testing purposes.
func NewPeer(id ClientIdentity, caps []Cap) *Peer { func NewPeer(id discover.NodeID, name string, caps []Cap) *Peer {
conn, _ := net.Pipe() conn, _ := net.Pipe()
peer := newPeer(conn, nil, nil) peer := newPeer(conn, nil, "", nil, &id)
peer.setHandshakeInfo(id, nil, caps) peer.setHandshakeInfo(name, caps)
close(peer.closed) close(peer.closed) // ensures Disconnect doesn't block
return peer return peer
} }
func newServerPeer(server *Server, conn net.Conn, dialAddr *peerAddr) *Peer { // ID returns the node's public key.
p := newPeer(conn, server.Protocols, dialAddr) func (p *Peer) ID() discover.NodeID {
p.ourID = server.Identity return *p.remoteID
p.newPeerAddr = server.peerConnect
p.otherPeers = server.Peers
p.pubkeyHook = server.verifyPeer
p.runBaseProtocol = true
// laddr can be updated concurrently by NAT traversal.
// newServerPeer must be called with the server lock held.
if server.laddr != nil {
p.ourListenAddr = newPeerAddr(server.laddr, server.Identity.Pubkey())
}
return p
} }
func newPeer(conn net.Conn, protocols []Protocol, dialAddr *peerAddr) *Peer { // Name returns the node name that the remote node advertised.
p := &Peer{ func (p *Peer) Name() string {
Logger: logger.NewLogger("P2P " + conn.RemoteAddr().String()), // this needs a lock because the information is part of the
conn: conn, // protocol handshake.
dialAddr: dialAddr, p.infoMu.Lock()
bufconn: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), name := p.name
protocols: protocols, p.infoMu.Unlock()
running: make(map[string]*proto), return name
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
cryptoReady: make(chan struct{}),
}
return p
}
// Identity returns the client identity of the remote peer. The
// identity can be nil if the peer has not yet completed the
// handshake.
func (p *Peer) Identity() ClientIdentity {
p.infolock.Lock()
defer p.infolock.Unlock()
return p.identity
}
func (self *Peer) Pubkey() (pubkey []byte) {
self.infolock.Lock()
defer self.infolock.Unlock()
switch {
case self.identity != nil:
pubkey = self.identity.Pubkey()[1:]
case self.dialAddr != nil:
pubkey = self.dialAddr.Pubkey
case self.listenAddr != nil:
pubkey = self.listenAddr.Pubkey
}
return
} }
// Caps returns the capabilities (supported subprotocols) of the remote peer. // Caps returns the capabilities (supported subprotocols) of the remote peer.
func (p *Peer) Caps() []Cap { func (p *Peer) Caps() []Cap {
p.infolock.Lock() // this needs a lock because the information is part of the
defer p.infolock.Unlock() // protocol handshake.
return p.caps p.infoMu.Lock()
} caps := p.caps
p.infoMu.Unlock()
func (p *Peer) setHandshakeInfo(id ClientIdentity, laddr *peerAddr, caps []Cap) { return caps
p.infolock.Lock()
p.identity = id
p.listenAddr = laddr
p.caps = caps
p.infolock.Unlock()
} }
// RemoteAddr returns the remote address of the network connection. // RemoteAddr returns the remote address of the network connection.
func (p *Peer) RemoteAddr() net.Addr { func (p *Peer) RemoteAddr() net.Addr {
return p.conn.RemoteAddr() return p.rw.RemoteAddr()
} }
// LocalAddr returns the local address of the network connection. // LocalAddr returns the local address of the network connection.
func (p *Peer) LocalAddr() net.Addr { func (p *Peer) LocalAddr() net.Addr {
return p.conn.LocalAddr() return p.rw.LocalAddr()
} }
// Disconnect terminates the peer connection with the given reason. // Disconnect terminates the peer connection with the given reason.
@ -199,201 +134,167 @@ func (p *Peer) Disconnect(reason DiscReason) {
// String implements fmt.Stringer. // String implements fmt.Stringer.
func (p *Peer) String() string { func (p *Peer) String() string {
kind := "inbound" return fmt.Sprintf("Peer %.8x %v", p.remoteID, p.RemoteAddr())
p.infolock.Lock()
if p.dialAddr != nil {
kind = "outbound"
}
p.infolock.Unlock()
return fmt.Sprintf("Peer(%p %v %s)", p, p.conn.RemoteAddr(), kind)
} }
const ( func newPeer(conn net.Conn, protocols []Protocol, ourName string, ourID, remoteID *discover.NodeID) *Peer {
// maximum amount of time allowed for reading a message logtag := fmt.Sprintf("Peer %.8x %v", remoteID, conn.RemoteAddr())
msgReadTimeout = 5 * time.Second return &Peer{
// maximum amount of time allowed for writing a message Logger: logger.NewLogger(logtag),
msgWriteTimeout = 5 * time.Second rw: newFrameRW(conn, msgWriteTimeout),
// messages smaller than this many bytes will be read at ourID: ourID,
// once before passing them to a protocol. ourName: ourName,
wholePayloadSize = 64 * 1024 remoteID: remoteID,
) protocols: protocols,
running: make(map[string]*proto),
disc: make(chan DiscReason),
protoErr: make(chan error),
closed: make(chan struct{}),
}
}
var ( func (p *Peer) setHandshakeInfo(name string, caps []Cap) {
inactivityTimeout = 2 * time.Second p.infoMu.Lock()
disconnectGracePeriod = 2 * time.Second p.name = name
) p.caps = caps
p.infoMu.Unlock()
}
func (p *Peer) loop() (reason DiscReason, err error) { func (p *Peer) run() DiscReason {
defer p.activity.Stop() var readErr = make(chan error, 1)
defer p.closeProtocols() defer p.closeProtocols()
defer close(p.closed) defer close(p.closed)
defer p.conn.Close() defer p.rw.Close()
var readLoop func(chan<- Msg, chan<- error, <-chan bool) // start the read loop
if p.cryptoHandshake { go func() { readErr <- p.readLoop() }()
if readLoop, err = p.handleCryptoHandshake(); err != nil {
// from here on everything can be encrypted, authenticated if p.protocolHandshakeEnabled {
return DiscProtocolError, err // no graceful disconnect if err := writeProtocolHandshake(p.rw, p.ourName, *p.ourID, p.protocols); err != nil {
p.DebugDetailf("Protocol handshake error: %v\n", err)
return DiscProtocolError
} }
} else {
readLoop = p.readLoop
} }
// read loop // wait for an error or disconnect
readMsg := make(chan Msg) var reason DiscReason
readErr := make(chan error)
readNext := make(chan bool, 1)
protoDone := make(chan struct{}, 1)
go readLoop(readMsg, readErr, readNext)
readNext <- true
close(p.cryptoReady)
if p.runBaseProtocol {
p.startBaseProtocol()
}
loop:
for {
select { select {
case msg := <-readMsg:
// a new message has arrived.
var wait bool
if wait, err = p.dispatch(msg, protoDone); err != nil {
p.Errorf("msg dispatch error: %v\n", err)
reason = discReasonForError(err)
break loop
}
if !wait {
// Msg has already been read completely, continue with next message.
readNext <- true
}
p.activity.Post(time.Now())
case <-protoDone:
// protocol has consumed the message payload,
// we can continue reading from the socket.
readNext <- true
case err := <-readErr: case err := <-readErr:
// read failed. there is no need to run the // We rely on protocols to abort if there is a write error. It
// polite disconnect sequence because the connection // might be more robust to handle them here as well.
// is probably dead anyway. p.DebugDetailf("Read error: %v\n", err)
// TODO: handle write errors as well reason = DiscNetworkError
return DiscNetworkError, err case err := <-p.protoErr:
case err = <-p.protoErr:
reason = discReasonForError(err) reason = discReasonForError(err)
break loop
case reason = <-p.disc: case reason = <-p.disc:
break loop
} }
if reason != DiscNetworkError {
p.politeDisconnect(reason)
} }
p.Debugf("Disconnected: %v\n", reason)
return reason
}
// wait for read loop to return. func (p *Peer) politeDisconnect(reason DiscReason) {
close(readNext)
<-readErr
// tell the remote end to disconnect
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
p.conn.SetDeadline(time.Now().Add(disconnectGracePeriod)) // send reason
p.writeMsg(NewMsg(discMsg, reason), disconnectGracePeriod) EncodeMsg(p.rw, discMsg, uint(reason))
io.Copy(ioutil.Discard, p.conn) // discard any data that might arrive
io.Copy(ioutil.Discard, p.rw)
close(done) close(done)
}() }()
select { select {
case <-done: case <-done:
case <-time.After(disconnectGracePeriod): case <-time.After(disconnectGracePeriod):
} }
return reason, err
} }
func (p *Peer) readLoop(msgc chan<- Msg, errc chan<- error, unblock <-chan bool) { func (p *Peer) readLoop() error {
for _ = range unblock { if p.protocolHandshakeEnabled {
p.conn.SetReadDeadline(time.Now().Add(msgReadTimeout)) if err := readProtocolHandshake(p, p.rw); err != nil {
if msg, err := readMsg(p.bufconn); err != nil { return err
errc <- err
} else {
msgc <- msg
} }
} }
close(errc) for {
msg, err := p.rw.ReadMsg()
if err != nil {
return err
}
if err = p.handle(msg); err != nil {
return err
}
}
return nil
} }
func (p *Peer) dispatch(msg Msg, protoDone chan struct{}) (wait bool, err error) { func (p *Peer) handle(msg Msg) error {
switch {
case msg.Code == pingMsg:
msg.Discard()
go EncodeMsg(p.rw, pongMsg)
case msg.Code == discMsg:
var reason DiscReason
// no need to discard or for error checking, we'll close the
// connection after this.
rlp.Decode(msg.Payload, &reason)
p.Disconnect(DiscRequested)
return discRequestedError(reason)
case msg.Code < baseProtocolLength:
// ignore other base protocol messages
return msg.Discard()
default:
// it's a subprotocol message
proto, err := p.getProto(msg.Code) proto, err := p.getProto(msg.Code)
if err != nil { if err != nil {
return false, err return fmt.Errorf("msg code out of range: %v", msg.Code)
} }
if msg.Size <= wholePayloadSize { proto.in <- msg
// optimization: msg is small enough, read all }
// of it and move on to the next message return nil
buf, err := ioutil.ReadAll(msg.Payload) }
func readProtocolHandshake(p *Peer, rw MsgReadWriter) error {
// read and handle remote handshake
msg, err := rw.ReadMsg()
if err != nil { if err != nil {
return false, err return err
} }
msg.Payload = bytes.NewReader(buf) if msg.Code != handshakeMsg {
proto.in <- msg return newPeerError(errProtocolBreach, "expected handshake, got %x", msg.Code)
} else {
wait = true
pr := &eofSignal{msg.Payload, int64(msg.Size), protoDone}
msg.Payload = pr
proto.in <- msg
} }
return wait, nil if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "required version %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if hs.NodeID == *p.remoteID {
return newPeerError(errPubkeyForbidden, "node ID mismatch")
}
// TODO: remove Caps with empty name
p.setHandshakeInfo(hs.Name, hs.Caps)
p.startSubprotocols(hs.Caps)
return nil
} }
type readLoop func(chan<- Msg, chan<- error, <-chan bool) func writeProtocolHandshake(w MsgWriter, name string, id discover.NodeID, ps []Protocol) error {
var caps []interface{}
func (p *Peer) PrivateKey() (prv *ecdsa.PrivateKey, err error) { for _, proto := range ps {
if prv = crypto.ToECDSA(p.privateKey); prv == nil { caps = append(caps, proto.cap())
err = fmt.Errorf("invalid private key")
} }
return return EncodeMsg(w, handshakeMsg, baseProtocolVersion, name, caps, 0, id)
}
func (p *Peer) handleCryptoHandshake() (loop readLoop, err error) {
// cryptoId is just created for the lifecycle of the handshake
// it is survived by an encrypted readwriter
var initiator bool
var sessionToken []byte
sessionToken = make([]byte, shaLen)
if _, err = rand.Read(sessionToken); err != nil {
return
}
if p.dialAddr != nil { // this should have its own method Outgoing() bool
initiator = true
}
// run on peer
// this bit handles the handshake and creates a secure communications channel with
// var rw *secretRW
var prvKey *ecdsa.PrivateKey
if prvKey, err = p.PrivateKey(); err != nil {
err = fmt.Errorf("unable to access private key for client: %v", err)
return
}
// initialise a new secure session
if sessionToken, _, err = NewSecureSession(p.conn, prvKey, p.Pubkey(), sessionToken, initiator); err != nil {
p.Debugf("unable to setup secure session: %v", err)
return
}
loop = func(msg chan<- Msg, err chan<- error, next <-chan bool) {
// this is the readloop :)
}
return
}
func (p *Peer) startBaseProtocol() {
p.runlock.Lock()
defer p.runlock.Unlock()
p.running[""] = p.startProto(0, Protocol{
Length: baseProtocolLength,
Run: runBaseProtocol,
})
} }
// startProtocols starts matching named subprotocols. // startProtocols starts matching named subprotocols.
func (p *Peer) startSubprotocols(caps []Cap) { func (p *Peer) startSubprotocols(caps []Cap) {
sort.Sort(capsByName(caps)) sort.Sort(capsByName(caps))
p.runlock.Lock() p.runlock.Lock()
defer p.runlock.Unlock() defer p.runlock.Unlock()
offset := baseProtocolLength offset := baseProtocolLength
@ -412,20 +313,22 @@ outer:
} }
func (p *Peer) startProto(offset uint64, impl Protocol) *proto { func (p *Peer) startProto(offset uint64, impl Protocol) *proto {
p.DebugDetailf("Starting protocol %s/%d\n", impl.Name, impl.Version)
rw := &proto{ rw := &proto{
name: impl.Name,
in: make(chan Msg), in: make(chan Msg),
offset: offset, offset: offset,
maxcode: impl.Length, maxcode: impl.Length,
peer: p, w: p.rw,
} }
p.protoWG.Add(1) p.protoWG.Add(1)
go func() { go func() {
err := impl.Run(p, rw) err := impl.Run(p, rw)
if err == nil { if err == nil {
p.Infof("protocol %q returned", impl.Name) p.DebugDetailf("Protocol %s/%d returned\n", impl.Name, impl.Version)
err = newPeerError(errMisc, "protocol returned") err = newPeerError(errMisc, "protocol returned")
} else { } else {
p.Errorf("protocol %q error: %v\n", impl.Name, err) p.DebugDetailf("Protocol %s/%d error: %v\n", impl.Name, impl.Version, err)
} }
select { select {
case p.protoErr <- err: case p.protoErr <- err:
@ -459,6 +362,7 @@ func (p *Peer) closeProtocols() {
} }
// writeProtoMsg sends the given message on behalf of the given named protocol. // writeProtoMsg sends the given message on behalf of the given named protocol.
// this exists because of Server.Broadcast.
func (p *Peer) writeProtoMsg(protoName string, msg Msg) error { func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
p.runlock.RLock() p.runlock.RLock()
proto, ok := p.running[protoName] proto, ok := p.running[protoName]
@ -470,25 +374,14 @@ func (p *Peer) writeProtoMsg(protoName string, msg Msg) error {
return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName) return newPeerError(errInvalidMsgCode, "code %x is out of range for protocol %q", msg.Code, protoName)
} }
msg.Code += proto.offset msg.Code += proto.offset
return p.writeMsg(msg, msgWriteTimeout) return p.rw.WriteMsg(msg)
}
// writeMsg writes a message to the connection.
func (p *Peer) writeMsg(msg Msg, timeout time.Duration) error {
p.writeMu.Lock()
defer p.writeMu.Unlock()
p.conn.SetWriteDeadline(time.Now().Add(timeout))
if err := writeMsg(p.bufconn, msg); err != nil {
return newPeerError(errWrite, "%v", err)
}
return p.bufconn.Flush()
} }
type proto struct { type proto struct {
name string name string
in chan Msg in chan Msg
maxcode, offset uint64 maxcode, offset uint64
peer *Peer w MsgWriter
} }
func (rw *proto) WriteMsg(msg Msg) error { func (rw *proto) WriteMsg(msg Msg) error {
@ -496,11 +389,7 @@ func (rw *proto) WriteMsg(msg Msg) error {
return newPeerError(errInvalidMsgCode, "not handled") return newPeerError(errInvalidMsgCode, "not handled")
} }
msg.Code += rw.offset msg.Code += rw.offset
return rw.peer.writeMsg(msg, msgWriteTimeout) return rw.w.WriteMsg(msg)
}
func (rw *proto) EncodeMsg(code uint64, data ...interface{}) error {
return rw.WriteMsg(NewMsg(code, data...))
} }
func (rw *proto) ReadMsg() (Msg, error) { func (rw *proto) ReadMsg() (Msg, error) {
@ -511,26 +400,3 @@ func (rw *proto) ReadMsg() (Msg, error) {
msg.Code -= rw.offset msg.Code -= rw.offset
return msg, nil return msg, nil
} }
// eofSignal wraps a reader with eof signaling. the eof channel is
// closed when the wrapped reader returns an error or when count bytes
// have been read.
//
type eofSignal struct {
wrapped io.Reader
count int64
eof chan<- struct{}
}
// note: when using eofSignal to detect whether a message payload
// has been read, Read might not be called for zero sized messages.
func (r *eofSignal) Read(buf []byte) (int, error) {
n, err := r.wrapped.Read(buf)
r.count -= int64(n)
if (err != nil || r.count <= 0) && r.eof != nil {
r.eof <- struct{}{} // tell Peer that msg has been consumed
r.eof = nil
}
return n, err
}

View File

@ -12,7 +12,6 @@ const (
errInvalidMsgCode errInvalidMsgCode
errInvalidMsg errInvalidMsg
errP2PVersionMismatch errP2PVersionMismatch
errPubkeyMissing
errPubkeyInvalid errPubkeyInvalid
errPubkeyForbidden errPubkeyForbidden
errProtocolBreach errProtocolBreach
@ -22,20 +21,19 @@ const (
) )
var errorToString = map[int]string{ var errorToString = map[int]string{
errMagicTokenMismatch: "Magic token mismatch", errMagicTokenMismatch: "magic token mismatch",
errRead: "Read error", errRead: "read error",
errWrite: "Write error", errWrite: "write error",
errMisc: "Misc error", errMisc: "misc error",
errInvalidMsgCode: "Invalid message code", errInvalidMsgCode: "invalid message code",
errInvalidMsg: "Invalid message", errInvalidMsg: "invalid message",
errP2PVersionMismatch: "P2P Version Mismatch", errP2PVersionMismatch: "P2P Version Mismatch",
errPubkeyMissing: "Public key missing", errPubkeyInvalid: "public key invalid",
errPubkeyInvalid: "Public key invalid", errPubkeyForbidden: "public key forbidden",
errPubkeyForbidden: "Public key forbidden", errProtocolBreach: "protocol Breach",
errProtocolBreach: "Protocol Breach", errPingTimeout: "ping timeout",
errPingTimeout: "Ping timeout", errInvalidNetworkId: "invalid network id",
errInvalidNetworkId: "Invalid network id", errInvalidProtocolVersion: "invalid protocol version",
errInvalidProtocolVersion: "Invalid protocol version",
} }
type peerError struct { type peerError struct {
@ -62,22 +60,22 @@ func (self *peerError) Error() string {
type DiscReason byte type DiscReason byte
const ( const (
DiscRequested DiscReason = 0x00 DiscRequested DiscReason = iota
DiscNetworkError = 0x01 DiscNetworkError
DiscProtocolError = 0x02 DiscProtocolError
DiscUselessPeer = 0x03 DiscUselessPeer
DiscTooManyPeers = 0x04 DiscTooManyPeers
DiscAlreadyConnected = 0x05 DiscAlreadyConnected
DiscIncompatibleVersion = 0x06 DiscIncompatibleVersion
DiscInvalidIdentity = 0x07 DiscInvalidIdentity
DiscQuitting = 0x08 DiscQuitting
DiscUnexpectedIdentity = 0x09 DiscUnexpectedIdentity
DiscSelf = 0x0a DiscSelf
DiscReadTimeout = 0x0b DiscReadTimeout
DiscSubprotocolError = 0x10 DiscSubprotocolError
) )
var discReasonToString = [DiscSubprotocolError + 1]string{ var discReasonToString = [...]string{
DiscRequested: "Disconnect requested", DiscRequested: "Disconnect requested",
DiscNetworkError: "Network error", DiscNetworkError: "Network error",
DiscProtocolError: "Breach of protocol", DiscProtocolError: "Breach of protocol",
@ -117,7 +115,7 @@ func discReasonForError(err error) DiscReason {
switch peerError.Code { switch peerError.Code {
case errP2PVersionMismatch: case errP2PVersionMismatch:
return DiscIncompatibleVersion return DiscIncompatibleVersion
case errPubkeyMissing, errPubkeyInvalid: case errPubkeyInvalid:
return DiscInvalidIdentity return DiscInvalidIdentity
case errPubkeyForbidden: case errPubkeyForbidden:
return DiscUselessPeer return DiscUselessPeer

View File

@ -1,15 +1,17 @@
package p2p package p2p
import ( import (
"bufio"
"bytes" "bytes"
"encoding/hex" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net" "net"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/rlp"
) )
var discard = Protocol{ var discard = Protocol{
@ -28,17 +30,13 @@ var discard = Protocol{
}, },
} }
func testPeer(protos []Protocol) (net.Conn, *Peer, <-chan error) { func testPeer(handshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
peer := newPeer(conn1, protos, nil) peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
peer.ourID = &peerId{} peer.protocolHandshakeEnabled = handshake
peer.pubkeyHook = func(*peerAddr) error { return nil } errc := make(chan DiscReason, 1)
errc := make(chan error, 1) go func() { errc <- peer.run() }()
go func() { return newFrameRW(conn2, msgWriteTimeout), peer, errc
_, err := peer.loop()
errc <- err
}()
return conn2, peer, errc
} }
func TestPeerProtoReadMsg(t *testing.T) { func TestPeerProtoReadMsg(t *testing.T) {
@ -49,31 +47,28 @@ func TestPeerProtoReadMsg(t *testing.T) {
Name: "a", Name: "a",
Length: 5, Length: 5,
Run: func(peer *Peer, rw MsgReadWriter) error { Run: func(peer *Peer, rw MsgReadWriter) error {
msg, err := rw.ReadMsg() if err := expectMsg(rw, 2, []uint{1}); err != nil {
if err != nil { t.Error(err)
t.Errorf("read error: %v", err)
} }
if msg.Code != 2 { if err := expectMsg(rw, 3, []uint{2}); err != nil {
t.Errorf("incorrect msg code %d relayed to protocol", msg.Code) t.Error(err)
} }
data, err := ioutil.ReadAll(msg.Payload) if err := expectMsg(rw, 4, []uint{3}); err != nil {
if err != nil { t.Error(err)
t.Errorf("payload read error: %v", err)
}
expdata, _ := hex.DecodeString("0183303030")
if !bytes.Equal(expdata, data) {
t.Errorf("incorrect msg data %x", data)
} }
close(done) close(done)
return nil return nil
}, },
} }
net, peer, errc := testPeer([]Protocol{proto}) rw, peer, errc := testPeer(false, []Protocol{proto})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, 1, "000")) EncodeMsg(rw, baseProtocolLength+2, 1)
EncodeMsg(rw, baseProtocolLength+3, 2)
EncodeMsg(rw, baseProtocolLength+4, 3)
select { select {
case <-done: case <-done:
case err := <-errc: case err := <-errc:
@ -105,11 +100,11 @@ func TestPeerProtoReadLargeMsg(t *testing.T) {
}, },
} }
net, peer, errc := testPeer([]Protocol{proto}) rw, peer, errc := testPeer(false, []Protocol{proto})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
writeMsg(net, NewMsg(18, make([]byte, msgsize))) EncodeMsg(rw, 18, make([]byte, msgsize))
select { select {
case <-done: case <-done:
case err := <-errc: case err := <-errc:
@ -135,32 +130,20 @@ func TestPeerProtoEncodeMsg(t *testing.T) {
return nil return nil
}, },
} }
net, peer, _ := testPeer([]Protocol{proto}) rw, peer, _ := testPeer(false, []Protocol{proto})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{proto.cap()}) peer.startSubprotocols([]Cap{proto.cap()})
bufr := bufio.NewReader(net) if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
msg, err := readMsg(bufr) t.Error(err)
if err != nil {
t.Errorf("read error: %v", err)
}
if msg.Code != 17 {
t.Errorf("incorrect message code: got %d, expected %d", msg.Code, 17)
}
var data []string
if err := msg.Decode(&data); err != nil {
t.Errorf("payload decode error: %v", err)
}
if !reflect.DeepEqual(data, []string{"foo", "bar"}) {
t.Errorf("payload RLP mismatch, got %#v, want %#v", data, []string{"foo", "bar"})
} }
} }
func TestPeerWrite(t *testing.T) { func TestPeerWriteForBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
net, peer, peerErr := testPeer([]Protocol{discard}) rw, peer, peerErr := testPeer(false, []Protocol{discard})
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{discard.cap()}) peer.startSubprotocols([]Cap{discard.cap()})
// test write errors // test write errors
@ -176,18 +159,13 @@ func TestPeerWrite(t *testing.T) {
// setup for reading the message on the other end // setup for reading the message on the other end
read := make(chan struct{}) read := make(chan struct{})
go func() { go func() {
bufr := bufio.NewReader(net) if err := expectMsg(rw, 16, nil); err != nil {
msg, err := readMsg(bufr) t.Error()
if err != nil {
t.Errorf("read error: %v", err)
} else if msg.Code != 16 {
t.Errorf("wrong code, got %d, expected %d", msg.Code, 16)
} }
msg.Discard()
close(read) close(read)
}() }()
// test succcessful write // test successful write
if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
t.Errorf("expect no error for known protocol: %v", err) t.Errorf("expect no error for known protocol: %v", err)
} }
@ -198,104 +176,152 @@ func TestPeerWrite(t *testing.T) {
} }
} }
func TestPeerActivity(t *testing.T) { func TestPeerPing(t *testing.T) {
// shorten inactivityTimeout while this test is running defer testlog(t).detach()
oldT := inactivityTimeout
defer func() { inactivityTimeout = oldT }()
inactivityTimeout = 20 * time.Millisecond
net, peer, peerErr := testPeer([]Protocol{discard}) rw, _, _ := testPeer(false, nil)
defer net.Close() defer rw.Close()
peer.startSubprotocols([]Cap{discard.cap()}) if err := EncodeMsg(rw, pingMsg); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, pongMsg, nil); err != nil {
t.Error(err)
}
}
sub := peer.activity.Subscribe(time.Time{}) func TestPeerDisconnect(t *testing.T) {
defer sub.Unsubscribe() defer testlog(t).detach()
for i := 0; i < 6; i++ { rw, _, disc := testPeer(false, nil)
writeMsg(net, NewMsg(16)) defer rw.Close()
if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
t.Fatal(err)
}
if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
t.Error(err)
}
rw.Close() // make test end faster
if reason := <-disc; reason != DiscRequested {
t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
}
}
func TestPeerHandshake(t *testing.T) {
defer testlog(t).detach()
// remote has two matching protocols: a and c
remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
remoteID := randomID()
remote.ourID = &remoteID
remote.ourName = "remote peer"
start := make(chan string)
stop := make(chan struct{})
run := func(p *Peer, rw MsgReadWriter) error {
name := rw.(*proto).name
if name != "a" && name != "c" {
t.Errorf("protocol %q should not be started", name)
} else {
start <- name
}
<-stop
return nil
}
protocols := []Protocol{
{Name: "a", Version: 1, Length: 1, Run: run},
{Name: "b", Version: 2, Length: 1, Run: run},
{Name: "c", Version: 3, Length: 1, Run: run},
{Name: "d", Version: 4, Length: 1, Run: run},
}
rw, p, disc := testPeer(true, protocols)
p.remoteID = remote.ourID
defer rw.Close()
// run the handshake
remoteProtocols := []Protocol{protocols[0], protocols[2]}
if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
t.Fatalf("handshake write error: %v", err)
}
if err := readProtocolHandshake(remote, rw); err != nil {
t.Fatalf("handshake read error: %v", err)
}
// check that all protocols have been started
var started []string
for i := 0; i < 2; i++ {
select { select {
case <-sub.Chan(): case name := <-start:
case <-time.After(inactivityTimeout / 2): started = append(started, name)
t.Fatal("no event within ", inactivityTimeout/2) case <-time.After(100 * time.Millisecond):
case err := <-peerErr:
t.Fatal("peer error", err)
} }
} }
sort.Strings(started)
if !reflect.DeepEqual(started, []string{"a", "c"}) {
t.Errorf("wrong protocols started: %v", started)
}
select { // check that metadata has been set
case <-time.After(inactivityTimeout * 2): if p.ID() != remoteID {
case <-sub.Chan(): t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
t.Fatal("got activity event while connection was inactive")
case err := <-peerErr:
t.Fatal("peer error", err)
} }
if p.Name() != remote.ourName {
t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
}
close(stop)
t.Logf("disc reason: %v", <-disc)
} }
func TestNewPeer(t *testing.T) { func TestNewPeer(t *testing.T) {
name := "nodename"
caps := []Cap{{"foo", 2}, {"bar", 3}} caps := []Cap{{"foo", 2}, {"bar", 3}}
id := &peerId{} id := randomID()
p := NewPeer(id, caps) p := NewPeer(id, name, caps)
if p.ID() != id {
t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
}
if p.Name() != name {
t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
}
if !reflect.DeepEqual(p.Caps(), caps) { if !reflect.DeepEqual(p.Caps(), caps) {
t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
} }
if p.Identity() != id {
t.Errorf("Identity mismatch: got %v, expected %v", p.Identity(), id) p.Disconnect(DiscAlreadyConnected) // Should not hang
}
// Should not hang.
p.Disconnect(DiscAlreadyConnected)
} }
func TestEOFSignal(t *testing.T) { // expectMsg reads a message from r and verifies that its
rb := make([]byte, 10) // code and encoded RLP content match the provided values.
// If content is nil, the payload is discarded and not verified.
func expectMsg(r MsgReader, code uint64, content interface{}) error {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if msg.Code != code {
return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
}
if content == nil {
return msg.Discard()
} else {
contentEnc, err := rlp.EncodeToBytes(content)
if err != nil {
panic("content encode error: " + err.Error())
}
// skip over list header in encoded value. this is temporary.
contentEncR := bytes.NewReader(contentEnc)
if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
panic("content must encode as RLP list")
}
contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
// empty reader actualContent, err := ioutil.ReadAll(msg.Payload)
eof := make(chan struct{}, 1) if err != nil {
sig := &eofSignal{new(bytes.Buffer), 0, eof} return err
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
} }
select { if !bytes.Equal(actualContent, contentEnc) {
case <-eof: return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc)
default:
t.Error("EOF chan not signaled")
} }
// count before error
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaa"), 4, eof}
if n, err := sig.Read(rb); n != 8 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// error before count
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaa"), 999, eof}
if n, err := sig.Read(rb); n != 4 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
if n, err := sig.Read(rb); n != 0 || err != io.EOF {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
default:
t.Error("EOF chan not signaled")
}
// no signal if neither occurs
eof = make(chan struct{}, 1)
sig = &eofSignal{bytes.NewBufferString("aaaaaaaaaaaaaaaaaaaaa"), 999, eof}
if n, err := sig.Read(rb); n != 10 || err != nil {
t.Errorf("Read returned unexpected values: (%v, %v)", n, err)
}
select {
case <-eof:
t.Error("unexpected EOF signal")
default:
} }
return nil
} }

View File

@ -1,10 +1,5 @@
package p2p package p2p
import (
"bytes"
"time"
)
// Protocol represents a P2P subprotocol implementation. // Protocol represents a P2P subprotocol implementation.
type Protocol struct { type Protocol struct {
// Name should contain the official protocol name, // Name should contain the official protocol name,
@ -32,42 +27,6 @@ func (p Protocol) cap() Cap {
return Cap{p.Name, p.Version} return Cap{p.Name, p.Version}
} }
const (
baseProtocolVersion = 2
baseProtocolLength = uint64(16)
baseProtocolMaxMsgSize = 10 * 1024 * 1024
)
const (
// devp2p message codes
handshakeMsg = 0x00
discMsg = 0x01
pingMsg = 0x02
pongMsg = 0x03
getPeersMsg = 0x04
peersMsg = 0x05
)
// handshake is the structure of a handshake list.
type handshake struct {
Version uint64
ID string
Caps []Cap
ListenPort uint64
NodeID []byte
}
func (h *handshake) String() string {
return h.ID
}
func (h *handshake) Pubkey() []byte {
return h.NodeID
}
func (h *handshake) PrivKey() []byte {
return nil
}
// Cap is the structure of a peer capability. // Cap is the structure of a peer capability.
type Cap struct { type Cap struct {
Name string Name string
@ -83,210 +42,3 @@ type capsByName []Cap
func (cs capsByName) Len() int { return len(cs) } func (cs capsByName) Len() int { return len(cs) }
func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name } func (cs capsByName) Less(i, j int) bool { return cs[i].Name < cs[j].Name }
func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] } func (cs capsByName) Swap(i, j int) { cs[i], cs[j] = cs[j], cs[i] }
type baseProtocol struct {
rw MsgReadWriter
peer *Peer
}
func runBaseProtocol(peer *Peer, rw MsgReadWriter) error {
bp := &baseProtocol{rw, peer}
errc := make(chan error, 1)
go func() { errc <- rw.WriteMsg(bp.handshakeMsg()) }()
if err := bp.readHandshake(); err != nil {
return err
}
// handle write error
if err := <-errc; err != nil {
return err
}
// run main loop
go func() {
for {
if err := bp.handle(rw); err != nil {
errc <- err
break
}
}
}()
return bp.loop(errc)
}
var pingTimeout = 2 * time.Second
func (bp *baseProtocol) loop(quit <-chan error) error {
ping := time.NewTimer(pingTimeout)
activity := bp.peer.activity.Subscribe(time.Time{})
lastActive := time.Time{}
defer ping.Stop()
defer activity.Unsubscribe()
getPeersTick := time.NewTicker(10 * time.Second)
defer getPeersTick.Stop()
err := EncodeMsg(bp.rw, getPeersMsg)
for err == nil {
select {
case err = <-quit:
return err
case <-getPeersTick.C:
err = EncodeMsg(bp.rw, getPeersMsg)
case event := <-activity.Chan():
ping.Reset(pingTimeout)
lastActive = event.(time.Time)
case t := <-ping.C:
if lastActive.Add(pingTimeout * 2).Before(t) {
err = newPeerError(errPingTimeout, "")
} else if lastActive.Add(pingTimeout).Before(t) {
err = EncodeMsg(bp.rw, pingMsg)
}
}
}
return err
}
func (bp *baseProtocol) handle(rw MsgReadWriter) error {
msg, err := rw.ReadMsg()
if err != nil {
return err
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
// make sure that the payload has been fully consumed
defer msg.Discard()
switch msg.Code {
case handshakeMsg:
return newPeerError(errProtocolBreach, "extra handshake received")
case discMsg:
var reason [1]DiscReason
if err := msg.Decode(&reason); err != nil {
return err
}
return discRequestedError(reason[0])
case pingMsg:
return EncodeMsg(bp.rw, pongMsg)
case pongMsg:
case getPeersMsg:
peers := bp.peerList()
// this is dangerous. the spec says that we should _delay_
// sending the response if no new information is available.
// this means that would need to send a response later when
// new peers become available.
//
// TODO: add event mechanism to notify baseProtocol for new peers
if len(peers) > 0 {
return EncodeMsg(bp.rw, peersMsg, peers...)
}
case peersMsg:
var peers []*peerAddr
if err := msg.Decode(&peers); err != nil {
return err
}
for _, addr := range peers {
bp.peer.Debugf("received peer suggestion: %v", addr)
bp.peer.newPeerAddr <- addr
}
default:
return newPeerError(errInvalidMsgCode, "unknown message code %v", msg.Code)
}
return nil
}
func (bp *baseProtocol) readHandshake() error {
// read and handle remote handshake
msg, err := bp.rw.ReadMsg()
if err != nil {
return err
}
if msg.Code != handshakeMsg {
return newPeerError(errProtocolBreach, "first message must be handshake, got %x", msg.Code)
}
if msg.Size > baseProtocolMaxMsgSize {
return newPeerError(errMisc, "message too big")
}
var hs handshake
if err := msg.Decode(&hs); err != nil {
return err
}
// validate handshake info
if hs.Version != baseProtocolVersion {
return newPeerError(errP2PVersionMismatch, "Require protocol %d, received %d\n",
baseProtocolVersion, hs.Version)
}
if len(hs.NodeID) == 0 {
return newPeerError(errPubkeyMissing, "")
}
if len(hs.NodeID) != 64 {
return newPeerError(errPubkeyInvalid, "require 512 bit, got %v", len(hs.NodeID)*8)
}
if da := bp.peer.dialAddr; da != nil {
// verify that the peer we wanted to connect to
// actually holds the target public key.
if da.Pubkey != nil && !bytes.Equal(da.Pubkey, hs.NodeID) {
return newPeerError(errPubkeyForbidden, "dial address pubkey mismatch")
}
}
pa := newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
if err := bp.peer.pubkeyHook(pa); err != nil {
return newPeerError(errPubkeyForbidden, "%v", err)
}
// TODO: remove Caps with empty name
var addr *peerAddr
if hs.ListenPort != 0 {
addr = newPeerAddr(bp.peer.conn.RemoteAddr(), hs.NodeID)
addr.Port = hs.ListenPort
}
bp.peer.setHandshakeInfo(&hs, addr, hs.Caps)
bp.peer.startSubprotocols(hs.Caps)
return nil
}
func (bp *baseProtocol) handshakeMsg() Msg {
var (
port uint64
caps []interface{}
)
if bp.peer.ourListenAddr != nil {
port = bp.peer.ourListenAddr.Port
}
for _, proto := range bp.peer.protocols {
caps = append(caps, proto.cap())
}
return NewMsg(handshakeMsg,
baseProtocolVersion,
bp.peer.ourID.String(),
caps,
port,
bp.peer.ourID.Pubkey()[1:],
)
}
func (bp *baseProtocol) peerList() []interface{} {
peers := bp.peer.otherPeers()
ds := make([]interface{}, 0, len(peers))
for _, p := range peers {
p.infolock.Lock()
addr := p.listenAddr
p.infolock.Unlock()
// filter out this peer and peers that are not listening or
// have not completed the handshake.
// TODO: track previously sent peers and exclude them as well.
if p == bp.peer || addr == nil {
continue
}
ds = append(ds, addr)
}
ourAddr := bp.peer.ourListenAddr
if ourAddr != nil && !ourAddr.IP.IsLoopback() && !ourAddr.IP.IsUnspecified() {
ds = append(ds, ourAddr)
}
return ds
}

View File

@ -1,167 +0,0 @@
package p2p
import (
"fmt"
"net"
"reflect"
"sync"
"testing"
"github.com/ethereum/go-ethereum/crypto"
)
type peerId struct {
privKey, pubkey []byte
}
func (self *peerId) String() string {
return fmt.Sprintf("test peer %x", self.Pubkey()[:4])
}
func (self *peerId) Pubkey() (pubkey []byte) {
pubkey = self.pubkey
if len(pubkey) == 0 {
pubkey = crypto.GenerateNewKeyPair().PublicKey
self.pubkey = pubkey
}
return
}
func (self *peerId) PrivKey() (privKey []byte) {
privKey = self.privKey
if len(privKey) == 0 {
privKey = crypto.GenerateNewKeyPair().PublicKey
self.privKey = privKey
}
return
}
func newTestPeer() (peer *Peer) {
peer = NewPeer(&peerId{}, []Cap{})
peer.pubkeyHook = func(*peerAddr) error { return nil }
peer.ourID = &peerId{}
peer.listenAddr = &peerAddr{}
peer.otherPeers = func() []*Peer { return nil }
return
}
func TestBaseProtocolPeers(t *testing.T) {
peerList := []*peerAddr{
{IP: net.ParseIP("1.2.3.4"), Port: 2222, Pubkey: []byte{}},
{IP: net.ParseIP("5.6.7.8"), Port: 3333, Pubkey: []byte{}},
}
listenAddr := &peerAddr{IP: net.ParseIP("1.3.5.7"), Port: 1111, Pubkey: []byte{}}
rw1, rw2 := MsgPipe()
defer rw1.Close()
wg := new(sync.WaitGroup)
// run matcher, close pipe when addresses have arrived
numPeers := len(peerList) + 1
addrChan := make(chan *peerAddr)
wg.Add(1)
go func() {
i := 0
for got := range addrChan {
var want *peerAddr
switch {
case i < len(peerList):
want = peerList[i]
case i == len(peerList):
want = listenAddr // listenAddr should be the last thing sent
}
t.Logf("got peer %d/%d: %v", i+1, numPeers, got)
if !reflect.DeepEqual(want, got) {
t.Errorf("mismatch: got %+v, want %+v", got, want)
}
i++
if i == numPeers {
break
}
}
if i != numPeers {
t.Errorf("wrong number of peers received: got %d, want %d", i, numPeers)
}
rw1.Close()
wg.Done()
}()
// run first peer (in background)
peer1 := newTestPeer()
peer1.ourListenAddr = listenAddr
peer1.otherPeers = func() []*Peer {
pl := make([]*Peer, len(peerList))
for i, addr := range peerList {
pl[i] = &Peer{listenAddr: addr}
}
return pl
}
wg.Add(1)
go func() {
runBaseProtocol(peer1, rw1)
wg.Done()
}()
// run second peer
peer2 := newTestPeer()
peer2.newPeerAddr = addrChan // feed peer suggestions into matcher
if err := runBaseProtocol(peer2, rw2); err != ErrPipeClosed {
t.Errorf("peer2 terminated with unexpected error: %v", err)
}
// terminate matcher
close(addrChan)
wg.Wait()
}
func TestBaseProtocolDisconnect(t *testing.T) {
peer := NewPeer(&peerId{}, nil)
peer.ourID = &peerId{}
peer.pubkeyHook = func(*peerAddr) error { return nil }
rw1, rw2 := MsgPipe()
done := make(chan struct{})
go func() {
if err := expectMsg(rw2, handshakeMsg); err != nil {
t.Error(err)
}
err := EncodeMsg(rw2, handshakeMsg,
baseProtocolVersion,
"",
[]interface{}{},
0,
make([]byte, 64),
)
if err != nil {
t.Error(err)
}
if err := expectMsg(rw2, getPeersMsg); err != nil {
t.Error(err)
}
if err := EncodeMsg(rw2, discMsg, DiscQuitting); err != nil {
t.Error(err)
}
close(done)
}()
if err := runBaseProtocol(peer, rw1); err == nil {
t.Errorf("base protocol returned without error")
} else if reason, ok := err.(discRequestedError); !ok || reason != DiscQuitting {
t.Errorf("base protocol returned wrong error: %v", err)
}
<-done
}
func expectMsg(r MsgReader, code uint64) error {
msg, err := r.ReadMsg()
if err != nil {
return err
}
if err := msg.Discard(); err != nil {
return err
}
if msg.Code != code {
return fmt.Errorf("wrong message code: got %d, expected %d", msg.Code, code)
}
return nil
}

View File

@ -2,37 +2,56 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"runtime"
"sync" "sync"
"time" "time"
"github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p/discover"
) )
const ( const (
outboundAddressPoolSize = 500
defaultDialTimeout = 10 * time.Second defaultDialTimeout = 10 * time.Second
refreshPeersInterval = 30 * time.Second
portMappingUpdateInterval = 15 * time.Minute portMappingUpdateInterval = 15 * time.Minute
portMappingTimeout = 20 * time.Minute portMappingTimeout = 20 * time.Minute
) )
var srvlog = logger.NewLogger("P2P Server") var srvlog = logger.NewLogger("P2P Server")
// MakeName creates a node name that follows the ethereum convention
// for such names. It adds the operation system name and Go runtime version
// the name.
func MakeName(name, version string) string {
return fmt.Sprintf("%s/v%s/%s/%s", name, version, runtime.GOOS, runtime.Version())
}
// Server manages all peer connections. // Server manages all peer connections.
// //
// The fields of Server are used as configuration parameters. // The fields of Server are used as configuration parameters.
// You should set them before starting the Server. Fields may not be // You should set them before starting the Server. Fields may not be
// modified while the server is running. // modified while the server is running.
type Server struct { type Server struct {
// This field must be set to a valid client identity. // This field must be set to a valid secp256k1 private key.
Identity ClientIdentity PrivateKey *ecdsa.PrivateKey
// MaxPeers is the maximum number of peers that can be // MaxPeers is the maximum number of peers that can be
// connected. It must be greater than zero. // connected. It must be greater than zero.
MaxPeers int MaxPeers int
// Name sets the node name of this server.
// Use MakeName to create a name that follows existing conventions.
Name string
// Bootstrap nodes are used to establish connectivity
// with the rest of the network.
BootstrapNodes []discover.Node
// Protocols should contain the protocols supported // Protocols should contain the protocols supported
// by the server. Matching protocols are launched for // by the server. Matching protocols are launched for
// each peer. // each peer.
@ -62,22 +81,23 @@ type Server struct {
// If NoDial is true, the server will not dial any peers. // If NoDial is true, the server will not dial any peers.
NoDial bool NoDial bool
// Hook for testing. This is useful because we can inhibit // Hooks for testing. These are useful because we can inhibit
// the whole protocol stack. // the whole protocol stack.
newPeerFunc peerFunc handshakeFunc
newPeerHook
lock sync.RWMutex lock sync.RWMutex
running bool running bool
listener net.Listener listener net.Listener
laddr *net.TCPAddr // real listen addr laddr *net.TCPAddr // real listen addr
peers []*Peer peers map[discover.NodeID]*Peer
peerSlots chan int
peerCount int ntab *discover.Table
quit chan struct{} quit chan struct{}
wg sync.WaitGroup loopWG sync.WaitGroup // {dial,listen,nat}Loop
peerConnect chan *peerAddr peerWG sync.WaitGroup // active peer goroutines
peerDisconnect chan *Peer peerConnect chan *discover.Node
} }
// NAT is implemented by NAT traversal methods. // NAT is implemented by NAT traversal methods.
@ -90,7 +110,8 @@ type NAT interface {
String() string String() string
} }
type peerFunc func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer type handshakeFunc func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (discover.NodeID, []byte, error)
type newPeerHook func(*Peer)
// Peers returns all connected peers. // Peers returns all connected peers.
func (srv *Server) Peers() (peers []*Peer) { func (srv *Server) Peers() (peers []*Peer) {
@ -107,18 +128,15 @@ func (srv *Server) Peers() (peers []*Peer) {
// PeerCount returns the number of connected peers. // PeerCount returns the number of connected peers.
func (srv *Server) PeerCount() int { func (srv *Server) PeerCount() int {
srv.lock.RLock() srv.lock.RLock()
defer srv.lock.RUnlock() n := len(srv.peers)
return srv.peerCount srv.lock.RUnlock()
return n
} }
// SuggestPeer injects an address into the outbound address pool. // SuggestPeer creates a connection to the given Node if it
func (srv *Server) SuggestPeer(ip net.IP, port int, nodeID []byte) { // is not already connected.
addr := &peerAddr{ip, uint64(port), nodeID} func (srv *Server) SuggestPeer(ip net.IP, port int, id discover.NodeID) {
select { srv.peerConnect <- &discover.Node{ID: id, Addr: &net.UDPAddr{IP: ip, Port: port}}
case srv.peerConnect <- addr:
default: // don't block
srvlog.Warnf("peer suggestion %v ignored", addr)
}
} }
// Broadcast sends an RLP-encoded message to all connected peers. // Broadcast sends an RLP-encoded message to all connected peers.
@ -152,47 +170,47 @@ func (srv *Server) Start() (err error) {
} }
srvlog.Infoln("Starting Server") srvlog.Infoln("Starting Server")
// initialize fields // initialize all the fields
if srv.Identity == nil { if srv.PrivateKey == nil {
return fmt.Errorf("Server.Identity must be set to a non-nil identity") return fmt.Errorf("Server.PrivateKey must be set to a non-nil key")
} }
if srv.MaxPeers <= 0 { if srv.MaxPeers <= 0 {
return fmt.Errorf("Server.MaxPeers must be > 0") return fmt.Errorf("Server.MaxPeers must be > 0")
} }
srv.quit = make(chan struct{}) srv.quit = make(chan struct{})
srv.peers = make([]*Peer, srv.MaxPeers) srv.peers = make(map[discover.NodeID]*Peer)
srv.peerSlots = make(chan int, srv.MaxPeers) srv.peerConnect = make(chan *discover.Node)
srv.peerConnect = make(chan *peerAddr, outboundAddressPoolSize)
srv.peerDisconnect = make(chan *Peer) if srv.handshakeFunc == nil {
if srv.newPeerFunc == nil { srv.handshakeFunc = encHandshake
srv.newPeerFunc = newServerPeer
} }
if srv.Blacklist == nil { if srv.Blacklist == nil {
srv.Blacklist = NewBlacklist() srv.Blacklist = NewBlacklist()
} }
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if srv.ListenAddr != "" { if srv.ListenAddr != "" {
if err := srv.startListening(); err != nil { if err := srv.startListening(); err != nil {
return err return err
} }
} }
// dial stuff
dt, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr)
if err != nil {
return err
}
srv.ntab = dt
if srv.Dialer == nil {
srv.Dialer = &net.Dialer{Timeout: defaultDialTimeout}
}
if !srv.NoDial { if !srv.NoDial {
srv.wg.Add(1) srv.loopWG.Add(1)
go srv.dialLoop() go srv.dialLoop()
} }
if srv.NoDial && srv.ListenAddr == "" { if srv.NoDial && srv.ListenAddr == "" {
srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.") srvlog.Warnln("I will be kind-of useless, neither dialing nor listening.")
} }
// make all slots available
for i := range srv.peers {
srv.peerSlots <- i
}
// note: discLoop is not part of WaitGroup
go srv.discLoop()
srv.running = true srv.running = true
return nil return nil
} }
@ -205,10 +223,10 @@ func (srv *Server) startListening() error {
srv.ListenAddr = listener.Addr().String() srv.ListenAddr = listener.Addr().String()
srv.laddr = listener.Addr().(*net.TCPAddr) srv.laddr = listener.Addr().(*net.TCPAddr)
srv.listener = listener srv.listener = listener
srv.wg.Add(1) srv.loopWG.Add(1)
go srv.listenLoop() go srv.listenLoop()
if !srv.laddr.IP.IsLoopback() && srv.NAT != nil { if !srv.laddr.IP.IsLoopback() && srv.NAT != nil {
srv.wg.Add(1) srv.loopWG.Add(1)
go srv.natLoop(srv.laddr.Port) go srv.natLoop(srv.laddr.Port)
} }
return nil return nil
@ -225,57 +243,41 @@ func (srv *Server) Stop() {
srv.running = false srv.running = false
srv.lock.Unlock() srv.lock.Unlock()
srvlog.Infoln("Stopping server") srvlog.Infoln("Stopping Server")
srv.ntab.Close()
if srv.listener != nil { if srv.listener != nil {
// this unblocks listener Accept // this unblocks listener Accept
srv.listener.Close() srv.listener.Close()
} }
close(srv.quit) close(srv.quit)
for _, peer := range srv.Peers() { srv.loopWG.Wait()
// No new peers can be added at this point because dialLoop and
// listenLoop are down. It is safe to call peerWG.Wait because
// peerWG.Add is not called outside of those loops.
for _, peer := range srv.peers {
peer.Disconnect(DiscQuitting) peer.Disconnect(DiscQuitting)
} }
srv.wg.Wait() srv.peerWG.Wait()
// wait till they actually disconnect
// this is checked by claiming all peerSlots.
// slots become available as the peers disconnect.
for i := 0; i < cap(srv.peerSlots); i++ {
<-srv.peerSlots
}
// terminate discLoop
close(srv.peerDisconnect)
}
func (srv *Server) discLoop() {
for peer := range srv.peerDisconnect {
srv.removePeer(peer)
}
} }
// main loop for adding connections via listening // main loop for adding connections via listening
func (srv *Server) listenLoop() { func (srv *Server) listenLoop() {
defer srv.wg.Done() defer srv.loopWG.Done()
srvlog.Infoln("Listening on", srv.listener.Addr()) srvlog.Infoln("Listening on", srv.listener.Addr())
for { for {
select {
case slot := <-srv.peerSlots:
srvlog.Debugf("grabbed slot %v for listening", slot)
conn, err := srv.listener.Accept() conn, err := srv.listener.Accept()
if err != nil { if err != nil {
srv.peerSlots <- slot
return
}
srvlog.Debugf("Accepted conn %v (slot %d)\n", conn.RemoteAddr(), slot)
srv.addPeer(conn, nil, slot)
case <-srv.quit:
return return
} }
srvlog.Debugf("Accepted conn %v\n", conn.RemoteAddr())
srv.peerWG.Add(1)
go srv.startPeer(conn, nil)
} }
} }
func (srv *Server) natLoop(port int) { func (srv *Server) natLoop(port int) {
defer srv.wg.Done() defer srv.loopWG.Done()
for { for {
srv.updatePortMapping(port) srv.updatePortMapping(port)
select { select {
@ -314,108 +316,131 @@ func (srv *Server) removePortMapping(port int) {
} }
func (srv *Server) dialLoop() { func (srv *Server) dialLoop() {
defer srv.wg.Done() defer srv.loopWG.Done()
var ( refresh := time.NewTicker(refreshPeersInterval)
suggest chan *peerAddr defer refresh.Stop()
slot *int
slots = srv.peerSlots srv.ntab.Bootstrap(srv.BootstrapNodes)
) go srv.findPeers()
dialed := make(chan *discover.Node)
dialing := make(map[discover.NodeID]bool)
// TODO: limit number of active dials
// TODO: ensure only one findPeers goroutine is running
// TODO: pause findPeers when we're at capacity
for { for {
select { select {
case i := <-slots: case <-refresh.C:
// we need a peer in slot i, slot reserved
slot = &i
// now we can watch for candidate peers in the next loop
suggest = srv.peerConnect
// do not consume more until candidate peer is found
slots = nil
case desc := <-suggest: go srv.findPeers()
// candidate peer found, will dial out asyncronously
// if connection fails slot will be released case dest := <-srv.peerConnect:
srvlog.DebugDetailf("dial %v (%v)", desc, *slot) srv.lock.Lock()
go srv.dialPeer(desc, *slot) _, isconnected := srv.peers[dest.ID]
// we can watch if more peers needed in the next loop srv.lock.Unlock()
slots = srv.peerSlots if isconnected || dialing[dest.ID] {
// until then we dont care about candidate peers continue
suggest = nil }
dialing[dest.ID] = true
srv.peerWG.Add(1)
go func() {
srv.dialNode(dest)
// at this point, the peer has been added
// or discarded. either way, we're not dialing it anymore.
dialed <- dest
}()
case dest := <-dialed:
delete(dialing, dest.ID)
case <-srv.quit: case <-srv.quit:
// give back the currently reserved slot // TODO: maybe wait for active dials
if slot != nil {
srv.peerSlots <- *slot
}
return return
} }
} }
} }
// connect to peer via dial out func (srv *Server) dialNode(dest *discover.Node) {
func (srv *Server) dialPeer(desc *peerAddr, slot int) { srvlog.Debugf("Dialing %v\n", dest.Addr)
srvlog.Debugf("Dialing %v (slot %d)\n", desc, slot) conn, err := srv.Dialer.Dial("tcp", dest.Addr.String())
conn, err := srv.Dialer.Dial(desc.Network(), desc.String())
if err != nil { if err != nil {
srvlog.DebugDetailf("dial error: %v", err) srvlog.DebugDetailf("dial error: %v", err)
srv.peerSlots <- slot
return return
} }
go srv.addPeer(conn, desc, slot) srv.startPeer(conn, dest)
} }
// creates the new peer object and inserts it into its slot func (srv *Server) findPeers() {
func (srv *Server) addPeer(conn net.Conn, desc *peerAddr, slot int) *Peer { far := srv.ntab.Self()
for i := range far {
far[i] = ^far[i]
}
closeToSelf := srv.ntab.Lookup(srv.ntab.Self())
farFromSelf := srv.ntab.Lookup(far)
for i := 0; i < len(closeToSelf) || i < len(farFromSelf); i++ {
if i < len(closeToSelf) {
srv.peerConnect <- closeToSelf[i]
}
if i < len(farFromSelf) {
srv.peerConnect <- farFromSelf[i]
}
}
}
func (srv *Server) startPeer(conn net.Conn, dest *discover.Node) {
// TODO: I/O timeout, handle/store session token
remoteID, _, err := srv.handshakeFunc(conn, srv.PrivateKey, dest)
if err != nil {
conn.Close()
srvlog.Debugf("Encryption Handshake with %v failed: %v", conn.RemoteAddr(), err)
return
}
ourID := srv.ntab.Self()
p := newPeer(conn, srv.Protocols, srv.Name, &ourID, &remoteID)
if ok, reason := srv.addPeer(remoteID, p); !ok {
p.Disconnect(reason)
return
}
srv.newPeerHook(p)
p.run()
srv.removePeer(p)
}
func (srv *Server) addPeer(id discover.NodeID, p *Peer) (bool, DiscReason) {
srv.lock.Lock() srv.lock.Lock()
defer srv.lock.Unlock() defer srv.lock.Unlock()
if !srv.running { switch {
conn.Close() case !srv.running:
srv.peerSlots <- slot // release slot return false, DiscQuitting
return nil case len(srv.peers) >= srv.MaxPeers:
return false, DiscTooManyPeers
case srv.peers[id] != nil:
return false, DiscAlreadyConnected
case srv.Blacklist.Exists(id[:]):
return false, DiscUselessPeer
case id == srv.ntab.Self():
return false, DiscSelf
} }
peer := srv.newPeerFunc(srv, conn, desc) srvlog.Debugf("Adding %v\n", p)
peer.slot = slot srv.peers[id] = p
srv.peers[slot] = peer return true, 0
srv.peerCount++
go func() { peer.loop(); srv.peerDisconnect <- peer }()
return peer
} }
// removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot // removes peer: sending disconnect msg, stop peer, remove rom list/table, release slot
func (srv *Server) removePeer(peer *Peer) { func (srv *Server) removePeer(p *Peer) {
srvlog.Debugf("Removing %v\n", p)
srv.lock.Lock() srv.lock.Lock()
defer srv.lock.Unlock() delete(srv.peers, *p.remoteID)
srvlog.Debugf("Removing %v (slot %v)\n", peer, peer.slot) srv.lock.Unlock()
if srv.peers[peer.slot] != peer { srv.peerWG.Done()
srvlog.Warnln("Invalid peer to remove:", peer)
return
}
// remove from list and index
srv.peerCount--
srv.peers[peer.slot] = nil
// release slot to signal need for a new peer, last!
srv.peerSlots <- peer.slot
} }
func (srv *Server) verifyPeer(addr *peerAddr) error {
if srv.Blacklist.Exists(addr.Pubkey) {
return errors.New("blacklisted")
}
if bytes.Equal(srv.Identity.Pubkey()[1:], addr.Pubkey) {
return newPeerError(errPubkeyForbidden, "not allowed to connect to srv")
}
srv.lock.RLock()
defer srv.lock.RUnlock()
for _, peer := range srv.peers {
if peer != nil {
id := peer.Identity()
if id != nil && bytes.Equal(id.Pubkey(), addr.Pubkey) {
return errors.New("already connected")
}
}
}
return nil
}
// TODO replace with "Set"
type Blacklist interface { type Blacklist interface {
Get([]byte) (bool, error) Get([]byte) (bool, error)
Put([]byte) error Put([]byte) error

View File

@ -2,19 +2,28 @@ package p2p
import ( import (
"bytes" "bytes"
"crypto/ecdsa"
"io" "io"
"math/rand"
"net" "net"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/p2p/discover"
) )
func startTestServer(t *testing.T, pf peerFunc) *Server { func startTestServer(t *testing.T, pf newPeerHook) *Server {
server := &Server{ server := &Server{
Identity: &peerId{}, Name: "test",
MaxPeers: 10, MaxPeers: 10,
ListenAddr: "127.0.0.1:0", ListenAddr: "127.0.0.1:0",
newPeerFunc: pf, PrivateKey: newkey(),
newPeerHook: pf,
handshakeFunc: func(io.ReadWriter, *ecdsa.PrivateKey, *discover.Node) (id discover.NodeID, st []byte, err error) {
return randomID(), nil, err
},
} }
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err) t.Fatalf("Could not start server: %v", err)
@ -27,16 +36,11 @@ func TestServerListen(t *testing.T) {
// start the test server // start the test server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) {
if conn == nil { if p == nil {
t.Error("peer func called with nil conn") t.Error("peer func called with nil conn")
} }
if dialAddr != nil { connected <- p
t.Error("peer func called with non-nil dialAddr")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
}) })
defer close(connected) defer close(connected)
defer srv.Stop() defer srv.Stop()
@ -50,9 +54,9 @@ func TestServerListen(t *testing.T) {
select { select {
case peer := <-connected: case peer := <-connected:
if peer.conn.LocalAddr().String() != conn.RemoteAddr().String() { if peer.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.LocalAddr(), conn.RemoteAddr()) peer.LocalAddr(), conn.RemoteAddr())
} }
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not accept within one second") t.Error("server did not accept within one second")
@ -62,7 +66,7 @@ func TestServerListen(t *testing.T) {
func TestServerDial(t *testing.T) { func TestServerDial(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
// run a fake TCP server to handle the connection. // run a one-shot TCP server to handle the connection.
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("could not setup listener: %v") t.Fatalf("could not setup listener: %v")
@ -72,41 +76,33 @@ func TestServerDial(t *testing.T) {
go func() { go func() {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
t.Error("acccept error:", err) t.Error("accept error:", err)
return
} }
conn.Close() conn.Close()
accepted <- conn accepted <- conn
}() }()
// start the test server // start the server
connected := make(chan *Peer) connected := make(chan *Peer)
srv := startTestServer(t, func(srv *Server, conn net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) { connected <- p })
if conn == nil {
t.Error("peer func called with nil conn")
}
peer := newPeer(conn, nil, dialAddr)
connected <- peer
return peer
})
defer close(connected) defer close(connected)
defer srv.Stop() defer srv.Stop()
// tell the server to connect. // tell the server to connect
connAddr := newPeerAddr(listener.Addr(), nil) tcpAddr := listener.Addr().(*net.TCPAddr)
connAddr := &discover.Node{Addr: &net.UDPAddr{IP: tcpAddr.IP, Port: tcpAddr.Port}}
srv.peerConnect <- connAddr srv.peerConnect <- connAddr
select { select {
case conn := <-accepted: case conn := <-accepted:
select { select {
case peer := <-connected: case peer := <-connected:
if peer.conn.RemoteAddr().String() != conn.LocalAddr().String() { if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v", t.Errorf("peer started with wrong conn: got %v, want %v",
peer.conn.RemoteAddr(), conn.LocalAddr()) peer.RemoteAddr(), conn.LocalAddr())
}
if peer.dialAddr != connAddr {
t.Errorf("peer started with wrong dialAddr: got %v, want %v",
peer.dialAddr, connAddr)
} }
// TODO: validate more fields
case <-time.After(1 * time.Second): case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second") t.Error("server did not launch peer within one second")
} }
@ -118,16 +114,16 @@ func TestServerDial(t *testing.T) {
func TestServerBroadcast(t *testing.T) { func TestServerBroadcast(t *testing.T) {
defer testlog(t).detach() defer testlog(t).detach()
var connected sync.WaitGroup var connected sync.WaitGroup
srv := startTestServer(t, func(srv *Server, c net.Conn, dialAddr *peerAddr) *Peer { srv := startTestServer(t, func(p *Peer) {
peer := newPeer(c, []Protocol{discard}, dialAddr) p.protocols = []Protocol{discard}
peer.startSubprotocols([]Cap{discard.cap()}) p.startSubprotocols([]Cap{discard.cap()})
connected.Done() connected.Done()
return peer
}) })
defer srv.Stop() defer srv.Stop()
// dial a bunch of conns // create a few peers
var conns = make([]net.Conn, 8) var conns = make([]net.Conn, 8)
connected.Add(len(conns)) connected.Add(len(conns))
deadline := time.Now().Add(3 * time.Second) deadline := time.Now().Add(3 * time.Second)
@ -159,3 +155,18 @@ func TestServerBroadcast(t *testing.T) {
} }
} }
} }
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
panic("couldn't generate key: " + err.Error())
}
return key
}
func randomID() (id discover.NodeID) {
for i := range id {
id[i] = byte(rand.Intn(255))
}
return id
}

View File

@ -15,7 +15,7 @@ func testlog(t *testing.T) testLogger {
return l return l
} }
func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugLevel } func (testLogger) GetLogLevel() logger.LogLevel { return logger.DebugDetailLevel }
func (testLogger) SetLogLevel(logger.LogLevel) {} func (testLogger) SetLogLevel(logger.LogLevel) {}
func (l testLogger) LogPrint(level logger.LogLevel, msg string) { func (l testLogger) LogPrint(level logger.LogLevel, msg string) {

View File

@ -1,40 +0,0 @@
// +build none
package main
import (
"fmt"
"log"
"net"
"os"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"github.com/ethereum/go-ethereum/logger"
"github.com/ethereum/go-ethereum/p2p"
)
func main() {
logger.AddLogSystem(logger.NewStdLogSystem(os.Stdout, log.LstdFlags, logger.DebugLevel))
pub, _ := secp256k1.GenerateKeyPair()
srv := p2p.Server{
MaxPeers: 10,
Identity: p2p.NewSimpleClientIdentity("test", "1.0", "", string(pub)),
ListenAddr: ":30303",
NAT: p2p.PMP(net.ParseIP("10.0.0.1")),
}
if err := srv.Start(); err != nil {
fmt.Println("could not start server:", err)
os.Exit(1)
}
// add seed peers
seed, err := net.ResolveTCPAddr("tcp", "poc-7.ethdev.com:30303")
if err != nil {
fmt.Println("couldn't resolve:", err)
os.Exit(1)
}
srv.SuggestPeer(seed.IP, seed.Port, nil)
select {}
}