diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 998aa50a..622de36c 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -366,7 +366,7 @@ func TestChannelStateTransition(t *testing.T) { // Add some HTLCs which were added during this new state transition. // Half of the HTLCs are incoming, while the other half are outgoing. var ( - htlcs []*HTLC + htlcs []HTLC htlcAmt lnwire.MilliSatoshi ) for i := uint32(0); i < 10; i++ { @@ -374,36 +374,48 @@ func TestChannelStateTransition(t *testing.T) { if i > 5 { incoming = true } - htlc := &HTLC{ + htlc := HTLC{ Signature: testSig.Serialize(), Incoming: incoming, Amt: 10, RHash: key, RefundTimeout: i, OutputIndex: int32(i * 3), + LogIndex: uint64(i * 2), + HtlcIndex: uint64(i), } + htlc.OnionBlob = make([]byte, 10) + copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10)) htlcs = append(htlcs, htlc) htlcAmt += htlc.Amt } - // TODO(roasbeef): ensure that expiry matches - // Create a new channel delta which includes the above HTLCs, some // balance updates, and an increment of the current commitment height. // Additionally, modify the signature and commitment transaction. newSequence := uint32(129498) newSig := bytes.Repeat([]byte{3}, 71) - newTx := channel.CommitTx.Copy() + newTx := channel.LocalCommitment.CommitTx.Copy() newTx.TxIn[0].Sequence = newSequence - delta := &ChannelDelta{ - LocalBalance: lnwire.MilliSatoshi(1e8), - RemoteBalance: lnwire.MilliSatoshi(1e8), - Htlcs: htlcs, - UpdateNum: 1, + commitment := ChannelCommitment{ + CommitHeight: 1, + LocalLogIndex: 2, + LocalHtlcIndex: 1, + RemoteLogIndex: 2, + RemoteHtlcIndex: 1, + LocalBalance: lnwire.MilliSatoshi(1e8), + RemoteBalance: lnwire.MilliSatoshi(1e8), + CommitFee: 55, + FeePerKw: 99, + CommitTx: newTx, + CommitSig: newSig, + Htlcs: htlcs, } - // First update the local node's broadcastable state. - if err := channel.UpdateCommitment(newTx, newSig, delta); err != nil { + // First update the local node's broadcastable state and also add a + // CommitDiff remote node's as well in order to simulate a proper state + // transition. + if err := channel.UpdateCommitment(&commitment); err != nil { t.Fatalf("unable to update commitment: %v", err) } @@ -414,44 +426,81 @@ func TestChannelStateTransition(t *testing.T) { if err != nil { t.Fatalf("unable to fetch updated channel: %v", err) } - if !bytes.Equal(updatedChannel[0].CommitSig, newSig) { - t.Fatalf("sigs don't match %x vs %x", - updatedChannel[0].CommitSig, newSig) - } - if updatedChannel[0].CommitTx.TxIn[0].Sequence != newSequence { - t.Fatalf("sequence numbers don't match: %v vs %v", - updatedChannel[0].CommitTx.TxIn[0].Sequence, newSequence) - } - if updatedChannel[0].LocalBalance != delta.LocalBalance { - t.Fatalf("local balances don't match: %v vs %v", - updatedChannel[0].LocalBalance, delta.LocalBalance) - } - if updatedChannel[0].RemoteBalance != delta.RemoteBalance { - t.Fatalf("remote balances don't match: %v vs %v", - updatedChannel[0].RemoteBalance, delta.RemoteBalance) - } - if updatedChannel[0].NumUpdates != uint64(delta.UpdateNum) { - t.Fatalf("update # doesn't match: %v vs %v", - updatedChannel[0].NumUpdates, delta.UpdateNum) - } + assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) numDiskUpdates, err := updatedChannel[0].CommitmentHeight() if err != nil { t.Fatalf("unable to read commitment height from disk: %v", err) } - if numDiskUpdates != uint64(delta.UpdateNum) { + if numDiskUpdates != uint64(commitment.CommitHeight) { t.Fatalf("num disk updates doesn't match: %v vs %v", - numDiskUpdates, delta.UpdateNum) + numDiskUpdates, commitment.CommitHeight) } - for i := 0; i < len(updatedChannel[0].Htlcs); i++ { - originalHTLC := updatedChannel[0].Htlcs[i] - diskHTLC := channel.Htlcs[i] - if !reflect.DeepEqual(originalHTLC, diskHTLC) { - t.Fatalf("htlc's dont match: %v vs %v", - spew.Sdump(originalHTLC), - spew.Sdump(diskHTLC)) - } + + // Attempting to query for a commitment diff should return + // ErrNoPendingCommit as we haven't yet created a new state for them. + _, err = channel.RemoteCommitChainTip() + if err != ErrNoPendingCommit { + t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) } + // To simulate us extending a new state to the remote party, we'll also + // create a new commit diff for them. + remoteCommit := commitment + remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8) + remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8) + remoteCommit.CommitHeight = 1 + commitDiff := &CommitDiff{ + Commitment: remoteCommit, + CommitSig: &lnwire.CommitSig{ + ChanID: lnwire.ChannelID(key), + CommitSig: testSig, + HtlcSigs: []*btcec.Signature{ + testSig, + testSig, + }, + }, + LogUpdates: []LogUpdate{ + { + LogIndex: 1, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, + }, + }, + { + LogIndex: 2, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, + }, + }, + }, + } + copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], + bytes.Repeat([]byte{1}, 32)) + copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], + bytes.Repeat([]byte{2}, 32)) + if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { + t.Fatalf("unable to add to commit chain: %v", err) + } + + // The commitment tip should now match the the commitment that we just + // inserted. + diskCommitDiff, err := channel.RemoteCommitChainTip() + if err != nil { + t.Fatalf("unable to fetch commit diff: %v", err) + } + if !reflect.DeepEqual(commitDiff, diskCommitDiff) { + t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit), + spew.Sdump(diskCommitDiff)) + } + + // We'll save the old remote commitment as this will be added to the + // revocation log shortly. + oldRemoteCommit := channel.RemoteCommitment + // Next, write to the log which tracks the necessary revocation state // needed to rectify any fishy behavior by the remote party. Modify the // current uncollapsed revocation state to simulate a state transition @@ -462,37 +511,28 @@ func TestChannelStateTransition(t *testing.T) { t.Fatalf("unable to generate key: %v", err) } channel.RemoteNextRevocation = newPriv.PubKey() - if err := channel.AppendToRevocationLog(delta); err != nil { + if err := channel.AdvanceCommitChainTail(); err != nil { t.Fatalf("unable to append to revocation log: %v", err) } + // At this point, the remote commit chain shuold be nil, and the posted + // remote commitment should match the one we added as a diff above. + if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit { + t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) + } + // We should be able to fetch the channel delta created above by it's // update number with all the state properly reconstructed. - diskDelta, err := channel.FindPreviousState(uint64(delta.UpdateNum)) + diskPrevCommit, err := channel.FindPreviousState( + oldRemoteCommit.CommitHeight, + ) if err != nil { t.Fatalf("unable to fetch past delta: %v", err) } // The two deltas (the original vs the on-disk version) should // identical, and all HTLC data should properly be retained. - if delta.LocalBalance != diskDelta.LocalBalance { - t.Fatal("local balances don't match") - } - if delta.RemoteBalance != diskDelta.RemoteBalance { - t.Fatal("remote balances don't match") - } - if delta.UpdateNum != diskDelta.UpdateNum { - t.Fatal("update number doesn't match") - } - for i := 0; i < len(delta.Htlcs); i++ { - originalHTLC := delta.Htlcs[i] - diskHTLC := diskDelta.Htlcs[i] - if !reflect.DeepEqual(originalHTLC, diskHTLC) { - t.Fatalf("htlc's dont match: %v vs %v", - spew.Sdump(originalHTLC), - spew.Sdump(diskHTLC)) - } - } + assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit) // The state number recovered from the tail of the revocation log // should be identical to this current state. @@ -500,36 +540,31 @@ func TestChannelStateTransition(t *testing.T) { if err != nil { t.Fatalf("unable to retrieve log: %v", err) } - if logTail.UpdateNum != delta.UpdateNum { + if logTail.CommitHeight != oldRemoteCommit.CommitHeight { t.Fatal("update number doesn't match") } - // Next modify the delta slightly, then create a new entry within the - // revocation log. - delta.UpdateNum = 2 - delta.LocalBalance -= htlcAmt - delta.RemoteBalance += htlcAmt - delta.Htlcs = nil - if err := channel.AppendToRevocationLog(delta); err != nil { + oldRemoteCommit = channel.RemoteCommitment + + // Next modify the posted diff commitment slightly, then create a new + // commitment diff and advance the tail. + commitDiff.Commitment.CommitHeight = 2 + commitDiff.Commitment.LocalBalance -= htlcAmt + commitDiff.Commitment.RemoteBalance += htlcAmt + commitDiff.LogUpdates = []LogUpdate{} + if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { + t.Fatal("unable to add to commit chain: %v", err) + } + if err := channel.AdvanceCommitChainTail(); err != nil { t.Fatalf("unable to append to revocation log: %v", err) } // Once again, fetch the state and ensure it has been properly updated. - diskDelta, err = channel.FindPreviousState(uint64(delta.UpdateNum)) + prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight) if err != nil { t.Fatalf("unable to fetch past delta: %v", err) } - if len(diskDelta.Htlcs) != 0 { - t.Fatalf("expected %v htlcs, got %v", 0, len(diskDelta.Htlcs)) - } - if delta.LocalBalance != 1e8-htlcAmt { - t.Fatalf("mismatched balances, expected %v got %v", 1e8-htlcAmt, - delta.LocalBalance) - } - if delta.RemoteBalance != 1e8+htlcAmt { - t.Fatalf("mismatched balances, expected %v got %v", 1e8+htlcAmt, - delta.RemoteBalance) - } + assertCommitmentEqual(t, &oldRemoteCommit, prevCommit) // Once again, state number recovered from the tail of the revocation // log should be identical to this current state. @@ -537,7 +572,7 @@ func TestChannelStateTransition(t *testing.T) { if err != nil { t.Fatalf("unable to retrieve log: %v", err) } - if logTail.UpdateNum != delta.UpdateNum { + if logTail.CommitHeight != oldRemoteCommit.CommitHeight { t.Fatal("update number doesn't match") } @@ -579,7 +614,7 @@ func TestChannelStateTransition(t *testing.T) { // Attempting to find previous states on the channel should fail as the // revocation log has been deleted. - _, err = updatedChannel[0].FindPreviousState(uint64(delta.UpdateNum)) + _, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight) if err == nil { t.Fatal("revocation log search should've failed") }