diff --git a/src/crdt.rs b/src/crdt.rs index 98180847a..d6dbee9f2 100644 --- a/src/crdt.rs +++ b/src/crdt.rs @@ -602,12 +602,18 @@ impl Crdt { if blob_ix == ix { let num_retransmits = wblob.meta.num_retransmits; wblob.meta.num_retransmits += 1; - - if me.current_leader_id == me.id && - num_retransmits != 0 && - !num_retransmits.is_power_of_two() + // Setting the sender id to the requester id + // prevents the requester from retransmitting this response + // to other peers + let mut sender_id = from.id; + + // Allow retransmission of this response if the node + // is the leader and the number of repair requests equals + // a power of two + if me.current_leader_id == me.id + && (num_retransmits == 0 || num_retransmits.is_power_of_two()) { - return None; + sender_id = me.id } let out = blob_recycler.allocate(); @@ -619,7 +625,7 @@ impl Crdt { outblob.meta.size = sz; outblob.data[..sz].copy_from_slice(&wblob.data[..sz]); outblob.meta.set_addr(&from.repair_addr); - outblob.set_id(me.id).expect("blob set_id"); + outblob.set_id(sender_id).expect("blob set_id"); } return Some(out); @@ -1124,6 +1130,7 @@ mod tests { #[test] fn run_window_request_with_backoff() { let window = default_window(); + let mut me = ReplicatedData::new( KeyPair::new().pubkey(), "127.0.0.1:1234".parse().unwrap(), @@ -1133,11 +1140,21 @@ mod tests { "127.0.0.1:1238".parse().unwrap(), ); + let mock_peer = ReplicatedData::new( + KeyPair::new().pubkey(), + "127.0.0.1:1234".parse().unwrap(), + "127.0.0.1:1235".parse().unwrap(), + "127.0.0.1:1236".parse().unwrap(), + "127.0.0.1:1237".parse().unwrap(), + "127.0.0.1:1238".parse().unwrap(), + ); + me.current_leader_id = me.id; let recycler = BlobRecycler::default(); let num_requests: u32 = 64; - let rv = Crdt::run_window_request(&window, &me, &me, 0, &recycler); + // Simulate handling a repair request from mock_peer + let rv = Crdt::run_window_request(&window, &me, &mock_peer, 0, &recycler); assert!(rv.is_none()); let out = recycler.allocate(); out.write().unwrap().meta.size = 200; @@ -1145,17 +1162,18 @@ mod tests { let range: std::ops::Range = 0..num_requests; for i in range { - let rv = Crdt::run_window_request(&window, &me, &me, 0, &recycler); - - if i != 0 && !(i.is_power_of_two()) { - assert!(rv.is_none()); - continue; - } - + let rv = Crdt::run_window_request(&window, &me, &mock_peer, 0, &recycler); assert!(rv.is_some()); let v = rv.unwrap(); - //test we copied the blob - assert_eq!(v.read().unwrap().meta.size, 200); + let blob = v.read().unwrap(); + // Test we copied the blob + assert_eq!(blob.meta.size, 200); + + if i != 0 && !(i.is_power_of_two()) { + assert_eq!(blob.get_id().unwrap(), mock_peer.id); + } else { + assert_eq!(blob.get_id().unwrap(), me.id); + } } } }