From 6951051416ff686612e652f0045e5966516f388a Mon Sep 17 00:00:00 2001 From: aniketfuryrocks Date: Tue, 21 Mar 2023 15:53:59 +0530 Subject: [PATCH] bather --- src/batcher.rs | 100 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 2 files changed, 101 insertions(+) create mode 100644 src/batcher.rs diff --git a/src/batcher.rs b/src/batcher.rs new file mode 100644 index 00000000..047a3858 --- /dev/null +++ b/src/batcher.rs @@ -0,0 +1,100 @@ +/// Maintain a virtual len, returning slices according to batch_size +/// Removes batched elements on drop +/// +/// Prevents re-sizing of vector after every batch +pub struct Batcher<'a, T> { + batch_from: &'a mut Vec, + batch_size: usize, + pointer: usize, + strategy: BatcherStrategy +} + +pub enum BatcherStrategy { + Start, + End +} + + +impl<'a, T> Batcher<'a, T> { + pub fn new(batch_from: &'a mut Vec, batch_size: usize, strategy: BatcherStrategy) -> Self { + Self { + pointer: match strategy { + BatcherStrategy::Start => 0, + BatcherStrategy::End => batch_from.len() + }, + batch_from, + batch_size, + strategy + } + } + + pub fn next_batch(&mut self) -> Option<&[T]> { + let range = match self.strategy { + BatcherStrategy::Start => { + let new_start = self.pointer + self.batch_size; + + if new_start >= self.batch_from.len() { + return None; + } + + let range = self.pointer .. new_start; + self.pointer = new_start; + range + }, + BatcherStrategy::End => { + let Some(new_len) = self.pointer.checked_sub(self.batch_size) else { + return None; + }; + + let range = new_len.. self.pointer; + self.pointer = new_len; + range + + }, + }; + + Some(&self.batch_from[range]) + } +} + + +impl<'a, T> Drop for Batcher<'a, T> { + fn drop(&mut self) { + let range = match self.strategy { + BatcherStrategy::Start => 0..self.pointer, + BatcherStrategy::End => self.pointer..self.batch_from.len(), + }; + + self.batch_from.drain(range); + } +} + +#[cfg(test)] +mod tests { + use super::{BatcherStrategy, Batcher}; + + #[test] + fn start() { + let mut elements = vec![1, 2, 3, 4, 5, 6, 7]; + let mut batcher = Batcher::new(&mut elements, 2, BatcherStrategy::Start); + assert_eq!(Some(&[1, 2][..]), batcher.next_batch()); + assert_eq!(Some(&[3, 4][..]), batcher.next_batch()); + assert_eq!(Some(&[5, 6][..]), batcher.next_batch()); + assert_eq!(None, batcher.next_batch()); + drop(batcher); + assert_eq!(&[7][..], &elements); + } + + #[test] + fn end() { + let mut elements = vec![1, 2, 3, 4, 5, 6, 7]; + let mut batcher = Batcher::new(&mut elements, 2, BatcherStrategy::End); + assert_eq!(Some(&[6, 7][..]), batcher.next_batch()); + assert_eq!(Some(&[4, 5][..]), batcher.next_batch()); + assert_eq!(Some(&[2, 3][..]), batcher.next_batch()); + assert_eq!(None, batcher.next_batch()); + drop(batcher); + assert_eq!(&[1][..], &elements); + } + +} diff --git a/src/lib.rs b/src/lib.rs index 5e0b4c6c..25a3ae04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ pub mod errors; pub mod rpc; pub mod tpu_manager; pub mod workers; +pub mod batcher; #[from_env] pub const DEFAULT_RPC_ADDR: &str = "http://0.0.0.0:8899";