gecko/network/network_test.go

1053 lines
19 KiB
Go

// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package network
import (
"errors"
"net"
"sync"
"testing"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"github.com/ava-labs/gecko/ids"
"github.com/ava-labs/gecko/snow/networking/router"
"github.com/ava-labs/gecko/snow/validators"
"github.com/ava-labs/gecko/utils"
"github.com/ava-labs/gecko/utils/hashing"
"github.com/ava-labs/gecko/utils/logging"
"github.com/ava-labs/gecko/version"
)
var (
errClosed = errors.New("closed")
errRefused = errors.New("connection refused")
)
type testListener struct {
addr net.Addr
inbound chan net.Conn
once sync.Once
closed chan struct{}
}
func (l *testListener) Accept() (net.Conn, error) {
select {
case c := <-l.inbound:
return c, nil
case _, _ = <-l.closed:
return nil, errClosed
}
}
func (l *testListener) Close() error {
l.once.Do(func() { close(l.closed) })
return nil
}
func (l *testListener) Addr() net.Addr { return l.addr }
type testDialer struct {
addr net.Addr
outbounds map[string]*testListener
}
func (d *testDialer) Dial(ip utils.IPDesc) (net.Conn, error) {
outbound, ok := d.outbounds[ip.String()]
if !ok {
return nil, errRefused
}
server := &testConn{
pendingReads: make(chan []byte, 1<<10),
pendingWrites: make(chan []byte, 1<<10),
closed: make(chan struct{}),
local: outbound.addr,
remote: d.addr,
}
client := &testConn{
pendingReads: server.pendingWrites,
pendingWrites: server.pendingReads,
closed: make(chan struct{}),
local: d.addr,
remote: outbound.addr,
}
select {
case outbound.inbound <- server:
return client, nil
default:
return nil, errRefused
}
}
type testConn struct {
partialRead []byte
pendingReads chan []byte
pendingWrites chan []byte
closed chan struct{}
once sync.Once
local, remote net.Addr
}
func (c *testConn) Read(b []byte) (int, error) {
for len(c.partialRead) == 0 {
select {
case read, ok := <-c.pendingReads:
if !ok {
return 0, errClosed
}
c.partialRead = read
case _, _ = <-c.closed:
return 0, errClosed
}
}
copy(b, c.partialRead)
if length := len(c.partialRead); len(b) > length {
c.partialRead = nil
return length, nil
}
c.partialRead = c.partialRead[len(b):]
return len(b), nil
}
func (c *testConn) Write(b []byte) (int, error) {
newB := make([]byte, len(b))
copy(newB, b)
select {
case c.pendingWrites <- newB:
case _, _ = <-c.closed:
return 0, errClosed
}
return len(b), nil
}
func (c *testConn) Close() error {
c.once.Do(func() { close(c.closed) })
return nil
}
func (c *testConn) LocalAddr() net.Addr { return c.local }
func (c *testConn) RemoteAddr() net.Addr { return c.remote }
func (c *testConn) SetDeadline(time.Time) error { return nil }
func (c *testConn) SetReadDeadline(time.Time) error { return nil }
func (c *testConn) SetWriteDeadline(time.Time) error { return nil }
type testHandler struct {
connected func(ids.ShortID) bool
disconnected func(ids.ShortID) bool
}
func (h *testHandler) Connected(id ids.ShortID) bool {
return h.connected != nil && h.connected(id)
}
func (h *testHandler) Disconnected(id ids.ShortID) bool {
return h.disconnected != nil && h.disconnected(id)
}
func TestNewDefaultNetwork(t *testing.T) {
log := logging.NoLog{}
ip := utils.IPDesc{
IP: net.IPv6loopback,
Port: 0,
}
id := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip.String())))
networkID := uint32(0)
appVersion := version.NewDefaultVersion("app", 0, 1, 0)
versionParser := version.NewDefaultParser()
listener := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
outbounds: make(map[string]*testListener),
}
serverUpgrader := NewIPUpgrader()
clientUpgrader := NewIPUpgrader()
vdrs := validators.NewSet()
handler := router.Router(nil)
net := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id,
ip,
networkID,
appVersion,
versionParser,
listener,
caller,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net)
go func() {
err := net.Close()
assert.NoError(t, err)
}()
err := net.Dispatch()
assert.Error(t, err)
}
func TestEstablishConnection(t *testing.T) {
log := logging.NoLog{}
networkID := uint32(0)
appVersion := version.NewDefaultVersion("app", 0, 1, 0)
versionParser := version.NewDefaultParser()
ip0 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 0,
}
id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String())))
ip1 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 1,
}
id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String())))
listener0 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller0 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
outbounds: make(map[string]*testListener),
}
listener1 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller1 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
outbounds: make(map[string]*testListener),
}
caller0.outbounds[ip1.String()] = listener1
caller1.outbounds[ip0.String()] = listener0
serverUpgrader := NewIPUpgrader()
clientUpgrader := NewIPUpgrader()
vdrs := validators.NewSet()
handler := router.Router(nil)
net0 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id0,
ip0,
networkID,
appVersion,
versionParser,
listener0,
caller0,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net0)
net1 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id1,
ip1,
networkID,
appVersion,
versionParser,
listener1,
caller1,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net1)
var (
wg0 sync.WaitGroup
wg1 sync.WaitGroup
)
wg0.Add(1)
wg1.Add(1)
h0 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id0) {
wg0.Done()
}
return false
},
}
h1 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id1) {
wg1.Done()
}
return false
},
}
net0.RegisterHandler(h0)
net1.RegisterHandler(h1)
net0.Track(ip1)
go func() {
err := net0.Dispatch()
assert.Error(t, err)
}()
go func() {
err := net1.Dispatch()
assert.Error(t, err)
}()
wg0.Wait()
wg1.Wait()
err := net0.Close()
assert.NoError(t, err)
err = net1.Close()
assert.NoError(t, err)
}
func TestDoubleTrack(t *testing.T) {
log := logging.NoLog{}
networkID := uint32(0)
appVersion := version.NewDefaultVersion("app", 0, 1, 0)
versionParser := version.NewDefaultParser()
ip0 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 0,
}
id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String())))
ip1 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 1,
}
id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String())))
listener0 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller0 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
outbounds: make(map[string]*testListener),
}
listener1 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller1 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
outbounds: make(map[string]*testListener),
}
caller0.outbounds[ip1.String()] = listener1
caller1.outbounds[ip0.String()] = listener0
serverUpgrader := NewIPUpgrader()
clientUpgrader := NewIPUpgrader()
vdrs := validators.NewSet()
handler := router.Router(nil)
net0 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id0,
ip0,
networkID,
appVersion,
versionParser,
listener0,
caller0,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net0)
net1 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id1,
ip1,
networkID,
appVersion,
versionParser,
listener1,
caller1,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net1)
var (
wg0 sync.WaitGroup
wg1 sync.WaitGroup
)
wg0.Add(1)
wg1.Add(1)
h0 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id0) {
wg0.Done()
}
return false
},
}
h1 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id1) {
wg1.Done()
}
return false
},
}
net0.RegisterHandler(h0)
net1.RegisterHandler(h1)
net0.Track(ip1)
net0.Track(ip1)
go func() {
err := net0.Dispatch()
assert.Error(t, err)
}()
go func() {
err := net1.Dispatch()
assert.Error(t, err)
}()
wg0.Wait()
wg1.Wait()
err := net0.Close()
assert.NoError(t, err)
err = net1.Close()
assert.NoError(t, err)
}
func TestDoubleClose(t *testing.T) {
log := logging.NoLog{}
networkID := uint32(0)
appVersion := version.NewDefaultVersion("app", 0, 1, 0)
versionParser := version.NewDefaultParser()
ip0 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 0,
}
id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String())))
ip1 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 1,
}
id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String())))
listener0 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller0 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
outbounds: make(map[string]*testListener),
}
listener1 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller1 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
outbounds: make(map[string]*testListener),
}
caller0.outbounds[ip1.String()] = listener1
caller1.outbounds[ip0.String()] = listener0
serverUpgrader := NewIPUpgrader()
clientUpgrader := NewIPUpgrader()
vdrs := validators.NewSet()
handler := router.Router(nil)
net0 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id0,
ip0,
networkID,
appVersion,
versionParser,
listener0,
caller0,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net0)
net1 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id1,
ip1,
networkID,
appVersion,
versionParser,
listener1,
caller1,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net1)
var (
wg0 sync.WaitGroup
wg1 sync.WaitGroup
)
wg0.Add(1)
wg1.Add(1)
h0 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id0) {
wg0.Done()
}
return false
},
}
h1 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id1) {
wg1.Done()
}
return false
},
}
net0.RegisterHandler(h0)
net1.RegisterHandler(h1)
net0.Track(ip1)
go func() {
err := net0.Dispatch()
assert.Error(t, err)
}()
go func() {
err := net1.Dispatch()
assert.Error(t, err)
}()
wg0.Wait()
wg1.Wait()
err := net0.Close()
assert.NoError(t, err)
err = net1.Close()
assert.NoError(t, err)
err = net0.Close()
assert.NoError(t, err)
err = net1.Close()
assert.NoError(t, err)
}
func TestRemoveHandlers(t *testing.T) {
log := logging.NoLog{}
networkID := uint32(0)
appVersion := version.NewDefaultVersion("app", 0, 1, 0)
versionParser := version.NewDefaultParser()
ip0 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 0,
}
id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String())))
ip1 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 1,
}
id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String())))
listener0 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller0 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
outbounds: make(map[string]*testListener),
}
listener1 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller1 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
outbounds: make(map[string]*testListener),
}
caller0.outbounds[ip1.String()] = listener1
caller1.outbounds[ip0.String()] = listener0
serverUpgrader := NewIPUpgrader()
clientUpgrader := NewIPUpgrader()
vdrs := validators.NewSet()
handler := router.Router(nil)
net0 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id0,
ip0,
networkID,
appVersion,
versionParser,
listener0,
caller0,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net0)
net1 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id1,
ip1,
networkID,
appVersion,
versionParser,
listener1,
caller1,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net1)
var (
wg0 sync.WaitGroup
wg1 sync.WaitGroup
)
wg0.Add(1)
wg1.Add(1)
h0 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id0) {
wg0.Done()
}
return false
},
}
h1 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id1) {
wg1.Done()
}
return false
},
}
net0.RegisterHandler(h0)
net1.RegisterHandler(h1)
net0.Track(ip1)
go func() {
err := net0.Dispatch()
assert.Error(t, err)
}()
go func() {
err := net1.Dispatch()
assert.Error(t, err)
}()
wg0.Wait()
wg1.Wait()
h3 := &testHandler{
connected: func(id ids.ShortID) bool {
assert.Equal(t, id0, id)
return true
},
}
h4 := &testHandler{
connected: func(id ids.ShortID) bool {
return id.Equals(id0)
},
}
net0.RegisterHandler(h3)
net1.RegisterHandler(h4)
err := net0.Close()
assert.NoError(t, err)
err = net1.Close()
assert.NoError(t, err)
}
func TestTrackConnected(t *testing.T) {
log := logging.NoLog{}
networkID := uint32(0)
appVersion := version.NewDefaultVersion("app", 0, 1, 0)
versionParser := version.NewDefaultParser()
ip0 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 0,
}
id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String())))
ip1 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 1,
}
id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String())))
listener0 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller0 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
outbounds: make(map[string]*testListener),
}
listener1 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller1 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
outbounds: make(map[string]*testListener),
}
caller0.outbounds[ip1.String()] = listener1
caller1.outbounds[ip0.String()] = listener0
serverUpgrader := NewIPUpgrader()
clientUpgrader := NewIPUpgrader()
vdrs := validators.NewSet()
handler := router.Router(nil)
net0 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id0,
ip0,
networkID,
appVersion,
versionParser,
listener0,
caller0,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net0)
net1 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id1,
ip1,
networkID,
appVersion,
versionParser,
listener1,
caller1,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net1)
var (
wg0 sync.WaitGroup
wg1 sync.WaitGroup
)
wg0.Add(1)
wg1.Add(1)
h0 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id0) {
wg0.Done()
}
return false
},
}
h1 := &testHandler{
connected: func(id ids.ShortID) bool {
if !id.Equals(id1) {
wg1.Done()
}
return false
},
}
net0.RegisterHandler(h0)
net1.RegisterHandler(h1)
net0.Track(ip1)
go func() {
err := net0.Dispatch()
assert.Error(t, err)
}()
go func() {
err := net1.Dispatch()
assert.Error(t, err)
}()
wg0.Wait()
wg1.Wait()
net0.Track(ip1)
err := net0.Close()
assert.NoError(t, err)
err = net1.Close()
assert.NoError(t, err)
}
func TestTrackConnectedRace(t *testing.T) {
log := logging.NoLog{}
networkID := uint32(0)
appVersion := version.NewDefaultVersion("app", 0, 1, 0)
versionParser := version.NewDefaultParser()
ip0 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 0,
}
id0 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip0.String())))
ip1 := utils.IPDesc{
IP: net.IPv6loopback,
Port: 1,
}
id1 := ids.NewShortID(hashing.ComputeHash160Array([]byte(ip1.String())))
listener0 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller0 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 0,
},
outbounds: make(map[string]*testListener),
}
listener1 := &testListener{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
inbound: make(chan net.Conn, 1<<10),
closed: make(chan struct{}),
}
caller1 := &testDialer{
addr: &net.TCPAddr{
IP: net.IPv6loopback,
Port: 1,
},
outbounds: make(map[string]*testListener),
}
caller0.outbounds[ip1.String()] = listener1
caller1.outbounds[ip0.String()] = listener0
serverUpgrader := NewIPUpgrader()
clientUpgrader := NewIPUpgrader()
vdrs := validators.NewSet()
handler := router.Router(nil)
net0 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id0,
ip0,
networkID,
appVersion,
versionParser,
listener0,
caller0,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net0)
net1 := NewDefaultNetwork(
prometheus.NewRegistry(),
log,
id1,
ip1,
networkID,
appVersion,
versionParser,
listener1,
caller1,
serverUpgrader,
clientUpgrader,
vdrs,
vdrs,
handler,
)
assert.NotNil(t, net1)
net0.Track(ip1)
go func() {
err := net0.Dispatch()
assert.Error(t, err)
}()
go func() {
err := net1.Dispatch()
assert.Error(t, err)
}()
err := net0.Close()
assert.NoError(t, err)
err = net1.Close()
assert.NoError(t, err)
}