server: add sessions timeouts (#386)
* server: add sessions timeouts * clippy fixes
This commit is contained in:
parent
5e860be303
commit
f5cb068ed2
|
@ -2676,6 +2676,8 @@ dependencies = [
|
|||
"frost-core",
|
||||
"frost-ed25519",
|
||||
"frost-rerandomized",
|
||||
"futures",
|
||||
"futures-util",
|
||||
"hex",
|
||||
"rand",
|
||||
"rcgen",
|
||||
|
|
|
@ -29,6 +29,8 @@ tracing = "0.1"
|
|||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1.11.0", features = ["v4", "fast-rng", "serde"] }
|
||||
xeddsa = "1.0.2"
|
||||
futures-util = "0.3.31"
|
||||
futures = "0.3.31"
|
||||
hex = "0.4.3"
|
||||
|
||||
[dev-dependencies]
|
||||
|
|
|
@ -106,12 +106,12 @@ pub(crate) async fn create_new_session(
|
|||
// Create new session object.
|
||||
let id = Uuid::new_v4();
|
||||
|
||||
let mut state = state.sessions.write().unwrap();
|
||||
let mut sessions = state.sessions.sessions.write().unwrap();
|
||||
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();
|
||||
|
||||
// Save session ID in global state
|
||||
for pubkey in &args.pubkeys {
|
||||
state
|
||||
.sessions_by_pubkey
|
||||
sessions_by_pubkey
|
||||
.entry(pubkey.0.clone())
|
||||
.or_default()
|
||||
.insert(id);
|
||||
|
@ -125,7 +125,7 @@ pub(crate) async fn create_new_session(
|
|||
queue: Default::default(),
|
||||
};
|
||||
// Save session into global state.
|
||||
state.sessions.insert(id, session);
|
||||
sessions.insert(id, session);
|
||||
|
||||
let user = CreateNewSessionOutput { session_id: id };
|
||||
Ok(Json(user))
|
||||
|
@ -137,10 +137,9 @@ pub(crate) async fn list_sessions(
|
|||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
) -> Result<Json<ListSessionsOutput>, AppError> {
|
||||
let state = state.sessions.read().unwrap();
|
||||
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();
|
||||
|
||||
let session_ids = state
|
||||
.sessions_by_pubkey
|
||||
let session_ids = sessions_by_pubkey
|
||||
.get(&user.pubkey)
|
||||
.map(|s| s.iter().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
|
@ -155,24 +154,22 @@ pub(crate) async fn get_session_info(
|
|||
user: User,
|
||||
Json(args): Json<GetSessionInfoArgs>,
|
||||
) -> Result<Json<GetSessionInfoOutput>, AppError> {
|
||||
let state_lock = state.sessions.read().unwrap();
|
||||
let sessions = state.sessions.sessions.read().unwrap();
|
||||
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();
|
||||
|
||||
let sessions = state_lock
|
||||
.sessions_by_pubkey
|
||||
.get(&user.pubkey)
|
||||
.ok_or(AppError(
|
||||
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("user is not in any session").into(),
|
||||
))?;
|
||||
|
||||
if !sessions.contains(&args.session_id) {
|
||||
if !user_sessions.contains(&args.session_id) {
|
||||
return Err(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found").into(),
|
||||
));
|
||||
}
|
||||
|
||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||
let session = sessions.get(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
@ -194,12 +191,11 @@ pub(crate) async fn send(
|
|||
Json(args): Json<SendArgs>,
|
||||
) -> Result<(), AppError> {
|
||||
// Get the mutex lock to read and write from the state
|
||||
let mut state_lock = state.sessions.write().unwrap();
|
||||
let mut sessions = state.sessions.sessions.write().unwrap();
|
||||
|
||||
let session = state_lock
|
||||
.sessions
|
||||
.get_mut(&args.session_id)
|
||||
.ok_or(AppError(
|
||||
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
|
||||
// adds support to it
|
||||
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
@ -219,6 +215,7 @@ pub(crate) async fn send(
|
|||
msg: args.msg.clone(),
|
||||
});
|
||||
}
|
||||
sessions.insert(args.session_id, session);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -232,12 +229,13 @@ pub(crate) async fn receive(
|
|||
Json(args): Json<ReceiveArgs>,
|
||||
) -> Result<Json<ReceiveOutput>, AppError> {
|
||||
// Get the mutex lock to read and write from the state
|
||||
let mut state_lock = state.sessions.write().unwrap();
|
||||
let sessions = state.sessions.sessions.read().unwrap();
|
||||
|
||||
let session = state_lock
|
||||
.sessions
|
||||
.get_mut(&args.session_id)
|
||||
.ok_or(AppError(
|
||||
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
|
||||
// adds support to it. This will also simplify the code since
|
||||
// we have to do a workaround in order to not renew the timeout if there
|
||||
// are no messages. See https://github.com/AgeManning/delay_map/issues/26
|
||||
let session = sessions.get(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
@ -248,7 +246,22 @@ pub(crate) async fn receive(
|
|||
user.pubkey
|
||||
};
|
||||
|
||||
// If there are no new messages, we don't want to renew the timeout.
|
||||
// Thus only if there are new messages we drop the read-only lock
|
||||
// to get the write lock and re-insert the updated session.
|
||||
let msgs = if session.queue.contains_key(&pubkey) {
|
||||
drop(sessions);
|
||||
let mut sessions = state.sessions.sessions.write().unwrap();
|
||||
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
|
||||
sessions.insert(args.session_id, session);
|
||||
msgs
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
Ok(Json(ReceiveOutput { msgs }))
|
||||
}
|
||||
|
@ -260,21 +273,22 @@ pub(crate) async fn close_session(
|
|||
user: User,
|
||||
Json(args): Json<CloseSessionArgs>,
|
||||
) -> Result<Json<()>, AppError> {
|
||||
let mut state = state.sessions.write().unwrap();
|
||||
let mut sessions = state.sessions.sessions.write().unwrap();
|
||||
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();
|
||||
|
||||
let sessions = state.sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
|
||||
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("user is not in any session").into(),
|
||||
))?;
|
||||
|
||||
if !sessions.contains(&args.session_id) {
|
||||
if !user_sessions.contains(&args.session_id) {
|
||||
return Err(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found").into(),
|
||||
));
|
||||
}
|
||||
|
||||
let session = state.sessions.get(&args.session_id).ok_or(AppError(
|
||||
let session = sessions.get(&args.session_id).ok_or(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("invalid session ID").into(),
|
||||
))?;
|
||||
|
@ -287,10 +301,10 @@ pub(crate) async fn close_session(
|
|||
}
|
||||
|
||||
for username in session.pubkeys.clone() {
|
||||
if let Some(v) = state.sessions_by_pubkey.get_mut(&username) {
|
||||
if let Some(v) = sessions_by_pubkey.get_mut(&username) {
|
||||
v.remove(&args.session_id);
|
||||
}
|
||||
}
|
||||
state.sessions.remove(&args.session_id);
|
||||
sessions.remove(&args.session_id);
|
||||
Ok(Json(()))
|
||||
}
|
||||
|
|
|
@ -1,18 +1,40 @@
|
|||
use std::{
|
||||
collections::{HashMap, HashSet, VecDeque},
|
||||
pin::Pin,
|
||||
sync::{Arc, RwLock},
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use delay_map::{HashMapDelay, HashSetDelay};
|
||||
use futures::{Stream, StreamExt as _};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Msg;
|
||||
|
||||
/// How long a session stays open.
|
||||
const SESSION_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60 * 24);
|
||||
/// How long a challenge can be replied to.
|
||||
const CHALLENGE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
||||
/// How long an acesss token lasts.
|
||||
const ACCESS_TOKEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60);
|
||||
|
||||
/// Helper struct that allows calling `next()` on a `Stream` behind a `RwLock`
|
||||
/// (namely a `HashMapDelay` or `HashSetDelay` in our case) without locking
|
||||
/// the `RwLock` while waiting.
|
||||
// From https://users.rust-lang.org/t/how-do-i-poll-a-stream-behind-a-rwlock/121787/2
|
||||
struct RwLockStream<'a, T>(pub &'a RwLock<T>);
|
||||
|
||||
impl<T: Stream + Unpin> Stream for RwLockStream<'_, T> {
|
||||
type Item = T::Item;
|
||||
fn poll_next(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<<Self as Stream>::Item>> {
|
||||
self.0.write().unwrap().poll_next_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// A particular signing session.
|
||||
#[derive(Debug)]
|
||||
pub struct Session {
|
||||
|
@ -22,8 +44,6 @@ pub struct Session {
|
|||
pub(crate) coordinator_pubkey: Vec<u8>,
|
||||
/// The number of signers in the session.
|
||||
pub(crate) num_signers: u16,
|
||||
/// The set of identifiers for the session.
|
||||
// pub(crate) identifiers: BTreeSet<SerializedIdentifier>,
|
||||
/// The number of messages being simultaneously signed.
|
||||
pub(crate) message_count: u8,
|
||||
/// The message queue.
|
||||
|
@ -33,7 +53,7 @@ pub struct Session {
|
|||
/// The global state of the server.
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
pub(crate) sessions: Arc<RwLock<SessionState>>,
|
||||
pub(crate) sessions: SessionState,
|
||||
pub(crate) challenges: Arc<RwLock<HashSetDelay<Uuid>>>,
|
||||
pub(crate) access_tokens: Arc<RwLock<HashMapDelay<Uuid, Vec<u8>>>>,
|
||||
}
|
||||
|
@ -41,18 +61,85 @@ pub struct AppState {
|
|||
#[derive(Debug, Default)]
|
||||
pub struct SessionState {
|
||||
/// Mapping of signing sessions by UUID.
|
||||
pub(crate) sessions: HashMap<Uuid, Session>,
|
||||
pub(crate) sessions_by_pubkey: HashMap<Vec<u8>, HashSet<Uuid>>,
|
||||
pub(crate) sessions: Arc<RwLock<HashMapDelay<Uuid, Session>>>,
|
||||
pub(crate) sessions_by_pubkey: Arc<RwLock<HashMap<Vec<u8>, HashSet<Uuid>>>>,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
/// Create a new SessionState
|
||||
pub fn new(timeout: Duration) -> Self {
|
||||
Self {
|
||||
sessions: RwLock::new(HashMapDelay::new(timeout)).into(),
|
||||
sessions_by_pubkey: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new() -> Result<SharedState, Box<dyn std::error::Error>> {
|
||||
let state = Self {
|
||||
sessions: Default::default(),
|
||||
let state = Arc::new(Self {
|
||||
sessions: SessionState::new(SESSION_TIMEOUT),
|
||||
challenges: RwLock::new(HashSetDelay::new(CHALLENGE_TIMEOUT)).into(),
|
||||
access_tokens: RwLock::new(HashMapDelay::new(ACCESS_TOKEN_TIMEOUT)).into(),
|
||||
};
|
||||
Ok(Arc::new(state))
|
||||
});
|
||||
|
||||
// In order to effectively removed timed out entries, we need to
|
||||
// repeatedly call `next()` on them.
|
||||
// These tasks will just run forever and will stop when the server stops.
|
||||
|
||||
let state_clone = state.clone();
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
match RwLockStream(&state_clone.sessions.sessions).next().await {
|
||||
Some(Ok((uuid, session))) => {
|
||||
tracing::debug!("session {} timed out", uuid);
|
||||
let mut sessions_by_pubkey =
|
||||
state_clone.sessions.sessions_by_pubkey.write().unwrap();
|
||||
for pubkey in session.pubkeys {
|
||||
if let Some(sessions) = sessions_by_pubkey.get_mut(&pubkey) {
|
||||
sessions.remove(&uuid);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Annoyingly, if the map is empty, it returns
|
||||
// immediately instead of waiting for an entry to be
|
||||
// inserted and waiting for that to timeout. To avoid a
|
||||
// busy loop when the map is empty, we sleep for a bit.
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
// TODO: we could refactor these two loops with a generic function
|
||||
// but it's just simpler to do this directly currently
|
||||
let state_clone = state.clone();
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
match RwLockStream(&state_clone.challenges).next().await {
|
||||
Some(Ok(challenge)) => {
|
||||
tracing::debug!("challenge {} timed out", challenge);
|
||||
}
|
||||
_ => {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let state_clone = state.clone();
|
||||
tokio::task::spawn(async move {
|
||||
loop {
|
||||
match RwLockStream(&state_clone.access_tokens).next().await {
|
||||
Some(Ok((access_token, _pubkey))) => {
|
||||
tracing::debug!("access_token {} timed out", access_token);
|
||||
}
|
||||
_ => {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(state)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue