lnwallet: update state machine to use new lnwire.Sig everywhere

This commit is contained in:
Olaoluwa Osuntokun 2018-01-30 19:55:39 -08:00
parent aa2e91f7c4
commit 9c483c38b1
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
3 changed files with 80 additions and 56 deletions

View File

@ -2557,8 +2557,8 @@ func genRemoteHtlcSigJobs(keyRing *CommitmentKeyRing,
// new commitment to the remote party. The commit diff returned contains all // new commitment to the remote party. The commit diff returned contains all
// information necessary for retransmission. // information necessary for retransmission.
func (lc *LightningChannel) createCommitDiff( func (lc *LightningChannel) createCommitDiff(
newCommit *commitment, commitSig *btcec.Signature, newCommit *commitment, commitSig lnwire.Sig,
htlcSigs []*btcec.Signature) (*channeldb.CommitDiff, error) { htlcSigs []lnwire.Sig) (*channeldb.CommitDiff, error) {
// First, we need to convert the funding outpoint into the ID that's // First, we need to convert the funding outpoint into the ID that's
// used on the wire to identify this channel. We'll use this shortly // used on the wire to identify this channel. We'll use this shortly
@ -2673,10 +2673,15 @@ func (lc *LightningChannel) createCommitDiff(
// itself, while the second parameter is a slice of all HTLC signatures (if // itself, while the second parameter is a slice of all HTLC signatures (if
// any). The HTLC signatures are sorted according to the BIP 69 order of the // any). The HTLC signatures are sorted according to the BIP 69 order of the
// HTLC's on the commitment transaction. // HTLC's on the commitment transaction.
func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Signature, error) { func (lc *LightningChannel) SignNextCommitment() (lnwire.Sig, []lnwire.Sig, error) {
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
var (
sig lnwire.Sig
htlcSigs []lnwire.Sig
)
// If we're awaiting for an ACK to a commitment signature, or if we // If we're awaiting for an ACK to a commitment signature, or if we
// don't yet have the initial next revocation point of the remote // don't yet have the initial next revocation point of the remote
// party, then we're unable to create new states. Each time we create a // party, then we're unable to create new states. Each time we create a
@ -2684,7 +2689,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
commitPoint := lc.channelState.RemoteNextRevocation commitPoint := lc.channelState.RemoteNextRevocation
if lc.remoteCommitChain.hasUnackedCommitment() || commitPoint == nil { if lc.remoteCommitChain.hasUnackedCommitment() || commitPoint == nil {
return nil, nil, ErrNoWindow return sig, htlcSigs, ErrNoWindow
} }
// Determine the last update on the remote log that has been locked in. // Determine the last update on the remote log that has been locked in.
@ -2698,7 +2703,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
err := lc.validateCommitmentSanity(remoteACKedIndex, err := lc.validateCommitmentSanity(remoteACKedIndex,
lc.localUpdateLog.logIndex, false, true, true) lc.localUpdateLog.logIndex, false, true, true)
if err != nil { if err != nil {
return nil, nil, err return sig, htlcSigs, err
} }
// Grab the next commitment point for the remote party. This will be // Grab the next commitment point for the remote party. This will be
@ -2719,7 +2724,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
remoteACKedIndex, remoteHtlcIndex, keyRing, remoteACKedIndex, remoteHtlcIndex, keyRing,
) )
if err != nil { if err != nil {
return nil, nil, err return sig, htlcSigs, err
} }
walletLog.Tracef("ChannelPoint(%v): extending remote chain to height %v, "+ walletLog.Tracef("ChannelPoint(%v): extending remote chain to height %v, "+
@ -2744,7 +2749,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
lc.localChanCfg, lc.remoteChanCfg, newCommitView, lc.localChanCfg, lc.remoteChanCfg, newCommitView,
) )
if err != nil { if err != nil {
return nil, nil, err return sig, htlcSigs, err
} }
lc.sigPool.SubmitSignBatch(sigBatch) lc.sigPool.SubmitSignBatch(sigBatch)
@ -2755,12 +2760,12 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
rawSig, err := lc.signer.SignOutputRaw(newCommitView.txn, lc.signDesc) rawSig, err := lc.signer.SignOutputRaw(newCommitView.txn, lc.signDesc)
if err != nil { if err != nil {
close(cancelChan) close(cancelChan)
return nil, nil, err return sig, htlcSigs, err
} }
sig, err := btcec.ParseSignature(rawSig, btcec.S256()) sig, err = lnwire.NewSigFromRawSignature(rawSig)
if err != nil { if err != nil {
close(cancelChan) close(cancelChan)
return nil, nil, err return sig, htlcSigs, err
} }
// We'll need to send over the signatures to the remote party in the // We'll need to send over the signatures to the remote party in the
@ -2772,7 +2777,7 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
// With the jobs sorted, we'll now iterate through all the responses to // With the jobs sorted, we'll now iterate through all the responses to
// gather each of the signatures in order. // gather each of the signatures in order.
htlcSigs := make([]*btcec.Signature, 0, len(sigBatch)) htlcSigs = make([]lnwire.Sig, 0, len(sigBatch))
for _, htlcSigJob := range sigBatch { for _, htlcSigJob := range sigBatch {
select { select {
case jobResp := <-htlcSigJob.resp: case jobResp := <-htlcSigJob.resp:
@ -2780,12 +2785,12 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
// active jobs. // active jobs.
if jobResp.err != nil { if jobResp.err != nil {
close(cancelChan) close(cancelChan)
return nil, nil, err return sig, htlcSigs, err
} }
htlcSigs = append(htlcSigs, jobResp.sig) htlcSigs = append(htlcSigs, jobResp.sig)
case <-lc.quit: case <-lc.quit:
return nil, nil, fmt.Errorf("channel shutting down") return sig, htlcSigs, fmt.Errorf("channel shutting down")
} }
} }
@ -2794,10 +2799,10 @@ func (lc *LightningChannel) SignNextCommitment() (*btcec.Signature, []*btcec.Sig
// can retransmit it if necessary. // can retransmit it if necessary.
commitDiff, err := lc.createCommitDiff(newCommitView, sig, htlcSigs) commitDiff, err := lc.createCommitDiff(newCommitView, sig, htlcSigs)
if err != nil { if err != nil {
return nil, nil, err return sig, htlcSigs, err
} }
if lc.channelState.AppendRemoteCommitChain(commitDiff); err != nil { if lc.channelState.AppendRemoteCommitChain(commitDiff); err != nil {
return nil, nil, err return sig, htlcSigs, err
} }
// TODO(roasbeef): check that one eclair bug // TODO(roasbeef): check that one eclair bug
@ -3128,13 +3133,13 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter,
// commitment state. The jobs generated are fully populated, and can be sent // commitment state. The jobs generated are fully populated, and can be sent
// directly into the pool of workers. // directly into the pool of workers.
func genHtlcSigValidationJobs(localCommitmentView *commitment, func genHtlcSigValidationJobs(localCommitmentView *commitment,
keyRing *CommitmentKeyRing, htlcSigs []*btcec.Signature, keyRing *CommitmentKeyRing, htlcSigs []lnwire.Sig,
localChanCfg, remoteChanCfg *channeldb.ChannelConfig) []verifyJob { localChanCfg, remoteChanCfg *channeldb.ChannelConfig) ([]verifyJob, error) {
// If this new commitment state doesn't have any HTLC's that are to be // If this new commitment state doesn't have any HTLC's that are to be
// signed, then we'll return a nil slice. // signed, then we'll return a nil slice.
if len(htlcSigs) == 0 { if len(htlcSigs) == 0 {
return nil return nil, nil
} }
txHash := localCommitmentView.txn.TxHash() txHash := localCommitmentView.txn.TxHash()
@ -3154,7 +3159,11 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment,
// to validate each signature within the worker pool. // to validate each signature within the worker pool.
i := 0 i := 0
for index := range localCommitmentView.txn.TxOut { for index := range localCommitmentView.txn.TxOut {
var sigHash func() ([]byte, error) var (
sigHash func() ([]byte, error)
sig *btcec.Signature
err error
)
outputIndex := int32(index) outputIndex := int32(index)
switch { switch {
@ -3197,7 +3206,11 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment,
// With the sighash generated, we'll also store the // With the sighash generated, we'll also store the
// signature so it can be written to disk if this state // signature so it can be written to disk if this state
// is valid. // is valid.
htlc.sig = htlcSigs[i] sig, err = htlcSigs[i].ToSignature()
if err != nil {
return nil, err
}
htlc.sig = sig
// Otherwise, if this is an outgoing HTLC, then we'll need to // Otherwise, if this is an outgoing HTLC, then we'll need to
// generate a timeout transaction so we can verify the // generate a timeout transaction so we can verify the
@ -3239,7 +3252,11 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment,
// With the sighash generated, we'll also store the // With the sighash generated, we'll also store the
// signature so it can be written to disk if this state // signature so it can be written to disk if this state
// is valid. // is valid.
htlc.sig = htlcSigs[i] sig, err = htlcSigs[i].ToSignature()
if err != nil {
return nil, err
}
htlc.sig = sig
default: default:
continue continue
@ -3247,14 +3264,14 @@ func genHtlcSigValidationJobs(localCommitmentView *commitment,
verifyJobs = append(verifyJobs, verifyJob{ verifyJobs = append(verifyJobs, verifyJob{
pubKey: keyRing.RemoteHtlcKey, pubKey: keyRing.RemoteHtlcKey,
sig: htlcSigs[i], sig: sig,
sigHash: sigHash, sigHash: sigHash,
}) })
i++ i++
} }
return verifyJobs return verifyJobs, nil
} }
// InvalidCommitSigError is a struct that implements the error interface to // InvalidCommitSigError is a struct that implements the error interface to
@ -3292,8 +3309,8 @@ var _ error = (*InvalidCommitSigError)(nil)
// to our local commitment chain. Once we send a revocation for our prior // to our local commitment chain. Once we send a revocation for our prior
// state, then this newly added commitment becomes our current accepted channel // state, then this newly added commitment becomes our current accepted channel
// state. // state.
func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature, func (lc *LightningChannel) ReceiveNewCommitment(commitSig lnwire.Sig,
htlcSigs []*btcec.Signature) error { htlcSigs []lnwire.Sig) error {
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
@ -3366,8 +3383,14 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature,
// As an optimization, we'll generate a series of jobs for the worker // As an optimization, we'll generate a series of jobs for the worker
// pool to verify each of the HTLc signatures presented. Once // pool to verify each of the HTLc signatures presented. Once
// generated, we'll submit these jobs to the worker pool. // generated, we'll submit these jobs to the worker pool.
verifyJobs := genHtlcSigValidationJobs(localCommitmentView, verifyJobs, err := genHtlcSigValidationJobs(
keyRing, htlcSigs, lc.localChanCfg, lc.remoteChanCfg) localCommitmentView, keyRing, htlcSigs, lc.localChanCfg,
lc.remoteChanCfg,
)
if err != nil {
return err
}
cancelChan := make(chan struct{}) cancelChan := make(chan struct{})
verifyResps := lc.sigPool.SubmitVerifyBatch(verifyJobs, cancelChan) verifyResps := lc.sigPool.SubmitVerifyBatch(verifyJobs, cancelChan)
@ -3379,7 +3402,11 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature,
Y: lc.remoteChanCfg.MultiSigKey.Y, Y: lc.remoteChanCfg.MultiSigKey.Y,
Curve: btcec.S256(), Curve: btcec.S256(),
} }
if !commitSig.Verify(sigHash, &verifyKey) { cSig, err := commitSig.ToSignature()
if err != nil {
return err
}
if !cSig.Verify(sigHash, &verifyKey) {
close(cancelChan) close(cancelChan)
// If we fail to validate their commitment signature, we'll // If we fail to validate their commitment signature, we'll
@ -3390,7 +3417,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature,
localCommitTx.Serialize(&txBytes) localCommitTx.Serialize(&txBytes)
return &InvalidCommitSigError{ return &InvalidCommitSigError{
commitHeight: nextHeight, commitHeight: nextHeight,
commitSig: commitSig.Serialize(), commitSig: commitSig.ToSignatureBytes(),
sigHash: sigHash, sigHash: sigHash,
commitTx: txBytes.Bytes(), commitTx: txBytes.Bytes(),
} }
@ -3414,7 +3441,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSig *btcec.Signature,
// The signature checks out, so we can now add the new commitment to // The signature checks out, so we can now add the new commitment to
// our local commitment chain. // our local commitment chain.
localCommitmentView.sig = commitSig.Serialize() localCommitmentView.sig = commitSig.ToSignatureBytes()
lc.localCommitChain.addCommitment(localCommitmentView) lc.localCommitChain.addCommitment(localCommitmentView)
// If we are not channel initiator, then the commitment just received // If we are not channel initiator, then the commitment just received

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"crypto/sha256" "crypto/sha256"
"io/ioutil" "io/ioutil"
"math/big"
"math/rand" "math/rand"
"os" "os"
"reflect" "reflect"
@ -2768,20 +2767,20 @@ func TestChanSyncOweCommitment(t *testing.T) {
t.Fatalf("expected a CommitSig message, instead have %v", t.Fatalf("expected a CommitSig message, instead have %v",
spew.Sdump(aliceMsgsToSend[4])) spew.Sdump(aliceMsgsToSend[4]))
} }
if !commitSigMsg.CommitSig.IsEqual(aliceSig) { if commitSigMsg.CommitSig != aliceSig {
t.Fatalf("commit sig msgs don't match: expected %x got %x", t.Fatalf("commit sig msgs don't match: expected %x got %x",
aliceSig.Serialize(), commitSigMsg.CommitSig.Serialize()) aliceSig, commitSigMsg.CommitSig)
} }
if len(commitSigMsg.HtlcSigs) != len(aliceHtlcSigs) { if len(commitSigMsg.HtlcSigs) != len(aliceHtlcSigs) {
t.Fatalf("wrong number of htlc sigs: expected %v, got %v", t.Fatalf("wrong number of htlc sigs: expected %v, got %v",
len(aliceHtlcSigs), len(commitSigMsg.HtlcSigs)) len(aliceHtlcSigs), len(commitSigMsg.HtlcSigs))
} }
for i, htlcSig := range commitSigMsg.HtlcSigs { for i, htlcSig := range commitSigMsg.HtlcSigs {
if !htlcSig.IsEqual(aliceHtlcSigs[i]) { if htlcSig != aliceHtlcSigs[i] {
t.Fatalf("htlc sig msgs don't match: "+ t.Fatalf("htlc sig msgs don't match: "+
"expected %x got %x", "expected %x got %x",
aliceHtlcSigs[i].Serialize(), aliceHtlcSigs[i],
htlcSig.Serialize()) htlcSig)
} }
} }
} }
@ -3228,20 +3227,19 @@ func TestChanSyncOweRevocationAndCommit(t *testing.T) {
t.Fatalf("expected bob to re-send commit sig, instead sending: %v", t.Fatalf("expected bob to re-send commit sig, instead sending: %v",
spew.Sdump(bobMsgsToSend[1])) spew.Sdump(bobMsgsToSend[1]))
} }
if !bobReCommitSigMsg.CommitSig.IsEqual(bobSig) { if bobReCommitSigMsg.CommitSig != bobSig {
t.Fatalf("commit sig msgs don't match: expected %x got %x", t.Fatalf("commit sig msgs don't match: expected %x got %x",
bobSig.Serialize(), bobReCommitSigMsg.CommitSig.Serialize()) bobSig, bobReCommitSigMsg.CommitSig)
} }
if len(bobReCommitSigMsg.HtlcSigs) != len(bobHtlcSigs) { if len(bobReCommitSigMsg.HtlcSigs) != len(bobHtlcSigs) {
t.Fatalf("wrong number of htlc sigs: expected %v, got %v", t.Fatalf("wrong number of htlc sigs: expected %v, got %v",
len(bobHtlcSigs), len(bobReCommitSigMsg.HtlcSigs)) len(bobHtlcSigs), len(bobReCommitSigMsg.HtlcSigs))
} }
for i, htlcSig := range bobReCommitSigMsg.HtlcSigs { for i, htlcSig := range bobReCommitSigMsg.HtlcSigs {
if !htlcSig.IsEqual(aliceHtlcSigs[i]) { if htlcSig != aliceHtlcSigs[i] {
t.Fatalf("htlc sig msgs don't match: "+ t.Fatalf("htlc sig msgs don't match: "+
"expected %x got %x", "expected %x got %x",
bobHtlcSigs[i].Serialize(), bobHtlcSigs[i], htlcSig)
htlcSig.Serialize())
} }
} }
} }
@ -3426,21 +3424,20 @@ func TestChanSyncOweRevocationAndCommitForceTransition(t *testing.T) {
t.Fatalf("revocation msgs don't match: expected %v, got %v", t.Fatalf("revocation msgs don't match: expected %v, got %v",
bobRevocation, bobReRevoke) bobRevocation, bobReRevoke)
} }
if !bobReCommitSigMsg.CommitSig.IsEqual(bobSigMsg.CommitSig) { if bobReCommitSigMsg.CommitSig != bobSigMsg.CommitSig {
t.Fatalf("commit sig msgs don't match: expected %x got %x", t.Fatalf("commit sig msgs don't match: expected %x got %x",
bobSigMsg.CommitSig.Serialize(), bobSigMsg.CommitSig,
bobReCommitSigMsg.CommitSig.Serialize()) bobReCommitSigMsg.CommitSig)
} }
if len(bobReCommitSigMsg.HtlcSigs) != len(bobSigMsg.HtlcSigs) { if len(bobReCommitSigMsg.HtlcSigs) != len(bobSigMsg.HtlcSigs) {
t.Fatalf("wrong number of htlc sigs: expected %v, got %v", t.Fatalf("wrong number of htlc sigs: expected %v, got %v",
len(bobSigMsg.HtlcSigs), len(bobReCommitSigMsg.HtlcSigs)) len(bobSigMsg.HtlcSigs), len(bobReCommitSigMsg.HtlcSigs))
} }
for i, htlcSig := range bobReCommitSigMsg.HtlcSigs { for i, htlcSig := range bobReCommitSigMsg.HtlcSigs {
if htlcSig.IsEqual(bobSigMsg.HtlcSigs[i]) { if htlcSig != bobSigMsg.HtlcSigs[i] {
t.Fatalf("htlc sig msgs don't match: "+ t.Fatalf("htlc sig msgs don't match: "+
"expected %x got %x", "expected %x got %x",
bobSigMsg.HtlcSigs[i].Serialize(), bobSigMsg.HtlcSigs[i], htlcSig)
htlcSig.Serialize())
} }
} }
@ -3598,20 +3595,19 @@ func TestChannelRetransmissionFeeUpdate(t *testing.T) {
t.Fatalf("expected a CommitSig message, instead have %v", t.Fatalf("expected a CommitSig message, instead have %v",
spew.Sdump(aliceMsgsToSend[1])) spew.Sdump(aliceMsgsToSend[1]))
} }
if !commitSigMsg.CommitSig.IsEqual(aliceSig) { if commitSigMsg.CommitSig != aliceSig {
t.Fatalf("commit sig msgs don't match: expected %x got %x", t.Fatalf("commit sig msgs don't match: expected %x got %x",
aliceSig.Serialize(), commitSigMsg.CommitSig.Serialize()) aliceSig, commitSigMsg.CommitSig)
} }
if len(commitSigMsg.HtlcSigs) != len(aliceHtlcSigs) { if len(commitSigMsg.HtlcSigs) != len(aliceHtlcSigs) {
t.Fatalf("wrong number of htlc sigs: expected %v, got %v", t.Fatalf("wrong number of htlc sigs: expected %v, got %v",
len(aliceHtlcSigs), len(commitSigMsg.HtlcSigs)) len(aliceHtlcSigs), len(commitSigMsg.HtlcSigs))
} }
for i, htlcSig := range commitSigMsg.HtlcSigs { for i, htlcSig := range commitSigMsg.HtlcSigs {
if !htlcSig.IsEqual(aliceHtlcSigs[i]) { if htlcSig != aliceHtlcSigs[i] {
t.Fatalf("htlc sig msgs don't match: "+ t.Fatalf("htlc sig msgs don't match: "+
"expected %x got %x", "expected %x got %x",
aliceHtlcSigs[i].Serialize(), aliceHtlcSigs[i], htlcSig)
htlcSig.Serialize())
} }
} }
@ -4127,7 +4123,7 @@ func TestInvalidCommitSigError(t *testing.T) {
// Before the signature gets to Bob, we'll mutate it, such that the // Before the signature gets to Bob, we'll mutate it, such that the
// signature is now actually invalid. // signature is now actually invalid.
aliceSig.R.Add(aliceSig.R, new(big.Int).SetInt64(1)) aliceSig[0] ^= 88
// Bob should reject this new state, and return the proper error. // Bob should reject this new state, and return the proper error.
err = bobChannel.ReceiveNewCommitment(aliceSig, aliceHtlcSigs) err = bobChannel.ReceiveNewCommitment(aliceSig, aliceHtlcSigs)

View File

@ -5,6 +5,7 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/btcec"
"github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcd/wire"
) )
@ -90,7 +91,7 @@ type signJobResp struct {
// sig is the generated signature for a particular signJob In the case // sig is the generated signature for a particular signJob In the case
// of an error during signature generation, then this value sent will // of an error during signature generation, then this value sent will
// be nil. // be nil.
sig *btcec.Signature sig lnwire.Sig
// err is the error that occurred when executing the specified // err is the error that occurred when executing the specified
// signature job. In the case that no error occurred, this value will // signature job. In the case that no error occurred, this value will
@ -185,7 +186,7 @@ func (s *sigPool) poolWorker() {
if err != nil { if err != nil {
select { select {
case sigMsg.resp <- signJobResp{ case sigMsg.resp <- signJobResp{
sig: nil, sig: lnwire.Sig{},
err: err, err: err,
}: }:
continue continue
@ -196,7 +197,7 @@ func (s *sigPool) poolWorker() {
} }
} }
sig, err := btcec.ParseSignature(rawSig, btcec.S256()) sig, err := lnwire.NewSigFromRawSignature(rawSig)
select { select {
case sigMsg.resp <- signJobResp{ case sigMsg.resp <- signJobResp{
sig: sig, sig: sig,