From 1328e61c004c78e8089c9de76ba815a4fdd51a81 Mon Sep 17 00:00:00 2001 From: Jim Posen Date: Mon, 23 Oct 2017 15:50:26 -0700 Subject: [PATCH] htlcswitch: Change circuit map keys to (channel ID, HTLC ID). This changes the circuit map internals and API to reference circuits by a primary key of (channel ID, HTLC ID) instead of paymnet hash. This is because each circuit has a unique offered HTLC, but there may be multiple circuits for a payment hash with different source or destination channels. --- htlcswitch/circuit.go | 213 ++++++++++++++++++++++--------------- htlcswitch/link.go | 19 ++++ htlcswitch/packet.go | 8 ++ htlcswitch/switch.go | 111 +++++++++---------- htlcswitch/switch_test.go | 65 ++++++++++- lnwire/short_channel_id.go | 9 ++ 6 files changed, 274 insertions(+), 151 deletions(-) diff --git a/htlcswitch/circuit.go b/htlcswitch/circuit.go index ce091146..cdab7b54 100644 --- a/htlcswitch/circuit.go +++ b/htlcswitch/circuit.go @@ -1,135 +1,174 @@ package htlcswitch import ( - "bytes" - "crypto/sha256" - "encoding/hex" + "fmt" "sync" "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/lnwire" ) -// circuitKey uniquely identifies an active circuit between two open channels. -// Currently, the payment hash is used to uniquely identify each circuit. -type circuitKey [sha256.Size]byte - -// String returns the string representation of the circuitKey. -func (k *circuitKey) String() string { - return hex.EncodeToString(k[:]) -} - -// paymentCircuit is used by the htlc switch subsystem to determine the -// forwards/backwards path for the settle/fail HTLC messages. A payment circuit -// will be created once a channel link forwards the htlc add request and -// removed when we receive settle/fail htlc message. -type paymentCircuit struct { +// PaymentCircuit is used by the HTLC switch subsystem to determine the +// backwards path for the settle/fail HTLC messages. A payment circuit +// will be created once a channel link forwards the HTLC add request and +// removed when we receive a settle/fail HTLC message. +type PaymentCircuit struct { // PaymentHash used as unique identifier of payment. - PaymentHash circuitKey + PaymentHash [32]byte - // Src identifies the channel from which add htlc request is came from - // and to which settle/fail htlc request will be returned back. Once + // IncomingChanID identifies the channel from which add HTLC request came + // and to which settle/fail HTLC request will be returned back. Once // the switch forwards the settle/fail message to the src the circuit // is considered to be completed. - Src lnwire.ShortChannelID + IncomingChanID lnwire.ShortChannelID - // Dest identifies the channel to which we propagate the htlc add - // update and from which we are expecting to receive htlc settle/fail + // IncomingHTLCID is the ID in the update_add_htlc message we received from + // the incoming channel, which will be included in any settle/fail messages + // we send back. + IncomingHTLCID uint64 + + // OutgoingChanID identifies the channel to which we propagate the HTLC add + // update and from which we are expecting to receive HTLC settle/fail // request back. - Dest lnwire.ShortChannelID + OutgoingChanID lnwire.ShortChannelID + + // OutgoingHTLCID is the ID in the update_add_htlc message we sent to the + // outgoing channel. + OutgoingHTLCID uint64 // ErrorEncrypter is used to re-encrypt the onion failure before // sending it back to the originator of the payment. ErrorEncrypter ErrorEncrypter - - // RefCount is used to count the circuits with the same circuit key. - RefCount int } -// newPaymentCircuit creates new payment circuit instance. -func newPaymentCircuit(src, dest lnwire.ShortChannelID, key circuitKey, - e ErrorEncrypter) *paymentCircuit { - - return &paymentCircuit{ - Src: src, - Dest: dest, - PaymentHash: key, - RefCount: 1, - ErrorEncrypter: e, - } +// circuitKey is a channel ID, HTLC ID tuple used as an identifying key for a +// payment circuit. The circuit map is keyed with the idenitifer for the +// outgoing HTLC +type circuitKey struct { + chanID lnwire.ShortChannelID + htlcID uint64 } -// isEqual checks the equality of two payment circuits. -func (a *paymentCircuit) isEqual(b *paymentCircuit) bool { - return bytes.Equal(a.PaymentHash[:], b.PaymentHash[:]) && - a.Src == b.Src && - a.Dest == b.Dest +// String returns a string representation of the circuitKey. +func (k *circuitKey) String() string { + return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.chanID, k.htlcID) } -// circuitMap is a data structure that implements thread safe storage of -// circuits. Each circuit key (payment hash) may have several of circuits -// corresponding to it due to the possibility of repeated payment hashes. +// CircuitMap is a data structure that implements thread safe storage of +// circuit routing information. The switch consults a circuit map to determine +// where to forward HTLC update messages. Each circuit is stored with it's +// outgoing HTLC as the primary key because, each offered HTLC has at most one +// received HTLC, but there may be multiple offered or received HTLCs with the +// same payment hash. Circuits are also indexed to provide fast lookups by +// payment hash. // // TODO(andrew.shvv) make it persistent -type circuitMap struct { - sync.RWMutex - circuits map[circuitKey]*paymentCircuit +type CircuitMap struct { + mtx sync.RWMutex + circuits map[circuitKey]*PaymentCircuit + hashIndex map[[32]byte]map[PaymentCircuit]struct{} } -// newCircuitMap creates a new instance of the circuitMap. -func newCircuitMap() *circuitMap { - return &circuitMap{ - circuits: make(map[circuitKey]*paymentCircuit), +// NewCircuitMap creates a new instance of the CircuitMap. +func NewCircuitMap() *CircuitMap { + return &CircuitMap{ + circuits: make(map[circuitKey]*PaymentCircuit), + hashIndex: make(map[[32]byte]map[PaymentCircuit]struct{}), } } -// add adds a new active payment circuit to the circuitMap. -func (m *circuitMap) add(circuit *paymentCircuit) error { - m.Lock() - defer m.Unlock() +// LookupByHTLC looks up the payment circuit by the outgoing channel and HTLC +// IDs. Returns nil if there is no such circuit. +func (cm *CircuitMap) LookupByHTLC(chanID lnwire.ShortChannelID, htlcID uint64) *PaymentCircuit { + cm.mtx.RLock() - // Examine the circuit map to see if this circuit is already in use or - // not. If so, then we'll simply increment the reference count. - // Otherwise, we'll create a new circuit from scratch. - // - // TODO(roasbeef): include dest+src+amt in key - if c, ok := m.circuits[circuit.PaymentHash]; ok { - c.RefCount++ - return nil + key := circuitKey{ + chanID: chanID, + htlcID: htlcID, + } + circuit := cm.circuits[key] + + cm.mtx.RUnlock() + return circuit +} + +// LookupByPaymentHash looks up and returns any payment circuits with a given +// payment hash. +func (cm *CircuitMap) LookupByPaymentHash(hash [32]byte) []*PaymentCircuit { + cm.mtx.RLock() + + var circuits []*PaymentCircuit + if circuitSet, ok := cm.hashIndex[hash]; ok { + circuits = make([]*PaymentCircuit, 0, len(circuitSet)) + for circuit := range circuitSet { + circuits = append(circuits, &circuit) + } } - m.circuits[circuit.PaymentHash] = circuit + cm.mtx.RUnlock() + return circuits +} +// Add adds a new active payment circuit to the CircuitMap. +func (cm *CircuitMap) Add(circuit *PaymentCircuit) error { + cm.mtx.Lock() + + key := circuitKey{ + chanID: circuit.OutgoingChanID, + htlcID: circuit.OutgoingHTLCID, + } + cm.circuits[key] = circuit + + // Add circuit to the hash index. + if _, ok := cm.hashIndex[circuit.PaymentHash]; !ok { + cm.hashIndex[circuit.PaymentHash] = make(map[PaymentCircuit]struct{}) + } + cm.hashIndex[circuit.PaymentHash][*circuit] = struct{}{} + + cm.mtx.Unlock() return nil } -// remove destroys the target circuit by removing it from the circuit map. -func (m *circuitMap) remove(key circuitKey) (*paymentCircuit, error) { - m.Lock() - defer m.Unlock() +// Remove destroys the target circuit by removing it from the circuit map. +func (cm *CircuitMap) Remove(chanID lnwire.ShortChannelID, htlcID uint64) error { + cm.mtx.Lock() + defer cm.mtx.Unlock() - if circuit, ok := m.circuits[key]; ok { - if circuit.RefCount--; circuit.RefCount == 0 { - delete(m.circuits, key) - } + // Look up circuit so that pointer can be matched in the hash index. + key := circuitKey{ + chanID: chanID, + htlcID: htlcID, + } + circuit, found := cm.circuits[key] + if !found { + return errors.Errorf("Can't find circuit for HTLC %v", key) + } + delete(cm.circuits, key) - return circuit, nil + // Remove circuit from hash index. + circuitsWithHash, ok := cm.hashIndex[circuit.PaymentHash] + if !ok { + return errors.Errorf("Can't find circuit in hash index for HTLC %v", + key) } - return nil, errors.Errorf("can't find circuit"+ - " for key %v", key) + if _, ok = circuitsWithHash[*circuit]; !ok { + return errors.Errorf("Can't find circuit in hash index for HTLC %v", + key) + } + + delete(circuitsWithHash, *circuit) + if len(circuitsWithHash) == 0 { + delete(cm.hashIndex, circuit.PaymentHash) + } + return nil } // pending returns number of circuits which are waiting for to be completed // (settle/fail responses to be received). -func (m *circuitMap) pending() int { - m.RLock() - defer m.RUnlock() - - var length int - for _, circuits := range m.circuits { - length += circuits.RefCount - } - - return length +func (cm *CircuitMap) pending() int { + cm.mtx.RLock() + count := len(cm.circuits) + cm.mtx.RUnlock() + return count } diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 762ef1f7..c04ef158 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -700,6 +700,8 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { failPkt := &htlcPacket{ src: l.ShortChanID(), + dest: pkt.src, + destID: pkt.srcID, payHash: htlc.PaymentHash, amount: htlc.Amount, isObfuscated: isObfuscated, @@ -720,6 +722,20 @@ func (l *channelLink) handleDownStreamPkt(pkt *htlcPacket, isReProcess bool) { "local_log_index=%v, batch_size=%v", htlc.PaymentHash[:], index, l.batchCounter+1) + // If packet was forwarded from another channel link then we should + // create circuit (remember the path) in order to forward settle/fail + // packet back. + if pkt.src != (lnwire.ShortChannelID{}) { + l.cfg.Switch.addCircuit(&PaymentCircuit{ + PaymentHash: htlc.PaymentHash, + IncomingChanID: pkt.src, + IncomingHTLCID: pkt.srcID, + OutgoingChanID: pkt.dest, + OutgoingHTLCID: index, + ErrorEncrypter: pkt.obfuscator, + }) + } + htlc.ID = index l.cfg.Peer.SendMessage(htlc) @@ -1180,6 +1196,7 @@ func (l *channelLink) processLockedInHtlcs( case lnwallet.Settle: settlePacket := &htlcPacket{ src: l.ShortChanID(), + srcID: pd.ParentIndex, payHash: pd.RHash, amount: pd.Amount, htlc: &lnwire.UpdateFufillHTLC{ @@ -1202,6 +1219,7 @@ func (l *channelLink) processLockedInHtlcs( // continue to propagate it. failPacket := &htlcPacket{ src: l.ShortChanID(), + srcID: pd.HtlcIndex, payHash: pd.RHash, amount: pd.Amount, isObfuscated: false, @@ -1573,6 +1591,7 @@ func (l *channelLink) processLockedInHtlcs( updatePacket := &htlcPacket{ src: l.ShortChanID(), + srcID: pd.HtlcIndex, dest: fwdInfo.NextHop, htlc: addMsg, obfuscator: obfuscator, diff --git a/htlcswitch/packet.go b/htlcswitch/packet.go index 1f68e0f9..c0a4af4c 100644 --- a/htlcswitch/packet.go +++ b/htlcswitch/packet.go @@ -27,6 +27,14 @@ type htlcPacket struct { // of the target link. src lnwire.ShortChannelID + // destID is the ID of the HTLC in the destination channel. This will be set + // when forwarding a settle or fail update back to the original source. + destID uint64 + + // srcID is the ID of the HTLC in the source channel. This will be set when + // forwarding any HTLC update message. + srcID uint64 + // amount is the value of the HTLC that is being created or modified. amount lnwire.MilliSatoshi diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 73791e78..4cfcf5a4 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -129,7 +129,7 @@ type Switch struct { // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. - circuits *circuitMap + circuits *CircuitMap // links is a map of channel id and channel link which manages // this channel. @@ -167,7 +167,7 @@ type Switch struct { func New(cfg Config) *Switch { return &Switch{ cfg: &cfg, - circuits: newCircuitMap(), + circuits: NewCircuitMap(), linkIndex: make(map[lnwire.ChannelID]ChannelLink), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}), @@ -481,7 +481,8 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { } source.HandleSwitchPacket(&htlcPacket{ - src: packet.src, + dest: packet.src, + destID: packet.srcID, payHash: htlc.PaymentHash, isObfuscated: true, htlc: &lnwire.UpdateFailHTLC{ @@ -529,7 +530,8 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { } source.HandleSwitchPacket(&htlcPacket{ - src: packet.src, + dest: packet.src, + destID: packet.srcID, payHash: htlc.PaymentHash, isObfuscated: true, htlc: &lnwire.UpdateFailHTLC{ @@ -544,38 +546,6 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { return err } - // If packet was forwarded from another channel link than we - // should create circuit (remember the path) in order to - // forward settle/fail packet back. - if err := s.circuits.add(newPaymentCircuit( - source.ShortChanID(), - destination.ShortChanID(), - htlc.PaymentHash, - packet.obfuscator, - )); err != nil { - failure := lnwire.NewTemporaryChannelFailure(nil) - reason, err := packet.obfuscator.EncryptFirstHop(failure) - if err != nil { - err := errors.Errorf("unable to obfuscate "+ - "error: %v", err) - log.Error(err) - return err - } - - source.HandleSwitchPacket(&htlcPacket{ - src: packet.src, - payHash: htlc.PaymentHash, - isObfuscated: true, - htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, - }, - }) - err = errors.Errorf("unable to add circuit: "+ - "%v", err) - log.Error(err) - return err - } - // Send the packet to the destination channel link which // manages the channel. destination.HandleSwitchPacket(packet) @@ -585,37 +555,49 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // payment circuit by forwarding the settle msg to the channel from // which htlc add packet was initially received. case *lnwire.UpdateFufillHTLC, *lnwire.UpdateFailHTLC: - // Exit if we can't find and remove the active circuit to - // continue propagating the fail over. - circuit, err := s.circuits.remove(packet.payHash) + if packet.dest == (lnwire.ShortChannelID{}) { + // Use circuit map to find the link to forward settle/fail to. + circuit := s.circuits.LookupByHTLC(packet.src, packet.srcID) + if circuit == nil { + err := errors.Errorf("Unable to find source channel for HTLC "+ + "settle/fail: channel ID = %s, HTLC ID = %d, "+ + "payment hash = %x", packet.src, packet.srcID, + packet.payHash[:]) + log.Error(err) + return err + } + + // Remove circuit since we are about to complete the HTLC. + err := s.circuits.Remove(packet.src, packet.srcID) + if err != nil { + log.Warnf("Failed to close completed onion circuit for %x: "+ + "%s<->%s", packet.payHash[:], circuit.IncomingChanID, + circuit.OutgoingChanID) + } else { + log.Debugf("Closed completed onion circuit for %x: %s<->%s", + packet.payHash[:], circuit.IncomingChanID, + circuit.OutgoingChanID) + } + + // Obfuscate the error message for fail updates before sending back + // through the circuit. + if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated { + htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( + htlc.Reason) + } + + packet.dest = circuit.IncomingChanID + packet.destID = circuit.IncomingHTLCID + } + + source, err := s.getLinkByShortID(packet.dest) if err != nil { - err := errors.Errorf("unable to remove "+ - "circuit for payment hash: %v", packet.payHash) + err := errors.Errorf("Unable to get source channel link to "+ + "forward HTLC settle/fail: %v", err) log.Error(err) return err } - // If this is failure than we need to obfuscate the error. - if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok && !packet.isObfuscated { - htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( - htlc.Reason, - ) - } - - // Propagating settle/fail htlc back to src of add htlc packet. - source, err := s.getLinkByShortID(circuit.Src) - if err != nil { - err := errors.Errorf("unable to get source "+ - "channel link to forward settle/fail htlc: %v", - err) - log.Error(err) - return err - } - - log.Debugf("Closing completed onion "+ - "circuit for %x: %v<->%v", packet.payHash[:], - circuit.Src, circuit.Dest) - source.HandleSwitchPacket(packet) return nil @@ -1109,3 +1091,8 @@ func (s *Switch) numPendingPayments() int { return l } + +// addCircuit adds a circuit to the switch's in-memory mapping. +func (s *Switch) addCircuit(circuit *PaymentCircuit) { + s.circuits.Add(circuit) +} diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 3a06846b..fb5f384f 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -57,6 +57,7 @@ func TestSwitchForward(t *testing.T) { rhash := fastsha256.Sum256(preimage[:]) packet := &htlcPacket{ src: aliceChannelLink.ShortChanID(), + srcID: 0, dest: bobChannelLink.ShortChanID(), obfuscator: newMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ @@ -70,6 +71,15 @@ func TestSwitchForward(t *testing.T) { t.Fatal(err) } + s.addCircuit(&PaymentCircuit{ + PaymentHash: packet.payHash, + IncomingChanID: packet.src, + IncomingHTLCID: 0, + OutgoingChanID: packet.dest, + OutgoingHTLCID: 0, + ErrorEncrypter: packet.obfuscator, + }) + select { case <-bobChannelLink.packets: break @@ -145,8 +155,9 @@ func TestSkipIneligibleLinksMultiHopForward(t *testing.T) { preimage := [sha256.Size]byte{1} rhash := fastsha256.Sum256(preimage[:]) packet = &htlcPacket{ - src: aliceChannelLink.ShortChanID(), - dest: bobChannelLink.ShortChanID(), + src: aliceChannelLink.ShortChanID(), + srcID: 0, + dest: bobChannelLink.ShortChanID(), htlc: &lnwire.UpdateAddHTLC{ PaymentHash: rhash, Amount: 1, @@ -234,6 +245,7 @@ func TestSwitchCancel(t *testing.T) { rhash := fastsha256.Sum256(preimage[:]) request := &htlcPacket{ src: aliceChannelLink.ShortChanID(), + srcID: 0, dest: bobChannelLink.ShortChanID(), obfuscator: newMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ @@ -247,6 +259,15 @@ func TestSwitchCancel(t *testing.T) { t.Fatal(err) } + s.addCircuit(&PaymentCircuit{ + PaymentHash: request.payHash, + IncomingChanID: request.src, + IncomingHTLCID: 0, + OutgoingChanID: request.dest, + OutgoingHTLCID: 0, + ErrorEncrypter: request.obfuscator, + }) + select { case <-bobChannelLink.packets: break @@ -263,6 +284,7 @@ func TestSwitchCancel(t *testing.T) { // request should be forwarder back to alice channel link. request = &htlcPacket{ src: bobChannelLink.ShortChanID(), + srcID: 0, payHash: rhash, amount: 1, isObfuscated: true, @@ -316,6 +338,7 @@ func TestSwitchAddSamePayment(t *testing.T) { rhash := fastsha256.Sum256(preimage[:]) request := &htlcPacket{ src: aliceChannelLink.ShortChanID(), + srcID: 0, dest: bobChannelLink.ShortChanID(), obfuscator: newMockObfuscator(), htlc: &lnwire.UpdateAddHTLC{ @@ -329,6 +352,15 @@ func TestSwitchAddSamePayment(t *testing.T) { t.Fatal(err) } + s.addCircuit(&PaymentCircuit{ + PaymentHash: request.payHash, + IncomingChanID: request.src, + IncomingHTLCID: 0, + OutgoingChanID: request.dest, + OutgoingHTLCID: 0, + ErrorEncrypter: request.obfuscator, + }) + select { case <-bobChannelLink.packets: break @@ -340,11 +372,31 @@ func TestSwitchAddSamePayment(t *testing.T) { t.Fatal("wrong amount of circuits") } + request = &htlcPacket{ + src: aliceChannelLink.ShortChanID(), + srcID: 1, + dest: bobChannelLink.ShortChanID(), + obfuscator: newMockObfuscator(), + htlc: &lnwire.UpdateAddHTLC{ + PaymentHash: rhash, + Amount: 1, + }, + } + // Handle the request and checks that bob channel link received it. if err := s.forward(request); err != nil { t.Fatal(err) } + s.addCircuit(&PaymentCircuit{ + PaymentHash: request.payHash, + IncomingChanID: request.src, + IncomingHTLCID: 1, + OutgoingChanID: request.dest, + OutgoingHTLCID: 1, + ErrorEncrypter: request.obfuscator, + }) + if s.circuits.pending() != 2 { t.Fatal("wrong amount of circuits") } @@ -376,6 +428,15 @@ func TestSwitchAddSamePayment(t *testing.T) { t.Fatal("wrong amount of circuits") } + request = &htlcPacket{ + src: bobChannelLink.ShortChanID(), + srcID: 1, + payHash: rhash, + amount: 1, + isObfuscated: true, + htlc: &lnwire.UpdateFailHTLC{}, + } + // Handle the request and checks that payment circuit works properly. if err := s.forward(request); err != nil { t.Fatal(err) diff --git a/lnwire/short_channel_id.go b/lnwire/short_channel_id.go index b190766b..36d38a3e 100644 --- a/lnwire/short_channel_id.go +++ b/lnwire/short_channel_id.go @@ -1,5 +1,9 @@ package lnwire +import ( + "fmt" +) + // ShortChannelID represents the set of data which is needed to retrieve all // necessary data to validate the channel existence. type ShortChannelID struct { @@ -37,3 +41,8 @@ func (c *ShortChannelID) ToUint64() uint64 { return ((uint64(c.BlockHeight) << 40) | (uint64(c.TxIndex) << 16) | (uint64(c.TxPosition))) } + +// String generates a human-readable representation of the channel ID. +func (c ShortChannelID) String() string { + return fmt.Sprintf("%d:%d:%d", c.BlockHeight, c.TxIndex, c.TxPosition) +}