server: add sessions timeouts (#386)

* server: add sessions timeouts

* clippy fixes
This commit is contained in:
Conrado Gouvea 2024-12-26 16:57:43 -03:00 committed by GitHub
parent 5e860be303
commit f5cb068ed2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 154 additions and 49 deletions

2
Cargo.lock generated
View File

@ -2676,6 +2676,8 @@ dependencies = [
"frost-core",
"frost-ed25519",
"frost-rerandomized",
"futures",
"futures-util",
"hex",
"rand",
"rcgen",

View File

@ -29,6 +29,8 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.11.0", features = ["v4", "fast-rng", "serde"] }
xeddsa = "1.0.2"
futures-util = "0.3.31"
futures = "0.3.31"
hex = "0.4.3"
[dev-dependencies]

View File

@ -106,12 +106,12 @@ pub(crate) async fn create_new_session(
// Create new session object.
let id = Uuid::new_v4();
let mut state = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();
// Save session ID in global state
for pubkey in &args.pubkeys {
state
.sessions_by_pubkey
sessions_by_pubkey
.entry(pubkey.0.clone())
.or_default()
.insert(id);
@ -125,7 +125,7 @@ pub(crate) async fn create_new_session(
queue: Default::default(),
};
// Save session into global state.
state.sessions.insert(id, session);
sessions.insert(id, session);
let user = CreateNewSessionOutput { session_id: id };
Ok(Json(user))
@ -137,10 +137,9 @@ pub(crate) async fn list_sessions(
State(state): State<SharedState>,
user: User,
) -> Result<Json<ListSessionsOutput>, AppError> {
let state = state.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();
let session_ids = state
.sessions_by_pubkey
let session_ids = sessions_by_pubkey
.get(&user.pubkey)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default();
@ -155,24 +154,22 @@ pub(crate) async fn get_session_info(
user: User,
Json(args): Json<GetSessionInfoArgs>,
) -> Result<Json<GetSessionInfoOutput>, AppError> {
let state_lock = state.sessions.read().unwrap();
let sessions = state.sessions.sessions.read().unwrap();
let sessions_by_pubkey = state.sessions.sessions_by_pubkey.read().unwrap();
let sessions = state_lock
.sessions_by_pubkey
.get(&user.pubkey)
.ok_or(AppError(
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
if !sessions.contains(&args.session_id) {
if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
}
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
@ -194,12 +191,11 @@ pub(crate) async fn send(
Json(args): Json<SendArgs>,
) -> Result<(), AppError> {
// Get the mutex lock to read and write from the state
let mut state_lock = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();
let session = state_lock
.sessions
.get_mut(&args.session_id)
.ok_or(AppError(
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
@ -219,6 +215,7 @@ pub(crate) async fn send(
msg: args.msg.clone(),
});
}
sessions.insert(args.session_id, session);
Ok(())
}
@ -232,12 +229,13 @@ pub(crate) async fn receive(
Json(args): Json<ReceiveArgs>,
) -> Result<Json<ReceiveOutput>, AppError> {
// Get the mutex lock to read and write from the state
let mut state_lock = state.sessions.write().unwrap();
let sessions = state.sessions.sessions.read().unwrap();
let session = state_lock
.sessions
.get_mut(&args.session_id)
.ok_or(AppError(
// TODO: change to get_mut and modify in-place, if HashMapDelay ever
// adds support to it. This will also simplify the code since
// we have to do a workaround in order to not renew the timeout if there
// are no messages. See https://github.com/AgeManning/delay_map/issues/26
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
@ -248,7 +246,22 @@ pub(crate) async fn receive(
user.pubkey
};
// If there are no new messages, we don't want to renew the timeout.
// Thus only if there are new messages we drop the read-only lock
// to get the write lock and re-insert the updated session.
let msgs = if session.queue.contains_key(&pubkey) {
drop(sessions);
let mut sessions = state.sessions.sessions.write().unwrap();
let mut session = sessions.remove(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
))?;
let msgs = session.queue.entry(pubkey).or_default().drain(..).collect();
sessions.insert(args.session_id, session);
msgs
} else {
vec![]
};
Ok(Json(ReceiveOutput { msgs }))
}
@ -260,21 +273,22 @@ pub(crate) async fn close_session(
user: User,
Json(args): Json<CloseSessionArgs>,
) -> Result<Json<()>, AppError> {
let mut state = state.sessions.write().unwrap();
let mut sessions = state.sessions.sessions.write().unwrap();
let mut sessions_by_pubkey = state.sessions.sessions_by_pubkey.write().unwrap();
let sessions = state.sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
let user_sessions = sessions_by_pubkey.get(&user.pubkey).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("user is not in any session").into(),
))?;
if !sessions.contains(&args.session_id) {
if !user_sessions.contains(&args.session_id) {
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found").into(),
));
}
let session = state.sessions.get(&args.session_id).ok_or(AppError(
let session = sessions.get(&args.session_id).ok_or(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid session ID").into(),
))?;
@ -287,10 +301,10 @@ pub(crate) async fn close_session(
}
for username in session.pubkeys.clone() {
if let Some(v) = state.sessions_by_pubkey.get_mut(&username) {
if let Some(v) = sessions_by_pubkey.get_mut(&username) {
v.remove(&args.session_id);
}
}
state.sessions.remove(&args.session_id);
sessions.remove(&args.session_id);
Ok(Json(()))
}

View File

@ -1,18 +1,40 @@
use std::{
collections::{HashMap, HashSet, VecDeque},
pin::Pin,
sync::{Arc, RwLock},
task::{Context, Poll},
time::Duration,
};
use delay_map::{HashMapDelay, HashSetDelay};
use futures::{Stream, StreamExt as _};
use uuid::Uuid;
use crate::Msg;
/// How long a session stays open.
const SESSION_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60 * 24);
/// How long a challenge can be replied to.
const CHALLENGE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
/// How long an acesss token lasts.
const ACCESS_TOKEN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60 * 60);
/// Helper struct that allows calling `next()` on a `Stream` behind a `RwLock`
/// (namely a `HashMapDelay` or `HashSetDelay` in our case) without locking
/// the `RwLock` while waiting.
// From https://users.rust-lang.org/t/how-do-i-poll-a-stream-behind-a-rwlock/121787/2
struct RwLockStream<'a, T>(pub &'a RwLock<T>);
impl<T: Stream + Unpin> Stream for RwLockStream<'_, T> {
type Item = T::Item;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<<Self as Stream>::Item>> {
self.0.write().unwrap().poll_next_unpin(cx)
}
}
/// A particular signing session.
#[derive(Debug)]
pub struct Session {
@ -22,8 +44,6 @@ pub struct Session {
pub(crate) coordinator_pubkey: Vec<u8>,
/// The number of signers in the session.
pub(crate) num_signers: u16,
/// The set of identifiers for the session.
// pub(crate) identifiers: BTreeSet<SerializedIdentifier>,
/// The number of messages being simultaneously signed.
pub(crate) message_count: u8,
/// The message queue.
@ -33,7 +53,7 @@ pub struct Session {
/// The global state of the server.
#[derive(Debug)]
pub struct AppState {
pub(crate) sessions: Arc<RwLock<SessionState>>,
pub(crate) sessions: SessionState,
pub(crate) challenges: Arc<RwLock<HashSetDelay<Uuid>>>,
pub(crate) access_tokens: Arc<RwLock<HashMapDelay<Uuid, Vec<u8>>>>,
}
@ -41,18 +61,85 @@ pub struct AppState {
#[derive(Debug, Default)]
pub struct SessionState {
/// Mapping of signing sessions by UUID.
pub(crate) sessions: HashMap<Uuid, Session>,
pub(crate) sessions_by_pubkey: HashMap<Vec<u8>, HashSet<Uuid>>,
pub(crate) sessions: Arc<RwLock<HashMapDelay<Uuid, Session>>>,
pub(crate) sessions_by_pubkey: Arc<RwLock<HashMap<Vec<u8>, HashSet<Uuid>>>>,
}
impl SessionState {
/// Create a new SessionState
pub fn new(timeout: Duration) -> Self {
Self {
sessions: RwLock::new(HashMapDelay::new(timeout)).into(),
sessions_by_pubkey: Default::default(),
}
}
}
impl AppState {
pub async fn new() -> Result<SharedState, Box<dyn std::error::Error>> {
let state = Self {
sessions: Default::default(),
let state = Arc::new(Self {
sessions: SessionState::new(SESSION_TIMEOUT),
challenges: RwLock::new(HashSetDelay::new(CHALLENGE_TIMEOUT)).into(),
access_tokens: RwLock::new(HashMapDelay::new(ACCESS_TOKEN_TIMEOUT)).into(),
};
Ok(Arc::new(state))
});
// In order to effectively removed timed out entries, we need to
// repeatedly call `next()` on them.
// These tasks will just run forever and will stop when the server stops.
let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.sessions.sessions).next().await {
Some(Ok((uuid, session))) => {
tracing::debug!("session {} timed out", uuid);
let mut sessions_by_pubkey =
state_clone.sessions.sessions_by_pubkey.write().unwrap();
for pubkey in session.pubkeys {
if let Some(sessions) = sessions_by_pubkey.get_mut(&pubkey) {
sessions.remove(&uuid);
}
}
}
_ => {
// Annoyingly, if the map is empty, it returns
// immediately instead of waiting for an entry to be
// inserted and waiting for that to timeout. To avoid a
// busy loop when the map is empty, we sleep for a bit.
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
// TODO: we could refactor these two loops with a generic function
// but it's just simpler to do this directly currently
let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.challenges).next().await {
Some(Ok(challenge)) => {
tracing::debug!("challenge {} timed out", challenge);
}
_ => {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
let state_clone = state.clone();
tokio::task::spawn(async move {
loop {
match RwLockStream(&state_clone.access_tokens).next().await {
Some(Ok((access_token, _pubkey))) => {
tracing::debug!("access_token {} timed out", access_token);
}
_ => {
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
});
Ok(state)
}
}