diff --git a/core/blockchain.go b/core/blockchain.go index 2c6ff24f9..79648869c 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -1268,12 +1268,14 @@ func (self *BlockChain) InsertChain(chain types.Blocks) (int, error) { // event about them func (self *BlockChain) reorg(oldBlock, newBlock *types.Block) error { var ( - newChain types.Blocks - commonBlock *types.Block - oldStart = oldBlock - newStart = newBlock - deletedTxs types.Transactions - deletedLogs vm.Logs + newChain types.Blocks + oldChain types.Blocks + commonBlock *types.Block + oldStart = oldBlock + newStart = newBlock + deletedTxs types.Transactions + deletedLogs vm.Logs + deletedLogsByHash = make(map[common.Hash]vm.Logs) // collectLogs collects the logs that were generated during the // processing of the block that corresponds with the given hash. // These logs are later announced as deleted. @@ -1282,6 +1284,8 @@ func (self *BlockChain) reorg(oldBlock, newBlock *types.Block) error { receipts := GetBlockReceipts(self.chainDb, h) for _, receipt := range receipts { deletedLogs = append(deletedLogs, receipt.Logs...) + + deletedLogsByHash[h] = receipt.Logs } } ) @@ -1290,6 +1294,7 @@ func (self *BlockChain) reorg(oldBlock, newBlock *types.Block) error { if oldBlock.NumberU64() > newBlock.NumberU64() { // reduce old chain for oldBlock = oldBlock; oldBlock != nil && oldBlock.NumberU64() != newBlock.NumberU64(); oldBlock = self.GetBlock(oldBlock.ParentHash()) { + oldChain = append(oldChain, oldBlock) deletedTxs = append(deletedTxs, oldBlock.Transactions()...) collectLogs(oldBlock.Hash()) @@ -1313,6 +1318,8 @@ func (self *BlockChain) reorg(oldBlock, newBlock *types.Block) error { commonBlock = oldBlock break } + + oldChain = append(oldChain, oldBlock) newChain = append(newChain, newBlock) deletedTxs = append(deletedTxs, oldBlock.Transactions()...) collectLogs(oldBlock.Hash()) @@ -1369,6 +1376,14 @@ func (self *BlockChain) reorg(oldBlock, newBlock *types.Block) error { go self.eventMux.Post(RemovedLogsEvent{deletedLogs}) } + if len(oldChain) > 0 { + go func() { + for _, block := range oldChain { + self.eventMux.Post(ChainSideEvent{Block: block, Logs: deletedLogsByHash[block.Hash()]}) + } + }() + } + return nil } diff --git a/core/blockchain_test.go b/core/blockchain_test.go index 1bb5f646d..a7d3ea56e 100644 --- a/core/blockchain_test.go +++ b/core/blockchain_test.go @@ -25,6 +25,7 @@ import ( "runtime" "strconv" "testing" + "time" "github.com/ethereum/ethash" "github.com/ethereum/go-ethereum/common" @@ -1006,3 +1007,82 @@ func TestLogReorgs(t *testing.T) { t.Error("expected logs") } } + +func TestReorgSideEvent(t *testing.T) { + var ( + db, _ = ethdb.NewMemDatabase() + key1, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291") + addr1 = crypto.PubkeyToAddress(key1.PublicKey) + genesis = WriteGenesisBlockForTesting(db, GenesisAccount{addr1, big.NewInt(10000000000000)}) + ) + + evmux := &event.TypeMux{} + blockchain, _ := NewBlockChain(db, FakePow{}, evmux) + + chain, _ := GenerateChain(genesis, db, 3, func(i int, gen *BlockGen) { + if i == 2 { + gen.OffsetTime(9) + } + }) + if _, err := blockchain.InsertChain(chain); err != nil { + t.Fatalf("failed to insert chain: %v", err) + } + + replacementBlocks, _ := GenerateChain(genesis, db, 4, func(i int, gen *BlockGen) { + tx, err := types.NewContractCreation(gen.TxNonce(addr1), new(big.Int), big.NewInt(1000000), new(big.Int), nil).SignECDSA(key1) + if err != nil { + t.Fatalf("failed to create tx: %v", err) + } + gen.AddTx(tx) + }) + + subs := evmux.Subscribe(ChainSideEvent{}) + if _, err := blockchain.InsertChain(replacementBlocks); err != nil { + t.Fatalf("failed to insert chain: %v", err) + } + + // first two block of the secondary chain are for a brief moment considered + // side chains because up to that point the first one is considered the + // heavier chain. + expectedSideHashes := map[common.Hash]bool{ + replacementBlocks[0].Hash(): true, + replacementBlocks[1].Hash(): true, + chain[0].Hash(): true, + chain[1].Hash(): true, + chain[2].Hash(): true, + } + + i := 0 + + const timeoutDura = 10 * time.Second + timeout := time.NewTimer(timeoutDura) +done: + for { + select { + case ev := <-subs.Chan(): + block := ev.Data.(ChainSideEvent).Block + if _, ok := expectedSideHashes[block.Hash()]; !ok { + t.Errorf("%d: didn't expect %x to be in side chain", i, block.Hash()) + } + i++ + + if i == len(expectedSideHashes) { + timeout.Stop() + + break done + } + timeout.Reset(timeoutDura) + + case <-timeout.C: + t.Fatal("Timeout. Possibly not all blocks were triggered for sideevent") + } + } + + // make sure no more events are fired + select { + case e := <-subs.Chan(): + t.Errorf("unexectped event fired: %v", e) + case <-time.After(250 * time.Millisecond): + } + +}