use std::{ future::Future, mem, pin::Pin, sync::Once, task::{Context, Poll}, time::Duration, }; use color_eyre::eyre::Result; use ed25519_zebra::*; use futures::stream::{FuturesUnordered, StreamExt}; use rand::thread_rng; use tokio::sync::broadcast::{channel, RecvError, Sender}; use tower::{Service, ServiceExt}; use tower_batch::{Batch, BatchControl}; // ============ service impl ============ pub struct Ed25519Verifier { batch: batch::Verifier, // This uses a "broadcast" channel, which is an mpmc channel. Tokio also // provides a spmc channel, "watch", but it only keeps the latest value, so // using it would require thinking through whether it was possible for // results from one batch to be mixed with another. tx: Sender>, } #[allow(clippy::new_without_default)] impl Ed25519Verifier { pub fn new() -> Self { let batch = batch::Verifier::default(); // XXX(hdevalence) what's a reasonable choice here? let (tx, _) = channel(10); Self { tx, batch } } } pub type Ed25519Item = batch::Item; impl<'msg> Service> for Ed25519Verifier { type Response = (); type Error = Error; type Future = Pin> + Send + 'static>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: BatchControl) -> Self::Future { match req { BatchControl::Item(item) => { tracing::trace!("got item"); self.batch.queue(item); let mut rx = self.tx.subscribe(); Box::pin(async move { match rx.recv().await { Ok(result) => result, Err(RecvError::Lagged(_)) => { tracing::warn!( "missed channel updates for the correct signature batch!" ); Err(Error::InvalidSignature) } Err(RecvError::Closed) => panic!("verifier was dropped without flushing"), } }) } BatchControl::Flush => { tracing::trace!("got flush command"); let batch = mem::take(&mut self.batch); let _ = self.tx.send(batch.verify(thread_rng())); Box::pin(async { Ok(()) }) } } } } impl Drop for Ed25519Verifier { fn drop(&mut self) { // We need to flush the current batch in case there are still any pending futures. let batch = mem::take(&mut self.batch); let _ = self.tx.send(batch.verify(thread_rng())); } } // =============== testing code ======== static LOGGER_INIT: Once = Once::new(); fn install_tracing() { use tracing_error::ErrorLayer; use tracing_subscriber::prelude::*; use tracing_subscriber::{fmt, EnvFilter}; LOGGER_INIT.call_once(|| { let fmt_layer = fmt::layer().with_target(false); let filter_layer = EnvFilter::try_from_default_env() .or_else(|_| EnvFilter::try_new("info")) .unwrap(); tracing_subscriber::registry() .with(filter_layer) .with(fmt_layer) .with(ErrorLayer::default()) .init(); }) } async fn sign_and_verify(mut verifier: V, n: usize) -> Result<(), V::Error> where V: Service, { let mut results = FuturesUnordered::new(); for i in 0..n { let span = tracing::trace_span!("sig", i); let sk = SigningKey::new(thread_rng()); let vk_bytes = VerificationKeyBytes::from(&sk); let msg = b"BatchVerifyTest"; let sig = sk.sign(&msg[..]); verifier.ready_and().await?; results.push(span.in_scope(|| verifier.call((vk_bytes, sig, msg).into()))) } while let Some(result) = results.next().await { result?; } Ok(()) } #[tokio::test] async fn batch_flushes_on_max_items() -> Result<()> { use tokio::time::timeout; install_tracing(); // Use a very long max_latency and a short timeout to check that // flushing is happening based on hitting max_items. let verifier = Batch::new(Ed25519Verifier::new(), 10, Duration::from_secs(1000)); timeout(Duration::from_secs(1), sign_and_verify(verifier, 100)).await? } #[tokio::test] async fn batch_flushes_on_max_latency() -> Result<()> { use tokio::time::timeout; install_tracing(); // Use a very high max_items and a short timeout to check that // flushing is happening based on hitting max_latency. let verifier = Batch::new(Ed25519Verifier::new(), 100, Duration::from_millis(500)); timeout(Duration::from_secs(1), sign_and_verify(verifier, 10)).await? }