use super::{ error::{Closed, Error, ServiceError}, message::Message, }; use futures_core::ready; use pin_project::pin_project; use std::sync::{Arc, Mutex}; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tokio::sync::mpsc; use tower_service::Service; /// Task that handles processing the buffer. This type should not be used /// directly, instead `Buffer` requires an `Executor` that can accept this task. /// /// The struct is `pub` in the private module and the type is *not* re-exported /// as part of the public API. This is the "sealed" pattern to include "private" /// types in public traits that are not meant for consumers of the library to /// implement (only call). #[pin_project] #[derive(Debug)] pub struct Worker where T: Service, T::Error: Into, { current_message: Option>, rx: mpsc::Receiver>, service: T, finish: bool, failed: Option, handle: Handle, } /// Get the error out #[derive(Debug)] pub(crate) struct Handle { inner: Arc>>, } impl Worker where T: Service, T::Error: Into, { pub(crate) fn new( service: T, rx: mpsc::Receiver>, ) -> (Handle, Worker) { let handle = Handle { inner: Arc::new(Mutex::new(None)), }; let worker = Worker { current_message: None, finish: false, failed: None, rx, service, handle: handle.clone(), }; (handle, worker) } /// Return the next queued Message that hasn't been canceled. /// /// If a `Message` is returned, the `bool` is true if this is the first time we received this /// message, and false otherwise (i.e., we tried to forward it to the backing service before). fn poll_next_msg( &mut self, cx: &mut Context<'_>, ) -> Poll, bool)>> { if self.finish { // We've already received None and are shutting down return Poll::Ready(None); } tracing::trace!("worker polling for next message"); if let Some(mut msg) = self.current_message.take() { // poll_closed returns Poll::Ready is the receiver is dropped. // Returning Pending means it is still alive, so we should still // use it. if msg.tx.poll_closed(cx).is_pending() { tracing::trace!("resuming buffered request"); return Poll::Ready(Some((msg, false))); } tracing::trace!("dropping cancelled buffered request"); } // Get the next request while let Some(mut msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) { if msg.tx.poll_closed(cx).is_pending() { tracing::trace!("processing new request"); return Poll::Ready(Some((msg, true))); } // Otherwise, request is canceled, so pop the next one. tracing::trace!("dropping cancelled request"); } Poll::Ready(None) } fn failed(&mut self, error: Error) { // The underlying service failed when we called `poll_ready` on it with the given `error`. We // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in // an `Arc`, send that `Arc` to all pending requests, and store it so that subsequent // requests will also fail with the same error. // Note that we need to handle the case where some handle is concurrently trying to send us // a request. We need to make sure that *either* the send of the request fails *or* it // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the // case where we send errors to all outstanding requests, and *then* the caller sends its // request. We do this by *first* exposing the error, *then* closing the channel used to // send more requests (so the client will see the error when the send fails), and *then* // sending the error to all outstanding requests. let error = ServiceError::new(error); let mut inner = self.handle.inner.lock().unwrap(); if inner.is_some() { // Future::poll was called after we've already errored out! return; } *inner = Some(error.clone()); drop(inner); self.rx.close(); // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None), // which will trigger the `self.finish == true` phase. We just need to make sure that any // requests that we receive before we've exhausted the receiver receive the error: self.failed = Some(error); } } impl Future for Worker where T: Service, T::Error: Into, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.finish { return Poll::Ready(()); } loop { match ready!(self.poll_next_msg(cx)) { Some((msg, first)) => { let _guard = msg.span.enter(); if let Some(ref failed) = self.failed { tracing::trace!("notifying caller about worker failure"); let _ = msg.tx.send(Err(failed.clone())); continue; } // Wait for the service to be ready tracing::trace!( resumed = !first, message = "worker received request; waiting for service readiness" ); match self.service.poll_ready(cx) { Poll::Ready(Ok(())) => { tracing::debug!(service.ready = true, message = "processing request"); let response = self.service.call(msg.request); // Send the response future back to the sender. // // An error means the request had been canceled in-between // our calls, the response future will just be dropped. tracing::trace!("returning response future"); let _ = msg.tx.send(Ok(response)); } Poll::Pending => { tracing::trace!(service.ready = false, message = "delay"); // Put out current message back in its slot. drop(_guard); self.current_message = Some(msg); return Poll::Pending; } Poll::Ready(Err(e)) => { let error = e.into(); tracing::debug!({ %error }, "service failed"); drop(_guard); self.failed(error); let _ = msg.tx.send(Err(self .failed .as_ref() .expect("Worker::failed did not set self.failed?") .clone())); } } } None => { // No more more requests _ever_. self.finish = true; return Poll::Ready(()); } } } } } impl Handle { pub(crate) fn get_error_on_closed(&self) -> Error { self.inner .lock() .unwrap() .as_ref() .map(|svc_err| svc_err.clone().into()) .unwrap_or_else(|| Closed::new().into()) } } impl Clone for Handle { fn clone(&self) -> Handle { Handle { inner: self.inner.clone(), } } }