htlcswitch: face race condition in unit tests by returning invoice

In this commit we modify the primary InvoiceRegistry interface within
the package to instead return a direct value for LookupInvoice rather
than a pointer. This fixes an existing race condition wherein a caller
could modify or read the value of the returned invoice.
This commit is contained in:
Olaoluwa Osuntokun 2017-11-11 16:09:14 -08:00
parent 010815e280
commit b6f64932c2
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
7 changed files with 23 additions and 14 deletions

View File

@ -12,7 +12,7 @@ import (
type InvoiceDatabase interface { type InvoiceDatabase interface {
// LookupInvoice attempts to look up an invoice according to it's 32 // LookupInvoice attempts to look up an invoice according to it's 32
// byte payment hash. // byte payment hash.
LookupInvoice(chainhash.Hash) (*channeldb.Invoice, error) LookupInvoice(chainhash.Hash) (channeldb.Invoice, error)
// SettleInvoice attempts to mark an invoice corresponding to the // SettleInvoice attempts to mark an invoice corresponding to the
// passed payment hash as fully settled. // passed payment hash as fully settled.

View File

@ -978,7 +978,7 @@ func TestChannelLinkMultiHopUnknownPaymentHash(t *testing.T) {
invoice.Terms.PaymentPreimage[0] ^= byte(255) invoice.Terms.PaymentPreimage[0] ^= byte(255)
// Check who is last in the route and add invoice to server registry. // Check who is last in the route and add invoice to server registry.
if err := n.carolServer.registry.AddInvoice(invoice); err != nil { if err := n.carolServer.registry.AddInvoice(*invoice); err != nil {
t.Fatalf("unable to add invoice in carol registry: %v", err) t.Fatalf("unable to add invoice in carol registry: %v", err)
} }
@ -1955,7 +1955,7 @@ func TestChannelRetransmission(t *testing.T) {
// TODO(andrew.shvv) Will be removed if we move the notification center // TODO(andrew.shvv) Will be removed if we move the notification center
// to the channel link itself. // to the channel link itself.
var invoice *channeldb.Invoice var invoice channeldb.Invoice
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
select { select {
case <-time.After(time.Millisecond * 200): case <-time.After(time.Millisecond * 200):

View File

@ -397,22 +397,22 @@ var _ ChannelLink = (*mockChannelLink)(nil)
type mockInvoiceRegistry struct { type mockInvoiceRegistry struct {
sync.Mutex sync.Mutex
invoices map[chainhash.Hash]*channeldb.Invoice invoices map[chainhash.Hash]channeldb.Invoice
} }
func newMockRegistry() *mockInvoiceRegistry { func newMockRegistry() *mockInvoiceRegistry {
return &mockInvoiceRegistry{ return &mockInvoiceRegistry{
invoices: make(map[chainhash.Hash]*channeldb.Invoice), invoices: make(map[chainhash.Hash]channeldb.Invoice),
} }
} }
func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoice, error) { func (i *mockInvoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
i.Lock() i.Lock()
defer i.Unlock() defer i.Unlock()
invoice, ok := i.invoices[rHash] invoice, ok := i.invoices[rHash]
if !ok { if !ok {
return nil, errors.New("can't find mock invoice") return channeldb.Invoice{}, errors.New("can't find mock invoice")
} }
return invoice, nil return invoice, nil
@ -428,11 +428,12 @@ func (i *mockInvoiceRegistry) SettleInvoice(rhash chainhash.Hash) error {
} }
invoice.Terms.Settled = true invoice.Terms.Settled = true
i.invoices[rhash] = invoice
return nil return nil
} }
func (i *mockInvoiceRegistry) AddInvoice(invoice *channeldb.Invoice) error { func (i *mockInvoiceRegistry) AddInvoice(invoice channeldb.Invoice) error {
i.Lock() i.Lock()
defer i.Unlock() defer i.Unlock()

View File

@ -549,7 +549,7 @@ func (n *threeHopNetwork) makePayment(sendingPeer, receivingPeer Peer,
rhash = fastsha256.Sum256(invoice.Terms.PaymentPreimage[:]) rhash = fastsha256.Sum256(invoice.Terms.PaymentPreimage[:])
// Check who is last in the route and add invoice to server registry. // Check who is last in the route and add invoice to server registry.
if err := receiver.registry.AddInvoice(invoice); err != nil { if err := receiver.registry.AddInvoice(*invoice); err != nil {
paymentErr <- err paymentErr <- err
return &paymentResponse{ return &paymentResponse{
rhash: rhash, rhash: rhash,

View File

@ -98,7 +98,7 @@ func (i *invoiceRegistry) AddInvoice(invoice *channeldb.Invoice) error {
// lookupInvoice looks up an invoice by its payment hash (R-Hash), if found // lookupInvoice looks up an invoice by its payment hash (R-Hash), if found
// then we're able to pull the funds pending within an HTLC. // then we're able to pull the funds pending within an HTLC.
// TODO(roasbeef): ignore if settled? // TODO(roasbeef): ignore if settled?
func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoice, error) { func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (channeldb.Invoice, error) {
// First check the in-memory debug invoice index to see if this is an // First check the in-memory debug invoice index to see if this is an
// existing invoice added for debugging. // existing invoice added for debugging.
i.RLock() i.RLock()
@ -107,12 +107,17 @@ func (i *invoiceRegistry) LookupInvoice(rHash chainhash.Hash) (*channeldb.Invoic
// If found, then simply return the invoice directly. // If found, then simply return the invoice directly.
if ok { if ok {
return invoice, nil return *invoice, nil
} }
// Otherwise, we'll check the database to see if there's an existing // Otherwise, we'll check the database to see if there's an existing
// matching invoice. // matching invoice.
return i.cdb.LookupInvoice(rHash) invoice, err := i.cdb.LookupInvoice(rHash)
if err != nil {
return channeldb.Invoice{}, err
}
return *invoice, nil
} }
// SettleInvoice attempts to mark an invoice as settled. If the invoice is a // SettleInvoice attempts to mark an invoice as settled. If the invoice is a

View File

@ -3359,7 +3359,10 @@ func TestChanSyncUnableToSync(t *testing.T) {
} }
} }
// TestChanAvailableBandwidth... // TestChanAvailableBandwidth tests the accuracy of the AvailableBalance()
// method. The value returned from this message should reflect the value
// returned within the commitment state of a channel after the transition is
// initiated.
func TestChanAvailableBandwidth(t *testing.T) { func TestChanAvailableBandwidth(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -2029,7 +2029,7 @@ func (r *rpcServer) LookupInvoice(ctx context.Context,
return spew.Sdump(invoice) return spew.Sdump(invoice)
})) }))
rpcInvoice, err := createRPCInvoice(invoice) rpcInvoice, err := createRPCInvoice(&invoice)
if err != nil { if err != nil {
return nil, err return nil, err
} }