diff --git a/htlcswitch/mailbox.go b/htlcswitch/mailbox.go index 9cbd07ad..64924fe3 100644 --- a/htlcswitch/mailbox.go +++ b/htlcswitch/mailbox.go @@ -1,23 +1,40 @@ package htlcswitch import ( + "container/list" + "errors" "sync" + "sync/atomic" + "time" "github.com/lightningnetwork/lnd/lnwire" ) -// mailBox is an interface which represents a concurrent-safe, in-order +// ErrMailBoxShuttingDown is returned when the mailbox is interrupted by a +// shutdown request. +var ErrMailBoxShuttingDown = errors.New("mailbox is shutting down") + +// MailBox is an interface which represents a concurrent-safe, in-order // delivery queue for messages from the network and also from the main switch. // This struct servers as a buffer between incoming messages, and messages to // the handled by the link. Each of the mutating methods within this interface // should be implemented in a non-blocking manner. -type mailBox interface { +type MailBox interface { // AddMessage appends a new message to the end of the message queue. AddMessage(msg lnwire.Message) error // AddPacket appends a new message to the end of the packet queue. AddPacket(pkt *htlcPacket) error + // HasPacket queries the packets for a circuit key, this is used to drop + // packets bound for the switch that already have a queued response. + HasPacket(CircuitKey) bool + + // AckPacket removes a packet from the mailboxes in-memory replay + // buffer. This will prevent a packet from being delivered after a link + // restarts if the switch has remained online. + AckPacket(CircuitKey) error + // MessageOutBox returns a channel that any new messages ready for // delivery will be sent on. MessageOutBox() chan lnwire.Message @@ -26,6 +43,12 @@ type mailBox interface { // delivery will be sent on. PacketOutBox() chan *htlcPacket + // Clears any pending wire messages from the inbox. + ResetMessages() error + + // Reset the packet head to point at the first element in the list. + ResetPackets() error + // Start starts the mailbox and any goroutines it needs to operate // properly. Start() error @@ -34,20 +57,28 @@ type mailBox interface { Stop() error } -// memoryMailBox is an implementation of the mailBox struct backed by purely +// memoryMailBox is an implementation of the MailBox struct backed by purely // in-memory queues. type memoryMailBox struct { - wireMessages []lnwire.Message + started uint32 + stopped uint32 + + wireMessages *list.List + wireHead *list.Element wireMtx sync.Mutex wireCond *sync.Cond messageOutbox chan lnwire.Message + msgReset chan chan struct{} - htlcPkts []*htlcPacket + htlcPkts *list.List + pktIndex map[CircuitKey]*list.Element + pktHead *list.Element pktMtx sync.Mutex pktCond *sync.Cond pktOutbox chan *htlcPacket + pktReset chan chan struct{} wg sync.WaitGroup quit chan struct{} @@ -56,9 +87,14 @@ type memoryMailBox struct { // newMemoryMailBox creates a new instance of the memoryMailBox. func newMemoryMailBox() *memoryMailBox { box := &memoryMailBox{ - quit: make(chan struct{}), + wireMessages: list.New(), + htlcPkts: list.New(), messageOutbox: make(chan lnwire.Message), pktOutbox: make(chan *htlcPacket), + msgReset: make(chan chan struct{}, 1), + pktReset: make(chan chan struct{}, 1), + pktIndex: make(map[CircuitKey]*list.Element), + quit: make(chan struct{}), } box.wireCond = sync.NewCond(&box.wireMtx) box.pktCond = sync.NewCond(&box.pktMtx) @@ -66,12 +102,12 @@ func newMemoryMailBox() *memoryMailBox { return box } -// A compile time assertion to ensure that memoryMailBox meets the mailBox +// A compile time assertion to ensure that memoryMailBox meets the MailBox // interface. -var _ mailBox = (*memoryMailBox)(nil) +var _ MailBox = (*memoryMailBox)(nil) // courierType is an enum that reflects the distinct types of messages a -// mailBox can handle. Each type will be placed in an isolated mail box and +// MailBox can handle. Each type will be placed in an isolated mail box and // will have a dedicated goroutine for delivering the messages. type courierType uint8 @@ -85,8 +121,12 @@ const ( // Start starts the mailbox and any goroutines it needs to operate properly. // -// NOTE: This method is part of the mailBox interface. +// NOTE: This method is part of the MailBox interface. func (m *memoryMailBox) Start() error { + if !atomic.CompareAndSwapUint32(&m.started, 0, 1) { + return nil + } + m.wg.Add(2) go m.mailCourier(wireCourier) go m.mailCourier(pktCourier) @@ -94,10 +134,90 @@ func (m *memoryMailBox) Start() error { return nil } +// ResetMessages blocks until all buffered wire messages are cleared. +func (m *memoryMailBox) ResetMessages() error { + msgDone := make(chan struct{}) + select { + case m.msgReset <- msgDone: + return m.signalUntilReset(wireCourier, msgDone) + case <-m.quit: + return ErrMailBoxShuttingDown + } +} + +// ResetPackets blocks until the head of packets buffer is reset, causing the +// packets to be redelivered in order. +func (m *memoryMailBox) ResetPackets() error { + pktDone := make(chan struct{}) + select { + case m.pktReset <- pktDone: + return m.signalUntilReset(pktCourier, pktDone) + case <-m.quit: + return ErrMailBoxShuttingDown + } +} + +// signalUntilReset strobes the condition variable for the specified inbox type +// until receiving a response that the mailbox has processed a reset. +func (m *memoryMailBox) signalUntilReset(cType courierType, + done chan struct{}) error { + + for { + switch cType { + case wireCourier: + m.wireCond.Signal() + case pktCourier: + m.pktCond.Signal() + } + + select { + case <-time.After(time.Millisecond): + continue + case <-done: + return nil + case <-m.quit: + return ErrMailBoxShuttingDown + } + } +} + +// AckPacket removes the packet identified by it's incoming circuit key from the +// queue of packets to be delivered. +// +// NOTE: It is safe to call this method multiple times for the same circuit key. +func (m *memoryMailBox) AckPacket(inKey CircuitKey) error { + m.pktCond.L.Lock() + entry, ok := m.pktIndex[inKey] + if !ok { + m.pktCond.L.Unlock() + return nil + } + + m.htlcPkts.Remove(entry) + delete(m.pktIndex, inKey) + m.pktCond.L.Unlock() + + return nil +} + +// HasPacket queries the packets for a circuit key, this is used to drop packets +// bound for the switch that already have a queued response. +func (m *memoryMailBox) HasPacket(inKey CircuitKey) bool { + m.pktCond.L.Lock() + _, ok := m.pktIndex[inKey] + m.pktCond.L.Unlock() + + return ok +} + // Stop signals the mailbox and its goroutines for a graceful shutdown. // -// NOTE: This method is part of the mailBox interface. +// NOTE: This method is part of the MailBox interface. func (m *memoryMailBox) Stop() error { + if !atomic.CompareAndSwapUint32(&m.stopped, 0, 1) { + return nil + } + close(m.quit) m.wireCond.Signal() @@ -121,10 +241,13 @@ func (m *memoryMailBox) mailCourier(cType courierType) { switch cType { case wireCourier: m.wireCond.L.Lock() - for len(m.wireMessages) == 0 { + for m.wireMessages.Front() == nil { m.wireCond.Wait() select { + case msgDone := <-m.msgReset: + m.wireMessages.Init() + close(msgDone) case <-m.quit: m.wireCond.L.Unlock() return @@ -134,10 +257,13 @@ func (m *memoryMailBox) mailCourier(cType courierType) { case pktCourier: m.pktCond.L.Lock() - for len(m.htlcPkts) == 0 { + for m.pktHead == nil { m.pktCond.Wait() select { + case pktDone := <-m.pktReset: + m.pktHead = m.htlcPkts.Front() + close(pktDone) case <-m.quit: m.pktCond.L.Unlock() return @@ -155,13 +281,11 @@ func (m *memoryMailBox) mailCourier(cType courierType) { ) switch cType { case wireCourier: - nextMsg = m.wireMessages[0] - m.wireMessages[0] = nil // Set to nil to prevent GC leak. - m.wireMessages = m.wireMessages[1:] + entry := m.wireMessages.Front() + nextMsg = m.wireMessages.Remove(entry).(lnwire.Message) case pktCourier: - nextPkt = m.htlcPkts[0] - m.htlcPkts[0] = nil // Set to nil to prevent GC leak. - m.htlcPkts = m.htlcPkts[1:] + nextPkt = m.pktHead.Value.(*htlcPacket) + m.pktHead = m.pktHead.Next() } // Now that we're done with the condition, we can unlock it to @@ -173,13 +297,17 @@ func (m *memoryMailBox) mailCourier(cType courierType) { m.pktCond.L.Unlock() } - // With the next message obtained, we'll now select to attempt // to deliver the message. If we receive a kill signal, then // we'll bail out. switch cType { case wireCourier: select { case m.messageOutbox <- nextMsg: + case msgDone := <-m.msgReset: + m.wireCond.L.Lock() + m.wireMessages.Init() + m.wireCond.L.Unlock() + close(msgDone) case <-m.quit: return } @@ -187,6 +315,11 @@ func (m *memoryMailBox) mailCourier(cType courierType) { case pktCourier: select { case m.pktOutbox <- nextPkt: + case pktDone := <-m.pktReset: + m.pktCond.L.Lock() + m.pktHead = m.htlcPkts.Front() + m.pktCond.L.Unlock() + close(pktDone) case <-m.quit: return } @@ -197,13 +330,13 @@ func (m *memoryMailBox) mailCourier(cType courierType) { // AddMessage appends a new message to the end of the message queue. // -// NOTE: This method is safe for concrete use and part of the mailBox +// NOTE: This method is safe for concrete use and part of the MailBox // interface. func (m *memoryMailBox) AddMessage(msg lnwire.Message) error { // First, we'll lock the condition, and add the message to the end of // the wire message inbox. m.wireCond.L.Lock() - m.wireMessages = append(m.wireMessages, msg) + m.wireMessages.PushBack(msg) m.wireCond.L.Unlock() // With the message added, we signal to the mailCourier that there are @@ -215,13 +348,22 @@ func (m *memoryMailBox) AddMessage(msg lnwire.Message) error { // AddPacket appends a new message to the end of the packet queue. // -// NOTE: This method is safe for concrete use and part of the mailBox +// NOTE: This method is safe for concrete use and part of the MailBox // interface. func (m *memoryMailBox) AddPacket(pkt *htlcPacket) error { // First, we'll lock the condition, and add the packet to the end of // the htlc packet inbox. m.pktCond.L.Lock() - m.htlcPkts = append(m.htlcPkts, pkt) + if _, ok := m.pktIndex[pkt.inKey()]; ok { + m.pktCond.L.Unlock() + return nil + } + + entry := m.htlcPkts.PushBack(pkt) + m.pktIndex[pkt.inKey()] = entry + if m.pktHead == nil { + m.pktHead = entry + } m.pktCond.L.Unlock() // With the packet added, we signal to the mailCourier that there are @@ -234,7 +376,7 @@ func (m *memoryMailBox) AddPacket(pkt *htlcPacket) error { // MessageOutBox returns a channel that any new messages ready for delivery // will be sent on. // -// NOTE: This method is part of the mailBox interface. +// NOTE: This method is part of the MailBox interface. func (m *memoryMailBox) MessageOutBox() chan lnwire.Message { return m.messageOutbox } @@ -242,7 +384,7 @@ func (m *memoryMailBox) MessageOutBox() chan lnwire.Message { // PacketOutBox returns a channel that any new packets ready for delivery will // be sent on. // -// NOTE: This method is part of the mailBox interface. +// NOTE: This method is part of the MailBox interface. func (m *memoryMailBox) PacketOutBox() chan *htlcPacket { return m.pktOutbox }