From 6ac2444172cbb39e14ded8214422dc56eac52f95 Mon Sep 17 00:00:00 2001 From: Jonathan Claudius Date: Tue, 9 May 2023 11:45:10 -0400 Subject: [PATCH] sdk: assert no negative numGuardians in quorum calculation (#2892) * sdk: assert no negative numGuardians in quorum calculation * sdk: fix formating on quorum tests --- sdk/vaa/quorum.go | 3 +++ sdk/vaa/quorum_test.go | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sdk/vaa/quorum.go b/sdk/vaa/quorum.go index 730ea15a9..4d2526755 100644 --- a/sdk/vaa/quorum.go +++ b/sdk/vaa/quorum.go @@ -5,5 +5,8 @@ package vaa // The canonical source is the calculation in the contracts (solana/bridge/src/processor.rs and // ethereum/contracts/Wormhole.sol), and this needs to match the implementation in the contracts. func CalculateQuorum(numGuardians int) int { + if numGuardians < 0 { + panic("Invalid numGuardians is less than zero") + } return ((numGuardians * 2) / 3) + 1 } diff --git a/sdk/vaa/quorum_test.go b/sdk/vaa/quorum_test.go index 929dbaef5..1d6381209 100644 --- a/sdk/vaa/quorum_test.go +++ b/sdk/vaa/quorum_test.go @@ -10,9 +10,11 @@ func TestCalculateQuorum(t *testing.T) { type Test struct { numGuardians int quorumResult int + shouldPanic bool } tests := []Test{ + // Positive Test Cases {numGuardians: 0, quorumResult: 1}, {numGuardians: 1, quorumResult: 1}, {numGuardians: 2, quorumResult: 2}, @@ -36,12 +38,20 @@ func TestCalculateQuorum(t *testing.T) { {numGuardians: 50, quorumResult: 34}, {numGuardians: 100, quorumResult: 67}, {numGuardians: 1000, quorumResult: 667}, + + // Negative Test Cases + {numGuardians: -1, quorumResult: 1, shouldPanic: true}, + {numGuardians: -1000, quorumResult: 1, shouldPanic: true}, } for _, tc := range tests { t.Run("", func(t *testing.T) { - num := CalculateQuorum(tc.numGuardians) - assert.Equal(t, tc.quorumResult, num) + if tc.shouldPanic { + assert.Panics(t, func() { CalculateQuorum(tc.numGuardians) }, "The code did not panic") + } else { + num := CalculateQuorum(tc.numGuardians) + assert.Equal(t, tc.quorumResult, num) + } }) } }