diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 121c50b3..c6412bfd 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -2,12 +2,13 @@ package lnwire import ( "crypto/sha256" + "encoding/binary" + "fmt" "io" "bytes" "github.com/go-errors/errors" - "github.com/roasbeef/btcutil" ) // FailureMessage represents the onion failure object identified by its unique @@ -18,7 +19,7 @@ type FailureMessage interface { // failureMessageLength is the size of the failure message plus the size of // padding. The FailureMessage message should always be EXACLTY this size. -const failureMessageLength = 128 +const failureMessageLength = 256 const ( // FlagBadOnion error flag describes an unparseable, encrypted by @@ -826,38 +827,43 @@ func (f *FailFinalIncorrectHtlcAmount) Encode(w io.Writer, pver uint32) error { // DecodeFailure decodes, validates, and parses the lnwire onion failure, for // the provided protocol version. func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) { - // Start processing the failure message by reading the code. - var code uint16 - if err := readElement(r, &code); err != nil { - return nil, err - } - - // Create the empty failure by given code and populate the failure with - // additional data if needed. - failure, err := makeEmptyOnionError(FailCode(code)) - if err != nil { - return nil, err - } - - // Read the failure length, check its size and read the failure message - // in order to check padding afterwards. + // First, we'll parse out the encapsulated failure message itself. This + // is a 2 byte length followed by the payload itself. var failureLength uint16 if err := readElement(r, &failureLength); err != nil { return nil, err } if failureLength > failureMessageLength { - return nil, errors.New("failure message is too long") + return nil, errors.New(fmt.Sprintf("failure message is too "+ + "long: %v", failureLength)) } - failureData := make([]byte, failureLength) if _, err := io.ReadFull(r, failureData); err != nil { return nil, err } - failureReader := bytes.NewReader(failureData) + dataReader := bytes.NewReader(failureData) + + // Once we have the failure data, we can obtain the failure code from + // the first two bytes of the buffer. + var codeBytes [2]byte + if _, err := io.ReadFull(dataReader, codeBytes[:]); err != nil { + return nil, err + } + failCode := FailCode(binary.BigEndian.Uint16(codeBytes[:])) + + // Create the empty failure by given code and populate the failure with + // additional data if needed. + failure, err := makeEmptyOnionError(failCode) + if err != nil { + return nil, err + } + + // Finally, if this failure has a payload, then we'll read that now as + // well. switch f := failure.(type) { case Serializable: - if err := f.Decode(failureReader, pver); err != nil { + if err := f.Decode(dataReader, pver); err != nil { return nil, err } } @@ -870,6 +876,19 @@ func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) { func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error { var failureMessageBuffer bytes.Buffer + // First, we'll write out the error code itself into the failure + // buffer. + var codeBytes [2]byte + code := uint16(failure.Code()) + binary.BigEndian.PutUint16(codeBytes[:], code) + _, err := failureMessageBuffer.Write(codeBytes[:]) + if err != nil { + return err + } + + // Next, some message have an additional message payload, if this is + // one of those types, then we'll also encode the error payload as + // well. switch failure := failure.(type) { case Serializable: if err := failure.Encode(&failureMessageBuffer, pver); err != nil { @@ -877,16 +896,19 @@ func EncodeFailure(w io.Writer, failure FailureMessage, pver uint32) error { } } + // The combined size of this message must be below the max allowed + // failure message length. failureMessage := failureMessageBuffer.Bytes() if len(failureMessage) > failureMessageLength { - return errors.New("failure message exceed max available size") + return errors.New(fmt.Sprintf("failure message exceed max "+ + "available size: %v", len(failureMessage))) } - code := uint16(failure.Code()) + // Finally, we'll add some padding in order to ensure that all failure + // messages are fixed size. pad := make([]byte, failureMessageLength-len(failureMessage)) return writeElements(w, - code, uint16(len(failureMessage)), failureMessage, uint16(len(pad)), diff --git a/lnwire/update_fail_htlc.go b/lnwire/update_fail_htlc.go index 9b70c881..562199b5 100644 --- a/lnwire/update_fail_htlc.go +++ b/lnwire/update_fail_htlc.go @@ -12,8 +12,8 @@ type OpaqueReason []byte // next commitment transaction, with the UpdateFailHTLC propagated backwards in // the route to fully undo the HTLC. type UpdateFailHTLC struct { - // ChanIDPoint is the particular active channel that this UpdateFailHTLC - // is bound to. + // ChanIDPoint is the particular active channel that this + // UpdateFailHTLC is bound to. ChanID ChannelID // ID references which HTLC on the remote node's commitment transaction @@ -79,7 +79,7 @@ func (c *UpdateFailHTLC) MaxPayloadLength(uint32) uint32 { length += 2 // Length of the Reason - length += 166 + length += 292 return length }