diff --git a/zpay32/invoice.go b/zpay32/invoice.go index 0856d3ea..4dbbfeec 100644 --- a/zpay32/invoice.go +++ b/zpay32/invoice.go @@ -265,9 +265,9 @@ func NewInvoice(net *chaincfg.Params, paymentHash [32]byte, return invoice, nil } -// Decode parses the provided encoded invoice, and returns a decoded Invoice in -// case it is valid by BOLT-0011. -func Decode(invoice string) (*Invoice, error) { +// Decode parses the provided encoded invoice and returns a decoded Invoice if +// it is valid by BOLT-0011 and matches the provided active network. +func Decode(invoice string, net *chaincfg.Params) (*Invoice, error) { decodedInvoice := Invoice{} // Decode the invoice using the modified bech32 decoder. @@ -276,9 +276,9 @@ func Decode(invoice string) (*Invoice, error) { return nil, err } - // We expect the human-readable part to at least have ln + two chars + // We expect the human-readable part to at least have ln + one char // encoding the network. - if len(hrp) < 4 { + if len(hrp) < 3 { return nil, fmt.Errorf("hrp too short") } @@ -288,24 +288,17 @@ func Decode(invoice string) (*Invoice, error) { } // The next characters should be a valid prefix for a segwit BIP173 - // address. This will also determine which network this invoice is - // meant for. - var net *chaincfg.Params - if strings.HasPrefix(hrp[2:], chaincfg.MainNetParams.Bech32HRPSegwit) { - net = &chaincfg.MainNetParams - } else if strings.HasPrefix(hrp[2:], chaincfg.TestNet3Params.Bech32HRPSegwit) { - net = &chaincfg.TestNet3Params - } else if strings.HasPrefix(hrp[2:], chaincfg.SimNetParams.Bech32HRPSegwit) { - net = &chaincfg.SimNetParams - } else { + // address that match the active network. + if !strings.HasPrefix(hrp[2:], net.Bech32HRPSegwit) { return nil, fmt.Errorf("unknown network") } decodedInvoice.Net = net - // Optionally, if there's anything left of the HRP, it encodes the - // payment amount. - if len(hrp) > 4 { - amount, err := decodeAmount(hrp[4:]) + // Optionally, if there's anything left of the HRP after ln + the segwit + // prefix, we try to decode this as the payment amount. + var netPrefixLength = len(net.Bech32HRPSegwit) + 2 + if len(hrp) > netPrefixLength { + amount, err := decodeAmount(hrp[netPrefixLength:]) if err != nil { return nil, err } @@ -573,11 +566,7 @@ func parseData(invoice *Invoice, data []byte, net *chaincfg.Params) error { // The rest are tagged parts. tagData := data[7:] - if err := parseTaggedFields(invoice, tagData, net); err != nil { - return err - } - - return nil + return parseTaggedFields(invoice, tagData, net) } // parseTimestamp converts a 35-bit timestamp (encoded in base32) to uint64.