From 9c483c38b18d2b21658df5ebd75603bcf3a628ac Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Tue, 30 Jan 2018 19:55:39 -0800 Subject: [PATCH] lnwallet: update state machine to use new lnwire.Sig everywhere --- lnwallet/channel.go | 87 ++++++++++++++++++++++++++-------------- lnwallet/channel_test.go | 42 +++++++++---------- lnwallet/sigpool.go | 7 ++-- 3 files changed, 80 insertions(+), 56 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index cab7ccd2..36859952 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2557,8 +2557,8 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing, // new commitment to the remote party. The commit diff returned contains all // information necessary for retransmission. func (lc *LightningChannel) createCommitDiff( - newCommit *commitment, commitSig *btcec.Signature, - htlcSigs []*btcec.Signature) (*channeldb.CommitDiff, error) { + newCommit *commitment, commitSig lnwire.Sig, + htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, error) { // First, we need to convert the funding outpoint into the ID that's // used on the wire to identify this channel. We'll use this shortly @@ -2673,10 +2673,15 @@ func (lc *LightningChannel) createCommitDiff( // itself, while the second parameter is a slice of all HTLC signatures (if // any). The HTLC signatures are sorted according to the BIP 69 order of the // HTLC's on the commitment transaction. -func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Signature, error) { +func (lc *LightningChannel) SignNextCommitment() (lnwire.Sig, []lnwire.Sig, error) { lc.Lock() defer lc.Unlock() + var ( + sig lnwire.Sig + htlcSigs []lnwire.Sig + ) + // If we're awaiting for an ACK to a commitment signature, or if we // don't yet have the initial next revocation point of the remote // party, then we're unable to create new states. Each time we create a @@ -2684,7 +2689,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig commitPoint := lc.channelState.RemoteNextRevocation if lc.remoteCommitChain.hasUnackedCommitment() || commitPoint == nil { - return nil, nil, ErrNoWindow + return sig, htlcSigs, ErrNoWindow } // Determine the last update on the remote log that has been locked in. @@ -2698,7 +2703,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig err := lc.validateCommitmentSanity(remoteACKedIndex, lc.localUpdateLog.logIndex, false, true, true) if err != nil { - return nil, nil, err + return sig, htlcSigs, err } // Grab the next commitment point for the remote party. This will be @@ -2719,7 +2724,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig remoteACKedIndex, remoteHtlcIndex, keyRing, ) if err != nil { - return nil, nil, err + return sig, htlcSigs, err } walletLog.Tracef("ChannelPoint(%v): extending remote chain to height %v, "+ @@ -2744,7 +2749,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig lc.localChanCfg, lc.remoteChanCfg, newCommitView, ) if err != nil { - return nil, nil, err + return sig, htlcSigs, err } lc.sigPool.SubmitSignBatch(sigBatch) @@ -2755,12 +2760,12 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig rawSig, err := lc.signer.SignOutputRaw(newCommitView.txn, lc.signDesc) if err != nil { close(cancelChan) - return nil, nil, err + return sig, htlcSigs, err } - sig, err := btcec.ParseSignature(rawSig, btcec.S256()) + sig, err = lnwire.NewSigFromRawSignature(rawSig) if err != nil { close(cancelChan) - return nil, nil, err + return sig, htlcSigs, err } // We'll need to send over the signatures to the remote party in the @@ -2772,7 +2777,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig // With the jobs sorted, we'll now iterate through all the responses to // gather each of the signatures in order. - htlcSigs := make([]*btcec.Signature, 0, len(sigBatch)) + htlcSigs = make([]lnwire.Sig, 0, len(sigBatch)) for _, htlcSigJob := range sigBatch { select { case jobResp := <-htlcSigJob.resp: @@ -2780,12 +2785,12 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig // active jobs. if jobResp.err != nil { close(cancelChan) - return nil, nil, err + return sig, htlcSigs, err } htlcSigs = append(htlcSigs, jobResp.sig) case <-lc.quit: - return nil, nil, fmt.Errorf("channel shutting down") + return sig, htlcSigs, fmt.Errorf("channel shutting down") } } @@ -2794,10 +2799,10 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig // can retransmit it if necessary. commitDiff, err := lc.createCommitDiff(newCommitView, sig, htlcSigs) if err != nil { - return nil, nil, err + return sig, htlcSigs, err } if lc.channelState.AppendRemoteCommitChain(commitDiff); err != nil { - return nil, nil, err + return sig, htlcSigs, err } // TODO(roasbeef): check that one eclair bug @@ -3128,13 +3133,13 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // commitment state. The jobs generated are fully populated, and can be sent // directly into the pool of workers. func genHtlcSigValidationJobs(localCommitmentView *commitment, - keyRing *CommitmentKeyRing, htlcSigs []*btcec.Signature, - localChanCfg, remoteChanCfg *channeldb.ChannelConfig) []verifyJob { + keyRing *CommitmentKeyRing, htlcSigs []lnwire.Sig, + localChanCfg, remoteChanCfg *channeldb.ChannelConfig) ([]verifyJob, error) { // If this new commitment state doesn't have any HTLC's that are to be // signed, then we'll return a nil slice. if len(htlcSigs) == 0 { - return nil + return nil, nil } txHash := localCommitmentView.txn.TxHash() @@ -3154,7 +3159,11 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, // to validate each signature within the worker pool. i := 0 for index := range localCommitmentView.txn.TxOut { - var sigHash func() ([]byte, error) + var ( + sigHash func() ([]byte, error) + sig *btcec.Signature + err error + ) outputIndex := int32(index) switch { @@ -3197,7 +3206,11 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, // With the sighash generated, we'll also store the // signature so it can be written to disk if this state // is valid. - htlc.sig = htlcSigs[i] + sig, err = htlcSigs[i].ToSignature() + if err != nil { + return nil, err + } + htlc.sig = sig // Otherwise, if this is an outgoing HTLC, then we'll need to // generate a timeout transaction so we can verify the @@ -3239,7 +3252,11 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, // With the sighash generated, we'll also store the // signature so it can be written to disk if this state // is valid. - htlc.sig = htlcSigs[i] + sig, err = htlcSigs[i].ToSignature() + if err != nil { + return nil, err + } + htlc.sig = sig default: continue @@ -3247,14 +3264,14 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment, verifyJobs = append(verifyJobs, verifyJob{ pubKey: keyRing.RemoteHtlcKey, - sig: htlcSigs[i], + sig: sig, sigHash: sigHash, }) i++ } - return verifyJobs + return verifyJobs, nil } // InvalidCommitSigError is a struct that implements the error interface to @@ -3292,8 +3309,8 @@ var _ error = (*InvalidCommitSigError)(nil) // to our local commitment chain. Once we send a revocation for our prior // state, then this newly added commitment becomes our current accepted channel // state. -func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature, - htlcSigs []*btcec.Signature) error { +func (lc *LightningChannel) ReceiveNewCommitment(commitSig lnwire.Sig, + htlcSigs []lnwire.Sig) error { lc.Lock() defer lc.Unlock() @@ -3366,8 +3383,14 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature, // As an optimization, we'll generate a series of jobs for the worker // pool to verify each of the HTLc signatures presented. Once // generated, we'll submit these jobs to the worker pool. - verifyJobs := genHtlcSigValidationJobs(localCommitmentView, - keyRing, htlcSigs, lc.localChanCfg, lc.remoteChanCfg) + verifyJobs, err := genHtlcSigValidationJobs( + localCommitmentView, keyRing, htlcSigs, lc.localChanCfg, + lc.remoteChanCfg, + ) + if err != nil { + return err + } + cancelChan := make(chan struct{}) verifyResps := lc.sigPool.SubmitVerifyBatch(verifyJobs, cancelChan) @@ -3379,7 +3402,11 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature, Y: lc.remoteChanCfg.MultiSigKey.Y, Curve: btcec.S256(), } - if !commitSig.Verify(sigHash, &verifyKey) { + cSig, err := commitSig.ToSignature() + if err != nil { + return err + } + if !cSig.Verify(sigHash, &verifyKey) { close(cancelChan) // If we fail to validate their commitment signature, we'll @@ -3390,7 +3417,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature, localCommitTx.Serialize(&txBytes) return &InvalidCommitSigError{ commitHeight: nextHeight, - commitSig: commitSig.Serialize(), + commitSig: commitSig.ToSignatureBytes(), sigHash: sigHash, commitTx: txBytes.Bytes(), } @@ -3414,7 +3441,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature, // The signature checks out, so we can now add the new commitment to // our local commitment chain. - localCommitmentView.sig = commitSig.Serialize() + localCommitmentView.sig = commitSig.ToSignatureBytes() lc.localCommitChain.addCommitment(localCommitmentView) // If we are not channel initiator, then the commitment just received diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index cf37de6e..be8407ee 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -4,7 +4,6 @@ import ( "bytes" "crypto/sha256" "io/ioutil" - "math/big" "math/rand" "os" "reflect" @@ -2768,20 +2767,20 @@ func TestChanSyncOweCommitment(t *testing.T) { t.Fatalf("expected a CommitSig message, instead have %v", spew.Sdump(aliceMsgsToSend[4])) } - if !commitSigMsg.CommitSig.IsEqual(aliceSig) { + if commitSigMsg.CommitSig != aliceSig { t.Fatalf("commit sig msgs don't match: expected %x got %x", - aliceSig.Serialize(), commitSigMsg.CommitSig.Serialize()) + aliceSig, commitSigMsg.CommitSig) } if len(commitSigMsg.HtlcSigs) != len(aliceHtlcSigs) { t.Fatalf("wrong number of htlc sigs: expected %v, got %v", len(aliceHtlcSigs), len(commitSigMsg.HtlcSigs)) } for i, htlcSig := range commitSigMsg.HtlcSigs { - if !htlcSig.IsEqual(aliceHtlcSigs[i]) { + if htlcSig != aliceHtlcSigs[i] { t.Fatalf("htlc sig msgs don't match: "+ "expected %x got %x", - aliceHtlcSigs[i].Serialize(), - htlcSig.Serialize()) + aliceHtlcSigs[i], + htlcSig) } } } @@ -3228,20 +3227,19 @@ func TestChanSyncOweRevocationAndCommit(t *testing.T) { t.Fatalf("expected bob to re-send commit sig, instead sending: %v", spew.Sdump(bobMsgsToSend[1])) } - if !bobReCommitSigMsg.CommitSig.IsEqual(bobSig) { + if bobReCommitSigMsg.CommitSig != bobSig { t.Fatalf("commit sig msgs don't match: expected %x got %x", - bobSig.Serialize(), bobReCommitSigMsg.CommitSig.Serialize()) + bobSig, bobReCommitSigMsg.CommitSig) } if len(bobReCommitSigMsg.HtlcSigs) != len(bobHtlcSigs) { t.Fatalf("wrong number of htlc sigs: expected %v, got %v", len(bobHtlcSigs), len(bobReCommitSigMsg.HtlcSigs)) } for i, htlcSig := range bobReCommitSigMsg.HtlcSigs { - if !htlcSig.IsEqual(aliceHtlcSigs[i]) { + if htlcSig != aliceHtlcSigs[i] { t.Fatalf("htlc sig msgs don't match: "+ "expected %x got %x", - bobHtlcSigs[i].Serialize(), - htlcSig.Serialize()) + bobHtlcSigs[i], htlcSig) } } } @@ -3426,21 +3424,20 @@ func TestChanSyncOweRevocationAndCommitForceTransition(t *testing.T) { t.Fatalf("revocation msgs don't match: expected %v, got %v", bobRevocation, bobReRevoke) } - if !bobReCommitSigMsg.CommitSig.IsEqual(bobSigMsg.CommitSig) { + if bobReCommitSigMsg.CommitSig != bobSigMsg.CommitSig { t.Fatalf("commit sig msgs don't match: expected %x got %x", - bobSigMsg.CommitSig.Serialize(), - bobReCommitSigMsg.CommitSig.Serialize()) + bobSigMsg.CommitSig, + bobReCommitSigMsg.CommitSig) } if len(bobReCommitSigMsg.HtlcSigs) != len(bobSigMsg.HtlcSigs) { t.Fatalf("wrong number of htlc sigs: expected %v, got %v", len(bobSigMsg.HtlcSigs), len(bobReCommitSigMsg.HtlcSigs)) } for i, htlcSig := range bobReCommitSigMsg.HtlcSigs { - if htlcSig.IsEqual(bobSigMsg.HtlcSigs[i]) { + if htlcSig != bobSigMsg.HtlcSigs[i] { t.Fatalf("htlc sig msgs don't match: "+ "expected %x got %x", - bobSigMsg.HtlcSigs[i].Serialize(), - htlcSig.Serialize()) + bobSigMsg.HtlcSigs[i], htlcSig) } } @@ -3598,20 +3595,19 @@ func TestChannelRetransmissionFeeUpdate(t *testing.T) { t.Fatalf("expected a CommitSig message, instead have %v", spew.Sdump(aliceMsgsToSend[1])) } - if !commitSigMsg.CommitSig.IsEqual(aliceSig) { + if commitSigMsg.CommitSig != aliceSig { t.Fatalf("commit sig msgs don't match: expected %x got %x", - aliceSig.Serialize(), commitSigMsg.CommitSig.Serialize()) + aliceSig, commitSigMsg.CommitSig) } if len(commitSigMsg.HtlcSigs) != len(aliceHtlcSigs) { t.Fatalf("wrong number of htlc sigs: expected %v, got %v", len(aliceHtlcSigs), len(commitSigMsg.HtlcSigs)) } for i, htlcSig := range commitSigMsg.HtlcSigs { - if !htlcSig.IsEqual(aliceHtlcSigs[i]) { + if htlcSig != aliceHtlcSigs[i] { t.Fatalf("htlc sig msgs don't match: "+ "expected %x got %x", - aliceHtlcSigs[i].Serialize(), - htlcSig.Serialize()) + aliceHtlcSigs[i], htlcSig) } } @@ -4127,7 +4123,7 @@ func TestInvalidCommitSigError(t *testing.T) { // Before the signature gets to Bob, we'll mutate it, such that the // signature is now actually invalid. - aliceSig.R.Add(aliceSig.R, new(big.Int).SetInt64(1)) + aliceSig[0] ^= 88 // Bob should reject this new state, and return the proper error. err = bobChannel.ReceiveNewCommitment(aliceSig, aliceHtlcSigs) diff --git a/lnwallet/sigpool.go b/lnwallet/sigpool.go index 4ee58f5e..a93c9630 100644 --- a/lnwallet/sigpool.go +++ b/lnwallet/sigpool.go @@ -5,6 +5,7 @@ import ( "sync" "sync/atomic" + "github.com/lightningnetwork/lnd/lnwire" "github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/wire" ) @@ -90,7 +91,7 @@ type signJobResp struct { // sig is the generated signature for a particular signJob In the case // of an error during signature generation, then this value sent will // be nil. - sig *btcec.Signature + sig lnwire.Sig // err is the error that occurred when executing the specified // signature job. In the case that no error occurred, this value will @@ -185,7 +186,7 @@ func (s *sigPool) poolWorker() { if err != nil { select { case sigMsg.resp <- signJobResp{ - sig: nil, + sig: lnwire.Sig{}, err: err, }: continue @@ -196,7 +197,7 @@ func (s *sigPool) poolWorker() { } } - sig, err := btcec.ParseSignature(rawSig, btcec.S256()) + sig, err := lnwire.NewSigFromRawSignature(rawSig) select { case sigMsg.resp <- signJobResp{ sig: sig,