diff --git a/tower/Cargo.toml b/tower/Cargo.toml index a316629..5a293e6 100644 --- a/tower/Cargo.toml +++ b/tower/Cargo.toml @@ -55,7 +55,7 @@ hdrhistogram = { version = "6.0", optional = true } indexmap = { version = "1.0.2", optional = true } rand = { version = "0.7", features = ["small_rng"], optional = true } slab = { version = "0.4", optional = true } -tokio = { version = "0.2", optional = true } +tokio = { version = "0.2", optional = true, features = ["sync"] } [dev-dependencies] futures-util = { version = "0.3", default-features = false, features = ["alloc", "async-await"] } diff --git a/tower/src/limit/concurrency/future.rs b/tower/src/limit/concurrency/future.rs index eb1ac33..ff82c21 100644 --- a/tower/src/limit/concurrency/future.rs +++ b/tower/src/limit/concurrency/future.rs @@ -1,27 +1,27 @@ //! Future types //! -use super::sync::semaphore::Semaphore; use futures_core::ready; -use pin_project::{pin_project, pinned_drop}; -use std::sync::Arc; +use pin_project::pin_project; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; +use tokio::sync::OwnedSemaphorePermit; /// Future for the `ConcurrencyLimit` service. -#[pin_project(PinnedDrop)] +#[pin_project] #[derive(Debug)] pub struct ResponseFuture { #[pin] inner: T, - semaphore: Arc, + // Keep this around so that it is dropped when the future completes + _permit: OwnedSemaphorePermit, } impl ResponseFuture { - pub(crate) fn new(inner: T, semaphore: Arc) -> ResponseFuture { - ResponseFuture { inner, semaphore } + pub(crate) fn new(inner: T, _permit: OwnedSemaphorePermit) -> ResponseFuture { + ResponseFuture { inner, _permit } } } @@ -35,10 +35,3 @@ where Poll::Ready(ready!(self.project().inner.poll(cx))) } } - -#[pinned_drop] -impl PinnedDrop for ResponseFuture { - fn drop(self: Pin<&mut Self>) { - self.project().semaphore.add_permits(1); - } -} diff --git a/tower/src/limit/concurrency/mod.rs b/tower/src/limit/concurrency/mod.rs index 0a9f6b6..a1eba7e 100644 --- a/tower/src/limit/concurrency/mod.rs +++ b/tower/src/limit/concurrency/mod.rs @@ -3,6 +3,5 @@ pub mod future; mod layer; mod service; -mod sync; pub use self::{layer::ConcurrencyLimitLayer, service::ConcurrencyLimit}; diff --git a/tower/src/limit/concurrency/service.rs b/tower/src/limit/concurrency/service.rs index cf4ba4c..1a544a7 100644 --- a/tower/src/limit/concurrency/service.rs +++ b/tower/src/limit/concurrency/service.rs @@ -2,34 +2,38 @@ use super::future::ResponseFuture; use tower_service::Service; -use super::sync::semaphore::{self, Semaphore}; use futures_core::ready; +use std::fmt; +use std::future::Future; +use std::mem; +use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; /// Enforces a limit on the concurrent number of requests the underlying /// service can handle. #[derive(Debug)] pub struct ConcurrencyLimit { inner: T, - limit: Limit, + semaphore: Arc, + state: State, } -#[derive(Debug)] -struct Limit { - semaphore: Arc, - permit: semaphore::Permit, +enum State { + Waiting(Pin + Send + 'static>>), + Ready(OwnedSemaphorePermit), + Empty, } impl ConcurrencyLimit { /// Create a new concurrency limiter. pub fn new(inner: T, max: usize) -> Self { + let semaphore = Arc::new(Semaphore::new(max)); ConcurrencyLimit { inner, - limit: Limit { - semaphore: Arc::new(Semaphore::new(max)), - permit: semaphore::Permit::new(), - }, + semaphore, + state: State::Empty, } } @@ -58,31 +62,32 @@ where type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - ready!(self.limit.permit.poll_acquire(cx, &self.limit.semaphore)) - .expect("poll_acquire after semaphore closed "); - - Poll::Ready(ready!(self.inner.poll_ready(cx))) + loop { + self.state = match self.state { + State::Ready(_) => return self.inner.poll_ready(cx), + State::Waiting(ref mut fut) => { + tokio::pin!(fut); + let permit = ready!(fut.poll(cx)); + State::Ready(permit) + } + State::Empty => State::Waiting(Box::pin(self.semaphore.clone().acquire_owned())), + }; + } } fn call(&mut self, request: Request) -> Self::Future { // Make sure a permit has been acquired - if self - .limit - .permit - .try_acquire(&self.limit.semaphore) - .is_err() - { - panic!("max requests in-flight; poll_ready must be called first"); - } + let permit = match mem::replace(&mut self.state, State::Empty) { + // Take the permit. + State::Ready(permit) => permit, + // whoopsie! + _ => panic!("max requests in-flight; poll_ready must be called first"), + }; // Call the inner service let future = self.inner.call(request); - // Forget the permit, the permit will be returned when - // `future::ResponseFuture` is dropped. - self.limit.permit.forget(); - - ResponseFuture::new(future, self.limit.semaphore.clone()) + ResponseFuture::new(future, permit) } } @@ -104,16 +109,21 @@ where fn clone(&self) -> ConcurrencyLimit { ConcurrencyLimit { inner: self.inner.clone(), - limit: Limit { - semaphore: self.limit.semaphore.clone(), - permit: semaphore::Permit::new(), - }, + semaphore: self.semaphore.clone(), + state: State::Empty, } } } -impl Drop for Limit { - fn drop(&mut self) { - self.permit.release(&self.semaphore); +impl fmt::Debug for State { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + State::Waiting(_) => f + .debug_tuple("State::Waiting") + .field(&format_args!("...")) + .finish(), + State::Ready(ref r) => f.debug_tuple("State::Ready").field(&r).finish(), + State::Empty => f.debug_tuple("State::Empty").finish(), + } } } diff --git a/tower/src/limit/concurrency/sync/cell.rs b/tower/src/limit/concurrency/sync/cell.rs deleted file mode 100644 index daeddaa..0000000 --- a/tower/src/limit/concurrency/sync/cell.rs +++ /dev/null @@ -1,51 +0,0 @@ -#![allow(dead_code)] - -use std::cell::UnsafeCell; - -#[derive(Debug)] -pub(crate) struct CausalCell(UnsafeCell); - -#[derive(Default)] -pub(crate) struct CausalCheck(()); - -impl CausalCell { - pub(crate) fn new(data: T) -> CausalCell { - CausalCell(UnsafeCell::new(data)) - } - - pub(crate) fn with(&self, f: F) -> R - where - F: FnOnce(*const T) -> R, - { - f(self.0.get()) - } - - pub(crate) fn with_unchecked(&self, f: F) -> R - where - F: FnOnce(*const T) -> R, - { - f(self.0.get()) - } - - pub(crate) fn check(&self) {} - - pub(crate) fn with_deferred(&self, f: F) -> (R, CausalCheck) - where - F: FnOnce(*const T) -> R, - { - (f(self.0.get()), CausalCheck::default()) - } - - pub(crate) fn with_mut(&self, f: F) -> R - where - F: FnOnce(*mut T) -> R, - { - f(self.0.get()) - } -} - -impl CausalCheck { - pub(crate) fn check(self) {} - - pub(crate) fn join(&mut self, _other: CausalCheck) {} -} diff --git a/tower/src/limit/concurrency/sync/mod.rs b/tower/src/limit/concurrency/sync/mod.rs deleted file mode 100644 index f06e8af..0000000 --- a/tower/src/limit/concurrency/sync/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -// Vendored `tokio/src/sync/semaphore.rs` and `tokio/src/sync/task/atomic_waker.rs` -// Commit sha: 24cd6d67f76f122f67cbbb101d555018fc27820b - -mod cell; -mod waker; - -pub(super) mod semaphore; diff --git a/tower/src/limit/concurrency/sync/semaphore.rs b/tower/src/limit/concurrency/sync/semaphore.rs deleted file mode 100644 index 7e27978..0000000 --- a/tower/src/limit/concurrency/sync/semaphore.rs +++ /dev/null @@ -1,1070 +0,0 @@ -#![allow(dead_code)] - -//! Thread-safe, asynchronous counting semaphore. -//! -//! A `Semaphore` instance holds a set of permits. Permits are used to -//! synchronize access to a shared resource. -//! -//! Before accessing the shared resource, callers acquire a permit from the -//! semaphore. Once the permit is acquired, the caller then enters the critical -//! section. If no permits are available, then acquiring the semaphore returns -//! `Pending`. The task is woken once a permit becomes available. - -use super::cell::CausalCell; -use super::waker::AtomicWaker; -use std::{ - sync::atomic::{AtomicPtr, AtomicUsize}, - thread, -}; - -use std::fmt; -use std::ptr::{self, NonNull}; -use std::sync::atomic::Ordering::{self, AcqRel, Acquire, Relaxed, Release}; -use std::sync::Arc; -use std::task::Poll::{Pending, Ready}; -use std::task::{Context, Poll}; -use std::usize; - -/// Futures-aware semaphore. -pub(crate) struct Semaphore { - /// Tracks both the waiter queue tail pointer and the number of remaining - /// permits. - state: AtomicUsize, - - /// waiter queue head pointer. - head: CausalCell>, - - /// Coordinates access to the queue head. - rx_lock: AtomicUsize, - - /// Stub waiter node used as part of the MPSC channel algorithm. - stub: Box, -} - -/// A semaphore permit -/// -/// Tracks the lifecycle of a semaphore permit. -/// -/// An instance of `Permit` is intended to be used with a **single** instance of -/// `Semaphore`. Using a single instance of `Permit` with multiple semaphore -/// instances will result in unexpected behavior. -/// -/// `Permit` does **not** release the permit back to the semaphore on drop. It -/// is the user's responsibility to ensure that `Permit::release` is called -/// before dropping the permit. -#[derive(Debug)] -pub(crate) struct Permit { - waiter: Option>, - state: PermitState, -} - -/// Error returned by `Permit::poll_acquire`. -#[derive(Debug)] -pub(crate) struct AcquireError(()); - -/// Error returned by `Permit::try_acquire`. -#[derive(Debug)] -pub(crate) struct TryAcquireError { - kind: ErrorKind, -} - -#[derive(Debug)] -enum ErrorKind { - Closed, - NoPermits, -} - -/// Node used to notify the semaphore waiter when permit is available. -#[derive(Debug)] -struct WaiterNode { - /// Stores waiter state. - /// - /// See `NodeState` for more details. - state: AtomicUsize, - - /// Task to wake when a permit is made available. - waker: AtomicWaker, - - /// Next pointer in the queue of waiting senders. - next: AtomicPtr, -} - -/// Semaphore state -/// -/// The 2 low bits track the modes. -/// -/// - Closed -/// - Full -/// -/// When not full, the rest of the `usize` tracks the total number of messages -/// in the channel. When full, the rest of the `usize` is a pointer to the tail -/// of the "waiting senders" queue. -#[derive(Copy, Clone)] -struct SemState(usize); - -/// Permit state -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -enum PermitState { - /// The permit has not been requested. - Idle, - - /// Currently waiting for a permit to be made available and assigned to the - /// waiter. - Waiting, - - /// The permit has been acquired. - Acquired, -} - -/// Waiter node state -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -#[repr(usize)] -enum NodeState { - /// Not waiting for a permit and the node is not in the wait queue. - /// - /// This is the initial state. - Idle = 0, - - /// Not waiting for a permit but the node is in the wait queue. - /// - /// This happens when the waiter has previously requested a permit, but has - /// since canceled the request. The node cannot be removed by the waiter, so - /// this state informs the receiver to skip the node when it pops it from - /// the wait queue. - Queued = 1, - - /// Waiting for a permit and the node is in the wait queue. - QueuedWaiting = 2, - - /// The waiter has been assigned a permit and the node has been removed from - /// the queue. - Assigned = 3, - - /// The semaphore has been closed. No more permits will be issued. - Closed = 4, -} - -// ===== impl Semaphore ===== - -impl Semaphore { - /// Creates a new semaphore with the initial number of permits - /// - /// # Panics - /// - /// Panics if `permits` is zero. - pub(crate) fn new(permits: usize) -> Semaphore { - let stub = Box::new(WaiterNode::new()); - let ptr = NonNull::new(&*stub as *const _ as *mut _).unwrap(); - - // Allocations are aligned - debug_assert!(ptr.as_ptr() as usize & NUM_FLAG == 0); - - let state = SemState::new(permits, &stub); - - Semaphore { - state: AtomicUsize::new(state.to_usize()), - head: CausalCell::new(ptr), - rx_lock: AtomicUsize::new(0), - stub, - } - } - - /// Returns the current number of available permits - pub(crate) fn available_permits(&self) -> usize { - let curr = SemState::load(&self.state, Acquire); - curr.available_permits() - } - - /// Poll for a permit - fn poll_permit( - &self, - mut permit: Option<(&mut Context<'_>, &mut Permit)>, - ) -> Poll> { - // Load the current state - let mut curr = SemState::load(&self.state, Acquire); - - // Tracks a *mut WaiterNode representing an Arc clone. - // - // This avoids having to bump the ref count unless required. - let mut maybe_strong: Option> = None; - - macro_rules! undo_strong { - () => { - if let Some(waiter) = maybe_strong { - // The waiter was cloned, but never got queued. - // Before entering `poll_permit`, the waiter was in the - // `Idle` state. We must transition the node back to the - // idle state. - let waiter = unsafe { Arc::from_raw(waiter.as_ptr()) }; - waiter.revert_to_idle(); - } - }; - } - - loop { - let mut next = curr; - - if curr.is_closed() { - undo_strong!(); - return Ready(Err(AcquireError::closed())); - } - - if !next.acquire_permit(&self.stub) { - debug_assert!(curr.waiter().is_some()); - - if maybe_strong.is_none() { - if let Some((ref mut cx, ref mut permit)) = permit { - // Get the Sender's waiter node, or initialize one - let waiter = permit - .waiter - .get_or_insert_with(|| Arc::new(WaiterNode::new())); - - waiter.register(cx); - - if !waiter.to_queued_waiting() { - // The node is alrady queued, there is no further work - // to do. - return Pending; - } - - maybe_strong = Some(WaiterNode::into_non_null(waiter.clone())); - } else { - // If no `waiter`, then the task is not registered and there - // is no further work to do. - return Pending; - } - } - - next.set_waiter(maybe_strong.unwrap()); - } - - debug_assert_ne!(curr.0, 0); - debug_assert_ne!(next.0, 0); - - match next.compare_exchange(&self.state, curr, AcqRel, Acquire) { - Ok(_) => { - match curr.waiter() { - Some(prev_waiter) => { - let waiter = maybe_strong.unwrap(); - - // Finish pushing - unsafe { - prev_waiter.as_ref().next.store(waiter.as_ptr(), Release); - } - - return Pending; - } - None => { - undo_strong!(); - - return Ready(Ok(())); - } - } - } - Err(actual) => { - curr = actual; - } - } - } - } - - /// Close the semaphore. This prevents the semaphore from issuing new - /// permits and notifies all pending waiters. - pub(crate) fn close(&self) { - // Acquire the `rx_lock`, setting the "closed" flag on the lock. - let prev = self.rx_lock.fetch_or(1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(0, true); - } - - /// Add `n` new permits to the semaphore. - pub(crate) fn add_permits(&self, n: usize) { - if n == 0 { - return; - } - - // TODO: Handle overflow. A panic is not sufficient, the process must - // abort. - let prev = self.rx_lock.fetch_add(n << 1, AcqRel); - - if prev != 0 { - // Another thread has the lock and will be responsible for notifying - // pending waiters. - return; - } - - self.add_permits_locked(n, false); - } - - fn add_permits_locked(&self, mut rem: usize, mut closed: bool) { - while rem > 0 || closed { - if closed { - SemState::fetch_set_closed(&self.state, AcqRel); - } - - // Release the permits and notify - self.add_permits_locked2(rem, closed); - - let n = rem << 1; - - let actual = if closed { - let actual = self.rx_lock.fetch_sub(n | 1, AcqRel); - closed = false; - actual - } else { - let actual = self.rx_lock.fetch_sub(n, AcqRel); - closed = actual & 1 == 1; - actual - }; - - rem = (actual >> 1) - rem; - } - } - - /// Release a specific amount of permits to the semaphore - /// - /// This function is called by `add_permits` after the add lock has been - /// acquired. - fn add_permits_locked2(&self, mut n: usize, closed: bool) { - while n > 0 || closed { - let waiter = match self.pop(n, closed) { - Some(waiter) => waiter, - None => { - return; - } - }; - - if waiter.notify(closed) { - n = n.saturating_sub(1); - } - } - } - - /// Pop a waiter - /// - /// `rem` represents the remaining number of times the caller will pop. If - /// there are no more waiters to pop, `rem` is used to set the available - /// permits. - fn pop(&self, rem: usize, closed: bool) -> Option> { - 'outer: loop { - unsafe { - let mut head = self.head.with(|head| *head); - let mut next_ptr = head.as_ref().next.load(Acquire); - - let stub = self.stub(); - - if head == stub { - let next = match NonNull::new(next_ptr) { - Some(next) => next, - None => { - // This loop is not part of the standard intrusive mpsc - // channel algorithm. This is where we atomically pop - // the last task and add `rem` to the remaining capacity. - // - // This modification to the pop algorithm works because, - // at this point, we have not done any work (only done - // reading). We have a *pretty* good idea that there is - // no concurrent pusher. - // - // The capacity is then atomically added by doing an - // AcqRel CAS on `state`. The `state` cell is the - // linchpin of the algorithm. - // - // By successfully CASing `head` w/ AcqRel, we ensure - // that, if any thread was racing and entered a push, we - // see that and abort pop, retrying as it is - // "inconsistent". - let mut curr = SemState::load(&self.state, Acquire); - - loop { - if curr.has_waiter(&self.stub) { - // Inconsistent - thread::yield_now(); - continue 'outer; - } - - // When closing the semaphore, nodes are popped - // with `rem == 0`. In this case, we are not - // adding permits, but notifying waiters of the - // semaphore's closed state. - if rem == 0 { - debug_assert!(curr.is_closed(), "state = {:?}", curr); - return None; - } - - let mut next = curr; - next.release_permits(rem, &self.stub); - - match next.compare_exchange(&self.state, curr, AcqRel, Acquire) { - Ok(_) => return None, - Err(actual) => { - curr = actual; - } - } - } - } - }; - - self.head.with_mut(|head| *head = next); - head = next; - next_ptr = next.as_ref().next.load(Acquire); - } - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - return Some(Arc::from_raw(head.as_ptr())); - } - - let state = SemState::load(&self.state, Acquire); - - // This must always be a pointer as the wait list is not empty. - let tail = state.waiter().unwrap(); - - if tail != head { - // Inconsistent - thread::yield_now(); - continue 'outer; - } - - self.push_stub(closed); - - next_ptr = head.as_ref().next.load(Acquire); - - if let Some(next) = NonNull::new(next_ptr) { - self.head.with_mut(|head| *head = next); - - return Some(Arc::from_raw(head.as_ptr())); - } - - // Inconsistent state, loop - thread::yield_now(); - } - } - } - - unsafe fn push_stub(&self, closed: bool) { - let stub = self.stub(); - - // Set the next pointer. This does not require an atomic operation as - // this node is not accessible. The write will be flushed with the next - // operation - stub.as_ref().next.store(ptr::null_mut(), Relaxed); - - // Update the tail to point to the new node. We need to see the previous - // node in order to update the next pointer as well as release `task` - // to any other threads calling `push`. - let prev = SemState::new_ptr(stub, closed).swap(&self.state, AcqRel); - - debug_assert_eq!(closed, prev.is_closed()); - - // The stub is only pushed when there are pending tasks. Because of - // this, the state must *always* be in pointer mode. - let prev = prev.waiter().unwrap(); - - // We don't want the *existing* pointer to be a stub. - debug_assert_ne!(prev, stub); - - // Release `task` to the consume end. - prev.as_ref().next.store(stub.as_ptr(), Release); - } - - fn stub(&self) -> NonNull { - unsafe { NonNull::new_unchecked(&*self.stub as *const _ as *mut _) } - } -} - -impl fmt::Debug for Semaphore { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("Semaphore") - .field("state", &SemState::load(&self.state, Relaxed)) - .field("head", &self.head.with(|ptr| ptr)) - .field("rx_lock", &self.rx_lock.load(Relaxed)) - .field("stub", &self.stub) - .finish() - } -} - -unsafe impl Send for Semaphore {} -unsafe impl Sync for Semaphore {} - -// ===== impl Permit ===== - -impl Permit { - /// Create a new `Permit`. - /// - /// The permit begins in the "unacquired" state. - pub(crate) fn new() -> Permit { - Permit { - waiter: None, - state: PermitState::Idle, - } - } - - /// Returns true if the permit has been acquired - pub(crate) fn is_acquired(&self) -> bool { - self.state == PermitState::Acquired - } - - /// Try to acquire the permit. If no permits are available, the current task - /// is notified once a new permit becomes available. - pub(crate) fn poll_acquire( - &mut self, - cx: &mut Context<'_>, - semaphore: &Semaphore, - ) -> Poll> { - match self.state { - PermitState::Idle => {} - PermitState::Waiting => { - let waiter = self.waiter.as_ref().unwrap(); - - if waiter.acquire(cx)? { - self.state = PermitState::Acquired; - return Ready(Ok(())); - } else { - return Pending; - } - } - PermitState::Acquired => { - return Ready(Ok(())); - } - } - - match semaphore.poll_permit(Some((cx, self)))? { - Ready(()) => { - self.state = PermitState::Acquired; - Ready(Ok(())) - } - Pending => { - self.state = PermitState::Waiting; - Pending - } - } - } - - /// Try to acquire the permit. - pub(crate) fn try_acquire(&mut self, semaphore: &Semaphore) -> Result<(), TryAcquireError> { - match self.state { - PermitState::Idle => {} - PermitState::Waiting => { - let waiter = self.waiter.as_ref().unwrap(); - - if waiter.acquire2().map_err(to_try_acquire)? { - self.state = PermitState::Acquired; - return Ok(()); - } else { - return Err(TryAcquireError::no_permits()); - } - } - PermitState::Acquired => { - return Ok(()); - } - } - - match semaphore.poll_permit(None).map_err(to_try_acquire)? { - Ready(()) => { - self.state = PermitState::Acquired; - Ok(()) - } - Pending => Err(TryAcquireError::no_permits()), - } - } - - /// Release a permit back to the semaphore - pub(crate) fn release(&mut self, semaphore: &Semaphore) { - if self.forget2() { - semaphore.add_permits(1); - } - } - - /// Forget the permit **without** releasing it back to the semaphore. - /// - /// After calling `forget`, `poll_acquire` is able to acquire new permit - /// from the sempahore. - /// - /// Repeatedly calling `forget` without associated calls to `add_permit` - /// will result in the semaphore losing all permits. - pub(crate) fn forget(&mut self) { - self.forget2(); - } - - /// Returns `true` if the permit was acquired - fn forget2(&mut self) -> bool { - match self.state { - PermitState::Idle => false, - PermitState::Waiting => { - let ret = self.waiter.as_ref().unwrap().cancel_interest(); - self.state = PermitState::Idle; - ret - } - PermitState::Acquired => { - self.state = PermitState::Idle; - true - } - } - } -} - -impl Default for Permit { - fn default() -> Self { - Self::new() - } -} - -// ===== impl AcquireError ==== - -impl AcquireError { - fn closed() -> AcquireError { - AcquireError(()) - } -} - -fn to_try_acquire(_: AcquireError) -> TryAcquireError { - TryAcquireError::closed() -} - -impl fmt::Display for AcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "semaphore closed") - } -} - -impl ::std::error::Error for AcquireError {} - -// ===== impl TryAcquireError ===== - -impl TryAcquireError { - fn closed() -> TryAcquireError { - TryAcquireError { - kind: ErrorKind::Closed, - } - } - - fn no_permits() -> TryAcquireError { - TryAcquireError { - kind: ErrorKind::NoPermits, - } - } - - /// Returns true if the error was caused by a closed semaphore. - pub(crate) fn is_closed(&self) -> bool { - match self.kind { - ErrorKind::Closed => true, - _ => false, - } - } - - /// Returns true if the error was caused by calling `try_acquire` on a - /// semaphore with no available permits. - pub(crate) fn is_no_permits(&self) -> bool { - match self.kind { - ErrorKind::NoPermits => true, - _ => false, - } - } -} - -impl fmt::Display for TryAcquireError { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let descr = match self.kind { - ErrorKind::Closed => "semaphore closed", - ErrorKind::NoPermits => "no permits available", - }; - write!(fmt, "{}", descr) - } -} - -impl ::std::error::Error for TryAcquireError {} - -// ===== impl WaiterNode ===== - -impl WaiterNode { - fn new() -> WaiterNode { - WaiterNode { - state: AtomicUsize::new(NodeState::new().to_usize()), - waker: AtomicWaker::new(), - next: AtomicPtr::new(ptr::null_mut()), - } - } - - fn acquire(&self, cx: &mut Context<'_>) -> Result { - if self.acquire2()? { - return Ok(true); - } - - self.waker.register_by_ref(cx.waker()); - - self.acquire2() - } - - fn acquire2(&self) -> Result { - use self::NodeState::*; - - match Idle.compare_exchange(&self.state, Assigned, AcqRel, Acquire) { - Ok(_) => Ok(true), - Err(Closed) => Err(AcquireError::closed()), - Err(_) => Ok(false), - } - } - - fn register(&self, cx: &mut Context<'_>) { - self.waker.register_by_ref(cx.waker()) - } - - /// Returns `true` if the permit has been acquired - fn cancel_interest(&self) -> bool { - use self::NodeState::*; - - match Queued.compare_exchange(&self.state, QueuedWaiting, AcqRel, Acquire) { - // Successfully removed interest from the queued node. The permit - // has not been assigned to the node. - Ok(_) => false, - // The semaphore has been closed, there is no further action to - // take. - Err(Closed) => false, - // The permit has been assigned. It must be acquired in order to - // be released back to the semaphore. - Err(Assigned) => { - match self.acquire2() { - Ok(true) => true, - // Not a reachable state - Ok(false) => panic!(), - // The semaphore has been closed, no further action to take. - Err(_) => false, - } - } - Err(state) => panic!("unexpected state = {:?}", state), - } - } - - /// Transition the state to `QueuedWaiting`. - /// - /// This step can only happen from `Queued` or from `Idle`. - /// - /// Returns `true` if transitioning into a queued state. - fn to_queued_waiting(&self) -> bool { - use self::NodeState::*; - - let mut curr = NodeState::load(&self.state, Acquire); - - loop { - debug_assert!(curr == Idle || curr == Queued, "actual = {:?}", curr); - let next = QueuedWaiting; - - match next.compare_exchange(&self.state, curr, AcqRel, Acquire) { - Ok(_) => { - if curr.is_queued() { - return false; - } else { - // Transitioned to queued, reset next pointer - self.next.store(ptr::null_mut(), Relaxed); - return true; - } - } - Err(actual) => { - curr = actual; - } - } - } - } - - /// Notify the waiter - /// - /// Returns `true` if the waiter accepts the notification - fn notify(&self, closed: bool) -> bool { - use self::NodeState::*; - - // Assume QueuedWaiting state - let mut curr = QueuedWaiting; - - loop { - let next = match curr { - Queued => Idle, - QueuedWaiting => { - if closed { - Closed - } else { - Assigned - } - } - actual => panic!("actual = {:?}", actual), - }; - - match next.compare_exchange(&self.state, curr, AcqRel, Acquire) { - Ok(_) => match curr { - QueuedWaiting => { - self.waker.wake(); - return true; - } - _ => return false, - }, - Err(actual) => curr = actual, - } - } - } - - fn revert_to_idle(&self) { - use self::NodeState::Idle; - - // There are no other handles to the node - NodeState::store(&self.state, Idle, Relaxed); - } - - #[allow(clippy::wrong_self_convention)] // https://github.com/rust-lang/rust-clippy/issues/4293 - fn into_non_null(self: Arc) -> NonNull { - let ptr = Arc::into_raw(self); - unsafe { NonNull::new_unchecked(ptr as *mut _) } - } -} - -// ===== impl State ===== - -/// Flag differentiating between available permits and waiter pointers. -/// -/// If we assume pointers are properly aligned, then the least significant bit -/// will always be zero. So, we use that bit to track if the value represents a -/// number. -const NUM_FLAG: usize = 0b01; - -const CLOSED_FLAG: usize = 0b10; - -const MAX_PERMITS: usize = usize::MAX >> NUM_SHIFT; - -/// When representing "numbers", the state has to be shifted this much (to get -/// rid of the flag bit). -const NUM_SHIFT: usize = 2; - -impl SemState { - /// Returns a new default `State` value. - fn new(permits: usize, stub: &WaiterNode) -> SemState { - assert!(permits <= MAX_PERMITS); - - if permits > 0 { - SemState((permits << NUM_SHIFT) | NUM_FLAG) - } else { - SemState(stub as *const _ as usize) - } - } - - /// Returns a `State` tracking `ptr` as the tail of the queue. - fn new_ptr(tail: NonNull, closed: bool) -> SemState { - let mut val = tail.as_ptr() as usize; - - if closed { - val |= CLOSED_FLAG; - } - - SemState(val) - } - - /// Returns the amount of remaining capacity - fn available_permits(self) -> usize { - if !self.has_available_permits() { - return 0; - } - - self.0 >> NUM_SHIFT - } - - /// Returns true if the state has permits that can be claimed by a waiter. - fn has_available_permits(self) -> bool { - self.0 & NUM_FLAG == NUM_FLAG - } - - fn has_waiter(self, stub: &WaiterNode) -> bool { - !self.has_available_permits() && !self.is_stub(stub) - } - - /// Try to acquire a permit - /// - /// # Return - /// - /// Returns `true` if the permit was acquired, `false` otherwise. If `false` - /// is returned, it can be assumed that `State` represents the head pointer - /// in the mpsc channel. - fn acquire_permit(&mut self, stub: &WaiterNode) -> bool { - if !self.has_available_permits() { - return false; - } - - debug_assert!(self.waiter().is_none()); - - self.0 -= 1 << NUM_SHIFT; - - if self.0 == NUM_FLAG { - // Set the state to the stub pointer. - self.0 = stub as *const _ as usize; - } - - true - } - - /// Release permits - /// - /// Returns `true` if the permits were accepted. - fn release_permits(&mut self, permits: usize, stub: &WaiterNode) { - debug_assert!(permits > 0); - - if self.is_stub(stub) { - self.0 = (permits << NUM_SHIFT) | NUM_FLAG | (self.0 & CLOSED_FLAG); - return; - } - - debug_assert!(self.has_available_permits()); - - self.0 += permits << NUM_SHIFT; - } - - fn is_waiter(self) -> bool { - self.0 & NUM_FLAG == 0 - } - - /// Returns the waiter, if one is set. - fn waiter(self) -> Option> { - if self.is_waiter() { - let waiter = NonNull::new(self.as_ptr()).expect("null pointer stored"); - - Some(waiter) - } else { - None - } - } - - /// Assumes `self` represents a pointer - fn as_ptr(self) -> *mut WaiterNode { - (self.0 & !CLOSED_FLAG) as *mut WaiterNode - } - - /// Set to a pointer to a waiter. - /// - /// This can only be done from the full state. - fn set_waiter(&mut self, waiter: NonNull) { - let waiter = waiter.as_ptr() as usize; - debug_assert!(waiter & NUM_FLAG == 0); - debug_assert!(!self.is_closed()); - - self.0 = waiter; - } - - fn is_stub(self, stub: &WaiterNode) -> bool { - self.as_ptr() as usize == stub as *const _ as usize - } - - /// Load the state from an AtomicUsize. - fn load(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.load(ordering); - SemState(value) - } - - /// Swap the values - fn swap(self, cell: &AtomicUsize, ordering: Ordering) -> SemState { - let prev = SemState(cell.swap(self.to_usize(), ordering)); - debug_assert_eq!(prev.is_closed(), self.is_closed()); - prev - } - - /// Compare and exchange the current value into the provided cell - fn compare_exchange( - self, - cell: &AtomicUsize, - prev: SemState, - success: Ordering, - failure: Ordering, - ) -> Result { - debug_assert_eq!(prev.is_closed(), self.is_closed()); - - let res = cell.compare_exchange(prev.to_usize(), self.to_usize(), success, failure); - - res.map(SemState).map_err(SemState) - } - - fn fetch_set_closed(cell: &AtomicUsize, ordering: Ordering) -> SemState { - let value = cell.fetch_or(CLOSED_FLAG, ordering); - SemState(value) - } - - fn is_closed(self) -> bool { - self.0 & CLOSED_FLAG == CLOSED_FLAG - } - - /// Converts the state into a `usize` representation. - fn to_usize(self) -> usize { - self.0 - } -} - -impl fmt::Debug for SemState { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut fmt = fmt.debug_struct("SemState"); - - if self.is_waiter() { - fmt.field("state", &""); - } else { - fmt.field("permits", &self.available_permits()); - } - - fmt.finish() - } -} - -// ===== impl NodeState ===== - -impl NodeState { - fn new() -> NodeState { - NodeState::Idle - } - - fn from_usize(value: usize) -> NodeState { - use self::NodeState::*; - - match value { - 0 => Idle, - 1 => Queued, - 2 => QueuedWaiting, - 3 => Assigned, - 4 => Closed, - _ => panic!(), - } - } - - fn load(cell: &AtomicUsize, ordering: Ordering) -> NodeState { - NodeState::from_usize(cell.load(ordering)) - } - - /// Store a value - fn store(cell: &AtomicUsize, value: NodeState, ordering: Ordering) { - cell.store(value.to_usize(), ordering); - } - - fn compare_exchange( - self, - cell: &AtomicUsize, - prev: NodeState, - success: Ordering, - failure: Ordering, - ) -> Result { - cell.compare_exchange(prev.to_usize(), self.to_usize(), success, failure) - .map(NodeState::from_usize) - .map_err(NodeState::from_usize) - } - - /// Returns `true` if `self` represents a queued state. - fn is_queued(self) -> bool { - use self::NodeState::*; - - match self { - Queued | QueuedWaiting => true, - _ => false, - } - } - - fn to_usize(self) -> usize { - self as usize - } -} diff --git a/tower/src/limit/concurrency/sync/waker.rs b/tower/src/limit/concurrency/sync/waker.rs deleted file mode 100644 index b467a14..0000000 --- a/tower/src/limit/concurrency/sync/waker.rs +++ /dev/null @@ -1,316 +0,0 @@ -use super::cell::CausalCell; -use std::sync::atomic::{self, AtomicUsize}; - -use std::fmt; -use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; -use std::task::Waker; - -/// A synchronization primitive for task waking. -/// -/// `AtomicWaker` will coordinate concurrent wakes with the consumer -/// potentially "waking" the underlying task. This is useful in scenarios -/// where a computation completes in another thread and wants to wake the -/// consumer, but the consumer is in the process of being migrated to a new -/// logical task. -/// -/// Consumers should call `register` before checking the result of a computation -/// and producers should call `wake` after producing the computation (this -/// differs from the usual `thread::park` pattern). It is also permitted for -/// `wake` to be called **before** `register`. This results in a no-op. -/// -/// A single `AtomicWaker` may be reused for any number of calls to `register` or -/// `wake`. -pub(crate) struct AtomicWaker { - state: AtomicUsize, - waker: CausalCell>, -} - -// `AtomicWaker` is a multi-consumer, single-producer transfer cell. The cell -// stores a `Waker` value produced by calls to `register` and many threads can -// race to take the waker by calling `wake. -// -// If a new `Waker` instance is produced by calling `register` before an existing -// one is consumed, then the existing one is overwritten. -// -// While `AtomicWaker` is single-producer, the implementation ensures memory -// safety. In the event of concurrent calls to `register`, there will be a -// single winner whose waker will get stored in the cell. The losers will not -// have their tasks woken. As such, callers should ensure to add synchronization -// to calls to `register`. -// -// The implementation uses a single `AtomicUsize` value to coordinate access to -// the `Waker` cell. There are two bits that are operated on independently. These -// are represented by `REGISTERING` and `WAKING`. -// -// The `REGISTERING` bit is set when a producer enters the critical section. The -// `WAKING` bit is set when a consumer enters the critical section. Neither -// bit being set is represented by `WAITING`. -// -// A thread obtains an exclusive lock on the waker cell by transitioning the -// state from `WAITING` to `REGISTERING` or `WAKING`, depending on the -// operation the thread wishes to perform. When this transition is made, it is -// guaranteed that no other thread will access the waker cell. -// -// # Registering -// -// On a call to `register`, an attempt to transition the state from WAITING to -// REGISTERING is made. On success, the caller obtains a lock on the waker cell. -// -// If the lock is obtained, then the thread sets the waker cell to the waker -// provided as an argument. Then it attempts to transition the state back from -// `REGISTERING` -> `WAITING`. -// -// If this transition is successful, then the registering process is complete -// and the next call to `wake` will observe the waker. -// -// If the transition fails, then there was a concurrent call to `wake` that -// was unable to access the waker cell (due to the registering thread holding the -// lock). To handle this, the registering thread removes the waker it just set -// from the cell and calls `wake` on it. This call to wake represents the -// attempt to wake by the other thread (that set the `WAKING` bit). The -// state is then transitioned from `REGISTERING | WAKING` back to `WAITING`. -// This transition must succeed because, at this point, the state cannot be -// transitioned by another thread. -// -// # Waking -// -// On a call to `wake`, an attempt to transition the state from `WAITING` to -// `WAKING` is made. On success, the caller obtains a lock on the waker cell. -// -// If the lock is obtained, then the thread takes ownership of the current value -// in the waker cell, and calls `wake` on it. The state is then transitioned -// back to `WAITING`. This transition must succeed as, at this point, the state -// cannot be transitioned by another thread. -// -// If the thread is unable to obtain the lock, the `WAKING` bit is still. -// This is because it has either been set by the current thread but the previous -// value included the `REGISTERING` bit **or** a concurrent thread is in the -// `WAKING` critical section. Either way, no action must be taken. -// -// If the current thread is the only concurrent call to `wake` and another -// thread is in the `register` critical section, when the other thread **exits** -// the `register` critical section, it will observe the `WAKING` bit and -// handle the waker itself. -// -// If another thread is in the `waker` critical section, then it will handle -// waking the caller task. -// -// # A potential race (is safely handled). -// -// Imagine the following situation: -// -// * Thread A obtains the `wake` lock and wakes a task. -// -// * Before thread A releases the `wake` lock, the woken task is scheduled. -// -// * Thread B attempts to wake the task. In theory this should result in the -// task being woken, but it cannot because thread A still holds the wake -// lock. -// -// This case is handled by requiring users of `AtomicWaker` to call `register` -// **before** attempting to observe the application state change that resulted -// in the task being woken. The wakers also change the application state -// before calling wake. -// -// Because of this, the task will do one of two things. -// -// 1) Observe the application state change that Thread B is waking on. In -// this case, it is OK for Thread B's wake to be lost. -// -// 2) Call register before attempting to observe the application state. Since -// Thread A still holds the `wake` lock, the call to `register` will result -// in the task waking itself and get scheduled again. - -/// Idle state -const WAITING: usize = 0; - -/// A new waker value is being registered with the `AtomicWaker` cell. -const REGISTERING: usize = 0b01; - -/// The task currently registered with the `AtomicWaker` cell is being woken. -const WAKING: usize = 0b10; - -impl AtomicWaker { - /// Create an `AtomicWaker` - pub(crate) fn new() -> AtomicWaker { - AtomicWaker { - state: AtomicUsize::new(WAITING), - waker: CausalCell::new(None), - } - } - - /// Registers the current waker to be notified on calls to `wake`. - /// - /// This is the same as calling `register_task` with `task::current()`. - #[cfg(feature = "io-driver")] - pub(crate) fn register(&self, waker: Waker) { - self.do_register(waker); - } - - /// Registers the provided waker to be notified on calls to `wake`. - /// - /// The new waker will take place of any previous wakers that were registered - /// by previous calls to `register`. Any calls to `wake` that happen after - /// a call to `register` (as defined by the memory ordering rules), will - /// wake the `register` caller's task. - /// - /// It is safe to call `register` with multiple other threads concurrently - /// calling `wake`. This will result in the `register` caller's current - /// task being woken once. - /// - /// This function is safe to call concurrently, but this is generally a bad - /// idea. Concurrent calls to `register` will attempt to register different - /// tasks to be woken. One of the callers will win and have its task set, - /// but there is no guarantee as to which caller will succeed. - pub(crate) fn register_by_ref(&self, waker: &Waker) { - self.do_register(waker); - } - - fn do_register(&self, waker: W) - where - W: WakerRef, - { - match self.state.compare_and_swap(WAITING, REGISTERING, Acquire) { - WAITING => { - unsafe { - // Locked acquired, update the waker cell - self.waker.with_mut(|t| *t = Some(waker.into_waker())); - - // Release the lock. If the state transitioned to include - // the `WAKING` bit, this means that a wake has been - // called concurrently, so we have to remove the waker and - // wake it.` - // - // Start by assuming that the state is `REGISTERING` as this - // is what we jut set it to. - let res = self - .state - .compare_exchange(REGISTERING, WAITING, AcqRel, Acquire); - - match res { - Ok(_) => {} - Err(actual) => { - // This branch can only be reached if a - // concurrent thread called `wake`. In this - // case, `actual` **must** be `REGISTERING | - // `WAKING`. - debug_assert_eq!(actual, REGISTERING | WAKING); - - // Take the waker to wake once the atomic operation has - // completed. - let waker = self.waker.with_mut(|t| (*t).take()).unwrap(); - - // Just swap, because no one could change state - // while state == `Registering | `Waking` - self.state.swap(WAITING, AcqRel); - - // The atomic swap was complete, now - // wake the waker and return. - waker.wake(); - } - } - } - } - WAKING => { - // Currently in the process of waking the task, i.e., - // `wake` is currently being called on the old waker. - // So, we call wake on the new waker. - waker.wake(); - - // This is equivalent to a spin lock, so use a spin hint. - atomic::spin_loop_hint(); - } - state => { - // In this case, a concurrent thread is holding the - // "registering" lock. This probably indicates a bug in the - // caller's code as racing to call `register` doesn't make much - // sense. - // - // We just want to maintain memory safety. It is ok to drop the - // call to `register`. - debug_assert!(state == REGISTERING || state == REGISTERING | WAKING); - } - } - } - - /// Wakes the task that last called `register`. - /// - /// If `register` has not been called yet, then this does nothing. - pub(crate) fn wake(&self) { - if let Some(waker) = self.take_waker() { - waker.wake(); - } - } - - /// Attempts to take the `Waker` value out of the `AtomicWaker` with the - /// intention that the caller will wake the task later. - pub(crate) fn take_waker(&self) -> Option { - // AcqRel ordering is used in order to acquire the value of the `waker` - // cell as well as to establish a `release` ordering with whatever - // memory the `AtomicWaker` is associated with. - match self.state.fetch_or(WAKING, AcqRel) { - WAITING => { - // The waking lock has been acquired. - let waker = unsafe { self.waker.with_mut(|t| (*t).take()) }; - - // Release the lock - self.state.fetch_and(!WAKING, Release); - - waker - } - state => { - // There is a concurrent thread currently updating the - // associated waker. - // - // Nothing more to do as the `WAKING` bit has been set. It - // doesn't matter if there are concurrent registering threads or - // not. - // - debug_assert!( - state == REGISTERING || state == REGISTERING | WAKING || state == WAKING - ); - None - } - } - } -} - -impl Default for AtomicWaker { - fn default() -> Self { - AtomicWaker::new() - } -} - -impl fmt::Debug for AtomicWaker { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(fmt, "AtomicWaker") - } -} - -unsafe impl Send for AtomicWaker {} -unsafe impl Sync for AtomicWaker {} - -trait WakerRef { - fn wake(self); - fn into_waker(self) -> Waker; -} - -impl WakerRef for Waker { - fn wake(self) { - self.wake() - } - - fn into_waker(self) -> Waker { - self - } -} - -impl WakerRef for &Waker { - fn wake(self) { - self.wake_by_ref() - } - - fn into_waker(self) -> Waker { - self.clone() - } -}