diff --git a/discovery/gossiper.go b/discovery/gossiper.go index d8300005..040696ad 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -397,11 +397,11 @@ type channelUpdateID struct { flags uint16 } -// deDupedAnnouncements de-duplicates announcements that have been -// added to the batch. Internally, announcements are stored in three maps +// deDupedAnnouncements de-duplicates announcements that have been added to the +// batch. Internally, announcements are stored in three maps // (one each for channel announcements, channel updates, and node -// announcements). These maps keep track of unique announcements and -// ensure no announcements are duplicated. +// announcements). These maps keep track of unique announcements and ensure no +// announcements are duplicated. type deDupedAnnouncements struct { // channelAnnouncements are identified by the short channel id field. channelAnnouncements map[lnwire.ShortChannelID]lnwire.Message @@ -411,59 +411,80 @@ type deDupedAnnouncements struct { // nodeAnnouncements are identified by the Vertex field. nodeAnnouncements map[routing.Vertex]lnwire.Message + + sync.Mutex } -// Reset operates on deDupedAnnouncements to reset storage of announcements +// Reset operates on deDupedAnnouncements to reset the storage of +// announcements. func (d *deDupedAnnouncements) Reset() { + d.Lock() + defer d.Unlock() + + d.reset() +} + +// reset is the private version of the Reset method. We have this so we can +// call this method within method that are already holding the lock. +func (d *deDupedAnnouncements) reset() { // Storage of each type of announcement (channel anouncements, channel // updates, node announcements) is set to an empty map where the - // approprate key points to the corresponding lnwire.Message. + // appropriate key points to the corresponding lnwire.Message. d.channelAnnouncements = make(map[lnwire.ShortChannelID]lnwire.Message) d.channelUpdates = make(map[channelUpdateID]lnwire.Message) d.nodeAnnouncements = make(map[routing.Vertex]lnwire.Message) } -// AddMsg adds a new message to the current batch. -func (d *deDupedAnnouncements) AddMsg(message lnwire.Message) { - // Depending on the message type (channel announcement, channel - // update, or node announcement), the message is added to the - // corresponding map in deDupedAnnouncements. Because each - // identifying key can have at most one value, the announcements - // are de-duplicated, with newer ones replacing older ones. +// addMsg adds a new message to the current batch. +func (d *deDupedAnnouncements) addMsg(message lnwire.Message) { + // Depending on the message type (channel announcement, channel update, + // or node announcement), the message is added to the corresponding map + // in deDupedAnnouncements. Because each identifying key can have at + // most one value, the announcements are de-duplicated, with newer ones + // replacing older ones. switch msg := message.(type) { + + // Channel announcements are identified by the short channel id field. case *lnwire.ChannelAnnouncement: - // Channel announcements are identified by the short channel - // id field. d.channelAnnouncements[msg.ShortChannelID] = msg + + // Channel updates are identified by the (short channel id, flags) + // tuple. case *lnwire.ChannelUpdate: - // Channel updates are identified by the (short channel id, - // flags) tuple. channelUpdateID := channelUpdateID{ msg.ShortChannelID, msg.Flags, } d.channelUpdates[channelUpdateID] = msg + + // Node announcements are identified by the Vertex field. Use the + // NodeID to create the corresponding Vertex. case *lnwire.NodeAnnouncement: - // Node announcements are identified by the Vertex field. - // Use the NodeID to create the corresponding Vertex. vertex := routing.NewVertex(msg.NodeID) d.nodeAnnouncements[vertex] = msg } } -// AddMsgs is a helper method to add multiple messages to the -// announcement batch. -func (d *deDupedAnnouncements) AddMsgs(msgs []lnwire.Message) { +// AddMsgs is a helper method to add multiple messages to the announcement +// batch. +func (d *deDupedAnnouncements) AddMsgs(msgs ...lnwire.Message) { + d.Lock() + defer d.Unlock() + for _, msg := range msgs { - d.AddMsg(msg) + d.addMsg(msg) } } -// Batch returns the set of de-duplicated announcements to be sent out -// during the next announcement epoch, in the order of channel announcements, -// channel updates, and node announcements. -func (d *deDupedAnnouncements) Batch() []lnwire.Message { +// Emit returns the set of de-duplicated announcements to be sent out during +// the next announcement epoch, in the order of channel announcements, channel +// updates, and node announcements. Additionally, the set of stored messages +// are reset. +func (d *deDupedAnnouncements) Emit() []lnwire.Message { + d.Lock() + defer d.Unlock() + // Get the total number of announcements. numAnnouncements := len(d.channelAnnouncements) + len(d.channelUpdates) + len(d.nodeAnnouncements) @@ -487,6 +508,8 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message { announcements = append(announcements, message) } + d.reset() + // Return the array of lnwire.messages. return announcements } @@ -500,11 +523,6 @@ func (d *deDupedAnnouncements) Batch() []lnwire.Message { func (d *AuthenticatedGossiper) networkHandler() { defer d.wg.Done() - // TODO(roasbeef): changes for spec compliance - // * buffer recv'd node ann until after chan ann that includes is - // created - // * can use mostly empty struct in db as place holder - // Initialize empty deDupedAnnouncements to store announcement batch. announcements := deDupedAnnouncements{} announcements.Reset() diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index f8ae0d46..a8a0140f 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -866,7 +866,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create remote channel announcement: %v", err) } - announcements.AddMsg(ca) + announcements.AddMsgs(ca) if len(announcements.channelAnnouncements) != 1 { t.Fatal("new channel announcement not stored in batch") } @@ -879,7 +879,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create remote channel announcement: %v", err) } - announcements.AddMsg(ca2) + announcements.AddMsgs(ca2) if len(announcements.channelAnnouncements) != 1 { t.Fatal("channel announcement not replaced in batch") } @@ -891,7 +891,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create update announcement: %v", err) } - announcements.AddMsg(ua) + announcements.AddMsgs(ua) if len(announcements.channelUpdates) != 1 { t.Fatal("new channel update not stored in batch") } @@ -902,7 +902,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create update announcement: %v", err) } - announcements.AddMsg(ua2) + announcements.AddMsgs(ua2) if len(announcements.channelUpdates) != 1 { t.Fatal("channel update not replaced in batch") } @@ -913,7 +913,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create node announcement: %v", err) } - announcements.AddMsg(na) + announcements.AddMsgs(na) if len(announcements.nodeAnnouncements) != 1 { t.Fatal("new node announcement not stored in batch") } @@ -923,7 +923,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create node announcement: %v", err) } - announcements.AddMsg(na2) + announcements.AddMsgs(na2) if len(announcements.nodeAnnouncements) != 2 { t.Fatal("second node announcement not stored in batch") } @@ -934,7 +934,7 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create node announcement: %v", err) } - announcements.AddMsg(na3) + announcements.AddMsgs(na3) if len(announcements.nodeAnnouncements) != 2 { t.Fatal("second node announcement not replaced in batch") } @@ -946,14 +946,14 @@ func TestDeDuplicatedAnnouncements(t *testing.T) { if err != nil { t.Fatalf("can't create node announcement: %v", err) } - announcements.AddMsg(na4) + announcements.AddMsgs(na4) if len(announcements.nodeAnnouncements) != 2 { t.Fatal("second node announcement not replaced again in batch") } // Ensure that announcement batch delivers channel announcements, // channel updates, and node announcements in proper order. - batch := announcements.Batch() + batch := announcements.Emit() if len(batch) != 4 { t.Fatal("announcement batch incorrect length") }