diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go index b80182a45..8d707da5c 100644 --- a/eth/fetcher/fetcher.go +++ b/eth/fetcher/fetcher.go @@ -20,6 +20,7 @@ const ( fetchTimeout = 5 * time.Second // Maximum alloted time to return an explicitly requested block maxUncleDist = 7 // Maximum allowed backward distance from the chain head maxQueueDist = 32 // Maximum allowed distance from the chain head to queue + announceLimit = 256 // Maximum number of unique blocks a peer may have announced ) var ( @@ -74,6 +75,7 @@ type Fetcher struct { quit chan struct{} // Announce states + announces map[string]int // Per peer announce counts to prevent memory exhaustion announced map[common.Hash][]*announce // Announced blocks, scheduled for fetching fetching map[common.Hash]*announce // Announced blocks, currently fetching @@ -98,6 +100,7 @@ func New(getBlock blockRetrievalFn, validateBlock blockValidatorFn, broadcastBlo filter: make(chan chan []*types.Block), done: make(chan common.Hash), quit: make(chan struct{}), + announces: make(map[string]int), announced: make(map[common.Hash][]*announce), fetching: make(map[common.Hash]*announce), queue: prque.New(), @@ -189,8 +192,7 @@ func (f *Fetcher) loop() { // Clean up any expired block fetches for hash, announce := range f.fetching { if time.Since(announce.time) > fetchTimeout { - delete(f.announced, hash) - delete(f.fetching, hash) + f.forgetBlock(hash) } } // Import any queued blocks that could potentially fit @@ -217,10 +219,17 @@ func (f *Fetcher) loop() { return case notification := <-f.notify: - // A block was announced, schedule if it's not yet downloading + // A block was announced, make sure the peer isn't DOSing us + count := f.announces[notification.origin] + 1 + if count > announceLimit { + glog.V(logger.Debug).Infof("Peer %s: exceeded outstanding announces (%d)", notification.origin, announceLimit) + break + } + // All is well, schedule the announce if block's not yet downloading if _, ok := f.fetching[notification.hash]; ok { break } + f.announces[notification.origin] = count f.announced[notification.hash] = append(f.announced[notification.hash], notification) if len(f.announced) == 1 { f.reschedule(fetch) @@ -232,8 +241,7 @@ func (f *Fetcher) loop() { case hash := <-f.done: // A pending import finished, remove all traces of the notification - delete(f.announced, hash) - delete(f.fetching, hash) + f.forgetBlock(hash) delete(f.queued, hash) case <-fetch.C: @@ -242,12 +250,15 @@ func (f *Fetcher) loop() { for hash, announces := range f.announced { if time.Since(announces[0].time) > arriveTimeout-gatherSlack { + // Pick a random peer to retrieve from, reset all others announce := announces[rand.Intn(len(announces))] + f.forgetBlock(hash) + + // If the block still didn't arrive, queue for fetching if f.getBlock(hash) == nil { request[announce.origin] = append(request[announce.origin], hash) f.fetching[hash] = announce } - delete(f.announced, hash) } } // Send out all block requests @@ -285,7 +296,7 @@ func (f *Fetcher) loop() { if f.getBlock(hash) == nil { explicit = append(explicit, block) } else { - delete(f.fetching, hash) + f.forgetBlock(hash) } } else { download = append(download, block) @@ -377,3 +388,24 @@ func (f *Fetcher) insert(peer string, block *types.Block) { go f.broadcastBlock(block, false) }() } + +// forgetBlock removes all traces of a block from the fetcher's internal state. +func (f *Fetcher) forgetBlock(hash common.Hash) { + // Remove all pending announces and decrement DOS counters + for _, announce := range f.announced[hash] { + f.announces[announce.origin]-- + if f.announces[announce.origin] == 0 { + delete(f.announces, announce.origin) + } + } + delete(f.announced, hash) + + // Remove any pending fetches and decrement the DOS counters + if announce := f.fetching[hash]; announce != nil { + f.announces[announce.origin]-- + if f.announces[announce.origin] == 0 { + delete(f.announces, announce.origin) + } + delete(f.fetching, hash) + } +} diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go index 0d069ac65..d594d830c 100644 --- a/eth/fetcher/fetcher_test.go +++ b/eth/fetcher/fetcher_test.go @@ -395,3 +395,46 @@ func TestDistantDiscarding(t *testing.T) { t.Fatalf("fetcher queued future block") } } + +// Tests that a peer is unable to use unbounded memory with sending infinite +// block announcements to a node, but that even in the face of such an attack, +// the fetcher remains operational. +func TestAnnounceMemoryExhaustionAttack(t *testing.T) { + tester := newTester() + + // Create a valid chain and an infinite junk chain + hashes := createHashes(announceLimit+2*maxQueueDist, knownHash) + blocks := createBlocksFromHashes(hashes) + valid := tester.makeFetcher(blocks) + + attack := createHashes(announceLimit+2*maxQueueDist, unknownHash) + attacker := tester.makeFetcher(nil) + + // Feed the tester a huge hashset from the attacker, and a limited from the valid peer + for i := 0; i < len(attack); i++ { + if i < maxQueueDist { + tester.fetcher.Notify("valid", hashes[len(hashes)-1-i], time.Now().Add(arriveTimeout/2), valid) + } + tester.fetcher.Notify("attacker", attack[i], time.Now().Add(arriveTimeout/2), attacker) + } + if len(tester.fetcher.announced) != announceLimit+maxQueueDist { + t.Fatalf("queued announce count mismatch: have %d, want %d", len(tester.fetcher.announced), announceLimit+maxQueueDist) + } + // Wait for synchronisation to complete and check success for the valid peer + time.Sleep(2 * arriveTimeout) + if imported := len(tester.blocks); imported != maxQueueDist { + t.Fatalf("partial synchronised block mismatch: have %v, want %v", imported, maxQueueDist) + } + // Feed the remaining valid hashes to ensure DOS protection state remains clean + for i := len(hashes) - maxQueueDist; i >= 0; { + for j := 0; j < maxQueueDist && i >= 0; j++ { + tester.fetcher.Notify("valid", hashes[i], time.Now().Add(time.Millisecond), valid) + i-- + } + time.Sleep(256 * time.Millisecond) + } + time.Sleep(256 * time.Millisecond) + if imported := len(tester.blocks); imported != len(hashes) { + t.Fatalf("fully synchronised block mismatch: have %v, want %v", imported, len(hashes)) + } +}