From 5fd1206da669c465d24bf3d1721bb9c4f63e1224 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Tue, 5 Dec 2017 21:23:23 -0800 Subject: [PATCH] htlcswitch/switch: modifies forward method to support async invocation --- htlcswitch/switch.go | 1045 +++++++++++++++++++++++++++++++++--------- 1 file changed, 824 insertions(+), 221 deletions(-) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 9086700b..f2c69f25 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -9,6 +9,7 @@ import ( "crypto/sha256" + "github.com/boltdb/bolt" "github.com/davecgh/go-spew/spew" "github.com/roasbeef/btcd/btcec" @@ -26,6 +27,15 @@ var ( // ErrChannelLinkNotFound is used when channel link hasn't been found. ErrChannelLinkNotFound = errors.New("channel link not found") + // ErrDuplicateAdd signals that the ADD htlc was already forwarded + // through the switch and is locked into another commitment txn. + ErrDuplicateAdd = errors.New("duplicate add HTLC detected") + + // ErrIncompleteForward is used when an htlc was already forwarded + // through the switch, but did not get locked into another commitment + // txn. + ErrIncompleteForward = errors.Errorf("incomplete forward detected") + // zeroPreimage is the empty preimage which is returned when we have // some errors. zeroPreimage [sha256.Size]byte @@ -39,6 +49,7 @@ type pendingPayment struct { amount lnwire.MilliSatoshi preimage chan [sha256.Size]byte + response chan *htlcPacket err chan error // deobfuscator is an serializable entity which is used if we received @@ -110,6 +121,15 @@ type Config struct { // forced unilateral closure of the channel initiated by a local // subsystem. LocalChannelClose func(pubKey []byte, request *ChanClose) + + // DB is the channeldb instance that will be used to back the switch's + // persistent circuit map. + DB *channeldb.DB + + // SwitchPackager provides access to the forwarding packages of all + // active channels. This gives the switch the ability to read arbitrary + // forwarding packages, and ack settles and fails contained within them. + SwitchPackager channeldb.FwdOperator } // Switch is the central messaging bus for all incoming/outgoing HTLCs. @@ -136,16 +156,24 @@ type Switch struct { // integer ID when it is created. pendingPayments map[uint64]*pendingPayment pendingMutex sync.RWMutex - nextPendingID uint64 + + paymentSequencer Sequencer // circuits is storage for payment circuits which are used to // forward the settle/fail htlc updates back to the add htlc initiator. - circuits *CircuitMap + circuits CircuitMap // links is a map of channel id and channel link which manages // this channel. linkIndex map[lnwire.ChannelID]ChannelLink + // mailMtx is a read/write mutex that protects the mailboxes map. + mailMtx sync.RWMutex + + // mailboxes is a map of channel id to mailboxes, which allows the + // switch to buffer messages for peers that have not come back online. + mailboxes map[lnwire.ShortChannelID]MailBox + // forwardingIndex is an index which is consulted by the switch when it // needs to locate the next hop to forward an incoming/outgoing HTLC // update to/from. @@ -185,11 +213,23 @@ type Switch struct { } // New creates the new instance of htlc switch. -func New(cfg Config) *Switch { +func New(cfg Config) (*Switch, error) { + circuitMap, err := NewCircuitMap(cfg.DB) + if err != nil { + return nil, err + } + + sequencer, err := NewPersistentSequencer(cfg.DB) + if err != nil { + return nil, err + } + return &Switch{ cfg: &cfg, - circuits: NewCircuitMap(), + circuits: circuitMap, + paymentSequencer: sequencer, linkIndex: make(map[lnwire.ChannelID]ChannelLink), + mailboxes: make(map[lnwire.ShortChannelID]MailBox), forwardingIndex: make(map[lnwire.ShortChannelID]ChannelLink), interfaceIndex: make(map[[33]byte]map[ChannelLink]struct{}), pendingPayments: make(map[uint64]*pendingPayment), @@ -198,7 +238,7 @@ func New(cfg Config) *Switch { resolutionMsgs: make(chan *resolutionMsg), linkControl: make(chan interface{}), quit: make(chan struct{}), - } + }, nil } // resolutionMsg is a struct that wraps an existing ResolutionMsg with a done @@ -246,15 +286,19 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, // able to retrieve it and return response to the user. payment := &pendingPayment{ err: make(chan error, 1), + response: make(chan *htlcPacket, 1), preimage: make(chan [sha256.Size]byte, 1), paymentHash: htlc.PaymentHash, amount: htlc.Amount, deobfuscator: deobfuscator, } + paymentID, err := s.paymentSequencer.NextID() + if err != nil { + return zeroPreimage, err + } + s.pendingMutex.Lock() - paymentID := s.nextPendingID - s.nextPendingID++ s.pendingPayments[paymentID] = payment s.pendingMutex.Unlock() @@ -262,10 +306,12 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, // this stage it means that packet haven't left boundaries of our // system and something wrong happened. packet := &htlcPacket{ + incomingChanID: sourceHop, incomingHTLCID: paymentID, destNode: nextNode, htlc: htlc, } + if err := s.forward(packet); err != nil { s.removePendingPayment(paymentID) return zeroPreimage, err @@ -274,7 +320,7 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, // Returns channels so that other subsystem might wait/skip the // waiting of handling of payment. var preimage [sha256.Size]byte - var err error + var response *htlcPacket select { case e := <-payment.err: @@ -284,6 +330,14 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, "while waiting for payment result") } + select { + case pkt := <-payment.response: + response = pkt + case <-s.quit: + return zeroPreimage, errors.New("htlc switch have been stopped " + + "while waiting for payment result") + } + select { case p := <-payment.preimage: preimage = p @@ -292,6 +346,24 @@ func (s *Switch) SendHTLC(nextNode [33]byte, htlc *lnwire.UpdateAddHTLC, "while waiting for payment result") } + // Remove circuit since we are about to complete an + // add/fail of this HTLC. + if teardownErr := s.teardownCircuit(response); teardownErr != nil { + log.Warnf("unable to teardown circuit %s: %v", + response.inKey(), teardownErr) + return preimage, err + } + + // Finally, if this response is contained in a forwarding package, ack + // the settle/fail so that we don't continue to retransmit the HTLC + // internally. + if response.destRef != nil { + if ackErr := s.ackSettleFail(*response.destRef); ackErr != nil { + log.Warnf("unable to ack settle/fail reference: %s: %v", + *response.destRef, ackErr) + } + } + return preimage, err } @@ -372,6 +444,192 @@ func (s *Switch) updateLinkPolicies(c *updatePoliciesCmd) error { // update. Also this function is used by channel links itself in order to // forward the update after it has been included in the channel. func (s *Switch) forward(packet *htlcPacket) error { + switch htlc := packet.htlc.(type) { + case *lnwire.UpdateAddHTLC: + circuit := newPaymentCircuit(&htlc.PaymentHash, packet) + actions, err := s.circuits.CommitCircuits(circuit) + if err != nil { + log.Errorf("unable to commit circuit in switch: %v", err) + return err + } + + // Drop duplicate packet if it has already been seen. + switch { + case len(actions.Drops) == 1: + return ErrDuplicateAdd + + case len(actions.Fails) == 1: + if packet.incomingChanID == sourceHop { + return err + } + + failure := lnwire.NewTemporaryChannelFailure(nil) + addErr := ErrIncompleteForward + + return s.failAddPacket(packet, failure, addErr) + } + + packet.circuit = circuit + } + + return s.route(packet) +} + +// ForwardPackets adds a list of packets to the switch for processing. Fails and +// settles are added on a first past, simultaneously constructing circuits for +// any adds. After persisting the circuits, another pass of the adds is given to +// forward them through the router. +// NOTE: This method guarantees that the returned err chan will eventually be +// closed. The receiver should read on the channel until receiving such a +// signal. +func (s *Switch) ForwardPackets(packets ...*htlcPacket) chan error { + + var ( + // fwdChan is a buffered channel used to receive err msgs from + // the htlcPlex when forwarding this batch. + fwdChan = make(chan error, len(packets)) + + // errChan is a buffered channel returned to the caller, that is + // proxied by the fwdChan. This method guarantees that errChan + // will be closed eventually to alert the receiver that it can + // stop reading from the channel. + errChan = make(chan error, len(packets)) + + // numSent keeps a running count of how many packets are + // forwarded to the switch, which determines how many responses + // we will wait for on the fwdChan.. + numSent int + ) + + // No packets, nothing to do. + if len(packets) == 0 { + close(errChan) + return errChan + } + + // Setup a barrier to prevent the background tasks from processing + // responses until this function returns to the user. + var wg sync.WaitGroup + wg.Add(1) + defer wg.Done() + + // Spawn a goroutine the proxy the errs back to the returned err chan. + // This is done to ensure the err chan returned to the caller closed + // properly, alerting the receiver of completion or shutdown. + s.wg.Add(1) + go s.proxyFwdErrs(&numSent, &wg, fwdChan, errChan) + + // Make a first pass over the packets, forwarding any settles or fails. + // As adds are found, we create a circuit and append it to our set of + // circuits to be written to disk. + var circuits []*PaymentCircuit + var addBatch []*htlcPacket + for _, packet := range packets { + switch htlc := packet.htlc.(type) { + case *lnwire.UpdateAddHTLC: + circuit := newPaymentCircuit(&htlc.PaymentHash, packet) + packet.circuit = circuit + circuits = append(circuits, circuit) + addBatch = append(addBatch, packet) + default: + s.routeAsync(packet, fwdChan) + numSent++ + } + } + + // If this batch did not contain any circuits to commit, we can return + // early. + if len(circuits) == 0 { + return errChan + } + + // Write any circuits that we found to disk. + actions, err := s.circuits.CommitCircuits(circuits...) + if err != nil { + log.Errorf("unable to commit circuits in switch: %v", err) + } + + // Split the htlc packets by comparing an in-order seek to the head of + // the added, dropped, or failed circuits. + // + // NOTE: This assumes each list is guaranteed to be a subsequence of the + // circuits, and that the union of the sets results in the original set + // of circuits. + var addedPackets, failedPackets []*htlcPacket + for _, packet := range addBatch { + switch { + case len(actions.Adds) > 0 && packet.circuit == actions.Adds[0]: + addedPackets = append(addedPackets, packet) + actions.Adds = actions.Adds[1:] + + case len(actions.Drops) > 0 && packet.circuit == actions.Drops[0]: + actions.Drops = actions.Drops[1:] + + case len(actions.Fails) > 0 && packet.circuit == actions.Fails[0]: + failedPackets = append(failedPackets, packet) + actions.Fails = actions.Fails[1:] + } + } + + // Now, forward any packets for circuits that were successfully added to + // the switch's circuit map. + for _, packet := range addedPackets { + s.routeAsync(packet, fwdChan) + numSent++ + } + + // Lastly, for any packets that failed, this implies that they were + // left in a half added state, which can happen when recovering from + // failures. + for _, packet := range failedPackets { + failure := lnwire.NewTemporaryChannelFailure(nil) + addErr := errors.Errorf("failing packet after detecting " + + "incomplete forward") + + // We don't handle the error here since this method always + // returns an error. + s.failAddPacket(packet, failure, addErr) + } + + return errChan +} + +// proxyFwdErrs transmits any errors received on `fwdChan` back to `errChan`, +// and guarantees that the `errChan` will be closed after 1) all errors have +// been sent, or 2) the switch has received a shutdown. The `errChan` should be +// buffered with at least the value of `num` after the barrier has been +// released. +// +// NOTE: The receiver of `errChan` should read until the channel closed, since +// this proxying guarantees that the close will happen. +func (s *Switch) proxyFwdErrs(num *int, wg *sync.WaitGroup, + fwdChan, errChan chan error) { + defer s.wg.Done() + defer func() { + close(errChan) + }() + + // Wait here until the outer function has finished persisting + // and routing the packets. This guarantees we don't read from num until + // the value is accurate. + wg.Wait() + + numSent := *num + for i := 0; i < numSent; i++ { + select { + case err := <-fwdChan: + errChan <- err + case <-s.quit: + log.Errorf("unable to forward htlc packet " + + "htlc switch was stopped") + return + } + } +} + +// route sends a single htlcPacket through the switch and synchronously awaits a +// response. +func (s *Switch) route(packet *htlcPacket) error { command := &plexPacket{ pkt: packet, err: make(chan error, 1), @@ -387,8 +645,24 @@ func (s *Switch) forward(packet *htlcPacket) error { case err := <-command.err: return err case <-s.quit: - return errors.New("unable to forward htlc packet htlc switch was " + - "stopped") + return errors.New("Htlc Switch was stopped") + } +} + +// routeAsync sends a packet through the htlc switch, using the provided err +// chan to propagate errors back to the caller. This method does not wait for +// a response before returning. +func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error) error { + command := &plexPacket{ + pkt: packet, + err: errChan, + } + + select { + case s.htlcPlex <- command: + return nil + case <-s.quit: + return errors.New("Htlc Switch was stopped") } } @@ -405,23 +679,23 @@ func (s *Switch) forward(packet *htlcPacket) error { // o <-settle-- o <--settle-- o // Alice Bob Carol // -func (s *Switch) handleLocalDispatch(packet *htlcPacket) error { +func (s *Switch) handleLocalDispatch(pkt *htlcPacket) error { // Pending payments use a special interpretation of the incomingChanID and // incomingHTLCID fields on packet where the channel ID is blank and the // HTLC ID is the payment ID. The switch basically views the users of the // node as a special channel that also offers a sequence of HTLCs. - payment, err := s.findPayment(packet.incomingHTLCID) + payment, err := s.findPayment(pkt.incomingHTLCID) if err != nil { return err } - switch htlc := packet.htlc.(type) { + switch htlc := pkt.htlc.(type) { // User have created the htlc update therefore we should find the // appropriate channel link and send the payment over this link. case *lnwire.UpdateAddHTLC: // Try to find links by node destination. - links, err := s.getLinks(packet.destNode) + links, err := s.getLinks(pkt.destNode) if err != nil { log.Errorf("unable to find links by destination %v", err) return &ForwardingError{ @@ -476,77 +750,25 @@ func (s *Switch) handleLocalDispatch(packet *htlcPacket) error { // manages then channel. // // TODO(roasbeef): should return with an error - packet.outgoingChanID = destination.ShortChanID() - destination.HandleSwitchPacket(packet) - return nil + pkt.outgoingChanID = destination.ShortChanID() + return destination.HandleSwitchPacket(pkt) // We've just received a settle update which means we can finalize the // user payment and return successful response. case *lnwire.UpdateFulfillHTLC: // Notify the user that his payment was successfully proceed. payment.err <- nil + payment.response <- pkt payment.preimage <- htlc.PaymentPreimage - s.removePendingPayment(packet.incomingHTLCID) + s.removePendingPayment(pkt.incomingHTLCID) // We've just received a fail update which means we can finalize the // user payment and return fail response. case *lnwire.UpdateFailHTLC: - var failure *ForwardingError - switch { - - // The payment never cleared the link, so we don't need to - // decrypt the error, simply decode it them report back to the - // user. - case packet.localFailure: - var userErr string - r := bytes.NewReader(htlc.Reason) - failureMsg, err := lnwire.DecodeFailure(r, 0) - if err != nil { - userErr = fmt.Sprintf("unable to decode onion failure, "+ - "htlc with hash(%x): %v", payment.paymentHash[:], err) - log.Error(userErr) - failureMsg = lnwire.NewTemporaryChannelFailure(nil) - } - failure = &ForwardingError{ - ErrorSource: s.cfg.SelfKey, - ExtraMsg: userErr, - FailureMessage: failureMsg, - } - - // A payment had to be timed out on chain before it got past - // the first hop. In this case, we'll report a permanent - // channel failure as this means us, or the remote party had to - // go on chain. - case packet.isResolution && htlc.Reason == nil: - userErr := fmt.Sprintf("payment was resolved " + - "on-chain, then cancelled back") - failure = &ForwardingError{ - ErrorSource: s.cfg.SelfKey, - ExtraMsg: userErr, - FailureMessage: lnwire.FailPermanentChannelFailure{}, - } - - // A regular multi-hop payment error that we'll need to - // decrypt. - default: - // We'll attempt to fully decrypt the onion encrypted - // error. If we're unable to then we'll bail early. - failure, err = payment.deobfuscator.DecryptError(htlc.Reason) - if err != nil { - userErr := fmt.Sprintf("unable to de-obfuscate onion failure, "+ - "htlc with hash(%x): %v", payment.paymentHash[:], err) - log.Error(userErr) - failure = &ForwardingError{ - ErrorSource: s.cfg.SelfKey, - ExtraMsg: userErr, - FailureMessage: lnwire.NewTemporaryChannelFailure(nil), - } - } - } - - payment.err <- failure + payment.err <- s.parseFailedPayment(payment, pkt, htlc) + payment.response <- pkt payment.preimage <- zeroPreimage - s.removePendingPayment(packet.incomingHTLCID) + s.removePendingPayment(pkt.incomingHTLCID) default: return errors.New("wrong update type") @@ -555,6 +777,73 @@ func (s *Switch) handleLocalDispatch(packet *htlcPacket) error { return nil } +// parseFailedPayment determines the appropriate failure message to return to +// a user initiated payment. The three cases handled are: +// 1) A local failure, which should already plaintext. +// 2) A resolution from the chain arbitrator, +// 3) A failure from the remote party, which will need to be decrypted using the +// payment deobfuscator. +func (s *Switch) parseFailedPayment(payment *pendingPayment, pkt *htlcPacket, + htlc *lnwire.UpdateFailHTLC) *ForwardingError { + + var failure *ForwardingError + + switch { + + // The payment never cleared the link, so we don't need to + // decrypt the error, simply decode it them report back to the + // user. + case pkt.localFailure: + var userErr string + r := bytes.NewReader(htlc.Reason) + failureMsg, err := lnwire.DecodeFailure(r, 0) + if err != nil { + userErr = fmt.Sprintf("unable to decode onion failure, "+ + "htlc with hash(%x): %v", payment.paymentHash[:], err) + log.Error(userErr) + failureMsg = lnwire.NewTemporaryChannelFailure(nil) + } + failure = &ForwardingError{ + ErrorSource: s.cfg.SelfKey, + ExtraMsg: userErr, + FailureMessage: failureMsg, + } + + // A payment had to be timed out on chain before it got past + // the first hop. In this case, we'll report a permanent + // channel failure as this means us, or the remote party had to + // go on chain. + case pkt.isResolution && htlc.Reason == nil: + userErr := fmt.Sprintf("payment was resolved " + + "on-chain, then cancelled back") + failure = &ForwardingError{ + ErrorSource: s.cfg.SelfKey, + ExtraMsg: userErr, + FailureMessage: lnwire.FailPermanentChannelFailure{}, + } + + // A regular multi-hop payment error that we'll need to + // decrypt. + default: + var err error + // We'll attempt to fully decrypt the onion encrypted + // error. If we're unable to then we'll bail early. + failure, err = payment.deobfuscator.DecryptError(htlc.Reason) + if err != nil { + userErr := fmt.Sprintf("unable to de-obfuscate onion failure, "+ + "htlc with hash(%x): %v", payment.paymentHash[:], err) + log.Error(userErr) + failure = &ForwardingError{ + ErrorSource: s.cfg.SelfKey, + ExtraMsg: userErr, + FailureMessage: lnwire.NewTemporaryChannelFailure(nil), + } + } + } + + return failure +} + // handlePacketForward is used in cases when we need forward the htlc update // from one channel link to another and be able to propagate the settle/fail // updates back. This behaviour is achieved by creation of payment circuits. @@ -565,46 +854,22 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // payment circuit within our internal state so we can properly forward // the ultimate settle message back latter. case *lnwire.UpdateAddHTLC: - if packet.incomingChanID == (lnwire.ShortChannelID{}) { - // A blank incomingChanID indicates that this is a - // pending user-initiated payment. + if packet.incomingChanID == sourceHop { + // A blank incomingChanID indicates that this is + // a pending user-initiated payment. return s.handleLocalDispatch(packet) } - source, err := s.getLinkByShortID(packet.incomingChanID) - if err != nil { - err := errors.Errorf("unable to find channel link "+ - "by channel point (%v): %v", packet.incomingChanID, err) - log.Error(err) - return err - } - targetLink, err := s.getLinkByShortID(packet.outgoingChanID) if err != nil { // If packet was forwarded from another channel link // than we should notify this link that some error // occurred. - failure := lnwire.FailUnknownNextPeer{} - reason, err := packet.obfuscator.EncryptFirstHop(failure) - if err != nil { - err := errors.Errorf("unable to obfuscate "+ - "error: %v", err) - log.Error(err) - return err - } - - source.HandleSwitchPacket(&htlcPacket{ - incomingChanID: packet.incomingChanID, - incomingHTLCID: packet.incomingHTLCID, - isRouted: true, - htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, - }, - }) - err = errors.Errorf("unable to find link with "+ + failure := &lnwire.FailUnknownNextPeer{} + addErr := errors.Errorf("unable to find link with "+ "destination %v", packet.outgoingChanID) - log.Error(err) - return err + + return s.failAddPacket(packet, failure, addErr) } interfaceLinks, _ := s.getLinks(targetLink.Peer().PubKey()) @@ -629,155 +894,277 @@ func (s *Switch) handlePacketForward(packet *htlcPacket) error { // over has insufficient capacity, then we'll cancel the htlc // as the payment cannot succeed. if destination == nil { - // If packet was forwarded from another - // channel link than we should notify this - // link that some error occurred. + // If packet was forwarded from another channel link + // than we should notify this link that some error + // occurred. failure := lnwire.NewTemporaryChannelFailure(nil) - reason, err := packet.obfuscator.EncryptFirstHop(failure) - if err != nil { - err := errors.Errorf("unable to obfuscate "+ - "error: %v", err) - log.Error(err) - return err - } - - source.HandleSwitchPacket(&htlcPacket{ - incomingChanID: packet.incomingChanID, - incomingHTLCID: packet.incomingHTLCID, - isRouted: true, - htlc: &lnwire.UpdateFailHTLC{ - Reason: reason, - }, - }) - - err = errors.Errorf("unable to find appropriate "+ + addErr := errors.Errorf("unable to find appropriate "+ "channel link insufficient capacity, need "+ "%v", htlc.Amount) - log.Error(err) - return err + + return s.failAddPacket(packet, failure, addErr) } // Send the packet to the destination channel link which // manages the channel. - destination.HandleSwitchPacket(packet) - return nil + packet.outgoingChanID = destination.ShortChanID() + return destination.HandleSwitchPacket(packet) - // We've just received a settle packet which means we can finalize the - // payment circuit by forwarding the settle msg to the channel from - // which htlc add packet was initially received. - case *lnwire.UpdateFulfillHTLC, *lnwire.UpdateFailHTLC: - if !packet.isRouted { - // Use circuit map to find the link to forward settle/fail to. - circuit := s.circuits.LookupByHTLC(packet.outgoingChanID, - packet.outgoingHTLCID) - if circuit == nil { - err := errors.Errorf("Unable to find target channel for HTLC "+ - "settle/fail: channel ID = %s, HTLC ID = %d", - packet.outgoingChanID, packet.outgoingHTLCID) - log.Error(err) - return err + case *lnwire.UpdateFailHTLC, *lnwire.UpdateFulfillHTLC: + + // If the source of this packet has not been set, use the + // circuit map to lookup the origin. + circuit, err := s.closeCircuit(packet) + if err != nil { + return err + } + + fail, isFail := htlc.(*lnwire.UpdateFailHTLC) + if isFail && !packet.hasSource { + switch { + case circuit.ErrorEncrypter == nil: + // No message to encrypt, locally sourced + // payment. + + case packet.isResolution: + // If this is a resolution message, then we'll need to encrypt + // it as it's actually internally sourced. + var err error + // TODO(roasbeef): don't need to pass actually? + failure := &lnwire.FailPermanentChannelFailure{} + fail.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop( + failure, + ) + if err != nil { + err = errors.Errorf("unable to obfuscate "+ + "error: %v", err) + log.Error(err) + } + + default: + // Otherwise, it's a forwarded error, so we'll perform a + // wrapper encryption as normal. + fail.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( + fail.Reason, + ) } - - // Remove the circuit since we are about to complete - // the HTLC. - err := s.circuits.Remove( - packet.outgoingChanID, - packet.outgoingHTLCID, - ) - if err != nil { - log.Warnf("Failed to close completed onion circuit for %x: "+ - "(%s, %d) <-> (%s, %d)", circuit.PaymentHash, - circuit.IncomingChanID, circuit.IncomingHTLCID, - circuit.OutgoingChanID, circuit.OutgoingHTLCID) - } else { - log.Debugf("Closed completed onion circuit for %x: "+ - "(%s, %d) <-> (%s, %d)", circuit.PaymentHash, - circuit.IncomingChanID, circuit.IncomingHTLCID, - circuit.OutgoingChanID, circuit.OutgoingHTLCID) - } - - packet.incomingChanID = circuit.IncomingChanID - packet.incomingHTLCID = circuit.IncomingHTLCID - + } else { // If this is an HTLC settle, and it wasn't from a // locally initiated HTLC, then we'll log a forwarding // event so we can flush it to disk later. // // TODO(roasbeef): only do this once link actually // fully settles? - _, isSettle := packet.htlc.(*lnwire.UpdateFulfillHTLC) - localHTLC := packet.incomingChanID == (lnwire.ShortChannelID{}) - if isSettle && !localHTLC { + localHTLC := packet.incomingChanID == sourceHop + if !localHTLC { s.fwdEventMtx.Lock() s.pendingFwdingEvents = append( s.pendingFwdingEvents, channeldb.ForwardingEvent{ Timestamp: time.Now(), - IncomingChanID: circuit.IncomingChanID, - OutgoingChanID: circuit.OutgoingChanID, - AmtIn: circuit.IncomingAmt, - AmtOut: circuit.OutgoingAmt, + IncomingChanID: circuit.Incoming.ChanID, + OutgoingChanID: circuit.Outgoing.ChanID, + AmtIn: circuit.IncomingAmount, + AmtOut: circuit.OutgoingAmount, }, ) s.fwdEventMtx.Unlock() } - - // Obfuscate the error message for fail updates before - // sending back through the circuit unless the payment - // was generated locally. - if circuit.ErrorEncrypter != nil { - if htlc, ok := htlc.(*lnwire.UpdateFailHTLC); ok { - // If this is a resolution message, - // then we'll need to encrypt it as - // it's actually internally sourced. - if packet.isResolution { - // TODO(roasbeef): don't need to pass actually? - failure := &lnwire.FailPermanentChannelFailure{} - htlc.Reason, err = circuit.ErrorEncrypter.EncryptFirstHop( - failure, - ) - if err != nil { - err := errors.Errorf("unable to obfuscate "+ - "error: %v", err) - log.Error(err) - } - } else { - // Otherwise, it's a forwarded - // error, so we'll perform a - // wrapper encryption as - // normal. - htlc.Reason = circuit.ErrorEncrypter.IntermediateEncrypt( - htlc.Reason, - ) - } - } - } } - // For local HTLCs we'll dispatch the settle event back to the - // caller, rather than to the peer that sent us the HTLC - // originally. - localHTLC := packet.incomingChanID == (lnwire.ShortChannelID{}) - if localHTLC { + // A blank IncomingChanID in a circuit indicates that it is a pending + // user-initiated payment. + if packet.incomingChanID == sourceHop { return s.handleLocalDispatch(packet) } - source, err := s.getLinkByShortID(packet.incomingChanID) - if err != nil { - err := errors.Errorf("Unable to get source channel "+ - "link to forward HTLC settle/fail: %v", err) - log.Error(err) - return err - } - - source.HandleSwitchPacket(packet) - return nil + // Check to see that the source link is online before removing + // the circuit. + sourceMailbox := s.getOrCreateMailBox(packet.incomingChanID) + return sourceMailbox.AddPacket(packet) default: return errors.New("wrong update type") } } +// failAddPacket encrypts a fail packet back to an add packet's source. +// The ciphertext will be derived from the failure message proivded by context. +// This method returns the failErr if all other steps complete successfully. +func (s *Switch) failAddPacket(packet *htlcPacket, + failure lnwire.FailureMessage, failErr error) error { + + // Encrypt the failure so that the sender will be able to read the error + // message. Since we failed this packet, we use EncryptFirstHop to + // obfuscate the failure for their eyes only. + reason, err := packet.obfuscator.EncryptFirstHop(failure) + if err != nil { + err := errors.Errorf("unable to obfuscate "+ + "error: %v", err) + log.Error(err) + return err + } + + log.Error(failErr) + + // Route a fail packet back to the source link. + sourceMailbox := s.getOrCreateMailBox(packet.incomingChanID) + if err = sourceMailbox.AddPacket(&htlcPacket{ + incomingChanID: packet.incomingChanID, + incomingHTLCID: packet.incomingHTLCID, + circuit: packet.circuit, + htlc: &lnwire.UpdateFailHTLC{ + Reason: reason, + }, + }); err != nil { + err = errors.Errorf("source chanid=%v unable to "+ + "handle switch packet: %v", + packet.incomingChanID, err) + log.Error(err) + return err + } + + return failErr +} + +// closeCircuit accepts a settle or fail htlc and the associated htlc packet and +// attempts to determine the source that forwarded this htlc. This method will +// set the incoming chan and htlc ID of the given packet if the source was +// found, and will properly [re]encrypt any failure messages. +func (s *Switch) closeCircuit(pkt *htlcPacket) (*PaymentCircuit, error) { + + // If the packet has its source, that means it was failed locally by the + // outgoing link. We fail it here to make sure only one response makes + // it through the switch. + if pkt.hasSource { + circuit, err := s.circuits.FailCircuit(pkt.inKey()) + switch err { + + // Circuit successfully closed. + case nil: + return circuit, nil + + // Circuit was previously closed, but has not been deleted. We'll just + // drop this response until the circuit has been fully removed. + case ErrCircuitClosing: + return nil, err + + // Failed to close circuit because it does not exist. This is likely + // because the circuit was already successfully closed. Since + // this packet failed locally, there is no forwarding package + // entry to acknowledge. + case ErrUnknownCircuit: + return nil, err + + // Unexpected error. + default: + return nil, err + } + } + + // Otherwise, this is packet was received from the remote party. + // Use circuit map to find the incoming link to receive the settle/fail. + circuit, err := s.circuits.CloseCircuit(pkt.outKey()) + switch err { + + // Open circuit successfully closed. + case nil: + pkt.incomingChanID = circuit.Incoming.ChanID + pkt.incomingHTLCID = circuit.Incoming.HtlcID + pkt.circuit = circuit + pkt.sourceRef = &circuit.AddRef + + return circuit, nil + + // Circuit was previously closed, but has not been deleted. We'll just + // drop this response until the circuit has been removed. + case ErrCircuitClosing: + return nil, err + + // Failed to close circuit because it does not exist. This is likely + // because the circuit was already successfully closed. + case ErrUnknownCircuit: + err := errors.Errorf("Unable to find target channel "+ + "for HTLC settle/fail: channel ID = %s, "+ + "HTLC ID = %d", pkt.outgoingChanID, + pkt.outgoingHTLCID) + log.Error(err) + + // TODO(conner): ack settle/fail + if pkt.destRef != nil { + if err := s.ackSettleFail(*pkt.destRef); err != nil { + return nil, err + } + } + + return nil, err + + // Unexpected error. + default: + return nil, err + } +} + +func (s *Switch) ackSettleFail(settleFailRef channeldb.SettleFailRef) error { + return s.cfg.DB.Update(func(tx *bolt.Tx) error { + return s.cfg.SwitchPackager.AckSettleFails(tx, settleFailRef) + }) +} + +// teardownCircuit removes a pending or open circuit from the switch's circuit +// map and prints useful logging statements regarding the outcome. +func (s *Switch) teardownCircuit(pkt *htlcPacket) error { + var pktType string + switch htlc := pkt.htlc.(type) { + case *lnwire.UpdateFulfillHTLC: + pktType = "SETTLE" + case *lnwire.UpdateFailHTLC: + pktType = "FAIL" + default: + err := fmt.Errorf("cannot tear down packet of type: %T", htlc) + log.Errorf(err.Error()) + return err + } + + switch { + case pkt.circuit.HasKeystone(): + log.Debugf("Tearing down open circuit with %s pkt, removing circuit=%v "+ + "with keystone=%v", pktType, pkt.inKey(), pkt.outKey()) + + err := s.circuits.DeleteCircuits(pkt.inKey()) + if err != nil { + log.Warnf("Failed to tear down open circuit (%s, %d) <-> (%s, %d) "+ + "with payment_hash-%v using %s pkt", + pkt.incomingChanID, pkt.incomingHTLCID, + pkt.outgoingChanID, pkt.outgoingHTLCID, + pkt.circuit.PaymentHash, pktType) + return err + } + + log.Debugf("Closed completed %s circuit for %x: "+ + "(%s, %d) <-> (%s, %d)", pktType, pkt.circuit.PaymentHash, + pkt.incomingChanID, pkt.incomingHTLCID, + pkt.outgoingChanID, pkt.outgoingHTLCID) + + default: + log.Debugf("Tearing down incomplete circuit with %s for inkey=%v", + pktType, pkt.inKey()) + err := s.circuits.DeleteCircuits(pkt.inKey()) + if err != nil { + log.Warnf("Failed to tear down pending %s circuit for %x: "+ + "(%s, %d)", pktType, pkt.circuit.PaymentHash, + pkt.incomingChanID, pkt.incomingHTLCID) + return err + } + + log.Debugf("Removed pending onion circuit for %x: "+ + "(%s, %d)", pkt.circuit.PaymentHash, + pkt.incomingChanID, pkt.incomingHTLCID) + } + + return nil +} + // CloseLink creates and sends the close channel command to the target link // directing the specified closure type. If the closure type if CloseRegular, // then the last parameter should be the ideal fee-per-kw that will be used as @@ -918,7 +1305,10 @@ func (s *Switch) htlcForwarder() { // collect all the forwarding events since the last internal, // and write them out to our log. case <-fwdEventTicker.C: + s.wg.Add(1) go func() { + defer s.wg.Done() + if err := s.FlushForwardingEvents(); err != nil { log.Errorf("unable to flush "+ "forwarding events: %v", err) @@ -1029,9 +1419,151 @@ func (s *Switch) Start() error { s.wg.Add(1) go s.htlcForwarder() + if err := s.reforwardResponses(); err != nil { + log.Errorf("unable to reforward responses: %v", err) + return err + } + return nil } +// reforwardResponses for every known, non-pending channel, loads all associated +// forwarding packages and reforwards any Settle or Fail HTLCs found. This is +// used to resurrect the switch's mailboxes after a restart. +func (s *Switch) reforwardResponses() error { + activeChannels, err := s.cfg.DB.FetchAllChannels() + if err != nil { + return err + } + + for _, activeChannel := range activeChannels { + if activeChannel.IsPending { + continue + } + + shortChanID := activeChannel.ShortChanID + fwdPkgs, err := s.loadChannelFwdPkgs(shortChanID) + if err != nil { + return err + } + + s.reforwardSettleFails(fwdPkgs) + } + + return nil +} + +// loadChannelFwdPkgs loads all forwarding packages owned by the `source` short +// channel identifier. +func (s *Switch) loadChannelFwdPkgs( + source lnwire.ShortChannelID) ([]*channeldb.FwdPkg, error) { + + var fwdPkgs []*channeldb.FwdPkg + if err := s.cfg.DB.Update(func(tx *bolt.Tx) error { + var err error + fwdPkgs, err = s.cfg.SwitchPackager.LoadChannelFwdPkgs( + tx, source, + ) + return err + }); err != nil { + return nil, err + } + + return fwdPkgs, nil +} + +// reforwardSettleFails parses the Settle and Fail HTLCs from the list of +// forwarding packages, and reforwards those that have not been acknowledged. +// This is intended to occur on startup, in order to recover the switch's +// mailboxes, and to ensure that responses can be propagated in case the +// outgoing link never comes back online. +// +// NOTE: This should mimic the behavior processRemoteSettleFails. +func (s *Switch) reforwardSettleFails(fwdPkgs []*channeldb.FwdPkg) { + for _, fwdPkg := range fwdPkgs { + settleFails := lnwallet.PayDescsFromRemoteLogUpdates( + fwdPkg.Source, fwdPkg.Height, fwdPkg.SettleFails, + ) + + switchPackets := make([]*htlcPacket, 0, len(settleFails)) + for i, pd := range settleFails { + + // Skip any settles or fails that have already been + // acknowledged by the incoming link that originated the + // forwarded Add. + if fwdPkg.SettleFailFilter.Contains(uint16(i)) { + continue + } + + switch pd.EntryType { + + // A settle for an HTLC we previously forwarded HTLC has + // been received. So we'll forward the HTLC to the + // switch which will handle propagating the settle to + // the prior hop. + case lnwallet.Settle: + settlePacket := &htlcPacket{ + outgoingChanID: fwdPkg.Source, + outgoingHTLCID: pd.ParentIndex, + destRef: pd.DestRef, + htlc: &lnwire.UpdateFulfillHTLC{ + PaymentPreimage: pd.RPreimage, + }, + } + + // Add the packet to the batch to be forwarded, and + // notify the overflow queue that a spare spot has been + // freed up within the commitment state. + switchPackets = append(switchPackets, settlePacket) + + // A failureCode message for a previously forwarded HTLC has been + // received. As a result a new slot will be freed up in our + // commitment state, so we'll forward this to the switch so the + // backwards undo can continue. + case lnwallet.Fail: + // Fetch the reason the HTLC was cancelled so we can + // continue to propagate it. + failPacket := &htlcPacket{ + outgoingChanID: fwdPkg.Source, + outgoingHTLCID: pd.ParentIndex, + destRef: pd.DestRef, + htlc: &lnwire.UpdateFailHTLC{ + Reason: lnwire.OpaqueReason(pd.FailReason), + }, + } + + // Add the packet to the batch to be forwarded, and + // notify the overflow queue that a spare spot has been + // freed up within the commitment state. + switchPackets = append(switchPackets, failPacket) + } + } + + errChan := s.ForwardPackets(switchPackets...) + go handleBatchFwdErrs(errChan) + } +} + +// handleBatchFwdErrs waits on the given errChan until it is closed, logging the +// errors returned from any unsuccessful forwarding attempts. +func handleBatchFwdErrs(errChan chan error) { + for { + err, ok := <-errChan + if !ok { + // Err chan has been drained or switch is shutting down. + // Either way, return. + return + } + + if err == nil { + continue + } + + log.Errorf("unhandled error while reforwarding htlc "+ + "settle/fail over htlcswitch: %v", err) + } +} + // Stop gracefully stops all active helper goroutines, then waits until they've // exited. func (s *Switch) Stop() error { @@ -1043,6 +1575,11 @@ func (s *Switch) Stop() error { log.Infof("HTLC Switch shutting down") close(s.quit) + + for _, mailBox := range s.mailboxes { + mailBox.Stop() + } + s.wg.Wait() return nil @@ -1096,6 +1633,14 @@ func (s *Switch) addLink(link ChannelLink) error { } s.interfaceIndex[peerPub][link] = struct{}{} + // Get the mailbox for this link, which buffers packets in case there + // packets that we tried to deliver while this link was offline. + mailbox := s.getOrCreateMailBox(link.ShortChanID()) + + // Give the link its mailbox, we only need to start the mailbox if it + // wasn't previously found. + link.AttachMailBox(mailbox) + if err := link.Start(); err != nil { s.removeLink(link.ChanID()) return err @@ -1107,6 +1652,32 @@ func (s *Switch) addLink(link ChannelLink) error { return nil } +// getOrCreateMailBox returns the known mailbox for a particular short channel +// id, or creates one if the link has no existing mailbox. +func (s *Switch) getOrCreateMailBox(chanID lnwire.ShortChannelID) MailBox { + // Check to see if we have a mailbox already populated for this link. + s.mailMtx.RLock() + mailbox, ok := s.mailboxes[chanID] + if ok { + s.mailMtx.RUnlock() + return mailbox + } + s.mailMtx.RUnlock() + + // Otherwise, we will make a new one only if the mailbox still is not + // present after the exclusive mutex is acquired. + s.mailMtx.Lock() + mailbox, ok = s.mailboxes[chanID] + if !ok { + mailbox = newMemoryMailBox() + mailbox.Start() + s.mailboxes[chanID] = mailbox + } + s.mailMtx.Unlock() + + return mailbox +} + // getLinkCmd is a get link command wrapper, it is used to propagate handler // parameters and return handler error. type getLinkCmd struct { @@ -1361,15 +1932,47 @@ func (s *Switch) findPayment(paymentID uint64) (*pendingPayment, error) { return payment, nil } +// CircuitModifier returns a reference to subset of the interfaces provided by +// the circuit map, to allow links to open and close circuits. +func (s *Switch) CircuitModifier() CircuitModifier { + return s.circuits +} + // numPendingPayments is helper function which returns the overall number of // pending user payments. func (s *Switch) numPendingPayments() int { return len(s.pendingPayments) } -// addCircuit adds a circuit to the switch's in-memory mapping. -func (s *Switch) addCircuit(circuit *PaymentCircuit) { - s.circuits.Add(circuit) +// commitCircuits persistently adds a circuit to the switch's circuit map. +func (s *Switch) commitCircuits(circuits ...*PaymentCircuit) ( + *CircuitFwdActions, error) { + + return s.circuits.CommitCircuits(circuits...) +} + +// openCircuits preemptively writes the keystones for Adds that are about to be +// added to a commitment txn. +func (s *Switch) openCircuits(keystones ...Keystone) error { + return s.circuits.OpenCircuits(keystones...) +} + +// deleteCircuits persistently removes the circuit, and keystone if present, +// from the circuit map. +func (s *Switch) deleteCircuits(inKeys ...CircuitKey) error { + return s.circuits.DeleteCircuits(inKeys...) +} + +// lookupCircuit queries the in memory representation of the circuit map to +// retrieve a particular circuit. +func (s *Switch) lookupCircuit(inKey CircuitKey) *PaymentCircuit { + return s.circuits.LookupCircuit(inKey) +} + +// lookupOpenCircuit queries the in-memory representation of the circuit map for a +// circuit whose outgoing circuit key matches outKey. +func (s *Switch) lookupOpenCircuit(outKey CircuitKey) *PaymentCircuit { + return s.circuits.LookupOpenCircuit(outKey) } // FlushForwardingEvents flushes out the set of pending forwarding events to