From 3892a62f28ac54c33bad7780273654ab79434d7f Mon Sep 17 00:00:00 2001 From: StephenButtolph Date: Fri, 22 May 2020 01:45:56 -0400 Subject: [PATCH] Added more race tests --- network/network_test.go | 768 ++++++++++++++++++++++++++++++++++++++-- network/peer.go | 2 +- 2 files changed, 732 insertions(+), 38 deletions(-) diff --git a/network/network_test.go b/network/network_test.go index 0e07ff3..e7f118c 100644 --- a/network/network_test.go +++ b/network/network_test.go @@ -16,6 +16,7 @@ import ( "github.com/ava-labs/gecko/snow/networking/router" "github.com/ava-labs/gecko/snow/validators" "github.com/ava-labs/gecko/utils" + "github.com/ava-labs/gecko/utils/hashing" "github.com/ava-labs/gecko/utils/logging" "github.com/ava-labs/gecko/version" ) @@ -28,23 +29,22 @@ var ( type testListener struct { addr net.Addr inbound chan net.Conn + once sync.Once + closed chan struct{} } func (l *testListener) Accept() (net.Conn, error) { - if c, ok := <-l.inbound; ok { + select { + case c := <-l.inbound: return c, nil + case _, _ = <-l.closed: + return nil, errClosed } - return nil, errClosed } -func (l *testListener) Close() (err error) { - defer func() { - if r := recover(); r != nil { - err = errClosed - } - }() - close(l.inbound) - return +func (l *testListener) Close() error { + l.once.Do(func() { close(l.closed) }) + return nil } func (l *testListener) Addr() net.Addr { return l.addr } @@ -62,34 +62,47 @@ func (d *testDialer) Dial(ip utils.IPDesc) (net.Conn, error) { server := &testConn{ pendingReads: make(chan []byte, 1<<10), pendingWrites: make(chan []byte, 1<<10), + closed: make(chan struct{}), local: outbound.addr, remote: d.addr, } client := &testConn{ pendingReads: server.pendingWrites, pendingWrites: server.pendingReads, + closed: make(chan struct{}), local: d.addr, remote: outbound.addr, } - outbound.inbound <- server - return client, nil + + select { + case outbound.inbound <- server: + return client, nil + default: + return nil, errRefused + } } type testConn struct { partialRead []byte pendingReads chan []byte pendingWrites chan []byte + closed chan struct{} + once sync.Once local, remote net.Addr } func (c *testConn) Read(b []byte) (int, error) { for len(c.partialRead) == 0 { - read, ok := <-c.pendingReads - if !ok { + select { + case read, ok := <-c.pendingReads: + if !ok { + return 0, errClosed + } + c.partialRead = read + case _, _ = <-c.closed: return 0, errClosed } - c.partialRead = read } copy(b, c.partialRead) @@ -101,29 +114,22 @@ func (c *testConn) Read(b []byte) (int, error) { return len(b), nil } -func (c *testConn) Write(b []byte) (length int, err error) { - defer func() { - if r := recover(); r != nil { - err = errClosed - } - }() - +func (c *testConn) Write(b []byte) (int, error) { newB := make([]byte, len(b)) copy(newB, b) - c.pendingWrites <- newB - length = len(b) - return + + select { + case c.pendingWrites <- newB: + case _, _ = <-c.closed: + return 0, errClosed + } + + return len(b), nil } -func (c *testConn) Close() (err error) { - defer func() { - if r := recover(); r != nil { - err = errClosed - } - }() - close(c.pendingReads) - close(c.pendingWrites) - return +func (c *testConn) Close() error { + c.once.Do(func() { close(c.closed) }) + return nil } func (c *testConn) LocalAddr() net.Addr { return c.local } @@ -146,11 +152,11 @@ func (h *testHandler) Disconnected(id ids.ShortID) bool { func TestNewDefaultNetwork(t *testing.T) { log := logging.NoLog{} - id := ids.ShortEmpty ip := utils.IPDesc{ IP: net.IPv6loopback, Port: 0, } + id := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip.String()))) networkID := uint32(0) appVersion := version.NewDefaultVersion("app", 0, 1, 0) versionParser := version.NewDefaultParser() @@ -161,6 +167,7 @@ func TestNewDefaultNetwork(t *testing.T) { Port: 0, }, inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), } caller := &testDialer{ addr: &net.TCPAddr{ @@ -206,16 +213,16 @@ func TestEstablishConnection(t *testing.T) { appVersion := version.NewDefaultVersion("app", 0, 1, 0) versionParser := version.NewDefaultParser() - id0 := ids.NewShortID([20]byte{0}) ip0 := utils.IPDesc{ IP: net.IPv6loopback, Port: 0, } - id1 := ids.NewShortID([20]byte{1}) + id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String()))) ip1 := utils.IPDesc{ IP: net.IPv6loopback, Port: 1, } + id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String()))) listener0 := &testListener{ addr: &net.TCPAddr{ @@ -223,6 +230,7 @@ func TestEstablishConnection(t *testing.T) { Port: 0, }, inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), } caller0 := &testDialer{ addr: &net.TCPAddr{ @@ -237,6 +245,7 @@ func TestEstablishConnection(t *testing.T) { Port: 1, }, inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), } caller1 := &testDialer{ addr: &net.TCPAddr{ @@ -327,4 +336,689 @@ func TestEstablishConnection(t *testing.T) { wg0.Wait() wg1.Wait() + + err := net0.Close() + assert.NoError(t, err) + + err = net1.Close() + assert.NoError(t, err) +} + +func TestDoubleTrack(t *testing.T) { + log := logging.NoLog{} + networkID := uint32(0) + appVersion := version.NewDefaultVersion("app", 0, 1, 0) + versionParser := version.NewDefaultParser() + + ip0 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 0, + } + id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String()))) + ip1 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 1, + } + id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String()))) + + listener0 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller0 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + outbounds: make(map[string]*testListener), + } + listener1 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller1 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + outbounds: make(map[string]*testListener), + } + + caller0.outbounds[ip1.String()] = listener1 + caller1.outbounds[ip0.String()] = listener0 + + serverUpgrader := NewIPUpgrader() + clientUpgrader := NewIPUpgrader() + + vdrs := validators.NewSet() + handler := router.Router(nil) + + net0 := NewDefaultNetwork( + log, + id0, + ip0, + networkID, + appVersion, + versionParser, + listener0, + caller0, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net0) + + net1 := NewDefaultNetwork( + log, + id1, + ip1, + networkID, + appVersion, + versionParser, + listener1, + caller1, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net1) + + var ( + wg0 sync.WaitGroup + wg1 sync.WaitGroup + ) + wg0.Add(1) + wg1.Add(1) + + h0 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id0) { + wg0.Done() + } + return false + }, + } + h1 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id1) { + wg1.Done() + } + return false + }, + } + + net0.RegisterHandler(h0) + net1.RegisterHandler(h1) + + net0.Track(ip1) + net0.Track(ip1) + + go func() { + err := net0.Dispatch() + assert.Error(t, err) + }() + go func() { + err := net1.Dispatch() + assert.Error(t, err) + }() + + wg0.Wait() + wg1.Wait() + + err := net0.Close() + assert.NoError(t, err) + + err = net1.Close() + assert.NoError(t, err) +} + +func TestDoubleClose(t *testing.T) { + log := logging.NoLog{} + networkID := uint32(0) + appVersion := version.NewDefaultVersion("app", 0, 1, 0) + versionParser := version.NewDefaultParser() + + ip0 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 0, + } + id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String()))) + ip1 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 1, + } + id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String()))) + + listener0 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller0 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + outbounds: make(map[string]*testListener), + } + listener1 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller1 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + outbounds: make(map[string]*testListener), + } + + caller0.outbounds[ip1.String()] = listener1 + caller1.outbounds[ip0.String()] = listener0 + + serverUpgrader := NewIPUpgrader() + clientUpgrader := NewIPUpgrader() + + vdrs := validators.NewSet() + handler := router.Router(nil) + + net0 := NewDefaultNetwork( + log, + id0, + ip0, + networkID, + appVersion, + versionParser, + listener0, + caller0, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net0) + + net1 := NewDefaultNetwork( + log, + id1, + ip1, + networkID, + appVersion, + versionParser, + listener1, + caller1, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net1) + + var ( + wg0 sync.WaitGroup + wg1 sync.WaitGroup + ) + wg0.Add(1) + wg1.Add(1) + + h0 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id0) { + wg0.Done() + } + return false + }, + } + h1 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id1) { + wg1.Done() + } + return false + }, + } + + net0.RegisterHandler(h0) + net1.RegisterHandler(h1) + + net0.Track(ip1) + + go func() { + err := net0.Dispatch() + assert.Error(t, err) + }() + go func() { + err := net1.Dispatch() + assert.Error(t, err) + }() + + wg0.Wait() + wg1.Wait() + + err := net0.Close() + assert.NoError(t, err) + + err = net1.Close() + assert.NoError(t, err) + + err = net0.Close() + assert.NoError(t, err) + + err = net1.Close() + assert.NoError(t, err) +} + +func TestRemoveHandlers(t *testing.T) { + log := logging.NoLog{} + networkID := uint32(0) + appVersion := version.NewDefaultVersion("app", 0, 1, 0) + versionParser := version.NewDefaultParser() + + ip0 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 0, + } + id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String()))) + ip1 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 1, + } + id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String()))) + + listener0 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller0 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + outbounds: make(map[string]*testListener), + } + listener1 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller1 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + outbounds: make(map[string]*testListener), + } + + caller0.outbounds[ip1.String()] = listener1 + caller1.outbounds[ip0.String()] = listener0 + + serverUpgrader := NewIPUpgrader() + clientUpgrader := NewIPUpgrader() + + vdrs := validators.NewSet() + handler := router.Router(nil) + + net0 := NewDefaultNetwork( + log, + id0, + ip0, + networkID, + appVersion, + versionParser, + listener0, + caller0, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net0) + + net1 := NewDefaultNetwork( + log, + id1, + ip1, + networkID, + appVersion, + versionParser, + listener1, + caller1, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net1) + + var ( + wg0 sync.WaitGroup + wg1 sync.WaitGroup + ) + wg0.Add(1) + wg1.Add(1) + + h0 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id0) { + wg0.Done() + } + return false + }, + } + h1 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id1) { + wg1.Done() + } + return false + }, + } + + net0.RegisterHandler(h0) + net1.RegisterHandler(h1) + + net0.Track(ip1) + + go func() { + err := net0.Dispatch() + assert.Error(t, err) + }() + go func() { + err := net1.Dispatch() + assert.Error(t, err) + }() + + wg0.Wait() + wg1.Wait() + + h3 := &testHandler{ + connected: func(id ids.ShortID) bool { + assert.Equal(t, id0, id) + return true + }, + } + h4 := &testHandler{ + connected: func(id ids.ShortID) bool { + return id.Equals(id0) + }, + } + + net0.RegisterHandler(h3) + net1.RegisterHandler(h4) + + err := net0.Close() + assert.NoError(t, err) + + err = net1.Close() + assert.NoError(t, err) +} + +func TestTrackConnected(t *testing.T) { + log := logging.NoLog{} + networkID := uint32(0) + appVersion := version.NewDefaultVersion("app", 0, 1, 0) + versionParser := version.NewDefaultParser() + + ip0 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 0, + } + id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String()))) + ip1 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 1, + } + id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String()))) + + listener0 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller0 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + outbounds: make(map[string]*testListener), + } + listener1 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller1 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + outbounds: make(map[string]*testListener), + } + + caller0.outbounds[ip1.String()] = listener1 + caller1.outbounds[ip0.String()] = listener0 + + serverUpgrader := NewIPUpgrader() + clientUpgrader := NewIPUpgrader() + + vdrs := validators.NewSet() + handler := router.Router(nil) + + net0 := NewDefaultNetwork( + log, + id0, + ip0, + networkID, + appVersion, + versionParser, + listener0, + caller0, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net0) + + net1 := NewDefaultNetwork( + log, + id1, + ip1, + networkID, + appVersion, + versionParser, + listener1, + caller1, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net1) + + var ( + wg0 sync.WaitGroup + wg1 sync.WaitGroup + ) + wg0.Add(1) + wg1.Add(1) + + h0 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id0) { + wg0.Done() + } + return false + }, + } + h1 := &testHandler{ + connected: func(id ids.ShortID) bool { + if !id.Equals(id1) { + wg1.Done() + } + return false + }, + } + + net0.RegisterHandler(h0) + net1.RegisterHandler(h1) + + net0.Track(ip1) + + go func() { + err := net0.Dispatch() + assert.Error(t, err) + }() + go func() { + err := net1.Dispatch() + assert.Error(t, err) + }() + + wg0.Wait() + wg1.Wait() + + net0.Track(ip1) + + err := net0.Close() + assert.NoError(t, err) + + err = net1.Close() + assert.NoError(t, err) +} + +func TestTrackConnectedRace(t *testing.T) { + log := logging.NoLog{} + networkID := uint32(0) + appVersion := version.NewDefaultVersion("app", 0, 1, 0) + versionParser := version.NewDefaultParser() + + ip0 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 0, + } + id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String()))) + ip1 := utils.IPDesc{ + IP: net.IPv6loopback, + Port: 1, + } + id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String()))) + + listener0 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller0 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 0, + }, + outbounds: make(map[string]*testListener), + } + listener1 := &testListener{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + inbound: make(chan net.Conn, 1<<10), + closed: make(chan struct{}), + } + caller1 := &testDialer{ + addr: &net.TCPAddr{ + IP: net.IPv6loopback, + Port: 1, + }, + outbounds: make(map[string]*testListener), + } + + caller0.outbounds[ip1.String()] = listener1 + caller1.outbounds[ip0.String()] = listener0 + + serverUpgrader := NewIPUpgrader() + clientUpgrader := NewIPUpgrader() + + vdrs := validators.NewSet() + handler := router.Router(nil) + + net0 := NewDefaultNetwork( + log, + id0, + ip0, + networkID, + appVersion, + versionParser, + listener0, + caller0, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net0) + + net1 := NewDefaultNetwork( + log, + id1, + ip1, + networkID, + appVersion, + versionParser, + listener1, + caller1, + serverUpgrader, + clientUpgrader, + vdrs, + handler, + ) + assert.NotNil(t, net1) + + net0.Track(ip1) + + go func() { + err := net0.Dispatch() + assert.Error(t, err) + }() + go func() { + err := net1.Dispatch() + assert.Error(t, err) + }() + + err := net0.Close() + assert.NoError(t, err) + + err = net1.Close() + assert.NoError(t, err) } diff --git a/network/peer.go b/network/peer.go index 6882209..8b5c39f 100644 --- a/network/peer.go +++ b/network/peer.go @@ -364,7 +364,7 @@ func (p *peer) peerList(msg Msg) { if !ip.Equal(p.net.ip) && !ip.IsZero() && (p.net.allowPrivateIPs || !ip.IsPrivate()) { - // TODO: this is a vulnerability, perhaps only try to connect once? + // TODO: only try to connect once p.net.Track(ip) } }