diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 26ce951a..ec81fd24 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -16,6 +16,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/routing" "github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/chaincfg/chainhash" @@ -186,6 +187,12 @@ type AuthenticatedGossiper struct { // selfKey is the identity public key of the backing Lighting node. selfKey *btcec.PublicKey + // channelMtx is used to restrict the database access to one + // goroutine per channel ID. This is done to ensure that when + // the gossiper is handling an announcement, the db state stays + // consistent between when the DB is first read to it's written. + channelMtx *multimutex.Mutex + sync.Mutex } @@ -206,6 +213,7 @@ func New(cfg Config, selfKey *btcec.PublicKey) (*AuthenticatedGossiper, error) { prematureAnnouncements: make(map[uint32][]*networkMsg), prematureChannelUpdates: make(map[uint64][]*networkMsg), waitingProofs: storage, + channelMtx: multimutex.NewMutex(), }, nil } @@ -972,11 +980,6 @@ func (d *AuthenticatedGossiper) networkHandler() { } } - // If we're able to broadcast the current batch - // successfully, then we reset the batch for a new - // round of announcements. - announcements.Reset() - // The retransmission timer has ticked which indicates that we // should check if we need to prune or re-broadcast any of our // personal channels. This addresses the case of "zombie" channels and @@ -1376,6 +1379,14 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(nMsg *networkMsg) []n // present in this channel are not present in the database, a // partial node will be added to represent each node while we // wait for a node announcement. + // + // Before we add the edge to the database, we obtain + // the mutex for this channel ID. We do this to ensure + // no other goroutine has read the database and is now + // making decisions based on this DB state, before it + // writes to the DB. + d.channelMtx.Lock(msg.ShortChannelID.ToUint64()) + defer d.channelMtx.Unlock(msg.ShortChannelID.ToUint64()) if err := d.cfg.Router.AddEdge(edge); err != nil { // If the edge was rejected due to already being known, // then it may be that case that this new message has a @@ -1516,6 +1527,13 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(nMsg *networkMsg) []n // Get the node pub key as far as we don't have it in channel // update announcement message. We'll need this to properly // verify message signature. + // + // We make sure to obtain the mutex for this channel ID + // before we acces the database. This ensures the state + // we read from the database has not changed between this + // point and when we call UpdateEdge() later. + d.channelMtx.Lock(msg.ShortChannelID.ToUint64()) + defer d.channelMtx.Unlock(msg.ShortChannelID.ToUint64()) chanInfo, _, _, err := d.cfg.Router.GetChannelByID(msg.ShortChannelID) if err != nil { switch err { @@ -1679,6 +1697,12 @@ func (d *AuthenticatedGossiper) processNetworkAnnouncement(nMsg *networkMsg) []n // Ensure that we know of a channel with the target channel ID // before proceeding further. + // + // We must acquire the mutex for this channel ID before getting + // the channel from the database, to ensure what we read does + // not change before we call AddProof() later. + d.channelMtx.Lock(msg.ShortChannelID.ToUint64()) + defer d.channelMtx.Unlock(msg.ShortChannelID.ToUint64()) chanInfo, e1, e2, err := d.cfg.Router.GetChannelByID( msg.ShortChannelID) if err != nil { diff --git a/lnd_test.go b/lnd_test.go index 37a5e30a..6d1968de 100644 --- a/lnd_test.go +++ b/lnd_test.go @@ -2366,12 +2366,25 @@ func testMultiHopPayments(net *lntest.NetworkHarness, t *harnessTest) { } // Wait for all nodes to have seen all channels. + nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol, dave} + nodeNames := []string{"Alice", "Bob", "Carol", "Dave"} for _, chanPoint := range networkChans { - for _, node := range []*lntest.HarnessNode{net.Alice, net.Bob, carol, dave} { + for i, node := range nodes { + txid, e := chainhash.NewHash(chanPoint.FundingTxid) + if e != nil { + t.Fatalf("unable to create sha hash: %v", e) + } + point := wire.OutPoint{ + Hash: *txid, + Index: chanPoint.OutputIndex, + } + ctxt, _ = context.WithTimeout(ctxb, timeout) err = node.WaitForNetworkChannelOpen(ctxt, chanPoint) if err != nil { - t.Fatalf("timeout waiting for channel open: %v", err) + t.Fatalf("%s(%d): timeout waiting for "+ + "channel(%s) open: %v", nodeNames[i], + node.NodeID, point, err) } } } @@ -2551,17 +2564,27 @@ func testPrivateChannels(net *lntest.NetworkHarness, t *harnessTest) { // Wait for all nodes to have seen all these channels, as they // are all public. nodes := []*lntest.HarnessNode{net.Alice, net.Bob, carol, dave} + nodeNames := []string{"Alice", "Bob", "Carol", "Dave"} for _, chanPoint := range networkChans { - for _, node := range nodes { + for i, node := range nodes { + txid, e := chainhash.NewHash(chanPoint.FundingTxid) + if e != nil { + t.Fatalf("unable to create sha hash: %v", e) + } + point := wire.OutPoint{ + Hash: *txid, + Index: chanPoint.OutputIndex, + } + ctxt, _ = context.WithTimeout(ctxb, timeout) err = node.WaitForNetworkChannelOpen(ctxt, chanPoint) if err != nil { - t.Fatalf("timeout waiting for channel open: %v", - err) + t.Fatalf("%s(%d): timeout waiting for "+ + "channel(%s) open: %v", nodeNames[i], + node.NodeID, point, err) } } } - // Now create a _private_ channel directly between Carol and // Alice of 100k. if err := net.ConnectNodes(ctxb, carol, net.Alice); err != nil { @@ -3150,14 +3173,23 @@ func testRevokedCloseRetribution(net *lntest.NetworkHarness, t *harnessTest) { // Next query for Bob's channel state, as we sent 3 payments of 10k // satoshis each, Bob should now see his balance as being 30k satoshis. - time.Sleep(time.Millisecond * 200) - bobChan, err := getBobChanInfo() + var bobChan *lnrpc.ActiveChannel + var predErr error + err = lntest.WaitPredicate(func() bool { + bobChan, err = getBobChanInfo() + if err != nil { + t.Fatalf("unable to get bob's channel info: %v", err) + } + if bobChan.LocalBalance != 30000 { + predErr = fmt.Errorf("bob's balance is incorrect, "+ + "got %v, expected %v", bobChan.LocalBalance, + 30000) + return false + } + return true + }, time.Second*15) if err != nil { - t.Fatalf("unable to get bob's channel info: %v", err) - } - if bobChan.LocalBalance != 30000 { - t.Fatalf("bob's balance is incorrect, got %v, expected %v", - bobChan.LocalBalance, 30000) + t.Fatalf("%v", predErr) } // Grab Bob's current commitment height (update number), we'll later diff --git a/multimutex/multimutex.go b/multimutex/multimutex.go new file mode 100644 index 00000000..e37c88d5 --- /dev/null +++ b/multimutex/multimutex.go @@ -0,0 +1,96 @@ +package multimutex + +import ( + "fmt" + "sync" +) + +// cntMutex is a struct that wraps a counter and a mutex, and is used +// to keep track of the number of goroutines waiting for access to the +// mutex, such that we can forget about it when the counter is zero. +type cntMutex struct { + cnt int + sync.Mutex +} + +// Mutex is a struct that keeps track of a set of mutexes with +// a given ID. It can be used for making sure only one goroutine +// gets given the mutex per ID. +type Mutex struct { + // mutexes is a map of IDs to a cntMutex. The cntMutex for + // a given ID will hold the mutex to be used by all + // callers requesting access for the ID, in addition to + // the count of callers. + mutexes map[uint64]*cntMutex + + // mapMtx is used to give synchronize concurrent access + // to the mutexes map. + mapMtx sync.Mutex +} + +// NewMutex creates a new Mutex. +func NewMutex() *Mutex { + return &Mutex{ + mutexes: make(map[uint64]*cntMutex), + } +} + +// Lock locks the mutex by the given ID. If the mutex is already +// locked by this ID, Lock blocks until the mutex is available. +func (c *Mutex) Lock(id uint64) { + c.mapMtx.Lock() + mtx, ok := c.mutexes[id] + if ok { + // If the mutex already existed in the map, we + // increment its counter, to indicate that there + // now is one more goroutine waiting for it. + mtx.cnt++ + } else { + // If it was not in the map, it means no other + // goroutine has locked the mutex for this ID, + // and we can create a new mutex with count 1 + // and add it to the map. + mtx = &cntMutex{ + cnt: 1, + } + c.mutexes[id] = mtx + } + c.mapMtx.Unlock() + + // Acquire the mutex for this ID. + mtx.Lock() +} + +// Unlock unlocks the mutex by the given ID. It is a run-time +// error if the mutex is not locked by the ID on entry to Unlock. +func (c *Mutex) Unlock(id uint64) { + // Since we are done with all the work for this + // update, we update the map to reflect that. + c.mapMtx.Lock() + + mtx, ok := c.mutexes[id] + if !ok { + // The mutex not existing in the map means + // an unlock for an ID not currently locked + // was attempted. + panic(fmt.Sprintf("double unlock for id %v", + id)) + } + + // Decrement the counter. If the count goes to + // zero, it means this caller was the last one + // to wait for the mutex, and we can delete it + // from the map. We can do this safely since we + // are under the mapMtx, meaning that all other + // goroutines waiting for the mutex already + // have incremented it, or will create a new + // mutex when they get the mapMtx. + mtx.cnt-- + if mtx.cnt == 0 { + delete(c.mutexes, id) + } + c.mapMtx.Unlock() + + // Unlock the mutex for this ID. + mtx.Unlock() +} diff --git a/routing/router.go b/routing/router.go index baa653cf..cef4f9c0 100644 --- a/routing/router.go +++ b/routing/router.go @@ -15,6 +15,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/routing/chainview" "github.com/roasbeef/btcd/btcec" "github.com/roasbeef/btcd/wire" @@ -162,97 +163,6 @@ func newRouteTuple(amt lnwire.MilliSatoshi, dest []byte) routeTuple { return r } -// cntMutex is a struct that wraps a counter and a mutex, and is used -// to keep track of the number of goroutines waiting for access to the -// mutex, such that we can forget about it when the counter is zero. -type cntMutex struct { - cnt int - sync.Mutex -} - -// mutexForID is a struct that keeps track of a set of mutexes with -// a given ID. It can be used for making sure only one goroutine -// gets given the mutex per ID. Here it is currently used to making -// sure we only process one ChannelEdgePolicy per channelID at a -// given time. -type mutexForID struct { - // mutexes is a map of IDs to a cntMutex. The cntMutex for - // a given ID will hold the mutex to be used by all - // callers requesting access for the ID, in addition to - // the count of callers. - mutexes map[uint64]*cntMutex - - // mapMtx is used to give synchronize concurrent access - // to the mutexes map. - mapMtx sync.Mutex -} - -func newMutexForID() *mutexForID { - return &mutexForID{ - mutexes: make(map[uint64]*cntMutex), - } -} - -// Lock locks the mutex by the given ID. If the mutex is already -// locked by this ID, Lock blocks until the mutex is available. -func (c *mutexForID) Lock(id uint64) { - c.mapMtx.Lock() - mtx, ok := c.mutexes[id] - if ok { - // If the mutex already existed in the map, we - // increment its counter, to indicate that there - // now is one more goroutine waiting for it. - mtx.cnt++ - } else { - // If it was not in the map, it means no other - // goroutine has locked the mutex for this ID, - // and we can create a new mutex with count 1 - // and add it to the map. - mtx = &cntMutex{ - cnt: 1, - } - c.mutexes[id] = mtx - } - c.mapMtx.Unlock() - - // Acquire the mutex for this ID. - mtx.Lock() -} - -// Unlock unlocks the mutex by the given ID. It is a run-time -// error if the mutex is not locked by the ID on entry to Unlock. -func (c *mutexForID) Unlock(id uint64) { - // Since we are done with all the work for this - // update, we update the map to reflect that. - c.mapMtx.Lock() - - mtx, ok := c.mutexes[id] - if !ok { - // The mutex not existing in the map means - // an unlock for an ID not currently locked - // was attempted. - panic(fmt.Sprintf("double unlock for id %v", - id)) - } - - // Decrement the counter. If the count goes to - // zero, it means this caller was the last one - // to wait for the mutex, and we can delete it - // from the map. We can do this safely since we - // are under the mapMtx, meaning that all other - // goroutines waiting for the mutex already - // have incremented it, or will create a new - // mutex when they get the mapMtx. - mtx.cnt-- - if mtx.cnt == 0 { - delete(c.mutexes, id) - } - c.mapMtx.Unlock() - - // Unlock the mutex for this ID. - mtx.Unlock() -} - // ChannelRouter is the layer 3 router within the Lightning stack. Below the // ChannelRouter is the HtlcSwitch, and below that is the Bitcoin blockchain // itself. The primary role of the ChannelRouter is to respond to queries for @@ -325,7 +235,7 @@ type ChannelRouter struct { // channelEdgeMtx is a mutex we use to make sure we process only one // ChannelEdgePolicy at a time for a given channelID, to ensure // consistency between the various database accesses. - channelEdgeMtx *mutexForID + channelEdgeMtx *multimutex.Mutex sync.RWMutex @@ -355,7 +265,7 @@ func New(cfg Config) (*ChannelRouter, error) { topologyClients: make(map[uint64]*topologyClient), ntfnClientUpdates: make(chan *topologyClientUpdate), missionControl: newMissionControl(cfg.Graph, selfNode), - channelEdgeMtx: newMutexForID(), + channelEdgeMtx: multimutex.NewMutex(), selfNode: selfNode, routeCache: make(map[routeTuple][]*Route), quit: make(chan struct{}),