diff --git a/go.mod b/go.mod index 4636c8c..a8a8f39 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/olekukonko/tablewriter v0.0.4 // indirect github.com/pborman/uuid v1.2.0 // indirect github.com/prometheus/client_golang v1.6.0 + github.com/prometheus/common v0.9.1 github.com/prometheus/tsdb v0.10.0 // indirect github.com/rjeczalik/notify v0.9.2 // indirect github.com/rs/cors v1.7.0 diff --git a/go.sum b/go.sum index d79e9a8..d46c985 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,10 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE github.com/Shopify/sarama v1.26.1/go.mod h1:NbSGBSSndYaIhRcBtY9V0U7AyH+x71bG668AuWys/yU= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 h1:Hs82Z41s6SdL1CELW+XaDYmOH4hkBN4/N9og/AsOv7E= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/allegro/bigcache v1.2.1 h1:hg1sY1raCwic3Vnsvje6TT7/pnZba83LeFck5NrFKSc= github.com/allegro/bigcache v1.2.1/go.mod h1:Cb/ax3seSYIx7SuZdm2G2xzfwmv3TPSk2ucNfQESPXM= @@ -218,6 +220,7 @@ github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/status-im/keycard-go v0.0.0-20200402102358-957c09536969 h1:Oo2KZNP70KE0+IUJSidPj/BFS/RXNHmKIJOdckzml2E= @@ -337,6 +340,7 @@ google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miE google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/bsm/ratelimit.v1 v1.0.0-20160220154919-db14e161995a/go.mod h1:KF9sEfUPAXdG8Oev9e99iLGnl2uJMjc5B+4y3O7x610= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/snow/networking/router/chain_router.go b/snow/networking/router/chain_router.go index 4505bec..8cf708a 100644 --- a/snow/networking/router/chain_router.go +++ b/snow/networking/router/chain_router.go @@ -124,19 +124,12 @@ func (sr *ChainRouter) GetAcceptedFrontierFailed(validatorID ids.ShortID, chainI sr.lock.RLock() defer sr.lock.RUnlock() + sr.timeouts.Cancel(validatorID, chainID, requestID) if chain, exists := sr.chains[chainID.Key()]; exists { - if !chain.GetAcceptedFrontierFailed(validatorID, requestID) { - sr.log.Debug("deferring GetAcceptedFrontier timeout due to a full queue on %s", chainID) - // Defer this call to later - sr.timeouts.Register(validatorID, chainID, requestID, func() { - sr.GetAcceptedFrontierFailed(validatorID, chainID, requestID) - }) - return - } + chain.GetAcceptedFrontierFailed(validatorID, requestID) } else { sr.log.Error("GetAcceptedFrontierFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } - sr.timeouts.Cancel(validatorID, chainID, requestID) } // GetAccepted routes an incoming GetAccepted request from the @@ -176,18 +169,12 @@ func (sr *ChainRouter) GetAcceptedFailed(validatorID ids.ShortID, chainID ids.ID sr.lock.RLock() defer sr.lock.RUnlock() + sr.timeouts.Cancel(validatorID, chainID, requestID) if chain, exists := sr.chains[chainID.Key()]; exists { - if !chain.GetAcceptedFailed(validatorID, requestID) { - sr.timeouts.Register(validatorID, chainID, requestID, func() { - sr.log.Debug("deferring GetAccepted timeout due to a full queue on %s", chainID) - sr.GetAcceptedFailed(validatorID, chainID, requestID) - }) - return - } + chain.GetAcceptedFailed(validatorID, requestID) } else { sr.log.Error("GetAcceptedFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } - sr.timeouts.Cancel(validatorID, chainID, requestID) } // GetAncestors routes an incoming GetAncestors message from the validator with ID [validatorID] @@ -227,18 +214,12 @@ func (sr *ChainRouter) GetAncestorsFailed(validatorID ids.ShortID, chainID ids.I sr.lock.RLock() defer sr.lock.RUnlock() + sr.timeouts.Cancel(validatorID, chainID, requestID) if chain, exists := sr.chains[chainID.Key()]; exists { - if !chain.GetAncestorsFailed(validatorID, requestID) { - sr.timeouts.Register(validatorID, chainID, requestID, func() { - sr.log.Debug("deferring GetAncestors timeout due to a full queue on %s", chainID) - sr.GetAncestorsFailed(validatorID, chainID, requestID) - }) - return - } + chain.GetAncestorsFailed(validatorID, requestID) } else { sr.log.Error("GetAncestorsFailed(%s, %s, %d, %d) dropped due to unknown chain", validatorID, chainID, requestID) } - sr.timeouts.Cancel(validatorID, chainID, requestID) } // Get routes an incoming Get request from the validator with ID [validatorID] @@ -278,18 +259,12 @@ func (sr *ChainRouter) GetFailed(validatorID ids.ShortID, chainID ids.ID, reques sr.lock.RLock() defer sr.lock.RUnlock() + sr.timeouts.Cancel(validatorID, chainID, requestID) if chain, exists := sr.chains[chainID.Key()]; exists { - if !chain.GetFailed(validatorID, requestID) { - sr.timeouts.Register(validatorID, chainID, requestID, func() { - sr.log.Debug("deferring Get timeout due to a full queue on %s", chainID) - sr.GetFailed(validatorID, chainID, requestID) - }) - return - } + chain.GetFailed(validatorID, requestID) } else { sr.log.Error("GetFailed(%s, %s, %d) dropped due to unknown chain", validatorID, chainID, requestID) } - sr.timeouts.Cancel(validatorID, chainID, requestID) } // PushQuery routes an incoming PushQuery request from the validator with ID [validatorID] @@ -341,18 +316,12 @@ func (sr *ChainRouter) QueryFailed(validatorID ids.ShortID, chainID ids.ID, requ sr.lock.RLock() defer sr.lock.RUnlock() + sr.timeouts.Cancel(validatorID, chainID, requestID) if chain, exists := sr.chains[chainID.Key()]; exists { - if !chain.QueryFailed(validatorID, requestID) { - sr.timeouts.Register(validatorID, chainID, requestID, func() { - sr.log.Debug("deferring Query timeout due to a full queue on %s", chainID) - sr.QueryFailed(validatorID, chainID, requestID) - }) - return - } + chain.QueryFailed(validatorID, requestID) } else { sr.log.Error("QueryFailed(%s, %s, %d, %s) dropped due to unknown chain", validatorID, chainID, requestID) } - sr.timeouts.Cancel(validatorID, chainID, requestID) } // Shutdown shuts down this router diff --git a/snow/networking/router/handler.go b/snow/networking/router/handler.go index 8b40223..9d45baf 100644 --- a/snow/networking/router/handler.go +++ b/snow/networking/router/handler.go @@ -4,12 +4,14 @@ package router import ( + "sync" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow/engine/common" - "github.com/prometheus/client_golang/prometheus" ) // Handler passes incoming messages from the network to the consensus engine @@ -17,12 +19,18 @@ import ( type Handler struct { metrics - msgs chan message - closed chan struct{} - engine common.Engine - msgChan <-chan common.Message + msgs chan message + reliableMsgsSema chan struct{} + reliableMsgsLock sync.Mutex + reliableMsgs []message + closed chan struct{} + msgChan <-chan common.Message + + ctx *snow.Context + engine common.Engine toClose func() + closing bool } // Initialize this consensus handler @@ -35,9 +43,12 @@ func (h *Handler) Initialize( ) { h.metrics.Initialize(namespace, metrics) h.msgs = make(chan message, bufferSize) + h.reliableMsgsSema = make(chan struct{}, 1) h.closed = make(chan struct{}) - h.engine = engine h.msgChan = msgChan + + h.ctx = engine.Context() + h.engine = engine } // Context of this Handler @@ -46,37 +57,38 @@ func (h *Handler) Context() *snow.Context { return h.engine.Context() } // Dispatch waits for incoming messages from the network // and, when they arrive, sends them to the consensus engine func (h *Handler) Dispatch() { - log := h.Context().Log defer func() { - log.Info("finished shutting down chain") + h.ctx.Log.Info("finished shutting down chain") close(h.closed) }() - closing := false for { select { case msg, ok := <-h.msgs: if !ok { + // the msgs channel has been closed, so this dispatcher should exit return } + h.metrics.pending.Dec() - if closing { - log.Debug("dropping message due to closing:\n%s", msg) - continue - } - if h.dispatchMsg(msg) { - closing = true + h.dispatchMsg(msg) + case <-h.reliableMsgsSema: + // get all the reliable messages + h.reliableMsgsLock.Lock() + msgs := h.reliableMsgs + h.reliableMsgs = nil + h.reliableMsgsLock.Unlock() + + // fire all the reliable messages + for _, msg := range msgs { + h.metrics.pending.Dec() + h.dispatchMsg(msg) } case msg := <-h.msgChan: - if closing { - log.Debug("dropping internal message due to closing:\n%s", msg) - continue - } - if h.dispatchMsg(message{messageType: notifyMsg, notification: msg}) { - closing = true - } + // handle a message from the VM + h.dispatchMsg(message{messageType: notifyMsg, notification: msg}) } - if closing && h.toClose != nil { + if h.closing && h.toClose != nil { go h.toClose() } } @@ -85,14 +97,19 @@ func (h *Handler) Dispatch() { // Dispatch a message to the consensus engine. // Returns true iff this consensus handler (and its associated engine) should shutdown // (due to receipt of a shutdown message) -func (h *Handler) dispatchMsg(msg message) bool { +func (h *Handler) dispatchMsg(msg message) { + if h.closing { + h.ctx.Log.Debug("dropping message due to closing:\n%s", msg) + h.metrics.dropped.Inc() + return + } + startTime := time.Now() - ctx := h.engine.Context() - ctx.Lock.Lock() - defer ctx.Lock.Unlock() + h.ctx.Lock.Lock() + defer h.ctx.Lock.Unlock() - ctx.Log.Verbo("Forwarding message to consensus: %s", msg) + h.ctx.Log.Verbo("Forwarding message to consensus: %s", msg) var ( err error done bool @@ -159,9 +176,10 @@ func (h *Handler) dispatchMsg(msg message) bool { } if err != nil { - ctx.Log.Fatal("forcing chain to shutdown due to %s", err) + h.ctx.Log.Fatal("forcing chain to shutdown due to %s", err) } - return done || err != nil + + h.closing = done || err != nil } // GetAcceptedFrontier passes a GetAcceptedFrontier message received from the @@ -187,8 +205,8 @@ func (h *Handler) AcceptedFrontier(validatorID ids.ShortID, requestID uint32, co // GetAcceptedFrontierFailed passes a GetAcceptedFrontierFailed message received // from the network to the consensus engine. -func (h *Handler) GetAcceptedFrontierFailed(validatorID ids.ShortID, requestID uint32) bool { - return h.sendMsg(message{ +func (h *Handler) GetAcceptedFrontierFailed(validatorID ids.ShortID, requestID uint32) { + h.sendReliableMsg(message{ messageType: getAcceptedFrontierFailedMsg, validatorID: validatorID, requestID: requestID, @@ -219,14 +237,43 @@ func (h *Handler) Accepted(validatorID ids.ShortID, requestID uint32, containerI // GetAcceptedFailed passes a GetAcceptedFailed message received from the // network to the consensus engine. -func (h *Handler) GetAcceptedFailed(validatorID ids.ShortID, requestID uint32) bool { - return h.sendMsg(message{ +func (h *Handler) GetAcceptedFailed(validatorID ids.ShortID, requestID uint32) { + h.sendReliableMsg(message{ messageType: getAcceptedFailedMsg, validatorID: validatorID, requestID: requestID, }) } +// GetAncestors passes a GetAncestors message received from the network to the consensus engine. +func (h *Handler) GetAncestors(validatorID ids.ShortID, requestID uint32, containerID ids.ID) bool { + return h.sendMsg(message{ + messageType: getAncestorsMsg, + validatorID: validatorID, + requestID: requestID, + containerID: containerID, + }) +} + +// MultiPut passes a MultiPut message received from the network to the consensus engine. +func (h *Handler) MultiPut(validatorID ids.ShortID, requestID uint32, containers [][]byte) bool { + return h.sendMsg(message{ + messageType: multiPutMsg, + validatorID: validatorID, + requestID: requestID, + containers: containers, + }) +} + +// GetAncestorsFailed passes a GetAncestorsFailed message to the consensus engine. +func (h *Handler) GetAncestorsFailed(validatorID ids.ShortID, requestID uint32) { + h.sendReliableMsg(message{ + messageType: getAncestorsFailedMsg, + validatorID: validatorID, + requestID: requestID, + }) +} + // Get passes a Get message received from the network to the consensus engine. func (h *Handler) Get(validatorID ids.ShortID, requestID uint32, containerID ids.ID) bool { return h.sendMsg(message{ @@ -237,16 +284,6 @@ func (h *Handler) Get(validatorID ids.ShortID, requestID uint32, containerID ids }) } -// GetAncestors passes a GetAncestors message received from the network to the consensus engine. -func (h *Handler) GetAncestors(validatorID ids.ShortID, requestID uint32, containerID ids.ID) bool { - return h.sendMsg(message{ - messageType: getAncestorsMsg, - validatorID: validatorID, - requestID: requestID, - containerID: containerID, - }) -} - // Put passes a Put message received from the network to the consensus engine. func (h *Handler) Put(validatorID ids.ShortID, requestID uint32, containerID ids.ID, container []byte) bool { return h.sendMsg(message{ @@ -258,34 +295,15 @@ func (h *Handler) Put(validatorID ids.ShortID, requestID uint32, containerID ids }) } -// MultiPut passes a MultiPut message received from the network to the consensus engine. -func (h *Handler) MultiPut(validatorID ids.ShortID, requestID uint32, containers [][]byte) bool { - return h.sendMsg(message{ - messageType: multiPutMsg, - validatorID: validatorID, - requestID: requestID, - containers: containers, - }) -} - // GetFailed passes a GetFailed message to the consensus engine. -func (h *Handler) GetFailed(validatorID ids.ShortID, requestID uint32) bool { - return h.sendMsg(message{ +func (h *Handler) GetFailed(validatorID ids.ShortID, requestID uint32) { + h.sendReliableMsg(message{ messageType: getFailedMsg, validatorID: validatorID, requestID: requestID, }) } -// GetAncestorsFailed passes a GetAncestorsFailed message to the consensus engine. -func (h *Handler) GetAncestorsFailed(validatorID ids.ShortID, requestID uint32) bool { - return h.sendMsg(message{ - messageType: getAncestorsFailedMsg, - validatorID: validatorID, - requestID: requestID, - }) -} - // PushQuery passes a PushQuery message received from the network to the consensus engine. func (h *Handler) PushQuery(validatorID ids.ShortID, requestID uint32, blockID ids.ID, block []byte) bool { return h.sendMsg(message{ @@ -318,8 +336,8 @@ func (h *Handler) Chits(validatorID ids.ShortID, requestID uint32, votes ids.Set } // QueryFailed passes a QueryFailed message received from the network to the consensus engine. -func (h *Handler) QueryFailed(validatorID ids.ShortID, requestID uint32) bool { - return h.sendMsg(message{ +func (h *Handler) QueryFailed(validatorID ids.ShortID, requestID uint32) { + h.sendReliableMsg(message{ messageType: queryFailedMsg, validatorID: validatorID, requestID: requestID, @@ -341,8 +359,9 @@ func (h *Handler) Notify(msg common.Message) bool { // Shutdown shuts down the dispatcher func (h *Handler) Shutdown() { - h.metrics.pending.Inc() - h.msgs <- message{messageType: shutdownMsg} + h.sendReliableMsg(message{ + messageType: shutdownMsg, + }) } func (h *Handler) sendMsg(msg message) bool { @@ -355,3 +374,15 @@ func (h *Handler) sendMsg(msg message) bool { return false } } + +func (h *Handler) sendReliableMsg(msg message) { + h.reliableMsgsLock.Lock() + defer h.reliableMsgsLock.Unlock() + + h.metrics.pending.Inc() + h.reliableMsgs = append(h.reliableMsgs, msg) + select { + case h.reliableMsgsSema <- struct{}{}: + default: + } +} diff --git a/snow/networking/sender/sender.go b/snow/networking/sender/sender.go index e81c5c0..92c02b8 100644 --- a/snow/networking/sender/sender.go +++ b/snow/networking/sender/sender.go @@ -31,17 +31,16 @@ func (s *Sender) Context() *snow.Context { return s.ctx } // GetAcceptedFrontier ... func (s *Sender) GetAcceptedFrontier(validatorIDs ids.ShortSet, requestID uint32) { - if validatorIDs.Contains(s.ctx.NodeID) { - validatorIDs.Remove(s.ctx.NodeID) - go s.router.GetAcceptedFrontier(s.ctx.NodeID, s.ctx.ChainID, requestID) - } - validatorList := validatorIDs.List() - for _, validatorID := range validatorList { + for _, validatorID := range validatorIDs.List() { vID := validatorID s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { s.router.GetAcceptedFrontierFailed(vID, s.ctx.ChainID, requestID) }) } + if validatorIDs.Contains(s.ctx.NodeID) { + validatorIDs.Remove(s.ctx.NodeID) + go s.router.GetAcceptedFrontier(s.ctx.NodeID, s.ctx.ChainID, requestID) + } s.sender.GetAcceptedFrontier(validatorIDs, s.ctx.ChainID, requestID) } @@ -49,24 +48,23 @@ func (s *Sender) GetAcceptedFrontier(validatorIDs ids.ShortSet, requestID uint32 func (s *Sender) AcceptedFrontier(validatorID ids.ShortID, requestID uint32, containerIDs ids.Set) { if validatorID.Equals(s.ctx.NodeID) { go s.router.AcceptedFrontier(validatorID, s.ctx.ChainID, requestID, containerIDs) - return + } else { + s.sender.AcceptedFrontier(validatorID, s.ctx.ChainID, requestID, containerIDs) } - s.sender.AcceptedFrontier(validatorID, s.ctx.ChainID, requestID, containerIDs) } // GetAccepted ... func (s *Sender) GetAccepted(validatorIDs ids.ShortSet, requestID uint32, containerIDs ids.Set) { - if validatorIDs.Contains(s.ctx.NodeID) { - validatorIDs.Remove(s.ctx.NodeID) - go s.router.GetAccepted(s.ctx.NodeID, s.ctx.ChainID, requestID, containerIDs) - } - validatorList := validatorIDs.List() - for _, validatorID := range validatorList { + for _, validatorID := range validatorIDs.List() { vID := validatorID s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { s.router.GetAcceptedFailed(vID, s.ctx.ChainID, requestID) }) } + if validatorIDs.Contains(s.ctx.NodeID) { + validatorIDs.Remove(s.ctx.NodeID) + go s.router.GetAccepted(s.ctx.NodeID, s.ctx.ChainID, requestID, containerIDs) + } s.sender.GetAccepted(validatorIDs, s.ctx.ChainID, requestID, containerIDs) } @@ -74,9 +72,9 @@ func (s *Sender) GetAccepted(validatorIDs ids.ShortSet, requestID uint32, contai func (s *Sender) Accepted(validatorID ids.ShortID, requestID uint32, containerIDs ids.Set) { if validatorID.Equals(s.ctx.NodeID) { go s.router.Accepted(validatorID, s.ctx.ChainID, requestID, containerIDs) - return + } else { + s.sender.Accepted(validatorID, s.ctx.ChainID, requestID, containerIDs) } - s.sender.Accepted(validatorID, s.ctx.ChainID, requestID, containerIDs) } // Get sends a Get message to the consensus engine running on the specified @@ -85,6 +83,13 @@ func (s *Sender) Accepted(validatorID ids.ShortID, requestID uint32, containerID // specified container. func (s *Sender) Get(validatorID ids.ShortID, requestID uint32, containerID ids.ID) { s.ctx.Log.Verbo("Sending Get to validator %s. RequestID: %d. ContainerID: %s", validatorID, requestID, containerID) + + // Sending a Get to myself will always fail + if validatorID.Equals(s.ctx.NodeID) { + go s.router.GetFailed(validatorID, s.ctx.ChainID, requestID) + return + } + // Add a timeout -- if we don't get a response before the timeout expires, // send this consensus engine a GetFailed message s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { @@ -101,6 +106,7 @@ func (s *Sender) GetAncestors(validatorID ids.ShortID, requestID uint32, contain go s.router.GetAncestorsFailed(validatorID, s.ctx.ChainID, requestID) return } + s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { s.router.GetAncestorsFailed(validatorID, s.ctx.ChainID, requestID) }) @@ -130,6 +136,13 @@ func (s *Sender) MultiPut(validatorID ids.ShortID, requestID uint32, containers // their preferred frontier given the existence of the specified container. func (s *Sender) PushQuery(validatorIDs ids.ShortSet, requestID uint32, containerID ids.ID, container []byte) { s.ctx.Log.Verbo("Sending PushQuery to validators %v. RequestID: %d. ContainerID: %s", validatorIDs, requestID, containerID) + for _, validatorID := range validatorIDs.List() { + vID := validatorID + s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { + s.router.QueryFailed(vID, s.ctx.ChainID, requestID) + }) + } + // If one of the validators in [validatorIDs] is myself, send this message directly // to my own router rather than sending it over the network if validatorIDs.Contains(s.ctx.NodeID) { // One of the validators in [validatorIDs] was myself @@ -139,13 +152,7 @@ func (s *Sender) PushQuery(validatorIDs ids.ShortSet, requestID uint32, containe // If this were not a goroutine, then we would deadlock here when [handler].msgs is full go s.router.PushQuery(s.ctx.NodeID, s.ctx.ChainID, requestID, containerID, container) } - validatorList := validatorIDs.List() // Convert set to list for easier iteration - for _, validatorID := range validatorList { - vID := validatorID - s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { - s.router.QueryFailed(vID, s.ctx.ChainID, requestID) - }) - } + s.sender.PushQuery(validatorIDs, s.ctx.ChainID, requestID, containerID, container) } @@ -155,6 +162,14 @@ func (s *Sender) PushQuery(validatorIDs ids.ShortSet, requestID uint32, containe // their preferred frontier. func (s *Sender) PullQuery(validatorIDs ids.ShortSet, requestID uint32, containerID ids.ID) { s.ctx.Log.Verbo("Sending PullQuery. RequestID: %d. ContainerID: %s", requestID, containerID) + + for _, validatorID := range validatorIDs.List() { + vID := validatorID + s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { + s.router.QueryFailed(vID, s.ctx.ChainID, requestID) + }) + } + // If one of the validators in [validatorIDs] is myself, send this message directly // to my own router rather than sending it over the network if validatorIDs.Contains(s.ctx.NodeID) { // One of the validators in [validatorIDs] was myself @@ -164,13 +179,7 @@ func (s *Sender) PullQuery(validatorIDs ids.ShortSet, requestID uint32, containe // If this were not a goroutine, then we would deadlock when [handler].msgs is full go s.router.PullQuery(s.ctx.NodeID, s.ctx.ChainID, requestID, containerID) } - validatorList := validatorIDs.List() // Convert set to list for easier iteration - for _, validatorID := range validatorList { - vID := validatorID - s.timeouts.Register(validatorID, s.ctx.ChainID, requestID, func() { - s.router.QueryFailed(vID, s.ctx.ChainID, requestID) - }) - } + s.sender.PullQuery(validatorIDs, s.ctx.ChainID, requestID, containerID) } @@ -181,9 +190,9 @@ func (s *Sender) Chits(validatorID ids.ShortID, requestID uint32, votes ids.Set) // to my own router rather than sending it over the network if validatorID.Equals(s.ctx.NodeID) { go s.router.Chits(validatorID, s.ctx.ChainID, requestID, votes) - return + } else { + s.sender.Chits(validatorID, s.ctx.ChainID, requestID, votes) } - s.sender.Chits(validatorID, s.ctx.ChainID, requestID, votes) } // Gossip the provided container diff --git a/snow/networking/sender/sender_test.go b/snow/networking/sender/sender_test.go index 7760307..8be5e99 100644 --- a/snow/networking/sender/sender_test.go +++ b/snow/networking/sender/sender_test.go @@ -4,18 +4,20 @@ package sender import ( + "math/rand" "reflect" "sync" "testing" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/ava-labs/gecko/ids" "github.com/ava-labs/gecko/snow" "github.com/ava-labs/gecko/snow/engine/common" "github.com/ava-labs/gecko/snow/networking/router" "github.com/ava-labs/gecko/snow/networking/timeout" "github.com/ava-labs/gecko/utils/logging" - "github.com/prometheus/client_golang/prometheus" ) func TestSenderContext(t *testing.T) { @@ -82,3 +84,128 @@ func TestTimeout(t *testing.T) { t.Fatalf("Timeouts should have fired") } } + +func TestReliableMessages(t *testing.T) { + tm := timeout.Manager{} + tm.Initialize(50 * time.Millisecond) + go tm.Dispatch() + + chainRouter := router.ChainRouter{} + chainRouter.Initialize(logging.NoLog{}, &tm, time.Hour, time.Second) + + sender := Sender{} + sender.Initialize(snow.DefaultContextTest(), &ExternalSenderTest{}, &chainRouter, &tm) + + engine := common.EngineTest{T: t} + engine.Default(true) + + engine.ContextF = snow.DefaultContextTest + engine.GossipF = func() error { return nil } + + queriesToSend := 1000 + awaiting := make([]chan struct{}, queriesToSend) + for i := 0; i < queriesToSend; i++ { + awaiting[i] = make(chan struct{}, 1) + } + + engine.QueryFailedF = func(validatorID ids.ShortID, reqID uint32) error { + close(awaiting[int(reqID)]) + return nil + } + + handler := router.Handler{} + handler.Initialize( + &engine, + nil, + 1, + "", + prometheus.NewRegistry(), + ) + go handler.Dispatch() + + chainRouter.AddChain(&handler) + + go func() { + for i := 0; i < queriesToSend; i++ { + vdrIDs := ids.ShortSet{} + vdrIDs.Add(ids.NewShortID([20]byte{1})) + + sender.PullQuery(vdrIDs, uint32(i), ids.Empty) + time.Sleep(time.Duration(rand.Float64() * float64(time.Microsecond))) + } + }() + + go func() { + for { + chainRouter.Gossip() + time.Sleep(time.Duration(rand.Float64() * float64(time.Microsecond))) + } + }() + + for _, await := range awaiting { + _, _ = <-await + } +} + +func TestReliableMessagesToMyself(t *testing.T) { + tm := timeout.Manager{} + tm.Initialize(50 * time.Millisecond) + go tm.Dispatch() + + chainRouter := router.ChainRouter{} + chainRouter.Initialize(logging.NoLog{}, &tm, time.Hour, time.Second) + + sender := Sender{} + sender.Initialize(snow.DefaultContextTest(), &ExternalSenderTest{}, &chainRouter, &tm) + + engine := common.EngineTest{T: t} + engine.Default(false) + + engine.ContextF = snow.DefaultContextTest + engine.GossipF = func() error { return nil } + engine.CantPullQuery = false + + queriesToSend := 2 + awaiting := make([]chan struct{}, queriesToSend) + for i := 0; i < queriesToSend; i++ { + awaiting[i] = make(chan struct{}, 1) + } + + engine.QueryFailedF = func(validatorID ids.ShortID, reqID uint32) error { + close(awaiting[int(reqID)]) + return nil + } + + handler := router.Handler{} + handler.Initialize( + &engine, + nil, + 1, + "", + prometheus.NewRegistry(), + ) + go handler.Dispatch() + + chainRouter.AddChain(&handler) + + go func() { + for i := 0; i < queriesToSend; i++ { + vdrIDs := ids.ShortSet{} + vdrIDs.Add(engine.Context().NodeID) + + sender.PullQuery(vdrIDs, uint32(i), ids.Empty) + time.Sleep(time.Duration(rand.Float64() * float64(time.Microsecond))) + } + }() + + go func() { + for { + chainRouter.Gossip() + time.Sleep(time.Duration(rand.Float64() * float64(time.Microsecond))) + } + }() + + for _, await := range awaiting { + _, _ = <-await + } +}