diff --git a/brontide/noise_test.go b/brontide/noise_test.go index 87bc72b4..de2fa078 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -13,16 +13,16 @@ import ( "github.com/roasbeef/btcd/btcec" ) -func establishTestConnection() (net.Conn, net.Conn, error) { +func establishTestConnection() (net.Conn, net.Conn, func(), error) { // First, generate the long-term private keys both ends of the // connection within our test. localPriv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { - return nil, nil, err + return nil, nil, nil, err } remotePriv, err := btcec.NewPrivateKey(btcec.S256()) if err != nil { - return nil, nil, err + return nil, nil, nil, err } // Having a port of ":0" means a random port, and interface will be @@ -32,7 +32,7 @@ func establishTestConnection() (net.Conn, net.Conn, error) { // Our listener will be local, and the connection remote. listener, err := NewListener(localPriv, addr) if err != nil { - return nil, nil, err + return nil, nil, nil, err } defer listener.Close() @@ -65,18 +65,23 @@ func establishTestConnection() (net.Conn, net.Conn, error) { select { case err := <-conErrChan: if err != nil { - return nil, nil, err + return nil, nil, nil, err } case err := <-lisErrChan: if err != nil { - return nil, nil, err + return nil, nil, nil, err } } localConn := <-lisChan remoteConn := <-connChan - return localConn, remoteConn, nil + cleanUp := func() { + localConn.Close() + remoteConn.Close() + } + + return localConn, remoteConn, cleanUp, nil } func TestConnectionCorrectness(t *testing.T) { @@ -85,10 +90,11 @@ func TestConnectionCorrectness(t *testing.T) { // Create a test connection, grabbing either side of the connection // into local variables. If the initial crypto handshake fails, then // we'll get a non-nil error here. - localConn, remoteConn, err := establishTestConnection() + localConn, remoteConn, cleanUp, err := establishTestConnection() if err != nil { t.Fatalf("unable to establish test connection: %v", err) } + defer cleanUp() // Test out some message full-message reads. for i := 0; i < 10; i++ { @@ -176,10 +182,11 @@ func TestWriteMessageChunking(t *testing.T) { // Create a test connection, grabbing either side of the connection // into local variables. If the initial crypto handshake fails, then // we'll get a non-nil error here. - localConn, remoteConn, err := establishTestConnection() + localConn, remoteConn, cleanUp, err := establishTestConnection() if err != nil { t.Fatalf("unable to establish test connection: %v", err) } + defer cleanUp() // Attempt to write a message which is over 3x the max allowed payload // size.