diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 18f8d2ba8..14ca2cd3d 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -34,6 +34,9 @@ var ( errPeersUnavailable = errors.New("no peers available or all peers tried for block download process") errAlreadyInPool = errors.New("hash already in pool") errBlockNumberOverflow = errors.New("received block which overflows") + errCancelHashFetch = errors.New("hash fetching cancelled (requested)") + errCancelBlockFetch = errors.New("block downloading cancelled (requested)") + errNoSyncActive = errors.New("no sync active") ) type hashCheckFn func(common.Hash) bool @@ -74,6 +77,7 @@ type Downloader struct { newPeerCh chan *peer hashCh chan hashPack blockCh chan blockPack + cancelCh chan struct{} } func New(hasBlock hashCheckFn, getBlock getBlockFn) *Downloader { @@ -129,6 +133,9 @@ func (d *Downloader) Synchronise(id string, hash common.Hash) error { } defer atomic.StoreInt32(&d.synchronising, 0) + // Create cancel channel for aborting midflight + d.cancelCh = make(chan struct{}) + // Abort if the queue still contains some leftover data if _, cached := d.queue.Size(); cached > 0 && d.queue.GetHeadBlock() != nil { return errPendingQueue @@ -161,7 +168,6 @@ func (d *Downloader) Has(hash common.Hash) bool { } func (d *Downloader) getFromPeer(p *peer, hash common.Hash, ignoreInitial bool) (err error) { - d.activePeer = p.id defer func() { // reset on error @@ -191,6 +197,42 @@ func (d *Downloader) getFromPeer(p *peer, hash common.Hash, ignoreInitial bool) return nil } +// Cancel cancels all of the operations and resets the queue. It returns true +// if the cancel operation was completed. +func (d *Downloader) Cancel() bool { + hs, bs := d.queue.Size() + // If we're not syncing just return. + if atomic.LoadInt32(&d.synchronising) == 0 && hs == 0 && bs == 0 { + return false + } + + close(d.cancelCh) + + // clean up +hashDone: + for { + select { + case <-d.hashCh: + default: + break hashDone + } + } + +blockDone: + for { + select { + case <-d.blockCh: + default: + break blockDone + } + } + + // reset the queue + d.queue.Reset() + + return true +} + // XXX Make synchronous func (d *Downloader) startFetchingHashes(p *peer, h common.Hash, ignoreInitial bool) error { glog.V(logger.Debug).Infof("Downloading hashes (%x) from %s", h[:4], p.id) @@ -217,6 +259,8 @@ func (d *Downloader) startFetchingHashes(p *peer, h common.Hash, ignoreInitial b out: for { select { + case <-d.cancelCh: + return errCancelHashFetch case hashPack := <-d.hashCh: // Make sure the active peer is giving us the hashes if hashPack.peerId != activePeer.id { @@ -305,6 +349,8 @@ func (d *Downloader) startFetchingBlocks(p *peer) error { out: for { select { + case <-d.cancelCh: + return errCancelBlockFetch case blockPack := <-d.blockCh: // If the peer was previously banned and failed to deliver it's pack // in a reasonable time frame, ignore it's message. @@ -394,11 +440,23 @@ out: // Deliver a chunk to the downloader. This is usually done through the BlocksMsg by // the protocol handler. -func (d *Downloader) DeliverChunk(id string, blocks []*types.Block) { +func (d *Downloader) DeliverChunk(id string, blocks []*types.Block) error { + // Make sure the downloader is active + if atomic.LoadInt32(&d.synchronising) == 0 { + return errNoSyncActive + } + d.blockCh <- blockPack{id, blocks} + + return nil } func (d *Downloader) AddHashes(id string, hashes []common.Hash) error { + // Make sure the downloader is active + if atomic.LoadInt32(&d.synchronising) == 0 { + return errNoSyncActive + } + // make sure that the hashes that are being added are actually from the peer // that's the current active peer. hashes that have been received from other // peers are dropped and ignored. diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 8ccc4d1a5..d0f8d4c8f 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -182,6 +182,49 @@ func TestTaking(t *testing.T) { } } +func TestInactiveDownloader(t *testing.T) { + targetBlocks := 1000 + hashes := createHashes(0, targetBlocks) + blocks := createBlocksFromHashSet(createHashSet(hashes)) + tester := newTester(t, hashes, nil) + + err := tester.downloader.AddHashes("bad peer 001", hashes) + if err != errNoSyncActive { + t.Error("expected no sync error, got", err) + } + + err = tester.downloader.DeliverChunk("bad peer 001", blocks) + if err != errNoSyncActive { + t.Error("expected no sync error, got", err) + } +} + +func TestCancel(t *testing.T) { + minDesiredPeerCount = 4 + blockTtl = 1 * time.Second + + targetBlocks := 1000 + hashes := createHashes(0, targetBlocks) + blocks := createBlocksFromHashes(hashes) + tester := newTester(t, hashes, blocks) + + tester.newPeer("peer1", big.NewInt(10000), hashes[0]) + + err := tester.sync("peer1", hashes[0]) + if err != nil { + t.Error("download error", err) + } + + if !tester.downloader.Cancel() { + t.Error("cancel operation unsuccessfull") + } + + hashSize, blockSize := tester.downloader.queue.Size() + if hashSize > 0 || blockSize > 0 { + t.Error("block (", blockSize, ") or hash (", hashSize, ") not 0") + } +} + func TestThrottling(t *testing.T) { minDesiredPeerCount = 4 blockTtl = 1 * time.Second diff --git a/eth/sync.go b/eth/sync.go index c49f5209d..d955eaa50 100644 --- a/eth/sync.go +++ b/eth/sync.go @@ -63,6 +63,9 @@ func (pm *ProtocolManager) processBlocks() error { max := int(math.Min(float64(len(blocks)), float64(blockProcAmount))) _, err := pm.chainman.InsertChain(blocks[:max]) if err != nil { + // cancel download process + pm.downloader.Cancel() + return err } blocks = blocks[max:]