diff --git a/htlcswitch/mailbox_test.go b/htlcswitch/mailbox_test.go index 3e36fd97..2d04ea69 100644 --- a/htlcswitch/mailbox_test.go +++ b/htlcswitch/mailbox_test.go @@ -25,6 +25,7 @@ func TestMailBoxCouriers(t *testing.T) { // We'll be adding 10 message of both types to the mailbox. const numPackets = 10 + const halfPackets = numPackets / 2 // We'll add a set of random packets to the mailbox. sentPackets := make([]*htlcPacket, numPackets) @@ -96,4 +97,53 @@ func TestMailBoxCouriers(t *testing.T) { t.Fatalf("recvd messages mismatched: expected %v, got %v", spew.Sdump(sentMessages), spew.Sdump(recvdMessages)) } + + // Now that we've received all of the intended msgs/pkts, ack back half + // of the packets. + for _, recvdPkt := range recvdPackets[:halfPackets] { + mailBox.AckPacket(recvdPkt.inKey()) + } + + // With the packets drained and partially acked, we reset the mailbox, + // simulating a link shutting down and then coming back up. + mailBox.ResetMessages() + mailBox.ResetPackets() + + // Now, we'll use the same alternating strategy to read from our + // mailbox. All wire messages are dropped on startup, but any unacked + // packets will be replayed in the same order they were delivered + // initially. + recvdPackets2 := make([]*htlcPacket, 0, halfPackets) + for i := 0; i < 2*halfPackets; i++ { + timeout := time.After(time.Second * 5) + if i%2 == 0 { + select { + case <-timeout: + t.Fatalf("didn't recv pkt after timeout") + case pkt := <-mailBox.PacketOutBox(): + recvdPackets2 = append(recvdPackets2, pkt) + } + } else { + select { + case <-mailBox.MessageOutBox(): + t.Fatalf("should not receive wire msg after reset") + default: + } + } + } + + // The number of packets we received should match the number of unacked + // packets left in the mailbox. + if halfPackets != len(recvdPackets2) { + t.Fatalf("expected %v packets instead got %v", halfPackets, + len(recvdPackets)) + } + + // Additionally, the set of packets should match exactly with the + // unacked packets, and we should have received the packets in the exact + // same ordering that we added. + if !reflect.DeepEqual(recvdPackets[halfPackets:], recvdPackets2) { + t.Fatalf("recvd packets mismatched: expected %v, got %v", + spew.Sdump(sentPackets), spew.Sdump(recvdPackets)) + } }