diff --git a/htlcswitch/iterator.go b/htlcswitch/iterator.go index 5c1603b1..8218b586 100644 --- a/htlcswitch/iterator.go +++ b/htlcswitch/iterator.go @@ -40,6 +40,10 @@ var ( // exitHop is a special "hop" which denotes that an incoming HTLC is // meant to pay finally to the receiving node. exitHop lnwire.ShortChannelID + + // sourceHop is a sentinel value denoting that an incoming HTLC is + // initiated by our own switch. + sourceHop lnwire.ShortChannelID ) // ForwardingInfo contains all the information that is necessary to forward and @@ -85,14 +89,20 @@ type HopIterator interface { // EncodeNextHop encodes the onion packet destined for the next hop // into the passed io.Writer. EncodeNextHop(w io.Writer) error + + // ExtractErrorEncrypter returns the ErrorEncrypter needed for this hop, + // along with a failure code to signal if the decoding was successful. + ExtractErrorEncrypter(ErrorEncrypterExtracter) (ErrorEncrypter, + lnwire.FailCode) } // sphinxHopIterator is the Sphinx implementation of hop iterator which uses // onion routing to encode the payment route in such a way so that node might // see only the next hop in the route.. type sphinxHopIterator struct { - // nextPacket is the decoded onion packet for the _next_ hop. - nextPacket *sphinx.OnionPacket + // ogPacket is the original packet from which the processed packet is + // derived. + ogPacket *sphinx.OnionPacket // processedPacket is the outcome of processing an onion packet. It // includes the information required to properly forward the packet to @@ -100,6 +110,17 @@ type sphinxHopIterator struct { processedPacket *sphinx.ProcessedPacket } +// makeSphinxHopIterator converts a processed packet returned from a sphinx +// router and converts it into an hop iterator for usage in the link. +func makeSphinxHopIterator(ogPacket *sphinx.OnionPacket, + packet *sphinx.ProcessedPacket) *sphinxHopIterator { + + return &sphinxHopIterator{ + ogPacket: ogPacket, + processedPacket: packet, + } +} + // A compile time check to ensure sphinxHopIterator implements the HopIterator // interface. var _ HopIterator = (*sphinxHopIterator)(nil) @@ -108,7 +129,7 @@ var _ HopIterator = (*sphinxHopIterator)(nil) // // NOTE: Part of the HopIterator interface. func (r *sphinxHopIterator) EncodeNextHop(w io.Writer) error { - return r.nextPacket.Encode(w) + return r.processedPacket.NextPacket.Encode(w) } // ForwardingInstructions returns the set of fields that detail exactly _how_ @@ -137,6 +158,18 @@ func (r *sphinxHopIterator) ForwardingInstructions() ForwardingInfo { } } +// ExtractErrorEncrypter decodes and returns the ErrorEncrypter for this hop, +// along with a failure code to signal if the decoding was successful. The +// ErrorEncrypter is used to encrypt errors back to the sender in the event that +// a payment fails. +// +// NOTE: Part of the HopIterator interface. +func (r *sphinxHopIterator) ExtractErrorEncrypter( + extracter ErrorEncrypterExtracter) (ErrorEncrypter, lnwire.FailCode) { + + return extracter(r.ogPacket) +} + // OnionProcessor is responsible for keeping all sphinx dependent parts inside // and expose only decoding function. With such approach we give freedom for // subsystems which wants to decode sphinx path to not be dependable from @@ -155,11 +188,22 @@ func NewOnionProcessor(router *sphinx.Router) *OnionProcessor { return &OnionProcessor{router} } +// Start spins up the onion processor's sphinx router. +func (p *OnionProcessor) Start() error { + return p.router.Start() +} + +// Stop shutsdown the onion processor's sphinx router. +func (p *OnionProcessor) Stop() error { + p.router.Stop() + return nil +} + // DecodeHopIterator attempts to decode a valid sphinx packet from the passed io.Reader // instance using the rHash as the associated data when checking the relevant // MACs during the decoding process. -func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte) (HopIterator, - lnwire.FailCode) { +func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte, + incomingCltv uint32) (HopIterator, lnwire.FailCode) { onionPkt := &sphinx.OnionPacket{} if err := onionPkt.Decode(r); err != nil { @@ -179,7 +223,9 @@ func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte) (HopIterat // associated data in order to thwart attempts a replay attacks. In the // case of a replay, an attacker is *forced* to use the same payment // hash twice, thereby losing their money entirely. - sphinxPacket, err := p.router.ProcessOnionPacket(onionPkt, rHash) + sphinxPacket, err := p.router.ProcessOnionPacket( + onionPkt, rHash, incomingCltv, + ) if err != nil { switch err { case sphinx.ErrInvalidOnionVersion: @@ -194,10 +240,160 @@ func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte) (HopIterat } } - return &sphinxHopIterator{ - nextPacket: sphinxPacket.NextPacket, - processedPacket: sphinxPacket, - }, lnwire.CodeNone + return makeSphinxHopIterator(onionPkt, sphinxPacket), lnwire.CodeNone +} + +// DecodeHopIteratorRequest encapsulates all date necessary to process an onion +// packet, perform sphinx replay detection, and schedule the entry for garbage +// collection. +type DecodeHopIteratorRequest struct { + OnionReader io.Reader + RHash []byte + IncomingCltv uint32 +} + +// DecodeHopIteratorResponse encapsulates the outcome of a batched sphinx onion +// processing. +type DecodeHopIteratorResponse struct { + HopIterator HopIterator + FailCode lnwire.FailCode +} + +// Result returns the (HopIterator, lnwire.FailCode) tuple, which should +// correspond to the index of a particular DecodeHopIteratorRequest. +// +// NOTE: The HopIterator should be considered invalid if the fail code is +// anything but lnwire.CodeNone. +func (r *DecodeHopIteratorResponse) Result() (HopIterator, lnwire.FailCode) { + return r.HopIterator, r.FailCode +} + +// DecodeHopIterators performs batched decoding and validation of incoming +// sphinx packets. For the same `id`, this method will return the same iterators +// and failcodes upon subsequent invocations. +// +// NOTE: In order for the responses to be valid, the caller must guarantee that +// the presented readers and rhashes *NEVER* deviate across invocations for the +// same id. +func (p *OnionProcessor) DecodeHopIterators(id []byte, + reqs []DecodeHopIteratorRequest) ([]DecodeHopIteratorResponse, error) { + + var ( + batchSize = len(reqs) + onionPkts = make([]sphinx.OnionPacket, batchSize) + resps = make([]DecodeHopIteratorResponse, batchSize) + ) + + tx := p.router.BeginTxn(id, batchSize) + + for i, req := range reqs { + onionPkt := &onionPkts[i] + resp := &resps[i] + + err := onionPkt.Decode(req.OnionReader) + switch err { + case nil: + // success + + case sphinx.ErrInvalidOnionVersion: + resp.FailCode = lnwire.CodeInvalidOnionVersion + continue + + case sphinx.ErrInvalidOnionKey: + resp.FailCode = lnwire.CodeInvalidOnionKey + continue + + default: + log.Errorf("unable to decode onion packet: %v", err) + resp.FailCode = lnwire.CodeInvalidOnionKey + continue + } + + err = tx.ProcessOnionPacket( + uint16(i), onionPkt, req.RHash, req.IncomingCltv, + ) + switch err { + case nil: + // success + + case sphinx.ErrInvalidOnionVersion: + resp.FailCode = lnwire.CodeInvalidOnionVersion + continue + + case sphinx.ErrInvalidOnionHMAC: + resp.FailCode = lnwire.CodeInvalidOnionHmac + continue + + case sphinx.ErrInvalidOnionKey: + resp.FailCode = lnwire.CodeInvalidOnionKey + continue + + default: + log.Errorf("unable to process onion packet: %v", err) + resp.FailCode = lnwire.CodeInvalidOnionKey + continue + } + } + + // With that batch created, we will now attempt to write the shared + // secrets to disk. This operation will returns the set of indices that + // were detected as replays, and the computed sphinx packets for all + // indices that did not fail the above loop. Only indices that are not + // in the replay set should be considered valid, as they are + // opportunistically computed. + packets, replays, err := tx.Commit() + if err != nil { + log.Errorf("unable to process onion packet batch %x: %v", + id, err) + + // If we failed to commit the batch to the secret share log, we + // will mark all not-yet-failed channels with a temporary + // channel failure and exit since we cannot proceed. + for i := range resps { + resp := &resps[i] + + // Skip any indexes that already failed onion decoding. + if resp.FailCode != lnwire.CodeNone { + continue + } + + log.Errorf("unable to process onion packet %x-%v", + id, i) + resp.FailCode = lnwire.CodeTemporaryChannelFailure + } + + // TODO(conner): return real errors to caller so link can fail? + return resps, err + } + + // Otherwise, the commit was successful. Now we will post process any + // remaining packets, additionally failing any that were included in the + // replay set. + for i := range resps { + resp := &resps[i] + + // Skip any indexes that already failed onion decoding. + if resp.FailCode != lnwire.CodeNone { + continue + } + + // If this index is contained in the replay set, mark it with a + // temporary channel failure error code. We infer that the + // offending error was due to a replayed packet because this + // index was found in the replay set. + if replays.Contains(uint16(i)) { + log.Errorf("unable to process onion packet: %v", + sphinx.ErrReplayedPacket) + resp.FailCode = lnwire.CodeTemporaryChannelFailure + continue + } + + // Finally, construct a hop iterator from our processed sphinx + // packet, simultaneously caching the original onion packet. + resp.HopIterator = makeSphinxHopIterator(&onionPkts[i], &packets[i]) + } + + return resps, nil } // ExtractErrorEncrypter takes an io.Reader which should contain the onion @@ -205,20 +401,8 @@ func (p *OnionProcessor) DecodeHopIterator(r io.Reader, rHash []byte) (HopIterat // ErrorEncrypter instance using the derived shared secret. In the case that en // error occurs, a lnwire failure code detailing the parsing failure will be // returned. -func (p *OnionProcessor) ExtractErrorEncrypter(r io.Reader) (ErrorEncrypter, lnwire.FailCode) { - - onionPkt := &sphinx.OnionPacket{} - if err := onionPkt.Decode(r); err != nil { - switch err { - case sphinx.ErrInvalidOnionVersion: - return nil, lnwire.CodeInvalidOnionVersion - case sphinx.ErrInvalidOnionKey: - return nil, lnwire.CodeInvalidOnionKey - default: - log.Errorf("unable to decode onion packet: %v", err) - return nil, lnwire.CodeInvalidOnionKey - } - } +func (p *OnionProcessor) ExtractErrorEncrypter(onionPkt *sphinx.OnionPacket) ( + ErrorEncrypter, lnwire.FailCode) { onionObfuscator, err := sphinx.NewOnionErrorEncrypter(p.router, onionPkt.EphemeralKey) @@ -238,5 +422,6 @@ func (p *OnionProcessor) ExtractErrorEncrypter(r io.Reader) (ErrorEncrypter, lnw return &SphinxErrorEncrypter{ OnionErrorEncrypter: onionObfuscator, + ogPacket: onionPkt, }, lnwire.CodeNone }