diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index b55e292b..27392298 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -1,6 +1,10 @@ package lnwire -import "io" +import ( + "io" + + "github.com/roasbeef/btcd/btcec" +) // ChannelReestablish is a message sent between peers that have an existing // open channel upon connection reestablishment. This message allows both sides @@ -43,6 +47,19 @@ type ChannelReestablish struct { // the message sent a revocation for a prior state, but the sender of // the message never fully processed it. RemoteCommitTailHeight uint64 + + // LastRemoteCommitSecret is the last commitment secret that the + // receiving node has sent to the sending party. This will be the + // secret of the last revoked commitment transaction. Including this + // provides proof that the sending node at least knows of this state, + // as they couldn't have produced it if it wasn't sent, as the value + // can be authenticated by querying the shachain or the receiving + // party. + LastRemoteCommitSecret [32]byte + + // LocalUnrevokedCommitPoint is the commitment point used in the + // current un-revoked commitment transaction of the sending party. + LocalUnrevokedCommitPoint *btcec.PublicKey } // A compile time check to ensure ChannelReestablish implements the @@ -54,11 +71,24 @@ var _ Message = (*ChannelReestablish)(nil) // // This is part of the lnwire.Message interface. func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { - return writeElements(w, + err := writeElements(w, a.ChanID, a.NextLocalCommitHeight, a.RemoteCommitTailHeight, ) + if err != nil { + return err + } + + // If the commit point wasn't sent, then we won't write out any of the + // remaining fields as they're optional. + if a.LocalUnrevokedCommitPoint == nil { + return nil + } + + // Otherwise, we'll write out the remaining elements. + return writeElements(w, a.LastRemoteCommitSecret[:], + a.LocalUnrevokedCommitPoint) } // Decode deserializes a serialized ChannelReestablish stored in the passed @@ -66,11 +96,40 @@ func (a *ChannelReestablish) Encode(w io.Writer, pver uint32) error { // // This is part of the lnwire.Message interface. func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { - return readElements(r, + err := readElements(r, &a.ChanID, &a.NextLocalCommitHeight, &a.RemoteCommitTailHeight, ) + if err != nil { + return err + } + + // This message has to currently defined optional fields. As a result, + // we'll only proceed if there's still bytes remaining within the + // reader. + // + // We'll manually parse out the optional fields in order to be able to + // still utilize the io.Reader interface. + + // We'll first attempt to read the optional commit secret, if we're at + // the EOF, then this means the field wasn't included so we can exit + // early. + var buf [32]byte + _, err = io.ReadFull(r, buf[:32]) + if err == io.EOF { + return nil + } else if err != nil { + return err + } + + // If the field is present, then we'll copy it over and proceed. + copy(a.LastRemoteCommitSecret[:], buf[:]) + + // We'll conclude by parsing out the commitment point. We don't check + // the error in this case, as it hey included the commit secret, then + // they MUST also include the commit point. + return readElement(r, &a.LocalUnrevokedCommitPoint) } // MsgType returns the integer uniquely identifying this message type on the @@ -97,5 +156,11 @@ func (a *ChannelReestablish) MaxPayloadLength(pver uint32) uint32 { // RemoteCommitTailHeight - 8 bytes length += 8 + // LastRemoteCommitSecret - 32 bytes + length += 32 + + // LocalUnrevokedCommitPoint - 33 bytes + length += 33 + return length } diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index d6b7528a..8c033f5a 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -451,6 +451,31 @@ func TestLightningWireProtocol(t *testing.T) { return } + v[0] = reflect.ValueOf(req) + }, + MsgChannelReestablish: func(v []reflect.Value, r *rand.Rand) { + req := ChannelReestablish{ + NextLocalCommitHeight: uint64(r.Int63()), + RemoteCommitTailHeight: uint64(r.Int63()), + } + + // With a 50/50 probability, we'll include the + // additional fields so we can test our ability to + // properly parse, and write out the optional fields. + if r.Int()%2 == 0 { + _, err := r.Read(req.LastRemoteCommitSecret[:]) + if err != nil { + t.Fatalf("unable to read commit secret: %v", err) + return + } + + req.LocalUnrevokedCommitPoint, err = randPubKey() + if err != nil { + t.Fatalf("unable to generate key: %v", err) + return + } + } + v[0] = reflect.ValueOf(req) }, }