From b4c644c99a6e7c73c6904755c6ec47841bc39485 Mon Sep 17 00:00:00 2001 From: Joseph Poon Date: Tue, 5 Jan 2016 08:53:42 -0800 Subject: [PATCH] Added Error message type to wire protocol --- lnwire/error_generic.go | 74 ++++++++++++++++++++++++++++++++++++ lnwire/error_generic_test.go | 32 ++++++++++++++++ lnwire/htlc_settleaccept.go | 6 --- lnwire/lnwire.go | 33 ++++++++++++++++ lnwire/message.go | 5 +++ 5 files changed, 144 insertions(+), 6 deletions(-) create mode 100644 lnwire/error_generic.go create mode 100644 lnwire/error_generic_test.go diff --git a/lnwire/error_generic.go b/lnwire/error_generic.go new file mode 100644 index 00000000..acf4f610 --- /dev/null +++ b/lnwire/error_generic.go @@ -0,0 +1,74 @@ +package lnwire + +import ( + "fmt" + "io" +) + +//Multiple Clearing Requests are possible by putting this inside an array of +//clearing requests +type ErrorGeneric struct { + //We can use a different data type for this if necessary... + ChannelID uint64 + //Some kind of message + //Max length 8192 + Problem string +} + +func (c *ErrorGeneric) Decode(r io.Reader, pver uint32) error { + //ChannelID(8) + //Problem + err := readElements(r, + &c.ChannelID, + &c.Problem, + ) + if err != nil { + return err + } + + return nil +} + +//Creates a new ErrorGeneric +func NewErrorGeneric() *ErrorGeneric { + return &ErrorGeneric{} +} + +//Serializes the item from the ErrorGeneric struct +//Writes the data to w +func (c *ErrorGeneric) Encode(w io.Writer, pver uint32) error { + err := writeElements(w, + c.ChannelID, + c.Problem, + ) + if err != nil { + return err + } + + return nil +} + +func (c *ErrorGeneric) Command() uint32 { + return CmdErrorGeneric +} + +func (c *ErrorGeneric) MaxPayloadLength(uint32) uint32 { + //8+8192 + return 8208 +} + +//Makes sure the struct data is valid (e.g. no negatives or invalid pkscripts) +func (c *ErrorGeneric) Validate() error { + if len(c.Problem) > 8192 { + return fmt.Errorf("Problem string length too long") + } + //We're good! + return nil +} + +func (c *ErrorGeneric) String() string { + return fmt.Sprintf("\n--- Begin ErrorGeneric ---\n") + + fmt.Sprintf("ChannelID:\t%d\n", c.ChannelID) + + fmt.Sprintf("Problem:\t%s\n", c.Problem) + + fmt.Sprintf("--- End ErrorGeneric ---\n") +} diff --git a/lnwire/error_generic_test.go b/lnwire/error_generic_test.go new file mode 100644 index 00000000..9ca6df27 --- /dev/null +++ b/lnwire/error_generic_test.go @@ -0,0 +1,32 @@ +package lnwire + +import ( + "testing" +) + +var ( + errorGeneric = &ErrorGeneric{ + ChannelID: uint64(12345678), + Problem: "Hello world!", + } + errorGenericSerializedString = "0000000000bc614e000c48656c6c6f20776f726c6421" + errorGenericSerializedMessage = "0709110b00000fa0000000160000000000bc614e000c48656c6c6f20776f726c6421" +) + +func TestErrorGenericEncodeDecode(t *testing.T) { + //All of these types being passed are of the message interface type + //Test serialization, runs: message.Encode(b, 0) + //Returns bytes + //Compares the expected serialized string from the original + s := SerializeTest(t, errorGeneric, errorGenericSerializedString, filename) + + //Test deserialization, runs: message.Decode(s, 0) + //Makes sure the deserialized struct is the same as the original + newMessage := NewErrorGeneric() + DeserializeTest(t, s, newMessage, errorGeneric) + + //Test message using Message interface + //Serializes into buf: WriteMessage(buf, message, uint32(1), wire.TestNet3) + //Deserializes into msg: _, msg, _ , err := ReadMessage(buf, uint32(1), wire.TestNet3) + MessageSerializeDeserializeTest(t, errorGeneric, errorGenericSerializedMessage) +} diff --git a/lnwire/htlc_settleaccept.go b/lnwire/htlc_settleaccept.go index 186128c9..3ec2aed5 100644 --- a/lnwire/htlc_settleaccept.go +++ b/lnwire/htlc_settleaccept.go @@ -18,12 +18,6 @@ type HTLCSettleAccept struct { func (c *HTLCSettleAccept) Decode(r io.Reader, pver uint32) error { //ChannelID(8) //StagingID(8) - //Expiry(4) - //Amount(4) - //NextHop(20) - //ContractType(1) - //RedemptionHashes (numOfHashes * 20 + numOfHashes) - //Blob(2+blobsize) err := readElements(r, &c.ChannelID, &c.StagingID, diff --git a/lnwire/lnwire.go b/lnwire/lnwire.go index 49583545..ada7eee8 100644 --- a/lnwire/lnwire.go +++ b/lnwire/lnwire.go @@ -202,6 +202,21 @@ func writeElement(w io.Writer, element interface{}) error { return err } return nil + case string: + strlen := len(e) + if strlen > 65535 { + return fmt.Errorf("String too long!") + } + //Write the size (2-bytes) + err = writeElement(w, uint16(strlen)) + if err != nil { + return err + } + //Write the data + _, err = w.Write([]byte(e)) + if err != nil { + return err + } case []*wire.TxIn: //Append the unsigned(!!!) txins //Write the size (1-byte) @@ -453,6 +468,24 @@ func readElement(r io.Reader, element interface{}) error { return fmt.Errorf("EOF: Signature length mismatch.") } return nil + case *string: + //Get the string length first + var strlen uint16 + err = readElement(r, &strlen) + if err != nil { + return err + } + //Read the string for the length + l := io.LimitReader(r, int64(strlen)) + b, err := ioutil.ReadAll(l) + if len(b) != int(strlen) { + return fmt.Errorf("EOF: String length mismatch.") + } + *e = string(b) + if err != nil { + return err + } + return nil case *[]*wire.TxIn: //Read the size (1-byte number of txins) var numScripts uint8 diff --git a/lnwire/message.go b/lnwire/message.go index 4413ac45..48a937cb 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -57,6 +57,9 @@ const ( //Commitments CmdCommitSignature = uint32(2000) CmdCommitRevocation = uint32(2010) + + //Error + CmdErrorGeneric = uint32(4000) ) //Every message has these functions: @@ -103,6 +106,8 @@ func makeEmptyMessage(command uint32) (Message, error) { msg = &CommitSignature{} case CmdCommitRevocation: msg = &CommitRevocation{} + case CmdErrorGeneric: + msg = &ErrorGeneric{} default: return nil, fmt.Errorf("unhandled command [%d]", command) }