From 3be905109cc468b83d64637e8cb215387e9046b5 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Mon, 6 Nov 2017 16:00:17 -0800 Subject: [PATCH] routing: add RestartRouter method to testCtx --- routing/router_test.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/routing/router_test.go b/routing/router_test.go index c13c11c3..8a94a0c7 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -31,6 +31,37 @@ type testCtx struct { chainView *mockChainView } +func (c *testCtx) RestartRouter() error { + // First, we'll reset the chainView's state as it doesn't persist the + // filter between restarts. + c.chainView.Reset() + + // With the chainView reset, we'll now re-create the router itself, and + // start it. + router, err := New(Config{ + Graph: c.graph, + Chain: c.chain, + ChainView: c.chainView, + SendToSwitch: func(_ *btcec.PublicKey, + _ *lnwire.UpdateAddHTLC, _ *sphinx.Circuit) ([32]byte, error) { + return [32]byte{}, nil + }, + ChannelPruneExpiry: time.Hour * 24, + GraphPruneInterval: time.Hour * 2, + }) + if err != nil { + return fmt.Errorf("unable to create router %v", err) + } + if err := router.Start(); err != nil { + return fmt.Errorf("unable to start router: %v", err) + } + + // Finally, we'll swap out the pointer in the testCtx with this fresh + // instance of the router. + c.router = router + return nil +} + func createTestCtx(startingHeight uint32, testGraph ...string) (*testCtx, func(), error) { var ( graph *channeldb.ChannelGraph