Add user registration (#252)
* add user registration * addint support to coordinator and participant * Update participant/src/comms/http.rs Co-authored-by: Pili Guerra <mpguerra@users.noreply.github.com> --------- Co-authored-by: Pili Guerra <mpguerra@users.noreply.github.com>
This commit is contained in:
parent
4503c10790
commit
4c6e860d69
File diff suppressed because it is too large
Load Diff
|
@ -17,6 +17,21 @@ pub struct Args {
|
|||
#[arg(long, default_value_t = false)]
|
||||
pub http: bool,
|
||||
|
||||
/// The username to use in HTTP mode.
|
||||
#[arg(short = 'u', long, default_value = "")]
|
||||
pub username: String,
|
||||
|
||||
/// The password to use in HTTP mode. If specified, it will be read from the
|
||||
/// environment variable with the given name.
|
||||
#[arg(short = 'w', long, default_value = "")]
|
||||
pub password: String,
|
||||
|
||||
/// The comma-separated usernames of the signers to use in HTTP mode.
|
||||
/// If HTTP mode is enabled and this is empty, then the session ID
|
||||
/// will be printed and will have to be shared manually.
|
||||
#[arg(short = 'S', long, value_delimiter = ',')]
|
||||
pub signers: Vec<String>,
|
||||
|
||||
/// The number of participants. If 0, will prompt for a value.
|
||||
#[arg(short = 'n', long, default_value_t = 0)]
|
||||
pub num_signers: u16,
|
||||
|
|
|
@ -22,7 +22,7 @@ pub async fn cli<C: RandomizedCiphersuite + 'static>(
|
|||
let mut comms: Box<dyn Comms<C>> = if args.cli {
|
||||
Box::new(CLIComms::new())
|
||||
} else if args.http {
|
||||
Box::new(HTTPComms::new(args))
|
||||
Box::new(HTTPComms::new(args)?)
|
||||
} else {
|
||||
Box::new(SocketComms::new(args))
|
||||
};
|
||||
|
|
|
@ -16,10 +16,12 @@ use frost::{
|
|||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
env,
|
||||
error::Error,
|
||||
io::{BufRead, Write},
|
||||
marker::PhantomData,
|
||||
time::Duration,
|
||||
vec,
|
||||
};
|
||||
|
||||
use super::Comms;
|
||||
|
@ -29,18 +31,27 @@ pub struct HTTPComms<C: Ciphersuite> {
|
|||
client: reqwest::Client,
|
||||
host_port: String,
|
||||
session_id: Option<Uuid>,
|
||||
username: String,
|
||||
password: String,
|
||||
access_token: String,
|
||||
signers: Vec<String>,
|
||||
_phantom: PhantomData<C>,
|
||||
}
|
||||
|
||||
impl<C: Ciphersuite> HTTPComms<C> {
|
||||
pub fn new(args: &Args) -> Self {
|
||||
pub fn new(args: &Args) -> Result<Self, Box<dyn Error>> {
|
||||
let client = reqwest::Client::new();
|
||||
Self {
|
||||
let password = env::var(&args.password).map_err(|_| eyre!("The password argument must specify the name of a environment variable containing the password"))?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host_port: format!("http://{}:{}", args.ip, args.port),
|
||||
session_id: None,
|
||||
username: args.username.clone(),
|
||||
password,
|
||||
access_token: String::new(),
|
||||
signers: args.signers.clone(),
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -53,10 +64,26 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
|||
_pub_key_package: &PublicKeyPackage<C>,
|
||||
num_signers: u16,
|
||||
) -> Result<BTreeMap<Identifier<C>, SigningCommitments<C>>, Box<dyn Error>> {
|
||||
self.access_token = self
|
||||
.client
|
||||
.post(format!("{}/login", self.host_port))
|
||||
.json(&server::LoginArgs {
|
||||
username: self.username.clone(),
|
||||
password: self.password.clone(),
|
||||
})
|
||||
.send()
|
||||
.await?
|
||||
.json::<server::LoginOutput>()
|
||||
.await?
|
||||
.access_token
|
||||
.to_string();
|
||||
|
||||
let r = self
|
||||
.client
|
||||
.post(format!("{}/create_new_session", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::CreateNewSessionArgs {
|
||||
usernames: self.signers.clone(),
|
||||
num_signers,
|
||||
message_count: 1,
|
||||
})
|
||||
|
@ -65,10 +92,12 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
|||
.json::<server::CreateNewSessionOutput>()
|
||||
.await?;
|
||||
|
||||
eprintln!(
|
||||
"Send the following session ID to participants: {}",
|
||||
r.session_id
|
||||
);
|
||||
if self.signers.is_empty() {
|
||||
eprintln!(
|
||||
"Send the following session ID to participants: {}",
|
||||
r.session_id
|
||||
);
|
||||
}
|
||||
self.session_id = Some(r.session_id);
|
||||
eprint!("Waiting for participants to send their commitments...");
|
||||
|
||||
|
@ -76,6 +105,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
|||
let r = self
|
||||
.client
|
||||
.post(format!("{}/get_commitments", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::GetCommitmentsArgs {
|
||||
session_id: r.session_id,
|
||||
})
|
||||
|
@ -114,6 +144,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
|||
let _r = self
|
||||
.client
|
||||
.post(format!("{}/send_signing_package", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::SendSigningPackageArgs {
|
||||
aux_msg: Default::default(),
|
||||
session_id: self.session_id.unwrap(),
|
||||
|
@ -131,6 +162,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
|||
let r = self
|
||||
.client
|
||||
.post(format!("{}/get_signature_shares", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::GetSignatureSharesArgs {
|
||||
session_id: self.session_id.unwrap(),
|
||||
})
|
||||
|
@ -145,6 +177,23 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
|||
};
|
||||
eprintln!();
|
||||
|
||||
let _r = self
|
||||
.client
|
||||
.post(format!("{}/close_session", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::CloseSessionArgs {
|
||||
session_id: self.session_id.unwrap(),
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let _r = self
|
||||
.client
|
||||
.post(format!("{}/logout", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let signature_shares = r
|
||||
.signature_shares
|
||||
.first()
|
||||
|
|
|
@ -48,7 +48,9 @@ async fn read_commitments<C: Ciphersuite>(
|
|||
|
||||
let pub_key_package: PublicKeyPackage<C> = serde_json::from_str(&out)?;
|
||||
|
||||
let num_of_participants = if args.num_signers == 0 {
|
||||
let num_of_participants = if !args.signers.is_empty() {
|
||||
args.signers.len() as u16
|
||||
} else if args.num_signers == 0 {
|
||||
writeln!(logger, "The number of participants: ")?;
|
||||
|
||||
let mut participants = String::new();
|
||||
|
|
|
@ -17,6 +17,15 @@ pub struct Args {
|
|||
#[arg(long, default_value_t = false)]
|
||||
pub http: bool,
|
||||
|
||||
/// The username to use in HTTP mode.
|
||||
#[arg(short = 'u', long, default_value = "")]
|
||||
pub username: String,
|
||||
|
||||
/// The password to use in HTTP mode. If specified, it will be read from the
|
||||
/// environment variable with the given name.
|
||||
#[arg(short = 'w', long, default_value = "")]
|
||||
pub password: String,
|
||||
|
||||
/// Public key package to use. Can be a file with a JSON-encoded
|
||||
/// package, or "". If the file does not exist or if "" is specified,
|
||||
/// then it will be read from standard input.
|
||||
|
|
|
@ -20,7 +20,7 @@ pub async fn cli<C: RandomizedCiphersuite + 'static>(
|
|||
let mut comms: Box<dyn Comms<C>> = if args.cli {
|
||||
Box::new(CLIComms::new())
|
||||
} else if args.http {
|
||||
Box::new(HTTPComms::new(args))
|
||||
Box::new(HTTPComms::new(args)?)
|
||||
} else {
|
||||
Box::new(SocketComms::new(args))
|
||||
};
|
||||
|
|
|
@ -10,6 +10,7 @@ use frost::{round1::SigningCommitments, round2::SignatureShare, Identifier};
|
|||
|
||||
use super::Comms;
|
||||
|
||||
use std::env;
|
||||
use std::io::{BufRead, Write};
|
||||
|
||||
use std::error::Error;
|
||||
|
@ -22,7 +23,10 @@ use crate::args::Args;
|
|||
pub struct HTTPComms<C: Ciphersuite> {
|
||||
client: reqwest::Client,
|
||||
host_port: String,
|
||||
session_id: Uuid,
|
||||
session_id: Option<Uuid>,
|
||||
username: String,
|
||||
password: String,
|
||||
access_token: String,
|
||||
_phantom: PhantomData<C>,
|
||||
}
|
||||
|
||||
|
@ -33,14 +37,18 @@ impl<C> HTTPComms<C>
|
|||
where
|
||||
C: Ciphersuite,
|
||||
{
|
||||
pub fn new(args: &Args) -> Self {
|
||||
pub fn new(args: &Args) -> Result<Self, Box<dyn Error>> {
|
||||
let client = reqwest::Client::new();
|
||||
Self {
|
||||
let password = env::var(&args.password).map_err(|_| eyre!("The password argument must specify the name of a environment variable containing the password"))?;
|
||||
Ok(Self {
|
||||
client,
|
||||
host_port: format!("http://{}:{}", args.ip, args.port),
|
||||
session_id: Uuid::parse_str(&args.session_id).expect("invalid session id"),
|
||||
session_id: Uuid::parse_str(&args.session_id).ok(),
|
||||
username: args.username.clone(),
|
||||
password,
|
||||
access_token: String::new(),
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -63,11 +71,48 @@ where
|
|||
),
|
||||
Box<dyn Error>,
|
||||
> {
|
||||
self.access_token = self
|
||||
.client
|
||||
.post(format!("{}/login", self.host_port))
|
||||
.json(&server::LoginArgs {
|
||||
username: self.username.clone(),
|
||||
password: self.password.clone(),
|
||||
})
|
||||
.send()
|
||||
.await?
|
||||
.json::<server::LoginOutput>()
|
||||
.await?
|
||||
.access_token
|
||||
.to_string();
|
||||
|
||||
let session_id = match self.session_id {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
// Get session ID from server
|
||||
let r = self
|
||||
.client
|
||||
.post(format!("{}/list_sessions", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.send()
|
||||
.await?
|
||||
.json::<server::ListSessionsOutput>()
|
||||
.await?;
|
||||
if r.session_ids.len() > 1 {
|
||||
return Err(eyre!("user has more than one FROST session active, which is still not supported by this tool").into());
|
||||
} else if r.session_ids.is_empty() {
|
||||
return Err(eyre!("User has no current sessions active. The Coordinator should either specify your username, or manually share the session ID which you can specify with --session_id").into());
|
||||
}
|
||||
r.session_ids[0]
|
||||
}
|
||||
};
|
||||
self.session_id = Some(session_id);
|
||||
|
||||
// Send Commitments to Server
|
||||
self.client
|
||||
.post(format!("{}/send_commitments", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::SendCommitmentsArgs {
|
||||
session_id: self.session_id,
|
||||
session_id,
|
||||
identifier: identifier.into(),
|
||||
commitments: vec![(&commitments).try_into()?],
|
||||
})
|
||||
|
@ -82,9 +127,8 @@ where
|
|||
let r = self
|
||||
.client
|
||||
.post(format!("{}/get_signing_package", self.host_port))
|
||||
.json(&server::GetSigningPackageArgs {
|
||||
session_id: self.session_id,
|
||||
})
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::GetSigningPackageArgs { session_id })
|
||||
.send()
|
||||
.await?;
|
||||
if r.status() != 200 {
|
||||
|
@ -126,14 +170,22 @@ where
|
|||
let _r = self
|
||||
.client
|
||||
.post(format!("{}/send_signature_share", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.json(&server::SendSignatureShareArgs {
|
||||
identifier: identifier.into(),
|
||||
session_id: self.session_id,
|
||||
session_id: self.session_id.unwrap(),
|
||||
signature_share: vec![signature_share.into()],
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let _r = self
|
||||
.client
|
||||
.post(format!("{}/logout", self.host_port))
|
||||
.bearer_auth(&self.access_token)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,6 +46,8 @@ async fn check_valid_round_1_inputs() {
|
|||
port: 80,
|
||||
session_id: "session-id".to_string(),
|
||||
http: false,
|
||||
username: "".to_string(),
|
||||
password: "".to_string(),
|
||||
};
|
||||
let input = SECRET_SHARE_JSON;
|
||||
let mut valid_input = input.as_bytes();
|
||||
|
|
|
@ -7,15 +7,19 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
axum = "0.7.5"
|
||||
axum-extra = { version = "0.9.3", features = ["typed-header"] }
|
||||
axum-macros = "0.4.1"
|
||||
clap = { version = "4.5.13", features = ["derive"] }
|
||||
derivative = "2.2.0"
|
||||
eyre = "0.6.11"
|
||||
frost-core = { version = "2.0.0-rc.0", features = ["serde"] }
|
||||
frost-rerandomized = { version = "2.0.0-rc.0", features = ["serde"] }
|
||||
password-auth = "1.0.0"
|
||||
rand = "0.8"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serdect = { version = "0.2.0" }
|
||||
serde_json = "1.0.122"
|
||||
sqlx = { version = "0.7.3", features = ["sqlite", "time", "runtime-tokio", "uuid"] }
|
||||
tokio = { version = "1.38", features = ["full"] }
|
||||
tower-http = { version = "0.5.2", features = ["trace"] }
|
||||
tracing = "0.1"
|
||||
|
@ -23,7 +27,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
|||
uuid = { version = "1.10.0", features = ["v4", "fast-rng", "serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
axum-test = "14.10.0"
|
||||
axum-test = "15.2.0"
|
||||
frost-ed25519 = { version = "2.0.0-rc.0", features = ["serde"] }
|
||||
reddsa = { git = "https://github.com/ZcashFoundation/reddsa.git", rev = "4d8c4bb337231e6e89117334d7c61dada589a953", features = [
|
||||
"frost",
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
// generated by `sqlx migrate build-script`
|
||||
fn main() {
|
||||
// trigger recompilation when a new migration is added
|
||||
println!("cargo:rerun-if-changed=migrations");
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
-- Add down migration script here
|
||||
drop table if exists users;
|
||||
drop table if exists access_tokens;
|
|
@ -0,0 +1,16 @@
|
|||
-- Create users table.
|
||||
create table if not exists users
|
||||
(
|
||||
id integer primary key not null,
|
||||
username text not null unique,
|
||||
password text not null,
|
||||
pubkey blob not null
|
||||
);
|
||||
|
||||
create table if not exists access_tokens
|
||||
(
|
||||
id integer primary key not null,
|
||||
user_id integer not null,
|
||||
access_token blob not null,
|
||||
foreign key(user_id) references users(id) on delete cascade
|
||||
);
|
|
@ -10,4 +10,8 @@ pub struct Args {
|
|||
/// Port to bind to
|
||||
#[arg(short, long, default_value_t = 2744)]
|
||||
pub port: u16,
|
||||
|
||||
/// Database to use.
|
||||
#[arg(short, long, default_value = "db.sqlite")]
|
||||
pub database: String,
|
||||
}
|
||||
|
|
|
@ -1,32 +1,174 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use axum::{extract::State, http::StatusCode, Json};
|
||||
|
||||
use eyre::eyre;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
state::{Session, SessionState, SharedState},
|
||||
types::*,
|
||||
user::{
|
||||
add_access_token, authenticate_user, create_user, delete_user, get_user,
|
||||
remove_access_token, User,
|
||||
},
|
||||
AppError,
|
||||
};
|
||||
|
||||
/// Implement the register API.
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,args), fields(args.username = %args.username))]
|
||||
pub(crate) async fn register(
|
||||
State(state): State<SharedState>,
|
||||
Json(args): Json<RegisterArgs>,
|
||||
) -> Result<Json<()>, AppError> {
|
||||
let username = args.username.trim();
|
||||
let password = args.password.trim();
|
||||
|
||||
if username.is_empty() || password.is_empty() {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("empty args").into(),
|
||||
));
|
||||
}
|
||||
|
||||
let db = {
|
||||
let state_lock = state.read().unwrap();
|
||||
state_lock.db.clone()
|
||||
};
|
||||
|
||||
create_user(db, username, password, args.pubkey)
|
||||
.await
|
||||
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
Ok(Json(()))
|
||||
}
|
||||
|
||||
/// Implement the login API.
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,args), fields(args.username = %args.username))]
|
||||
pub(crate) async fn login(
|
||||
State(state): State<SharedState>,
|
||||
Json(args): Json<LoginArgs>,
|
||||
) -> Result<Json<LoginOutput>, AppError> {
|
||||
// Check if the user sent the credentials
|
||||
if args.username.is_empty() || args.password.is_empty() {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("empty args").into(),
|
||||
));
|
||||
}
|
||||
|
||||
let db = {
|
||||
let state_lock = state.read().unwrap();
|
||||
state_lock.db.clone()
|
||||
};
|
||||
|
||||
let user = authenticate_user(db.clone(), &args.username, &args.password)
|
||||
.await
|
||||
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
let user = match user {
|
||||
Some(user) => user,
|
||||
None => {
|
||||
return Err(AppError(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
eyre!("invalid user or password").into(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let access_token = add_access_token(db.clone(), user.id)
|
||||
.await
|
||||
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
let token = LoginOutput { access_token };
|
||||
|
||||
Ok(Json(token))
|
||||
}
|
||||
|
||||
/// Implement the logout API.
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn logout(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
) -> Result<Json<()>, AppError> {
|
||||
let db = {
|
||||
let state_lock = state.read().unwrap();
|
||||
state_lock.db.clone()
|
||||
};
|
||||
|
||||
remove_access_token(
|
||||
db.clone(),
|
||||
user.current_token
|
||||
.expect("user is logged in so they must have a token"),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
Ok(Json(()))
|
||||
}
|
||||
|
||||
/// Implement the unregister API.
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn unregister(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
) -> Result<Json<()>, AppError> {
|
||||
let db = {
|
||||
let state_lock = state.read().unwrap();
|
||||
state_lock.db.clone()
|
||||
};
|
||||
|
||||
delete_user(db, user.id)
|
||||
.await
|
||||
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
Ok(Json(()))
|
||||
}
|
||||
|
||||
/// Implement the create_new_session API.
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn create_new_session(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<CreateNewSessionArgs>,
|
||||
) -> Result<Json<CreateNewSessionOutput>, AppError> {
|
||||
tracing::info!("create_new_session");
|
||||
if args.message_count == 0 {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("invalid message_count"),
|
||||
eyre!("invalid message_count").into(),
|
||||
));
|
||||
}
|
||||
let db = {
|
||||
let state_lock = state.read().unwrap();
|
||||
state_lock.db.clone()
|
||||
};
|
||||
for username in &args.usernames {
|
||||
if get_user(db.clone(), username)
|
||||
.await
|
||||
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?
|
||||
.is_none()
|
||||
{
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("invalid user").into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
// Create new session object.
|
||||
let id = Uuid::new_v4();
|
||||
|
||||
let mut state = state.write().unwrap();
|
||||
|
||||
// Save session ID in global state
|
||||
for username in &args.usernames {
|
||||
state
|
||||
.sessions_by_username
|
||||
.entry(username.to_string())
|
||||
.or_default()
|
||||
.insert(id);
|
||||
}
|
||||
// Create Session object
|
||||
let session = Session {
|
||||
usernames: args.usernames,
|
||||
num_signers: args.num_signers,
|
||||
message_count: args.message_count,
|
||||
state: SessionState::WaitingForCommitments {
|
||||
|
@ -34,22 +176,41 @@ pub(crate) async fn create_new_session(
|
|||
},
|
||||
};
|
||||
// Save session into global state.
|
||||
state.write().unwrap().sessions.insert(id, session);
|
||||
state.sessions.insert(id, session);
|
||||
|
||||
let user = CreateNewSessionOutput { session_id: id };
|
||||
Ok(Json(user))
|
||||
}
|
||||
|
||||
/// Implement the create_new_session API.
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn list_sessions(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
) -> Result<Json<ListSessionsOutput>, AppError> {
|
||||
let state = state.read().unwrap();
|
||||
|
||||
let session_ids = state
|
||||
.sessions_by_username
|
||||
.get(&user.username)
|
||||
.map(|s| s.iter().cloned().collect())
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(Json(ListSessionsOutput { session_ids }))
|
||||
}
|
||||
|
||||
/// Implement the get_session_info API
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn get_session_info(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<GetSessionInfoArgs>,
|
||||
) -> Result<Json<GetSessionInfoOutput>, AppError> {
|
||||
let state_lock = state.read().unwrap();
|
||||
|
||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found"),
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
||||
Ok(Json(GetSessionInfoOutput {
|
||||
|
@ -60,9 +221,10 @@ pub(crate) async fn get_session_info(
|
|||
|
||||
/// Implement the send_commitments API
|
||||
// TODO: get identifier from channel rather from arguments
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn send_commitments(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<SendCommitmentsArgs>,
|
||||
) -> Result<(), AppError> {
|
||||
// Get the mutex lock to read and write from the state
|
||||
|
@ -73,7 +235,7 @@ pub(crate) async fn send_commitments(
|
|||
.get_mut(&args.session_id)
|
||||
.ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found"),
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
||||
match &mut session.state {
|
||||
|
@ -81,7 +243,7 @@ pub(crate) async fn send_commitments(
|
|||
if args.commitments.len() != session.message_count as usize {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("wrong number of commitments"),
|
||||
eyre!("wrong number of commitments").into(),
|
||||
));
|
||||
}
|
||||
// Add commitment to map.
|
||||
|
@ -89,6 +251,11 @@ pub(crate) async fn send_commitments(
|
|||
// (it seems better to ignore overwrites, which could be caused by
|
||||
// poor networking connectivity leading to retries)
|
||||
commitments.insert(args.identifier, args.commitments);
|
||||
tracing::debug!(
|
||||
"added commitments, currently {}/{}",
|
||||
commitments.len(),
|
||||
session.num_signers
|
||||
);
|
||||
// If complete, advance to next state
|
||||
if commitments.len() == session.num_signers as usize {
|
||||
session.state = SessionState::CommitmentsReady {
|
||||
|
@ -99,7 +266,7 @@ pub(crate) async fn send_commitments(
|
|||
_ => {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("incompatible session state"),
|
||||
eyre!("incompatible session state").into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -107,16 +274,17 @@ pub(crate) async fn send_commitments(
|
|||
}
|
||||
|
||||
/// Implement the get_commitments API
|
||||
// #[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn get_commitments(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<GetCommitmentsArgs>,
|
||||
) -> Result<Json<GetCommitmentsOutput>, AppError> {
|
||||
let state_lock = state.read().unwrap();
|
||||
|
||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found"),
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
||||
match &session.state {
|
||||
|
@ -135,15 +303,16 @@ pub(crate) async fn get_commitments(
|
|||
})),
|
||||
_ => Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("incompatible session state"),
|
||||
eyre!("incompatible session state").into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Implement the send_signing_package API
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn send_signing_package(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<SendSigningPackageArgs>,
|
||||
) -> Result<(), AppError> {
|
||||
let mut state_lock = state.write().unwrap();
|
||||
|
@ -153,7 +322,7 @@ pub(crate) async fn send_signing_package(
|
|||
.get_mut(&args.session_id)
|
||||
.ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found"),
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
||||
match &mut session.state {
|
||||
|
@ -161,7 +330,7 @@ pub(crate) async fn send_signing_package(
|
|||
if args.signing_package.len() != session.message_count as usize {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("wrong number of inputs"),
|
||||
eyre!("wrong number of inputs").into(),
|
||||
));
|
||||
}
|
||||
if args.randomizer.len() != session.message_count as usize
|
||||
|
@ -169,7 +338,7 @@ pub(crate) async fn send_signing_package(
|
|||
{
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("wrong number of inputs"),
|
||||
eyre!("wrong number of inputs").into(),
|
||||
));
|
||||
}
|
||||
session.state = SessionState::WaitingForSignatureShares {
|
||||
|
@ -183,7 +352,7 @@ pub(crate) async fn send_signing_package(
|
|||
_ => {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("incompatible session state"),
|
||||
eyre!("incompatible session state").into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -191,16 +360,17 @@ pub(crate) async fn send_signing_package(
|
|||
}
|
||||
|
||||
/// Implement the get_signing_package API
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn get_signing_package(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<GetSigningPackageArgs>,
|
||||
) -> Result<Json<GetSigningPackageOutput>, AppError> {
|
||||
let state_lock = state.read().unwrap();
|
||||
|
||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found"),
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
||||
match &session.state {
|
||||
|
@ -217,16 +387,17 @@ pub(crate) async fn get_signing_package(
|
|||
})),
|
||||
_ => Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("incompatible session state"),
|
||||
eyre!("incompatible session state").into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Implement the send_signature_share API
|
||||
// TODO: get identifier from channel rather from arguments
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn send_signature_share(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<SendSignatureShareArgs>,
|
||||
) -> Result<(), AppError> {
|
||||
let mut state_lock = state.write().unwrap();
|
||||
|
@ -236,7 +407,7 @@ pub(crate) async fn send_signature_share(
|
|||
.get_mut(&args.session_id)
|
||||
.ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found"),
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
||||
match &mut session.state {
|
||||
|
@ -248,12 +419,15 @@ pub(crate) async fn send_signature_share(
|
|||
aux_msg: _,
|
||||
} => {
|
||||
if !identifiers.contains(&args.identifier) {
|
||||
return Err(AppError(StatusCode::NOT_FOUND, eyre!("invalid identifier")));
|
||||
return Err(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("invalid identifier").into(),
|
||||
));
|
||||
}
|
||||
if args.signature_share.len() != session.message_count as usize {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("wrong number of signature shares"),
|
||||
eyre!("wrong number of signature shares").into(),
|
||||
));
|
||||
}
|
||||
// Currently ignoring the possibility of overwriting previous values
|
||||
|
@ -270,7 +444,7 @@ pub(crate) async fn send_signature_share(
|
|||
_ => {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("incompatible session state"),
|
||||
eyre!("incompatible session state").into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -278,16 +452,17 @@ pub(crate) async fn send_signature_share(
|
|||
}
|
||||
|
||||
/// Implement the get_signature_shares API
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn get_signature_shares(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<GetSignatureSharesArgs>,
|
||||
) -> Result<Json<GetSignatureSharesOutput>, AppError> {
|
||||
let state_lock = state.read().unwrap();
|
||||
|
||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||
StatusCode::NOT_FOUND,
|
||||
eyre!("session ID not found"),
|
||||
eyre!("session ID not found").into(),
|
||||
))?;
|
||||
|
||||
match &session.state {
|
||||
|
@ -308,17 +483,34 @@ pub(crate) async fn get_signature_shares(
|
|||
}
|
||||
_ => Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("incompatible session state"),
|
||||
eyre!("incompatible session state").into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Implement the close_session API.
|
||||
#[tracing::instrument(ret, err(Debug))]
|
||||
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||
pub(crate) async fn close_session(
|
||||
State(state): State<SharedState>,
|
||||
user: User,
|
||||
Json(args): Json<CloseSessionArgs>,
|
||||
) -> Result<Json<()>, AppError> {
|
||||
state.write().unwrap().sessions.remove(&args.session_id);
|
||||
let mut state = state.write().unwrap();
|
||||
|
||||
for username in state
|
||||
.sessions
|
||||
.get(&args.session_id)
|
||||
.ok_or(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("invalid session ID").into(),
|
||||
))?
|
||||
.usernames
|
||||
.clone()
|
||||
{
|
||||
if let Some(v) = state.sessions_by_username.get_mut(&username) {
|
||||
v.remove(&args.session_id);
|
||||
}
|
||||
}
|
||||
state.sessions.remove(&args.session_id);
|
||||
Ok(Json(()))
|
||||
}
|
||||
|
|
|
@ -2,6 +2,9 @@ pub mod args;
|
|||
mod functions;
|
||||
mod state;
|
||||
mod types;
|
||||
mod user;
|
||||
|
||||
pub use state::{AppState, SharedState};
|
||||
use tower_http::trace::TraceLayer;
|
||||
pub use types::*;
|
||||
|
||||
|
@ -16,11 +19,15 @@ use axum::{
|
|||
/// Create the axum Router for the server.
|
||||
/// Maps specific endpoints to handler functions.
|
||||
// TODO: use methods of a single object instead of separate functions?
|
||||
pub fn router() -> Router {
|
||||
pub fn router(shared_state: SharedState) -> Router {
|
||||
// Shared state that is passed to each handler by axum
|
||||
let shared_state = state::SharedState::default();
|
||||
Router::new()
|
||||
.route("/register", post(functions::register))
|
||||
.route("/login", post(functions::login))
|
||||
.route("/logout", post(functions::logout))
|
||||
.route("/unregister", post(functions::unregister))
|
||||
.route("/create_new_session", post(functions::create_new_session))
|
||||
.route("/list_sessions", post(functions::list_sessions))
|
||||
.route("/get_session_info", post(functions::get_session_info))
|
||||
.route("/send_commitments", post(functions::send_commitments))
|
||||
.route("/get_commitments", post(functions::get_commitments))
|
||||
|
@ -44,7 +51,8 @@ pub fn router() -> Router {
|
|||
|
||||
/// Run the server with the specified arguments.
|
||||
pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let app = router();
|
||||
let shared_state = AppState::new(&args.database).await?;
|
||||
let app = router(shared_state);
|
||||
|
||||
let addr = format!("{}:{}", args.ip, args.port);
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
@ -55,7 +63,7 @@ pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
|
|||
/// error happens during a API call, and a generic eyre::Report.
|
||||
// TODO: create an enum with specific errors
|
||||
#[derive(Debug)]
|
||||
pub struct AppError(StatusCode, eyre::Report);
|
||||
pub struct AppError(StatusCode, Box<dyn std::error::Error>);
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
str::FromStr,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
|
||||
use sqlx::{
|
||||
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
|
||||
SqlitePool,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
|
@ -67,6 +72,8 @@ impl Default for SessionState {
|
|||
/// A particular signing session.
|
||||
#[derive(Debug)]
|
||||
pub struct Session {
|
||||
/// The usernames of the participants
|
||||
pub(crate) usernames: Vec<String>,
|
||||
/// The number of signers in the session.
|
||||
pub(crate) num_signers: u16,
|
||||
/// The set of identifiers for the session.
|
||||
|
@ -78,10 +85,27 @@ pub struct Session {
|
|||
}
|
||||
|
||||
/// The global state of the server.
|
||||
#[derive(Default, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
/// Mapping of signing sessions by UUID.
|
||||
pub(crate) sessions: HashMap<Uuid, Session>,
|
||||
pub(crate) sessions_by_username: HashMap<String, HashSet<Uuid>>,
|
||||
pub(crate) db: SqlitePool,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub async fn new(database: &str) -> Result<SharedState, Box<dyn std::error::Error>> {
|
||||
tracing::event!(tracing::Level::INFO, "opening database {}", database);
|
||||
let options = SqliteConnectOptions::from_str(database)?.create_if_missing(true);
|
||||
let db = SqlitePoolOptions::new().connect_with(options).await?;
|
||||
sqlx::migrate!().run(&db).await?;
|
||||
let state = Self {
|
||||
sessions: Default::default(),
|
||||
sessions_by_username: Default::default(),
|
||||
db,
|
||||
};
|
||||
Ok(Arc::new(RwLock::new(state)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for the global state under a reference-counted RW mutex,
|
||||
|
|
|
@ -147,8 +147,31 @@ impl<C: frost_core::Ciphersuite> TryFrom<&SerializedSignatureShare>
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RegisterArgs {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
#[serde(
|
||||
serialize_with = "serdect::slice::serialize_hex_lower_or_bin",
|
||||
deserialize_with = "serdect::slice::deserialize_hex_or_bin_vec"
|
||||
)]
|
||||
pub pubkey: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginOutput {
|
||||
pub access_token: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct LoginArgs {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CreateNewSessionArgs {
|
||||
pub usernames: Vec<String>,
|
||||
pub num_signers: u16,
|
||||
pub message_count: u8,
|
||||
}
|
||||
|
@ -158,6 +181,11 @@ pub struct CreateNewSessionOutput {
|
|||
pub session_id: Uuid,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ListSessionsOutput {
|
||||
pub session_ids: Vec<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct GetSessionInfoArgs {
|
||||
pub session_id: Uuid,
|
||||
|
|
|
@ -0,0 +1,237 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use axum::{
|
||||
async_trait,
|
||||
extract::FromRequestParts,
|
||||
http::{request::Parts, StatusCode},
|
||||
RequestPartsExt,
|
||||
};
|
||||
use axum_extra::{
|
||||
headers::{authorization::Bearer, Authorization},
|
||||
TypedHeader,
|
||||
};
|
||||
use eyre::eyre;
|
||||
use sqlx::{FromRow, SqlitePool};
|
||||
use tokio::task;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{state::SharedState, AppError};
|
||||
|
||||
/// An User, as stored in the database.
|
||||
#[derive(Debug, FromRow)]
|
||||
#[allow(dead_code)]
|
||||
pub struct User {
|
||||
pub(crate) id: i64,
|
||||
pub(crate) username: String,
|
||||
pub(crate) password: String,
|
||||
pub(crate) pubkey: Vec<u8>,
|
||||
#[sqlx(skip)]
|
||||
pub(crate) access_tokens: Vec<AccessToken>,
|
||||
#[sqlx(skip)]
|
||||
pub(crate) current_token: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromRow)]
|
||||
#[allow(dead_code)]
|
||||
pub struct AccessToken {
|
||||
pub(crate) id: i64,
|
||||
pub(crate) user_id: i64,
|
||||
pub(crate) access_token: Option<Uuid>,
|
||||
}
|
||||
|
||||
/// Create user in the database.
|
||||
///
|
||||
/// The password is hashed and its hash is written in the DB.
|
||||
pub(crate) async fn create_user(
|
||||
db: SqlitePool,
|
||||
username: &str,
|
||||
password: &str,
|
||||
pubkey: Vec<u8>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// TODO: enforce mininum password length
|
||||
let password = password.to_owned();
|
||||
let pwhash = task::spawn_blocking(|| password_auth::generate_hash(password)).await?;
|
||||
sqlx::query(
|
||||
r#"
|
||||
insert into users (username, password, pubkey)
|
||||
values (?, ?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(username)
|
||||
.bind(pwhash)
|
||||
.bind(pubkey)
|
||||
.execute(&db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get user from database, or None if it's not registered.
|
||||
pub(crate) async fn get_user(
|
||||
db: SqlitePool,
|
||||
username: &str,
|
||||
) -> Result<Option<User>, Box<dyn std::error::Error>> {
|
||||
let user: Option<User> = sqlx::query_as("select * from users where username = ? ")
|
||||
.bind(username)
|
||||
.fetch_optional(&db)
|
||||
.await?;
|
||||
if let Some(mut user) = user {
|
||||
let access_tokens: Vec<AccessToken> =
|
||||
sqlx::query_as("select * from access_tokens where user_id = ?")
|
||||
.bind(user.id)
|
||||
.fetch_all(&db)
|
||||
.await?;
|
||||
user.access_tokens = access_tokens;
|
||||
Ok(Some(user))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete an User from the database, given its database ID.
|
||||
pub(crate) async fn delete_user(db: SqlitePool, id: i64) -> Result<(), Box<dyn std::error::Error>> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
delete from users where id = ?
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.execute(&db)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Authenticate user registered in the database. Returns the User if
|
||||
/// authentication is successful, or None if the username or password is wrong.
|
||||
///
|
||||
/// The given password is hashed and verified against the stored hash.
|
||||
pub(crate) async fn authenticate_user(
|
||||
db: SqlitePool,
|
||||
username: &str,
|
||||
password: &str,
|
||||
) -> Result<Option<User>, Box<dyn std::error::Error>> {
|
||||
let user: Option<User> = get_user(db, username).await?;
|
||||
|
||||
// Verifying the password is blocking and potentially slow, so we'll do so
|
||||
// via `spawn_blocking`.
|
||||
let password = password.to_owned();
|
||||
let r: Result<_, password_auth::VerifyError> = task::spawn_blocking(|| {
|
||||
// We're using password-based authentication--this works by comparing our form
|
||||
// input with an argon2 password hash.
|
||||
Ok(user.filter(|user| password_auth::verify_password(password, &user.password).is_ok()))
|
||||
})
|
||||
.await?;
|
||||
Ok(r?)
|
||||
}
|
||||
|
||||
/// Refreshes the user's access token, identified by its id in the database.
|
||||
///
|
||||
/// Generates a new token and overwrites the old one in the database, if any.
|
||||
pub(crate) async fn add_access_token(
|
||||
db: SqlitePool,
|
||||
id: i64,
|
||||
) -> Result<Uuid, Box<dyn std::error::Error>> {
|
||||
let access_token = Uuid::new_v4();
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
insert into access_tokens (user_id, access_token)
|
||||
values (?, ?)
|
||||
"#,
|
||||
)
|
||||
.bind(id)
|
||||
.bind(access_token)
|
||||
.execute(&db)
|
||||
.await?;
|
||||
|
||||
Ok(access_token)
|
||||
}
|
||||
|
||||
/// Remove a user's access token.
|
||||
pub(crate) async fn remove_access_token(
|
||||
db: SqlitePool,
|
||||
access_token: Uuid,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
delete from access_tokens where access_token = ?
|
||||
"#,
|
||||
)
|
||||
.bind(access_token)
|
||||
.execute(&db)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return the User for a given access token, or None if there is no match.
|
||||
pub(crate) async fn get_user_for_access_token(
|
||||
db: SqlitePool,
|
||||
access_token: Uuid,
|
||||
) -> Result<Option<User>, Box<dyn std::error::Error>> {
|
||||
let user: Option<User> = sqlx::query_as(
|
||||
r#"
|
||||
select * from users inner join access_tokens on users.id = access_tokens.user_id where access_tokens.access_token = ?
|
||||
"#,
|
||||
)
|
||||
.bind(access_token)
|
||||
.fetch_optional(&db)
|
||||
.await?;
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// Read a User from a request. This is used to authenticate users. If any axum
|
||||
/// handler has an User argument, this will be called and the authentication
|
||||
/// will be carried out.
|
||||
#[async_trait]
|
||||
impl FromRequestParts<SharedState> for User {
|
||||
type Rejection = AppError;
|
||||
|
||||
#[tracing::instrument(ret, err(Debug), skip(parts, state))]
|
||||
// Can be removed after this fix is released
|
||||
// https://github.com/rust-lang/rust-clippy/issues/12281
|
||||
#[allow(clippy::blocks_in_conditions)]
|
||||
async fn from_request_parts(
|
||||
parts: &mut Parts,
|
||||
state: &SharedState,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
// Extract the token from the authorization header
|
||||
let TypedHeader(Authorization(bearer)) = parts
|
||||
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||||
.await
|
||||
.map_err(|_| {
|
||||
AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("Bearer token missing").into(),
|
||||
)
|
||||
})?;
|
||||
// Decode the user data
|
||||
let access_token = Uuid::from_str(bearer.token()).map_err(|_| {
|
||||
AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("invalid access token").into(),
|
||||
)
|
||||
})?;
|
||||
|
||||
let db = {
|
||||
let state_lock = state.read().unwrap();
|
||||
state_lock.db.clone()
|
||||
};
|
||||
|
||||
let user = get_user_for_access_token(db, access_token)
|
||||
.await
|
||||
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||
|
||||
match user {
|
||||
Some(mut user) => {
|
||||
user.current_token = Some(access_token);
|
||||
Ok(user)
|
||||
}
|
||||
None => {
|
||||
return Err(AppError(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
eyre!("user not found").into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,7 +2,7 @@ use std::{collections::BTreeMap, time::Duration};
|
|||
|
||||
use axum_test::TestServer;
|
||||
use rand::thread_rng;
|
||||
use server::{args::Args, router, SerializedSignatureShare, SerializedSigningPackage};
|
||||
use server::{args::Args, router, AppState, SerializedSignatureShare, SerializedSigningPackage};
|
||||
|
||||
use frost_core as frost;
|
||||
|
||||
|
@ -47,14 +47,52 @@ async fn test_main_router<
|
|||
.collect();
|
||||
|
||||
// Instantiate test server using axum_test
|
||||
let router = router();
|
||||
let shared_state = AppState::new(":memory:").await?;
|
||||
let router = router(shared_state);
|
||||
let server = TestServer::new(router)?;
|
||||
|
||||
// Create a dummy user. We make all requests with the same user since
|
||||
// it currently it doesn't really matter who the user is, users are only
|
||||
// used to share session IDs. This will likely change soon.
|
||||
|
||||
let res = server
|
||||
.post("/register")
|
||||
.json(&server::RegisterArgs {
|
||||
username: "alice".to_string(),
|
||||
password: "passw0rd".to_string(),
|
||||
pubkey: vec![],
|
||||
})
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
|
||||
let res = server
|
||||
.post("/register")
|
||||
.json(&server::RegisterArgs {
|
||||
username: "bob".to_string(),
|
||||
password: "passw0rd".to_string(),
|
||||
pubkey: vec![],
|
||||
})
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
|
||||
let res = server
|
||||
.post("/login")
|
||||
.json(&server::LoginArgs {
|
||||
username: "alice".to_string(),
|
||||
password: "passw0rd".to_string(),
|
||||
})
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
let r: server::LoginOutput = res.json();
|
||||
let token = r.access_token;
|
||||
|
||||
// As the coordinator, create a new signing session with all participants,
|
||||
// for 2 messages
|
||||
let res = server
|
||||
.post("/create_new_session")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::CreateNewSessionArgs {
|
||||
usernames: vec!["alice".to_string(), "bob".to_string()],
|
||||
num_signers: 2,
|
||||
message_count: 2,
|
||||
})
|
||||
|
@ -75,6 +113,7 @@ async fn test_main_router<
|
|||
// asking the server).
|
||||
let res = server
|
||||
.post("/get_session_info")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::GetSessionInfoArgs { session_id })
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
|
@ -96,6 +135,7 @@ async fn test_main_router<
|
|||
// Send commitments to server
|
||||
let res = server
|
||||
.post("/send_commitments")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::SendCommitmentsArgs {
|
||||
identifier: (*identifier).into(),
|
||||
session_id,
|
||||
|
@ -110,6 +150,7 @@ async fn test_main_router<
|
|||
// As the coordinator, get the commitments
|
||||
let res = server
|
||||
.post("/get_commitments")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::GetCommitmentsArgs { session_id })
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
|
@ -146,6 +187,7 @@ async fn test_main_router<
|
|||
// As the coordinator, send the SigningPackages to the server
|
||||
let res = server
|
||||
.post("/send_signing_package")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::SendSigningPackageArgs {
|
||||
session_id,
|
||||
signing_package: signing_packages
|
||||
|
@ -173,6 +215,7 @@ async fn test_main_router<
|
|||
// Get SigningPackages
|
||||
let res = server
|
||||
.post("get_signing_package")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::GetSigningPackageArgs { session_id })
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
|
@ -210,6 +253,7 @@ async fn test_main_router<
|
|||
// Send SignatureShares to the server
|
||||
let res = server
|
||||
.post("/send_signature_share")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::SendSignatureShareArgs {
|
||||
identifier: (*identifier).into(),
|
||||
session_id,
|
||||
|
@ -225,6 +269,7 @@ async fn test_main_router<
|
|||
// As the coordinator, get SignatureShares
|
||||
let res = server
|
||||
.post("/get_signature_shares")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::GetSignatureSharesArgs { session_id })
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
|
@ -265,6 +310,7 @@ async fn test_main_router<
|
|||
// Close the session
|
||||
let res = server
|
||||
.post("/close_session")
|
||||
.authorization_bearer(token)
|
||||
.json(&server::CloseSessionArgs { session_id })
|
||||
.await;
|
||||
res.assert_status_ok();
|
||||
|
@ -288,9 +334,12 @@ async fn test_main_router<
|
|||
/// A better example on how to write client code.
|
||||
#[tokio::test]
|
||||
async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
// Spawn server for testing
|
||||
tokio::spawn(async move {
|
||||
server::run(&Args {
|
||||
database: ":memory:".to_string(),
|
||||
ip: "127.0.0.1".to_string(),
|
||||
port: 2744,
|
||||
})
|
||||
|
@ -302,19 +351,67 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// TODO: this could possibly be not enough, use some retry logic instead
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
|
||||
// Call create_new_session
|
||||
// Create a client to make requests
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Call register to create users
|
||||
let r = client
|
||||
.post("http://127.0.0.1:2744/create_new_session")
|
||||
.json(&server::CreateNewSessionArgs {
|
||||
num_signers: 2,
|
||||
message_count: 1,
|
||||
.post("http://127.0.0.1:2744/register")
|
||||
.json(&server::RegisterArgs {
|
||||
username: "alice".to_string(),
|
||||
password: "passw0rd".to_string(),
|
||||
pubkey: vec![],
|
||||
})
|
||||
.send()
|
||||
.await?
|
||||
.json::<server::CreateNewSessionOutput>()
|
||||
.await?;
|
||||
println!("{}", r.session_id);
|
||||
if r.status() != reqwest::StatusCode::OK {
|
||||
panic!("{}", r.text().await?)
|
||||
}
|
||||
let r = client
|
||||
.post("http://127.0.0.1:2744/register")
|
||||
.json(&server::RegisterArgs {
|
||||
username: "bob".to_string(),
|
||||
password: "passw0rd".to_string(),
|
||||
pubkey: vec![],
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
if r.status() != reqwest::StatusCode::OK {
|
||||
panic!("{}", r.text().await?)
|
||||
}
|
||||
|
||||
// Call login to authenticate
|
||||
let r = client
|
||||
.post("http://127.0.0.1:2744/login")
|
||||
.json(&server::LoginArgs {
|
||||
username: "alice".to_string(),
|
||||
password: "passw0rd".to_string(),
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
if r.status() != reqwest::StatusCode::OK {
|
||||
panic!("{}", r.text().await?)
|
||||
}
|
||||
let r = r.json::<server::LoginOutput>().await?;
|
||||
let access_token = r.access_token;
|
||||
|
||||
// Call create_new_session
|
||||
let r = client
|
||||
.post("http://127.0.0.1:2744/create_new_session")
|
||||
.bearer_auth(access_token)
|
||||
.json(&server::CreateNewSessionArgs {
|
||||
usernames: vec!["alice".to_string(), "bob".to_string()],
|
||||
message_count: 1,
|
||||
num_signers: 2,
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
if r.status() != reqwest::StatusCode::OK {
|
||||
panic!("{}", r.text().await?)
|
||||
}
|
||||
let r = r.json::<server::CreateNewSessionOutput>().await?;
|
||||
let session_id = r.session_id;
|
||||
println!("Session ID: {}", session_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue