diff --git a/cmd/bootnode/main.go b/cmd/bootnode/main.go index abecac3d8..9b5ba1936 100644 --- a/cmd/bootnode/main.go +++ b/cmd/bootnode/main.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" ) func main() { @@ -39,6 +40,7 @@ func main() { nodeKeyFile = flag.String("nodekey", "", "private key filename") nodeKeyHex = flag.String("nodekeyhex", "", "private key as hex (for testing)") natdesc = flag.String("nat", "none", "port mapping mechanism (any|none|upnp|pmp|extip:)") + netrestrict = flag.String("netrestrict", "", "restrict network communication to the given IP networks (CIDR masks)") runv5 = flag.Bool("v5", false, "run a v5 topic discovery bootnode") nodeKey *ecdsa.PrivateKey @@ -81,12 +83,20 @@ func main() { os.Exit(0) } + var restrictList *netutil.Netlist + if *netrestrict != "" { + restrictList, err = netutil.ParseNetlist(*netrestrict) + if err != nil { + utils.Fatalf("-netrestrict: %v", err) + } + } + if *runv5 { - if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil { + if _, err := discv5.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil { utils.Fatalf("%v", err) } } else { - if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, ""); err != nil { + if _, err := discover.ListenUDP(nodeKey, *listenAddr, natm, "", restrictList); err != nil { utils.Fatalf("%v", err) } } diff --git a/cmd/bzzd/main.go b/cmd/bzzd/main.go index b2f14a4a9..a3e87dc8a 100644 --- a/cmd/bzzd/main.go +++ b/cmd/bzzd/main.go @@ -96,6 +96,7 @@ func init() { utils.BootnodesFlag, utils.KeyStoreDirFlag, utils.ListenPortFlag, + utils.NetrestrictFlag, utils.MaxPeersFlag, utils.NATFlag, utils.NodeKeyFileFlag, diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 13d771790..a275d8aa5 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -148,6 +148,7 @@ participating. utils.NatspecEnabledFlag, utils.NoDiscoverFlag, utils.DiscoveryV5Flag, + utils.NetrestrictFlag, utils.NodeKeyFileFlag, utils.NodeKeyHexFlag, utils.RPCEnabledFlag, diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 3bb625387..5c09e44ec 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -45,6 +45,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/pow" "github.com/ethereum/go-ethereum/rpc" @@ -366,10 +367,16 @@ var ( Name: "v5disc", Usage: "Enables the experimental RLPx V5 (Topic Discovery) mechanism", } + NetrestrictFlag = cli.StringFlag{ + Name: "netrestrict", + Usage: "Restricts network communication to the given IP networks (CIDR masks)", + } + WhisperEnabledFlag = cli.BoolFlag{ Name: "shh", Usage: "Enable Whisper", } + // ATM the url is left to the user and deployment to JSpathFlag = cli.StringFlag{ Name: "jspath", @@ -693,6 +700,14 @@ func MakeNode(ctx *cli.Context, name, gitCommit string) *node.Node { config.MaxPeers = 0 config.ListenAddr = ":0" } + if netrestrict := ctx.GlobalString(NetrestrictFlag.Name); netrestrict != "" { + list, err := netutil.ParseNetlist(netrestrict) + if err != nil { + Fatalf("Option %q: %v", NetrestrictFlag.Name, err) + } + config.NetRestrict = list + } + stack, err := node.New(config) if err != nil { Fatalf("Failed to create the protocol stack: %v", err) diff --git a/node/config.go b/node/config.go index 8d85b7ff8..8d75e441b 100644 --- a/node/config.go +++ b/node/config.go @@ -34,6 +34,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" ) var ( @@ -103,6 +104,10 @@ type Config struct { // Listener address for the V5 discovery protocol UDP traffic. DiscoveryV5Addr string + // Restrict communication to white listed IP networks. + // The whitelist only applies when non-nil. + NetRestrict *netutil.Netlist + // BootstrapNodes used to establish connectivity with the rest of the network. BootstrapNodes []*discover.Node diff --git a/node/node.go b/node/node.go index d49ae3a45..4b56fba4c 100644 --- a/node/node.go +++ b/node/node.go @@ -165,6 +165,7 @@ func (n *Node) Start() error { TrustedNodes: n.config.TrusterNodes(), NodeDatabase: n.config.NodeDB(), ListenAddr: n.config.ListenAddr, + NetRestrict: n.config.NetRestrict, NAT: n.config.NAT, Dialer: n.config.Dialer, NoDial: n.config.NoDial, diff --git a/p2p/dial.go b/p2p/dial.go index 691b8539e..57fba136a 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -19,6 +19,7 @@ package p2p import ( "container/heap" "crypto/rand" + "errors" "fmt" "net" "time" @@ -26,6 +27,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/netutil" ) const ( @@ -48,6 +50,7 @@ const ( type dialstate struct { maxDynDials int ntab discoverTable + netrestrict *netutil.Netlist lookupRunning bool dialing map[discover.NodeID]connFlag @@ -100,10 +103,11 @@ type waitExpireTask struct { time.Duration } -func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int) *dialstate { +func newDialState(static []*discover.Node, ntab discoverTable, maxdyn int, netrestrict *netutil.Netlist) *dialstate { s := &dialstate{ maxDynDials: maxdyn, ntab: ntab, + netrestrict: netrestrict, static: make(map[discover.NodeID]*dialTask), dialing: make(map[discover.NodeID]connFlag), randomNodes: make([]*discover.Node, maxdyn/2), @@ -128,12 +132,9 @@ func (s *dialstate) removeStatic(n *discover.Node) { func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now time.Time) []task { var newtasks []task - isDialing := func(id discover.NodeID) bool { - _, found := s.dialing[id] - return found || peers[id] != nil || s.hist.contains(id) - } addDial := func(flag connFlag, n *discover.Node) bool { - if isDialing(n.ID) { + if err := s.checkDial(n, peers); err != nil { + glog.V(logger.Debug).Infof("skipping dial candidate %x@%v:%d: %v", n.ID[:8], n.IP, n.TCP, err) return false } s.dialing[n.ID] = flag @@ -159,7 +160,12 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now // Create dials for static nodes if they are not connected. for id, t := range s.static { - if !isDialing(id) { + err := s.checkDial(t.dest, peers) + switch err { + case errNotWhitelisted, errSelf: + glog.V(logger.Debug).Infof("removing static dial candidate %x@%v:%d: %v", t.dest.ID[:8], t.dest.IP, t.dest.TCP, err) + delete(s.static, t.dest.ID) + case nil: s.dialing[id] = t.flags newtasks = append(newtasks, t) } @@ -202,6 +208,31 @@ func (s *dialstate) newTasks(nRunning int, peers map[discover.NodeID]*Peer, now return newtasks } +var ( + errSelf = errors.New("is self") + errAlreadyDialing = errors.New("already dialing") + errAlreadyConnected = errors.New("already connected") + errRecentlyDialed = errors.New("recently dialed") + errNotWhitelisted = errors.New("not contained in netrestrict whitelist") +) + +func (s *dialstate) checkDial(n *discover.Node, peers map[discover.NodeID]*Peer) error { + _, dialing := s.dialing[n.ID] + switch { + case dialing: + return errAlreadyDialing + case peers[n.ID] != nil: + return errAlreadyConnected + case s.ntab != nil && n.ID == s.ntab.Self().ID: + return errSelf + case s.netrestrict != nil && !s.netrestrict.Contains(n.IP): + return errNotWhitelisted + case s.hist.contains(n.ID): + return errRecentlyDialed + } + return nil +} + func (s *dialstate) taskDone(t task, now time.Time) { switch t := t.(type) { case *dialTask: diff --git a/p2p/dial_test.go b/p2p/dial_test.go index 05d9b7562..c850233db 100644 --- a/p2p/dial_test.go +++ b/p2p/dial_test.go @@ -25,6 +25,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/netutil" ) func init() { @@ -86,7 +87,7 @@ func (t fakeTable) ReadRandomNodes(buf []*discover.Node) int { return copy(buf, // This test checks that dynamic dials are launched from discovery results. func TestDialStateDynDial(t *testing.T) { runDialTest(t, dialtest{ - init: newDialState(nil, fakeTable{}, 5), + init: newDialState(nil, fakeTable{}, 5, nil), rounds: []round{ // A discovery query is launched. { @@ -233,7 +234,7 @@ func TestDialStateDynDialFromTable(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(nil, table, 10), + init: newDialState(nil, table, 10, nil), rounds: []round{ // 5 out of 8 of the nodes returned by ReadRandomNodes are dialed. { @@ -313,6 +314,36 @@ func TestDialStateDynDialFromTable(t *testing.T) { }) } +// This test checks that candidates that do not match the netrestrict list are not dialed. +func TestDialStateNetRestrict(t *testing.T) { + // This table always returns the same random nodes + // in the order given below. + table := fakeTable{ + {ID: uintID(1), IP: net.ParseIP("127.0.0.1")}, + {ID: uintID(2), IP: net.ParseIP("127.0.0.2")}, + {ID: uintID(3), IP: net.ParseIP("127.0.0.3")}, + {ID: uintID(4), IP: net.ParseIP("127.0.0.4")}, + {ID: uintID(5), IP: net.ParseIP("127.0.2.5")}, + {ID: uintID(6), IP: net.ParseIP("127.0.2.6")}, + {ID: uintID(7), IP: net.ParseIP("127.0.2.7")}, + {ID: uintID(8), IP: net.ParseIP("127.0.2.8")}, + } + restrict := new(netutil.Netlist) + restrict.Add("127.0.2.0/24") + + runDialTest(t, dialtest{ + init: newDialState(nil, table, 10, restrict), + rounds: []round{ + { + new: []task{ + &dialTask{flags: dynDialedConn, dest: table[4]}, + &discoverTask{}, + }, + }, + }, + }) +} + // This test checks that static dials are launched. func TestDialStateStaticDial(t *testing.T) { wantStatic := []*discover.Node{ @@ -324,7 +355,7 @@ func TestDialStateStaticDial(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(wantStatic, fakeTable{}, 0), + init: newDialState(wantStatic, fakeTable{}, 0, nil), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -405,7 +436,7 @@ func TestDialStateCache(t *testing.T) { } runDialTest(t, dialtest{ - init: newDialState(wantStatic, fakeTable{}, 0), + init: newDialState(wantStatic, fakeTable{}, 0, nil), rounds: []round{ // Static dials are launched for the nodes that // aren't yet connected. @@ -467,7 +498,7 @@ func TestDialStateCache(t *testing.T) { func TestDialResolve(t *testing.T) { resolved := discover.NewNode(uintID(1), net.IP{127, 0, 55, 234}, 3333, 4444) table := &resolveMock{answer: resolved} - state := newDialState(nil, table, 0) + state := newDialState(nil, table, 0, nil) // Check that the task is generated with an incomplete ID. dest := discover.NewNode(uintID(1), nil, 0, 0) diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 1a2405740..102c7c2d1 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -146,6 +146,7 @@ func fillBucket(tab *Table, ld int) (last *Node) { func nodeAtDistance(base common.Hash, ld int) (n *Node) { n = new(Node) n.sha = hashAtDistance(base, ld) + n.IP = net.IP{10, 0, 2, byte(ld)} copy(n.ID[:], n.sha[:]) // ensure the node still has a unique ID return n } diff --git a/p2p/discover/udp.go b/p2p/discover/udp.go index 74758b6fd..e09c63ffb 100644 --- a/p2p/discover/udp.go +++ b/p2p/discover/udp.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -126,8 +127,16 @@ func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint { return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} } -func nodeFromRPC(rn rpcNode) (*Node, error) { - // TODO: don't accept localhost, LAN addresses from internet hosts +func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { + if rn.UDP <= 1024 { + return nil, errors.New("low port") + } + if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { + return nil, err + } + if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { + return nil, errors.New("not contained in netrestrict whitelist") + } n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) err := n.validateComplete() return n, err @@ -151,6 +160,7 @@ type conn interface { // udp implements the RPC protocol. type udp struct { conn conn + netrestrict *netutil.Netlist priv *ecdsa.PrivateKey ourEndpoint rpcEndpoint @@ -201,7 +211,7 @@ type reply struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) { +func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, error) { addr, err := net.ResolveUDPAddr("udp", laddr) if err != nil { return nil, err @@ -210,7 +220,7 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP if err != nil { return nil, err } - tab, _, err := newUDP(priv, conn, natm, nodeDBPath) + tab, _, err := newUDP(priv, conn, natm, nodeDBPath, netrestrict) if err != nil { return nil, err } @@ -218,13 +228,14 @@ func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBP return tab, nil } -func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp, error) { +func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Table, *udp, error) { udp := &udp{ - conn: c, - priv: priv, - closing: make(chan struct{}), - gotreply: make(chan reply), - addpending: make(chan *pending), + conn: c, + priv: priv, + netrestrict: netrestrict, + closing: make(chan struct{}), + gotreply: make(chan reply), + addpending: make(chan *pending), } realaddr := c.LocalAddr().(*net.UDPAddr) if natm != nil { @@ -281,9 +292,12 @@ func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node reply := r.(*neighbors) for _, rn := range reply.Nodes { nreceived++ - if n, err := nodeFromRPC(rn); err == nil { - nodes = append(nodes, n) + n, err := t.nodeFromRPC(toaddr, rn) + if err != nil { + glog.V(logger.Detail).Infof("invalid neighbor node (%v) from %v: %v", rn.IP, toaddr, err) + continue } + nodes = append(nodes, n) } return nreceived >= bucketSize }) @@ -479,13 +493,6 @@ func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, return packet, nil } -func isTemporaryError(err error) bool { - tempErr, ok := err.(interface { - Temporary() bool - }) - return ok && tempErr.Temporary() || isPacketTooBig(err) -} - // readLoop runs in its own goroutine. it handles incoming UDP packets. func (t *udp) readLoop() { defer t.conn.Close() @@ -495,7 +502,7 @@ func (t *udp) readLoop() { buf := make([]byte, 1280) for { nbytes, from, err := t.conn.ReadFromUDP(buf) - if isTemporaryError(err) { + if netutil.IsTemporaryError(err) { // Ignore temporary read errors. glog.V(logger.Debug).Infof("Temporary read error: %v", err) continue @@ -602,6 +609,9 @@ func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte // Send neighbors in chunks with at most maxNeighbors per packet // to stay below the 1280 byte limit. for i, n := range closest { + if netutil.CheckRelayIP(from.IP, n.IP) != nil { + continue + } p.Nodes = append(p.Nodes, nodeToRPC(n)) if len(p.Nodes) == maxNeighbors || i == len(closest)-1 { t.send(from, neighborsPacket, p) diff --git a/p2p/discover/udp_test.go b/p2p/discover/udp_test.go index f43bf3726..53cfac6f9 100644 --- a/p2p/discover/udp_test.go +++ b/p2p/discover/udp_test.go @@ -43,56 +43,6 @@ func init() { spew.Config.DisableMethods = true } -// This test checks that isPacketTooBig correctly identifies -// errors that result from receiving a UDP packet larger -// than the supplied receive buffer. -func TestIsPacketTooBig(t *testing.T) { - listener, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer listener.Close() - sender, err := net.Dial("udp", listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - defer sender.Close() - - sendN := 1800 - recvN := 300 - for i := 0; i < 20; i++ { - go func() { - buf := make([]byte, sendN) - for i := range buf { - buf[i] = byte(i) - } - sender.Write(buf) - }() - - buf := make([]byte, recvN) - listener.SetDeadline(time.Now().Add(1 * time.Second)) - n, _, err := listener.ReadFrom(buf) - if err != nil { - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - continue - } - if !isPacketTooBig(err) { - t.Fatal("unexpected read error:", spew.Sdump(err)) - } - continue - } - if n != recvN { - t.Fatalf("short read: %d, want %d", n, recvN) - } - for i := range buf { - if buf[i] != byte(i) { - t.Fatalf("error in pattern") - break - } - } - } -} - // shared test variables var ( futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) @@ -118,9 +68,9 @@ func newUDPTest(t *testing.T) *udpTest { pipe: newpipe(), localkey: newkey(), remotekey: newkey(), - remoteaddr: &net.UDPAddr{IP: net.IP{1, 2, 3, 4}, Port: 30303}, + remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, } - test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "") + test.table, test.udp, _ = newUDP(test.localkey, test.pipe, nil, "", nil) return test } @@ -362,8 +312,9 @@ func TestUDP_findnodeMultiReply(t *testing.T) { // check that the sent neighbors are all returned by findnode select { case result := <-resultc: - if !reflect.DeepEqual(result, list) { - t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, list) + want := append(list[:2], list[3:]...) + if !reflect.DeepEqual(result, want) { + t.Errorf("neighbors mismatch:\n got: %v\n want: %v", result, want) } case err := <-errc: t.Errorf("findnode error: %v", err) diff --git a/p2p/discv5/net.go b/p2p/discv5/net.go index 7ad6f1e5b..d1c48904e 100644 --- a/p2p/discv5/net.go +++ b/p2p/discv5/net.go @@ -31,6 +31,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -45,6 +46,7 @@ const ( bucketRefreshInterval = 1 * time.Minute seedCount = 30 seedMaxAge = 5 * 24 * time.Hour + lowPort = 1024 ) const testTopic = "foo" @@ -62,8 +64,9 @@ func debugLog(s string) { // Network manages the table and all protocol interaction. type Network struct { - db *nodeDB // database of known nodes - conn transport + db *nodeDB // database of known nodes + conn transport + netrestrict *netutil.Netlist closed chan struct{} // closed when loop is done closeReq chan struct{} // 'request to close' @@ -132,7 +135,7 @@ type timeoutEvent struct { node *Node } -func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string) (*Network, error) { +func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, dbPath string, netrestrict *netutil.Netlist) (*Network, error) { ourID := PubkeyID(&ourPubkey) var db *nodeDB @@ -147,6 +150,7 @@ func newNetwork(conn transport, ourPubkey ecdsa.PublicKey, natm nat.Interface, d net := &Network{ db: db, conn: conn, + netrestrict: netrestrict, tab: tab, topictab: newTopicTable(db, tab.self), ticketStore: newTicketStore(), @@ -684,16 +688,22 @@ func (net *Network) internNodeFromDB(dbn *Node) *Node { return n } -func (net *Network) internNodeFromNeighbours(rn rpcNode) (n *Node, err error) { +func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) { if rn.ID == net.tab.self.ID { return nil, errors.New("is self") } + if rn.UDP <= lowPort { + return nil, errors.New("low port") + } n = net.nodes[rn.ID] if n == nil { // We haven't seen this node before. - n, err = nodeFromRPC(rn) - n.state = unknown + n, err = nodeFromRPC(sender, rn) + if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) { + return n, errors.New("not contained in netrestrict whitelist") + } if err == nil { + n.state = unknown net.nodes[n.ID] = n } return n, err @@ -1095,7 +1105,7 @@ func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) net.conn.sendNeighbours(n, results) return n.state, nil case neighborsPacket: - err := net.handleNeighboursPacket(n, pkt.data.(*neighbors)) + err := net.handleNeighboursPacket(n, pkt) return n.state, err case neighboursTimeout: if n.pendingNeighbours != nil { @@ -1182,17 +1192,18 @@ func rlpHash(x interface{}) (h common.Hash) { return h } -func (net *Network) handleNeighboursPacket(n *Node, req *neighbors) error { +func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error { if n.pendingNeighbours == nil { return errNoQuery } net.abortTimedEvent(n, neighboursTimeout) + req := pkt.data.(*neighbors) nodes := make([]*Node, len(req.Nodes)) for i, rn := range req.Nodes { - nn, err := net.internNodeFromNeighbours(rn) + nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn) if err != nil { - glog.V(logger.Debug).Infof("invalid neighbour from %x: %v", n.ID[:8], err) + glog.V(logger.Debug).Infof("invalid neighbour (%v) from %x@%v: %v", rn.IP, n.ID[:8], pkt.remoteAddr, err) continue } nodes[i] = nn diff --git a/p2p/discv5/net_test.go b/p2p/discv5/net_test.go index 422daa33b..327457c7c 100644 --- a/p2p/discv5/net_test.go +++ b/p2p/discv5/net_test.go @@ -28,7 +28,7 @@ import ( func TestNetwork_Lookup(t *testing.T) { key, _ := crypto.GenerateKey() - network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "") + network, err := newNetwork(lookupTestnet, key.PublicKey, nil, "", nil) if err != nil { t.Fatal(err) } @@ -40,7 +40,7 @@ func TestNetwork_Lookup(t *testing.T) { // t.Fatalf("lookup on empty table returned %d results: %#v", len(results), results) // } // seed table with initial node (otherwise lookup will terminate immediately) - seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{}, 256, 999)} + seeds := []*Node{NewNode(lookupTestnet.dists[256][0], net.IP{10, 0, 2, 99}, lowPort+256, 999)} if err := network.SetFallbackNodes(seeds); err != nil { t.Fatal(err) } @@ -272,13 +272,13 @@ func (tn *preminedTestnet) sendFindnode(to *Node, target NodeID) { func (tn *preminedTestnet) sendFindnodeHash(to *Node, target common.Hash) { // current log distance is encoded in port number // fmt.Println("findnode query at dist", toaddr.Port) - if to.UDP == 0 { - panic("query to node at distance 0") + if to.UDP <= lowPort { + panic("query to node at or below distance 0") } next := to.UDP - 1 var result []rpcNode - for i, id := range tn.dists[to.UDP] { - result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) + for i, id := range tn.dists[to.UDP-lowPort] { + result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort))) } injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) } @@ -296,14 +296,14 @@ func (tn *preminedTestnet) send(to *Node, ptype nodeEvent, data interface{}) (ha // ignored case findnodeHashPacket: // current log distance is encoded in port number - // fmt.Println("findnode query at dist", toaddr.Port) - if to.UDP == 0 { - panic("query to node at distance 0") + // fmt.Println("findnode query at dist", toaddr.Port-lowPort) + if to.UDP <= lowPort { + panic("query to node at or below distance 0") } next := to.UDP - 1 var result []rpcNode - for i, id := range tn.dists[to.UDP] { - result = append(result, nodeToRPC(NewNode(id, net.ParseIP("127.0.0.1"), next, uint16(i)+1))) + for i, id := range tn.dists[to.UDP-lowPort] { + result = append(result, nodeToRPC(NewNode(id, net.ParseIP("10.0.2.99"), next, uint16(i)+1+lowPort))) } injectResponse(tn.net, to, neighborsPacket, &neighbors{Nodes: result}) default: @@ -328,8 +328,11 @@ func (tn *preminedTestnet) sendTopicRegister(to *Node, topics []Topic, idx int, panic("sendTopicRegister called") } -func (*preminedTestnet) Close() {} -func (*preminedTestnet) localAddr() *net.UDPAddr { return new(net.UDPAddr) } +func (*preminedTestnet) Close() {} + +func (*preminedTestnet) localAddr() *net.UDPAddr { + return &net.UDPAddr{IP: net.ParseIP("10.0.1.1"), Port: 40000} +} // mine generates a testnet struct literal with nodes at // various distances to the given target. diff --git a/p2p/discv5/sim_test.go b/p2p/discv5/sim_test.go index 2e232fbaa..cb64d7fa0 100644 --- a/p2p/discv5/sim_test.go +++ b/p2p/discv5/sim_test.go @@ -290,7 +290,7 @@ func (s *simulation) launchNode(log bool) *Network { addr := &net.UDPAddr{IP: ip, Port: 30303} transport := &simTransport{joinTime: time.Now(), sender: id, senderAddr: addr, sim: s, priv: key} - net, err := newNetwork(transport, key.PublicKey, nil, "") + net, err := newNetwork(transport, key.PublicKey, nil, "", nil) if err != nil { panic("cannot launch new node: " + err.Error()) } diff --git a/p2p/discv5/udp.go b/p2p/discv5/udp.go index 46d3200bf..a6114e032 100644 --- a/p2p/discv5/udp.go +++ b/p2p/discv5/udp.go @@ -29,6 +29,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/rlp" ) @@ -198,8 +199,10 @@ func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool { return e1.UDP == e2.UDP && e1.TCP == e2.TCP && bytes.Equal(e1.IP, e2.IP) } -func nodeFromRPC(rn rpcNode) (*Node, error) { - // TODO: don't accept localhost, LAN addresses from internet hosts +func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) { + if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { + return nil, err + } n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP) err := n.validateComplete() return n, err @@ -235,12 +238,12 @@ type udp struct { } // ListenUDP returns a new table that listens for UDP packets on laddr. -func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Network, error) { +func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) { transport, err := listenUDP(priv, laddr) if err != nil { return nil, err } - net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath) + net, err := newNetwork(transport, priv.PublicKey, natm, nodeDBPath, netrestrict) if err != nil { return nil, err } @@ -327,6 +330,9 @@ func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) return } for i, result := range nodes { + if netutil.CheckRelayIP(remote.IP, result.IP) != nil { + continue + } p.Nodes = append(p.Nodes, nodeToRPC(result)) if len(p.Nodes) == maxTopicNodes || i == len(nodes)-1 { t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p) @@ -385,7 +391,7 @@ func (t *udp) readLoop() { buf := make([]byte, 1280) for { nbytes, from, err := t.conn.ReadFromUDP(buf) - if isTemporaryError(err) { + if netutil.IsTemporaryError(err) { // Ignore temporary read errors. glog.V(logger.Debug).Infof("Temporary read error: %v", err) continue @@ -398,13 +404,6 @@ func (t *udp) readLoop() { } } -func isTemporaryError(err error) bool { - tempErr, ok := err.(interface { - Temporary() bool - }) - return ok && tempErr.Temporary() || isPacketTooBig(err) -} - func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error { pkt := ingressPacket{remoteAddr: from} if err := decodePacket(buf, &pkt); err != nil { diff --git a/p2p/discv5/udp_test.go b/p2p/discv5/udp_test.go index cacc0f004..98c737669 100644 --- a/p2p/discv5/udp_test.go +++ b/p2p/discv5/udp_test.go @@ -36,56 +36,6 @@ func init() { spew.Config.DisableMethods = true } -// This test checks that isPacketTooBig correctly identifies -// errors that result from receiving a UDP packet larger -// than the supplied receive buffer. -func TestIsPacketTooBig(t *testing.T) { - listener, err := net.ListenPacket("udp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer listener.Close() - sender, err := net.Dial("udp", listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - defer sender.Close() - - sendN := 1800 - recvN := 300 - for i := 0; i < 20; i++ { - go func() { - buf := make([]byte, sendN) - for i := range buf { - buf[i] = byte(i) - } - sender.Write(buf) - }() - - buf := make([]byte, recvN) - listener.SetDeadline(time.Now().Add(1 * time.Second)) - n, _, err := listener.ReadFrom(buf) - if err != nil { - if nerr, ok := err.(net.Error); ok && nerr.Timeout() { - continue - } - if !isPacketTooBig(err) { - t.Fatal("unexpected read error:", spew.Sdump(err)) - } - continue - } - if n != recvN { - t.Fatalf("short read: %d, want %d", n, recvN) - } - for i := range buf { - if buf[i] != byte(i) { - t.Fatalf("error in pattern") - break - } - } - } -} - // shared test variables var ( futureExp = uint64(time.Now().Add(10 * time.Hour).Unix()) diff --git a/p2p/discv5/udp_windows.go b/p2p/discv5/udp_windows.go deleted file mode 100644 index 1ab9d655e..000000000 --- a/p2p/discv5/udp_windows.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright 2016 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -//+build windows - -package discv5 - -import ( - "net" - "os" - "syscall" -) - -const _WSAEMSGSIZE = syscall.Errno(10040) - -// reports whether err indicates that a UDP packet didn't -// fit the receive buffer. On Windows, WSARecvFrom returns -// code WSAEMSGSIZE and no data if this happens. -func isPacketTooBig(err error) bool { - if opErr, ok := err.(*net.OpError); ok { - if scErr, ok := opErr.Err.(*os.SyscallError); ok { - return scErr.Err == _WSAEMSGSIZE - } - return opErr.Err == _WSAEMSGSIZE - } - return false -} diff --git a/p2p/discv5/udp_notwindows.go b/p2p/netutil/error.go similarity index 75% rename from p2p/discv5/udp_notwindows.go rename to p2p/netutil/error.go index 4da18d0f6..cb21b9cd4 100644 --- a/p2p/discv5/udp_notwindows.go +++ b/p2p/netutil/error.go @@ -14,13 +14,12 @@ // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . -//+build !windows +package netutil -package discv5 - -// reports whether err indicates that a UDP packet didn't -// fit the receive buffer. There is no such error on -// non-Windows platforms. -func isPacketTooBig(err error) bool { - return false +// IsTemporaryError checks whether the given error should be considered temporary. +func IsTemporaryError(err error) bool { + tempErr, ok := err.(interface { + Temporary() bool + }) + return ok && tempErr.Temporary() || isPacketTooBig(err) } diff --git a/p2p/netutil/error_test.go b/p2p/netutil/error_test.go new file mode 100644 index 000000000..645e48f83 --- /dev/null +++ b/p2p/netutil/error_test.go @@ -0,0 +1,73 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package netutil + +import ( + "net" + "testing" + "time" +) + +// This test checks that isPacketTooBig correctly identifies +// errors that result from receiving a UDP packet larger +// than the supplied receive buffer. +func TestIsPacketTooBig(t *testing.T) { + listener, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + sender, err := net.Dial("udp", listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + defer sender.Close() + + sendN := 1800 + recvN := 300 + for i := 0; i < 20; i++ { + go func() { + buf := make([]byte, sendN) + for i := range buf { + buf[i] = byte(i) + } + sender.Write(buf) + }() + + buf := make([]byte, recvN) + listener.SetDeadline(time.Now().Add(1 * time.Second)) + n, _, err := listener.ReadFrom(buf) + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + continue + } + if !isPacketTooBig(err) { + t.Fatalf("unexpected read error: %v", err) + } + continue + } + if n != recvN { + t.Fatalf("short read: %d, want %d", n, recvN) + } + for i := range buf { + if buf[i] != byte(i) { + t.Fatalf("error in pattern") + break + } + } + } +} diff --git a/p2p/netutil/net.go b/p2p/netutil/net.go new file mode 100644 index 000000000..3c3715788 --- /dev/null +++ b/p2p/netutil/net.go @@ -0,0 +1,166 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +// Package netutil contains extensions to the net package. +package netutil + +import ( + "errors" + "net" + "strings" +) + +var lan4, lan6, special4, special6 Netlist + +func init() { + // Lists from RFC 5735, RFC 5156, + // https://www.iana.org/assignments/iana-ipv4-special-registry/ + lan4.Add("0.0.0.0/8") // "This" network + lan4.Add("10.0.0.0/8") // Private Use + lan4.Add("172.16.0.0/12") // Private Use + lan4.Add("192.168.0.0/16") // Private Use + lan6.Add("fe80::/10") // Link-Local + lan6.Add("fc00::/7") // Unique-Local + special4.Add("192.0.0.0/29") // IPv4 Service Continuity + special4.Add("192.0.0.9/32") // PCP Anycast + special4.Add("192.0.0.170/32") // NAT64/DNS64 Discovery + special4.Add("192.0.0.171/32") // NAT64/DNS64 Discovery + special4.Add("192.0.2.0/24") // TEST-NET-1 + special4.Add("192.31.196.0/24") // AS112 + special4.Add("192.52.193.0/24") // AMT + special4.Add("192.88.99.0/24") // 6to4 Relay Anycast + special4.Add("192.175.48.0/24") // AS112 + special4.Add("198.18.0.0/15") // Device Benchmark Testing + special4.Add("198.51.100.0/24") // TEST-NET-2 + special4.Add("203.0.113.0/24") // TEST-NET-3 + special4.Add("255.255.255.255/32") // Limited Broadcast + + // http://www.iana.org/assignments/iana-ipv6-special-registry/ + special6.Add("100::/64") + special6.Add("2001::/32") + special6.Add("2001:1::1/128") + special6.Add("2001:2::/48") + special6.Add("2001:3::/32") + special6.Add("2001:4:112::/48") + special6.Add("2001:5::/32") + special6.Add("2001:10::/28") + special6.Add("2001:20::/28") + special6.Add("2001:db8::/32") + special6.Add("2002::/16") +} + +// Netlist is a list of IP networks. +type Netlist []net.IPNet + +// ParseNetlist parses a comma-separated list of CIDR masks. +// Whitespace and extra commas are ignored. +func ParseNetlist(s string) (*Netlist, error) { + ws := strings.NewReplacer(" ", "", "\n", "", "\t", "") + masks := strings.Split(ws.Replace(s), ",") + l := make(Netlist, 0) + for _, mask := range masks { + if mask == "" { + continue + } + _, n, err := net.ParseCIDR(mask) + if err != nil { + return nil, err + } + l = append(l, *n) + } + return &l, nil +} + +// Add parses a CIDR mask and appends it to the list. It panics for invalid masks and is +// intended to be used for setting up static lists. +func (l *Netlist) Add(cidr string) { + _, n, err := net.ParseCIDR(cidr) + if err != nil { + panic(err) + } + *l = append(*l, *n) +} + +// Contains reports whether the given IP is contained in the list. +func (l *Netlist) Contains(ip net.IP) bool { + if l == nil { + return false + } + for _, net := range *l { + if net.Contains(ip) { + return true + } + } + return false +} + +// IsLAN reports whether an IP is a local network address. +func IsLAN(ip net.IP) bool { + if ip.IsLoopback() { + return true + } + if v4 := ip.To4(); v4 != nil { + return lan4.Contains(v4) + } + return lan6.Contains(ip) +} + +// IsSpecialNetwork reports whether an IP is located in a special-use network range +// This includes broadcast, multicast and documentation addresses. +func IsSpecialNetwork(ip net.IP) bool { + if ip.IsMulticast() { + return true + } + if v4 := ip.To4(); v4 != nil { + return special4.Contains(v4) + } + return special6.Contains(ip) +} + +var ( + errInvalid = errors.New("invalid IP") + errUnspecified = errors.New("zero address") + errSpecial = errors.New("special network") + errLoopback = errors.New("loopback address from non-loopback host") + errLAN = errors.New("LAN address from WAN host") +) + +// CheckRelayIP reports whether an IP relayed from the given sender IP +// is a valid connection target. +// +// There are four rules: +// - Special network addresses are never valid. +// - Loopback addresses are OK if relayed by a loopback host. +// - LAN addresses are OK if relayed by a LAN host. +// - All other addresses are always acceptable. +func CheckRelayIP(sender, addr net.IP) error { + if len(addr) != net.IPv4len && len(addr) != net.IPv6len { + return errInvalid + } + if addr.IsUnspecified() { + return errUnspecified + } + if IsSpecialNetwork(addr) { + return errSpecial + } + if addr.IsLoopback() && !sender.IsLoopback() { + return errLoopback + } + if IsLAN(addr) && !IsLAN(sender) { + return errLAN + } + return nil +} diff --git a/p2p/netutil/net_test.go b/p2p/netutil/net_test.go new file mode 100644 index 000000000..1ee1fcb4d --- /dev/null +++ b/p2p/netutil/net_test.go @@ -0,0 +1,173 @@ +// Copyright 2016 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package netutil + +import ( + "net" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +func TestParseNetlist(t *testing.T) { + var tests = []struct { + input string + wantErr error + wantList *Netlist + }{ + { + input: "", + wantList: &Netlist{}, + }, + { + input: "127.0.0.0/8", + wantErr: nil, + wantList: &Netlist{{IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}}, + }, + { + input: "127.0.0.0/44", + wantErr: &net.ParseError{Type: "CIDR address", Text: "127.0.0.0/44"}, + }, + { + input: "127.0.0.0/16, 23.23.23.23/24,", + wantList: &Netlist{ + {IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(16, 32)}, + {IP: net.IP{23, 23, 23, 0}, Mask: net.CIDRMask(24, 32)}, + }, + }, + } + + for _, test := range tests { + l, err := ParseNetlist(test.input) + if !reflect.DeepEqual(err, test.wantErr) { + t.Errorf("%q: got error %q, want %q", test.input, err, test.wantErr) + continue + } + if !reflect.DeepEqual(l, test.wantList) { + spew.Dump(l) + spew.Dump(test.wantList) + t.Errorf("%q: got %v, want %v", test.input, l, test.wantList) + } + } +} + +func TestNilNetListContains(t *testing.T) { + var list *Netlist + checkContains(t, list.Contains, nil, []string{"1.2.3.4"}) +} + +func TestIsLAN(t *testing.T) { + checkContains(t, IsLAN, + []string{ // included + "0.0.0.0", + "0.2.0.8", + "127.0.0.1", + "10.0.1.1", + "10.22.0.3", + "172.31.252.251", + "192.168.1.4", + "fe80::f4a1:8eff:fec5:9d9d", + "febf::ab32:2233", + "fc00::4", + }, + []string{ // excluded + "192.0.2.1", + "1.0.0.0", + "172.32.0.1", + "fec0::2233", + }, + ) +} + +func TestIsSpecialNetwork(t *testing.T) { + checkContains(t, IsSpecialNetwork, + []string{ // included + "192.0.2.1", + "192.0.2.44", + "2001:db8:85a3:8d3:1319:8a2e:370:7348", + "255.255.255.255", + "224.0.0.22", // IPv4 multicast + "ff05::1:3", // IPv6 multicast + }, + []string{ // excluded + "192.0.3.1", + "1.0.0.0", + "172.32.0.1", + "fec0::2233", + }, + ) +} + +func checkContains(t *testing.T, fn func(net.IP) bool, inc, exc []string) { + for _, s := range inc { + if !fn(parseIP(s)) { + t.Error("returned false for included address", s) + } + } + for _, s := range exc { + if fn(parseIP(s)) { + t.Error("returned true for excluded address", s) + } + } +} + +func parseIP(s string) net.IP { + ip := net.ParseIP(s) + if ip == nil { + panic("invalid " + s) + } + return ip +} + +func TestCheckRelayIP(t *testing.T) { + tests := []struct { + sender, addr string + want error + }{ + {"127.0.0.1", "0.0.0.0", errUnspecified}, + {"192.168.0.1", "0.0.0.0", errUnspecified}, + {"23.55.1.242", "0.0.0.0", errUnspecified}, + {"127.0.0.1", "255.255.255.255", errSpecial}, + {"192.168.0.1", "255.255.255.255", errSpecial}, + {"23.55.1.242", "255.255.255.255", errSpecial}, + {"192.168.0.1", "127.0.2.19", errLoopback}, + {"23.55.1.242", "192.168.0.1", errLAN}, + + {"127.0.0.1", "127.0.2.19", nil}, + {"127.0.0.1", "192.168.0.1", nil}, + {"127.0.0.1", "23.55.1.242", nil}, + {"192.168.0.1", "192.168.0.1", nil}, + {"192.168.0.1", "23.55.1.242", nil}, + {"23.55.1.242", "23.55.1.242", nil}, + } + + for _, test := range tests { + err := CheckRelayIP(parseIP(test.sender), parseIP(test.addr)) + if err != test.want { + t.Errorf("%s from %s: got %q, want %q", test.addr, test.sender, err, test.want) + } + } +} + +func BenchmarkCheckRelayIP(b *testing.B) { + sender := parseIP("23.55.1.242") + addr := parseIP("23.55.1.2") + for i := 0; i < b.N; i++ { + CheckRelayIP(sender, addr) + } +} diff --git a/p2p/discover/udp_notwindows.go b/p2p/netutil/toobig_notwindows.go similarity index 91% rename from p2p/discover/udp_notwindows.go rename to p2p/netutil/toobig_notwindows.go index e9de83aa9..47b643857 100644 --- a/p2p/discover/udp_notwindows.go +++ b/p2p/netutil/toobig_notwindows.go @@ -16,9 +16,9 @@ //+build !windows -package discover +package netutil -// reports whether err indicates that a UDP packet didn't +// isPacketTooBig reports whether err indicates that a UDP packet didn't // fit the receive buffer. There is no such error on // non-Windows platforms. func isPacketTooBig(err error) bool { diff --git a/p2p/discover/udp_windows.go b/p2p/netutil/toobig_windows.go similarity index 93% rename from p2p/discover/udp_windows.go rename to p2p/netutil/toobig_windows.go index 66bbf9597..dfbb6d44f 100644 --- a/p2p/discover/udp_windows.go +++ b/p2p/netutil/toobig_windows.go @@ -16,7 +16,7 @@ //+build windows -package discover +package netutil import ( "net" @@ -26,7 +26,7 @@ import ( const _WSAEMSGSIZE = syscall.Errno(10040) -// reports whether err indicates that a UDP packet didn't +// isPacketTooBig reports whether err indicates that a UDP packet didn't // fit the receive buffer. On Windows, WSARecvFrom returns // code WSAEMSGSIZE and no data if this happens. func isPacketTooBig(err error) bool { diff --git a/p2p/server.go b/p2p/server.go index 7381127dc..cf9672e2d 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -30,6 +30,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/discv5" "github.com/ethereum/go-ethereum/p2p/nat" + "github.com/ethereum/go-ethereum/p2p/netutil" ) const ( @@ -101,6 +102,11 @@ type Config struct { // allowed to connect, even above the peer limit. TrustedNodes []*discover.Node + // Connectivity can be restricted to certain IP networks. + // If this option is set to a non-nil value, only hosts which match one of the + // IP networks contained in the list are considered. + NetRestrict *netutil.Netlist + // NodeDatabase is the path to the database containing the previously seen // live nodes in the network. NodeDatabase string @@ -356,7 +362,7 @@ func (srv *Server) Start() (err error) { // node table if srv.Discovery { - ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase) + ntab, err := discover.ListenUDP(srv.PrivateKey, srv.ListenAddr, srv.NAT, srv.NodeDatabase, srv.NetRestrict) if err != nil { return err } @@ -367,7 +373,7 @@ func (srv *Server) Start() (err error) { } if srv.DiscoveryV5 { - ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "") //srv.NodeDatabase) + ntab, err := discv5.ListenUDP(srv.PrivateKey, srv.DiscoveryV5Addr, srv.NAT, "", srv.NetRestrict) //srv.NodeDatabase) if err != nil { return err } @@ -381,7 +387,7 @@ func (srv *Server) Start() (err error) { if !srv.Discovery { dynPeers = 0 } - dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers) + dialer := newDialState(srv.StaticNodes, srv.ntab, dynPeers, srv.NetRestrict) // handshake srv.ourHandshake = &protoHandshake{Version: baseProtocolVersion, Name: srv.Name, ID: discover.PubkeyID(&srv.PrivateKey.PublicKey)} @@ -634,8 +640,19 @@ func (srv *Server) listenLoop() { } break } + + // Reject connections that do not match NetRestrict. + if srv.NetRestrict != nil { + if tcp, ok := fd.RemoteAddr().(*net.TCPAddr); ok && !srv.NetRestrict.Contains(tcp.IP) { + glog.V(logger.Debug).Infof("Rejected conn %v because it is not whitelisted in NetRestrict", fd.RemoteAddr()) + fd.Close() + slots <- struct{}{} + continue + } + } + fd = newMeteredConn(fd, true) - glog.V(logger.Debug).Infof("Accepted conn %v\n", fd.RemoteAddr()) + glog.V(logger.Debug).Infof("Accepted conn %v", fd.RemoteAddr()) // Spawn the handler. It will give the slot back when the connection // has been established. diff --git a/swarm/network/hive.go b/swarm/network/hive.go index f5ebdd008..f81761b97 100644 --- a/swarm/network/hive.go +++ b/swarm/network/hive.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/logger" "github.com/ethereum/go-ethereum/logger/glog" "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/p2p/netutil" "github.com/ethereum/go-ethereum/swarm/network/kademlia" "github.com/ethereum/go-ethereum/swarm/storage" ) @@ -288,6 +289,10 @@ func newNodeRecord(addr *peerAddr) *kademlia.NodeRecord { func (self *Hive) HandlePeersMsg(req *peersMsgData, from *peer) { var nrs []*kademlia.NodeRecord for _, p := range req.Peers { + if err := netutil.CheckRelayIP(from.remoteAddr.IP, p.IP); err != nil { + glog.V(logger.Detail).Infof("invalid peer IP %v from %v: %v", from.remoteAddr.IP, p.IP, err) + continue + } nrs = append(nrs, newNodeRecord(p)) } self.kad.Add(nrs)