diff --git a/container/ethereum.go b/container/ethereum.go index eff3f374..b50dd889 100644 --- a/container/ethereum.go +++ b/container/ethereum.go @@ -436,7 +436,7 @@ func (eth *ethereum) WaitForProposed(expectedAddress common.Address, timeout tim case <-timer.C: // FIXME: this event may be missed return errors.New("no result") case head := <-subCh: - if getProposer(head) == expectedAddress { + if GetProposer(head) == expectedAddress { return nil } } diff --git a/container/utils.go b/container/utils.go index 86a33e5b..d43e3a4f 100644 --- a/container/utils.go +++ b/container/utils.go @@ -93,7 +93,7 @@ func sigHash(header *types.Header) (hash common.Hash) { return hash } -func getProposer(header *types.Header) common.Address { +func GetProposer(header *types.Header) common.Address { if header == nil { return common.Address{} } diff --git a/istclient/client.go b/istclient/client.go index 7ed9d690..a3f5512e 100644 --- a/istclient/client.go +++ b/istclient/client.go @@ -19,6 +19,8 @@ package istclient import ( "context" "math/big" + "sort" + "strings" "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/common" @@ -120,12 +122,29 @@ func (ic *Client) ProposeValidator(ctx context.Context, address common.Address, return err } +type addresses []common.Address + +func (addrs addresses) Len() int { + return len(addrs) +} + +func (addrs addresses) Less(i, j int) bool { + return strings.Compare(addrs[i].String(), addrs[j].String()) < 0 +} + +func (addrs addresses) Swap(i, j int) { + addrs[i], addrs[j] = addrs[j], addrs[i] +} + func (ic *Client) GetValidators(ctx context.Context, blockNumbers *big.Int) ([]common.Address, error) { var r []common.Address err := ic.c.CallContext(ctx, &r, "istanbul_getValidators", toNumArg(blockNumbers)) if err == nil && r == nil { return nil, ethereum.NotFound } + + sort.Sort(addresses(r)) + return r, err } diff --git a/tests/general_consensus_test.go b/tests/general_consensus_test.go index 316428de..7e28c14a 100644 --- a/tests/general_consensus_test.go +++ b/tests/general_consensus_test.go @@ -152,7 +152,8 @@ var _ = Describe("TFS-01: General consensus", func() { if lastBlockTime != 0 { diff := header.Time.Int64() - lastBlockTime if diff > maxBlockPeriod { - errc <- errors.New("Invalid block period.") + errStr := fmt.Sprintf("Invaild block(%v) period, want:%v, got:%v", header.Number.Int64(), maxBlockPeriod, diff) + errc <- errors.New(errStr) return } } @@ -169,4 +170,92 @@ var _ = Describe("TFS-01: General consensus", func() { }) close(done) }, 60) + + It("TFS-01-05: Round robin proposer selection", func(done Done) { + var ( + timesOfBeProposer = 3 + targetBlockHeight = timesOfBeProposer * numberOfValidators + emptyProposer = common.Address{} + ) + + By("Wait for consensus progress", func() { + waitFor(blockchain.Validators(), func(geth container.Ethereum, wg *sync.WaitGroup) { + Expect(geth.WaitForBlockHeight(targetBlockHeight)).To(BeNil()) + wg.Done() + }) + }) + + By("Block proposer selection should follow round-robin policy", func() { + errc := make(chan error, len(blockchain.Validators())) + for _, geth := range blockchain.Validators() { + go func(geth container.Ethereum) { + c := geth.NewClient() + istClient := geth.NewIstanbulClient() + + // get initial validator set + vals, err := istClient.GetValidators(context.Background(), big.NewInt(0)) + if err != nil { + errc <- err + return + } + + lastProposerIdx := -1 + counts := make(map[common.Address]int, numberOfValidators) + // initial count map + for _, addr := range vals { + counts[addr] = 0 + } + for i := 1; i <= targetBlockHeight; i++ { + header, err := c.HeaderByNumber(context.Background(), big.NewInt(int64(i))) + if err != nil { + errc <- err + return + } + + p := container.GetProposer(header) + if p == emptyProposer { + errStr := fmt.Sprintf("Empty block(%v) proposer", header.Number.Int64()) + errc <- errors.New(errStr) + return + } + // count the times to be the proposer + if count, ok := counts[p]; ok { + counts[p] = count + 1 + } + // check if the proposer is valid + if lastProposerIdx == -1 { + for i, val := range vals { + if p == val { + lastProposerIdx = i + break + } + } + } else { + proposerIdx := (lastProposerIdx + 1) % len(vals) + if p != vals[proposerIdx] { + errStr := fmt.Sprintf("Invaild block(%v) proposer, want:%v, got:%v", header.Number.Int64(), vals[proposerIdx], p) + errc <- errors.New(errStr) + return + } + lastProposerIdx = proposerIdx + } + } + // check times to be proposer + for _, count := range counts { + if count != timesOfBeProposer { + errc <- errors.New("Wrong times to be proposer.") + return + } + } + errc <- nil + }(geth) + } + + for i := 0; i < len(blockchain.Validators()); i++ { + err := <-errc + Expect(err).To(BeNil()) + } + }) + close(done) + }, 120) })