diff --git a/brontide/noise.go b/brontide/noise.go index fdf91924..5a314404 100644 --- a/brontide/noise.go +++ b/brontide/noise.go @@ -4,7 +4,9 @@ import ( "crypto/cipher" "crypto/sha256" "encoding/binary" + "errors" "io" + "math" "golang.org/x/crypto/hkdf" @@ -19,6 +21,18 @@ const ( // exact same string for this value, along with prologue of the Bitcoin // network, then the initial handshake will fail. protocolName = "Noise_XK_secp256k1_ChaChaPoly_SHA256" + + // macSize is the length in bytes of the tags generated by poly1305. + macSize = 16 + + // lengthHeaderSize is the number of bytes used to prefix encode the + // length of a message payload. + lengthHeaderSize = 2 +) + +var ( + ErrMaxMessageLengthExceeded = errors.New("the generated payload exceeds " + + "the max allowed message length of (2^16)-1") ) // TODO(roasbeef): free buffer pool? @@ -533,12 +547,18 @@ func (b *BrontideMachine) split() { // must be used as the AD to the AEAD construction when being decrypted by the // other side. func (b *BrontideMachine) WriteMessage(w io.Writer, p []byte) error { - // The full length of the packet includes the 16 byte MAC. - fullLength := uint64(len(p) + 16) + // The total length of each message payload including the MAC size + // payload exceed the largest number encodable within a 16-bit unsigned + // integer. + if len(p)+macSize > math.MaxUint16 { + return ErrMaxMessageLengthExceeded + } - // TODO(roasbeef): The Summit decided on 24 bits? - var pktLen [8]byte - binary.BigEndian.PutUint64(pktLen[:], fullLength) + // The full length of the packet includes the 16 byte MAC. + fullLength := uint16(len(p) + macSize) + + var pktLen [2]byte + binary.BigEndian.PutUint16(pktLen[:], fullLength) // First, write out the encrypted+MAC'd length prefix for the packet. cipherLen := b.sendCipher.Encrypt(nil, nil, pktLen[:]) @@ -560,7 +580,7 @@ func (b *BrontideMachine) WriteMessage(w io.Writer, p []byte) error { // ReadMessage attemps to read the next message from the passed io.Reader. In // the case of an authentication error, a non-nil error is returned. func (b *BrontideMachine) ReadMessage(r io.Reader) ([]byte, error) { - var cipherLen [8 + 16]byte + var cipherLen [lengthHeaderSize + macSize]byte if _, err := io.ReadFull(r, cipherLen[:]); err != nil { return nil, err } @@ -573,7 +593,7 @@ func (b *BrontideMachine) ReadMessage(r io.Reader) ([]byte, error) { // Next, using the length read from the packet header, read the // encrypted packet itself. - pktLen := binary.BigEndian.Uint64(pktLenBytes) + pktLen := binary.BigEndian.Uint16(pktLenBytes) ciperText := make([]byte, pktLen) if _, err := io.ReadFull(r, ciperText[:]); err != nil { return nil, err diff --git a/brontide/noise_test.go b/brontide/noise_test.go index ad1507ba..cf3f1ac2 100644 --- a/brontide/noise_test.go +++ b/brontide/noise_test.go @@ -2,6 +2,7 @@ package brontide import ( "bytes" + "math" "net" "testing" @@ -97,6 +98,44 @@ func TestConnectionCorrectness(t *testing.T) { } } +func TestMaxPayloadLength(t *testing.T) { + b := BrontideMachine{} + b.split() + + // Create a payload that's juust over the maximum alloted payload + // length. + payloadToReject := make([]byte, math.MaxUint16+1) + + var buf bytes.Buffer + + // A write of the payload generated above to the state machine should + // be rejected as it's over the max payload length. + err := b.WriteMessage(&buf, payloadToReject) + if err != ErrMaxMessageLengthExceeded { + t.Fatalf("payload is over the max allowed length, the write " + + "should have been rejected") + } + + // Generate another payload which with the MAC acounted for, should be + // accepted as a valid payload. + payloadToAccept := make([]byte, math.MaxUint16-macSize) + if err := b.WriteMessage(&buf, payloadToAccept); err != nil { + t.Fatalf("write for payload was rejected, should have been " + + "accepted") + } + + // Generate a final payload which is juuust over the max payload length + // when the MAC is accounted for. + payloadToReject = make([]byte, math.MaxUint16-macSize+1) + + // This payload should be rejected. + err = b.WriteMessage(&buf, payloadToReject) + if err != ErrMaxMessageLengthExceeded { + t.Fatalf("payload is over the max allowed length, the write " + + "should have been rejected") + } +} + func TestNoiseIdentityHiding(t *testing.T) { // TODO(roasbeef): fin }