diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 92ab9330..d5f0541b 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -116,8 +116,7 @@ func validateInvoice(i *Invoice) error { // insertion will be aborted and rejected due to the strict policy banning any // duplicate payment hashes. func (d *DB) AddInvoice(i *Invoice) error { - err := validateInvoice(i) - if err != nil { + if err := validateInvoice(i); err != nil { return err } return d.Update(func(tx *bolt.Tx) error { diff --git a/channeldb/payments.go b/channeldb/payments.go index 5c91300d..b10749c2 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -1,51 +1,55 @@ package channeldb import ( + "bytes" + "encoding/binary" + "github.com/boltdb/bolt" + "github.com/roasbeef/btcd/wire" "github.com/roasbeef/btcutil" "io" - "github.com/roasbeef/btcd/wire" - "github.com/boltdb/bolt" - "encoding/binary" - "bytes" + "time" ) var ( - // invoiceBucket is the name of the bucket within the database that - // stores all data related to payments. - // Within the payments bucket, each invoice is keyed by its invoice ID + // invoiceBucket is the name of the bucket within + // the database that stores all data related to payments. + // Within the payments bucket, each invoice is keyed + // by its invoice ID // which is a monotonically increasing uint64. - // BoltDB sequence feature is used for generating monotonically increasing - // id. + // BoltDB sequence feature is used for generating + // monotonically increasing id. paymentBucket = []byte("payments") ) +// OutgoingPayment represents payment from given node. type OutgoingPayment struct { Invoice - // Total fee paid - Fee btcutil.Amount - // Path including starting and ending nodes - Path [][]byte - TimeLockLength uint64 - // We probably need both RHash and Preimage - // because we start knowing only RHash - RHash [32]byte -} -func validatePayment(p *OutgoingPayment) error { - err := validateInvoice(&p.Invoice) - if err != nil { - return err - } - return nil + // Total fee paid. + Fee btcutil.Amount + + // Path including starting and ending nodes. + Path [][33]byte + + // Timelock length. + TimeLockLength uint32 + + // RHash value used for payment. + // We need RHash because we start payment knowing only RHash + RHash [32]byte + + // Timestamp is time when payment was created. + Timestamp time.Time } // AddPayment adds payment to DB. // There is no checking that payment with the same hash already exist. func (db *DB) AddPayment(p *OutgoingPayment) error { - err := validatePayment(p) + err := validateInvoice(&p.Invoice) if err != nil { return err } + // We serialize before writing to database // so no db access in the case of serialization errors b := new(bytes.Buffer) @@ -54,15 +58,17 @@ func (db *DB) AddPayment(p *OutgoingPayment) error { return err } paymentBytes := b.Bytes() - return db.Update(func (tx *bolt.Tx) error { + return db.Update(func(tx *bolt.Tx) error { payments, err := tx.CreateBucketIfNotExists(paymentBucket) if err != nil { return err } + paymentId, err := payments.NextSequence() if err != nil { return err } + // We use BigEndian for keys because // it orders keys in ascending order paymentIdBytes := make([]byte, 8) @@ -78,12 +84,12 @@ func (db *DB) AddPayment(p *OutgoingPayment) error { // FetchAllPayments returns all outgoing payments in DB. func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) { var payments []*OutgoingPayment - err := db.View(func (tx *bolt.Tx) error { + err := db.View(func(tx *bolt.Tx) error { bucket := tx.Bucket(paymentBucket) if bucket == nil { - return nil + return ErrNoPaymentsCreated } - err := bucket.ForEach(func (k, v []byte) error { + err := bucket.ForEach(func(k, v []byte) error { // Value can be nil if it is a sub-backet // so simply ignore it. if v == nil { @@ -109,11 +115,12 @@ func (db *DB) FetchAllPayments() ([]*OutgoingPayment, error) { // If payments bucket does not exist it will create // new bucket without error. func (db *DB) DeleteAllPayments() error { - return db.Update(func (tx *bolt.Tx) error { + return db.Update(func(tx *bolt.Tx) error { err := tx.DeleteBucket(paymentBucket) if err != nil && err != bolt.ErrBucketNotFound { return err } + _, err = tx.CreateBucket(paymentBucket) if err != nil { return err @@ -127,6 +134,7 @@ func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { if err != nil { return err } + // Serialize fee. feeBytes := make([]byte, 8) byteOrder.PutUint64(feeBytes, uint64(p.Fee)) @@ -134,43 +142,60 @@ func serializeOutgoingPayment(w io.Writer, p *OutgoingPayment) error { if err != nil { return err } + // Serialize path. pathLen := uint32(len(p.Path)) pathLenBytes := make([]byte, 4) + // Write length of the path byteOrder.PutUint32(pathLenBytes, pathLen) _, err = w.Write(pathLenBytes) if err != nil { return err } + // Serialize each element of the path for i := uint32(0); i < pathLen; i++ { - err := wire.WriteVarBytes(w, 0, p.Path[i]) + _, err := w.Write(p.Path[i][:]) if err != nil { return err } } + // Serialize TimeLockLength - timeLockLengthBytes := make([]byte, 8) - byteOrder.PutUint64(timeLockLengthBytes, p.TimeLockLength) + timeLockLengthBytes := make([]byte, 4) + byteOrder.PutUint32(timeLockLengthBytes, p.TimeLockLength) _, err = w.Write(timeLockLengthBytes) if err != nil { return err } + // Serialize RHash _, err = w.Write(p.RHash[:]) if err != nil { return err } + + // Serialize Timestamp. + tBytes, err := p.Timestamp.MarshalBinary() + if err != nil { + return err + } + err = wire.WriteVarBytes(w, 0, tBytes) + if err != nil { + return err + } return nil } func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { p := &OutgoingPayment{} + // Deserialize invoice inv, err := deserializeInvoice(r) if err != nil { return nil, err } p.Invoice = *inv + // Deserialize fee feeBytes := make([]byte, 8) _, err = r.Read(feeBytes) @@ -178,6 +203,7 @@ func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { return nil, err } p.Fee = btcutil.Amount(byteOrder.Uint64(feeBytes)) + // Deserialize path pathLenBytes := make([]byte, 4) _, err = r.Read(pathLenBytes) @@ -185,27 +211,38 @@ func deserializeOutgoingPayment(r io.Reader) (*OutgoingPayment, error) { return nil, err } pathLen := byteOrder.Uint32(pathLenBytes) - path := make([][]byte, pathLen) - for i := uint32(0); i