Add user registration (#252)
* add user registration * addint support to coordinator and participant * Update participant/src/comms/http.rs Co-authored-by: Pili Guerra <mpguerra@users.noreply.github.com> --------- Co-authored-by: Pili Guerra <mpguerra@users.noreply.github.com>
This commit is contained in:
parent
4503c10790
commit
4c6e860d69
File diff suppressed because it is too large
Load Diff
|
@ -17,6 +17,21 @@ pub struct Args {
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
pub http: bool,
|
pub http: bool,
|
||||||
|
|
||||||
|
/// The username to use in HTTP mode.
|
||||||
|
#[arg(short = 'u', long, default_value = "")]
|
||||||
|
pub username: String,
|
||||||
|
|
||||||
|
/// The password to use in HTTP mode. If specified, it will be read from the
|
||||||
|
/// environment variable with the given name.
|
||||||
|
#[arg(short = 'w', long, default_value = "")]
|
||||||
|
pub password: String,
|
||||||
|
|
||||||
|
/// The comma-separated usernames of the signers to use in HTTP mode.
|
||||||
|
/// If HTTP mode is enabled and this is empty, then the session ID
|
||||||
|
/// will be printed and will have to be shared manually.
|
||||||
|
#[arg(short = 'S', long, value_delimiter = ',')]
|
||||||
|
pub signers: Vec<String>,
|
||||||
|
|
||||||
/// The number of participants. If 0, will prompt for a value.
|
/// The number of participants. If 0, will prompt for a value.
|
||||||
#[arg(short = 'n', long, default_value_t = 0)]
|
#[arg(short = 'n', long, default_value_t = 0)]
|
||||||
pub num_signers: u16,
|
pub num_signers: u16,
|
||||||
|
|
|
@ -22,7 +22,7 @@ pub async fn cli<C: RandomizedCiphersuite + 'static>(
|
||||||
let mut comms: Box<dyn Comms<C>> = if args.cli {
|
let mut comms: Box<dyn Comms<C>> = if args.cli {
|
||||||
Box::new(CLIComms::new())
|
Box::new(CLIComms::new())
|
||||||
} else if args.http {
|
} else if args.http {
|
||||||
Box::new(HTTPComms::new(args))
|
Box::new(HTTPComms::new(args)?)
|
||||||
} else {
|
} else {
|
||||||
Box::new(SocketComms::new(args))
|
Box::new(SocketComms::new(args))
|
||||||
};
|
};
|
||||||
|
|
|
@ -16,10 +16,12 @@ use frost::{
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::BTreeMap,
|
collections::BTreeMap,
|
||||||
|
env,
|
||||||
error::Error,
|
error::Error,
|
||||||
io::{BufRead, Write},
|
io::{BufRead, Write},
|
||||||
marker::PhantomData,
|
marker::PhantomData,
|
||||||
time::Duration,
|
time::Duration,
|
||||||
|
vec,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::Comms;
|
use super::Comms;
|
||||||
|
@ -29,18 +31,27 @@ pub struct HTTPComms<C: Ciphersuite> {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
host_port: String,
|
host_port: String,
|
||||||
session_id: Option<Uuid>,
|
session_id: Option<Uuid>,
|
||||||
|
username: String,
|
||||||
|
password: String,
|
||||||
|
access_token: String,
|
||||||
|
signers: Vec<String>,
|
||||||
_phantom: PhantomData<C>,
|
_phantom: PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C: Ciphersuite> HTTPComms<C> {
|
impl<C: Ciphersuite> HTTPComms<C> {
|
||||||
pub fn new(args: &Args) -> Self {
|
pub fn new(args: &Args) -> Result<Self, Box<dyn Error>> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
Self {
|
let password = env::var(&args.password).map_err(|_| eyre!("The password argument must specify the name of a environment variable containing the password"))?;
|
||||||
|
Ok(Self {
|
||||||
client,
|
client,
|
||||||
host_port: format!("http://{}:{}", args.ip, args.port),
|
host_port: format!("http://{}:{}", args.ip, args.port),
|
||||||
session_id: None,
|
session_id: None,
|
||||||
|
username: args.username.clone(),
|
||||||
|
password,
|
||||||
|
access_token: String::new(),
|
||||||
|
signers: args.signers.clone(),
|
||||||
_phantom: Default::default(),
|
_phantom: Default::default(),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -53,10 +64,26 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
||||||
_pub_key_package: &PublicKeyPackage<C>,
|
_pub_key_package: &PublicKeyPackage<C>,
|
||||||
num_signers: u16,
|
num_signers: u16,
|
||||||
) -> Result<BTreeMap<Identifier<C>, SigningCommitments<C>>, Box<dyn Error>> {
|
) -> Result<BTreeMap<Identifier<C>, SigningCommitments<C>>, Box<dyn Error>> {
|
||||||
|
self.access_token = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/login", self.host_port))
|
||||||
|
.json(&server::LoginArgs {
|
||||||
|
username: self.username.clone(),
|
||||||
|
password: self.password.clone(),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json::<server::LoginOutput>()
|
||||||
|
.await?
|
||||||
|
.access_token
|
||||||
|
.to_string();
|
||||||
|
|
||||||
let r = self
|
let r = self
|
||||||
.client
|
.client
|
||||||
.post(format!("{}/create_new_session", self.host_port))
|
.post(format!("{}/create_new_session", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
.json(&server::CreateNewSessionArgs {
|
.json(&server::CreateNewSessionArgs {
|
||||||
|
usernames: self.signers.clone(),
|
||||||
num_signers,
|
num_signers,
|
||||||
message_count: 1,
|
message_count: 1,
|
||||||
})
|
})
|
||||||
|
@ -65,10 +92,12 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
||||||
.json::<server::CreateNewSessionOutput>()
|
.json::<server::CreateNewSessionOutput>()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
if self.signers.is_empty() {
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"Send the following session ID to participants: {}",
|
"Send the following session ID to participants: {}",
|
||||||
r.session_id
|
r.session_id
|
||||||
);
|
);
|
||||||
|
}
|
||||||
self.session_id = Some(r.session_id);
|
self.session_id = Some(r.session_id);
|
||||||
eprint!("Waiting for participants to send their commitments...");
|
eprint!("Waiting for participants to send their commitments...");
|
||||||
|
|
||||||
|
@ -76,6 +105,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
||||||
let r = self
|
let r = self
|
||||||
.client
|
.client
|
||||||
.post(format!("{}/get_commitments", self.host_port))
|
.post(format!("{}/get_commitments", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
.json(&server::GetCommitmentsArgs {
|
.json(&server::GetCommitmentsArgs {
|
||||||
session_id: r.session_id,
|
session_id: r.session_id,
|
||||||
})
|
})
|
||||||
|
@ -114,6 +144,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
||||||
let _r = self
|
let _r = self
|
||||||
.client
|
.client
|
||||||
.post(format!("{}/send_signing_package", self.host_port))
|
.post(format!("{}/send_signing_package", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
.json(&server::SendSigningPackageArgs {
|
.json(&server::SendSigningPackageArgs {
|
||||||
aux_msg: Default::default(),
|
aux_msg: Default::default(),
|
||||||
session_id: self.session_id.unwrap(),
|
session_id: self.session_id.unwrap(),
|
||||||
|
@ -131,6 +162,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
||||||
let r = self
|
let r = self
|
||||||
.client
|
.client
|
||||||
.post(format!("{}/get_signature_shares", self.host_port))
|
.post(format!("{}/get_signature_shares", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
.json(&server::GetSignatureSharesArgs {
|
.json(&server::GetSignatureSharesArgs {
|
||||||
session_id: self.session_id.unwrap(),
|
session_id: self.session_id.unwrap(),
|
||||||
})
|
})
|
||||||
|
@ -145,6 +177,23 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
|
||||||
};
|
};
|
||||||
eprintln!();
|
eprintln!();
|
||||||
|
|
||||||
|
let _r = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/close_session", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
|
.json(&server::CloseSessionArgs {
|
||||||
|
session_id: self.session_id.unwrap(),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let _r = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/logout", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
let signature_shares = r
|
let signature_shares = r
|
||||||
.signature_shares
|
.signature_shares
|
||||||
.first()
|
.first()
|
||||||
|
|
|
@ -48,7 +48,9 @@ async fn read_commitments<C: Ciphersuite>(
|
||||||
|
|
||||||
let pub_key_package: PublicKeyPackage<C> = serde_json::from_str(&out)?;
|
let pub_key_package: PublicKeyPackage<C> = serde_json::from_str(&out)?;
|
||||||
|
|
||||||
let num_of_participants = if args.num_signers == 0 {
|
let num_of_participants = if !args.signers.is_empty() {
|
||||||
|
args.signers.len() as u16
|
||||||
|
} else if args.num_signers == 0 {
|
||||||
writeln!(logger, "The number of participants: ")?;
|
writeln!(logger, "The number of participants: ")?;
|
||||||
|
|
||||||
let mut participants = String::new();
|
let mut participants = String::new();
|
||||||
|
|
|
@ -17,6 +17,15 @@ pub struct Args {
|
||||||
#[arg(long, default_value_t = false)]
|
#[arg(long, default_value_t = false)]
|
||||||
pub http: bool,
|
pub http: bool,
|
||||||
|
|
||||||
|
/// The username to use in HTTP mode.
|
||||||
|
#[arg(short = 'u', long, default_value = "")]
|
||||||
|
pub username: String,
|
||||||
|
|
||||||
|
/// The password to use in HTTP mode. If specified, it will be read from the
|
||||||
|
/// environment variable with the given name.
|
||||||
|
#[arg(short = 'w', long, default_value = "")]
|
||||||
|
pub password: String,
|
||||||
|
|
||||||
/// Public key package to use. Can be a file with a JSON-encoded
|
/// Public key package to use. Can be a file with a JSON-encoded
|
||||||
/// package, or "". If the file does not exist or if "" is specified,
|
/// package, or "". If the file does not exist or if "" is specified,
|
||||||
/// then it will be read from standard input.
|
/// then it will be read from standard input.
|
||||||
|
|
|
@ -20,7 +20,7 @@ pub async fn cli<C: RandomizedCiphersuite + 'static>(
|
||||||
let mut comms: Box<dyn Comms<C>> = if args.cli {
|
let mut comms: Box<dyn Comms<C>> = if args.cli {
|
||||||
Box::new(CLIComms::new())
|
Box::new(CLIComms::new())
|
||||||
} else if args.http {
|
} else if args.http {
|
||||||
Box::new(HTTPComms::new(args))
|
Box::new(HTTPComms::new(args)?)
|
||||||
} else {
|
} else {
|
||||||
Box::new(SocketComms::new(args))
|
Box::new(SocketComms::new(args))
|
||||||
};
|
};
|
||||||
|
|
|
@ -10,6 +10,7 @@ use frost::{round1::SigningCommitments, round2::SignatureShare, Identifier};
|
||||||
|
|
||||||
use super::Comms;
|
use super::Comms;
|
||||||
|
|
||||||
|
use std::env;
|
||||||
use std::io::{BufRead, Write};
|
use std::io::{BufRead, Write};
|
||||||
|
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
@ -22,7 +23,10 @@ use crate::args::Args;
|
||||||
pub struct HTTPComms<C: Ciphersuite> {
|
pub struct HTTPComms<C: Ciphersuite> {
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
host_port: String,
|
host_port: String,
|
||||||
session_id: Uuid,
|
session_id: Option<Uuid>,
|
||||||
|
username: String,
|
||||||
|
password: String,
|
||||||
|
access_token: String,
|
||||||
_phantom: PhantomData<C>,
|
_phantom: PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,14 +37,18 @@ impl<C> HTTPComms<C>
|
||||||
where
|
where
|
||||||
C: Ciphersuite,
|
C: Ciphersuite,
|
||||||
{
|
{
|
||||||
pub fn new(args: &Args) -> Self {
|
pub fn new(args: &Args) -> Result<Self, Box<dyn Error>> {
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
Self {
|
let password = env::var(&args.password).map_err(|_| eyre!("The password argument must specify the name of a environment variable containing the password"))?;
|
||||||
|
Ok(Self {
|
||||||
client,
|
client,
|
||||||
host_port: format!("http://{}:{}", args.ip, args.port),
|
host_port: format!("http://{}:{}", args.ip, args.port),
|
||||||
session_id: Uuid::parse_str(&args.session_id).expect("invalid session id"),
|
session_id: Uuid::parse_str(&args.session_id).ok(),
|
||||||
|
username: args.username.clone(),
|
||||||
|
password,
|
||||||
|
access_token: String::new(),
|
||||||
_phantom: Default::default(),
|
_phantom: Default::default(),
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,11 +71,48 @@ where
|
||||||
),
|
),
|
||||||
Box<dyn Error>,
|
Box<dyn Error>,
|
||||||
> {
|
> {
|
||||||
|
self.access_token = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/login", self.host_port))
|
||||||
|
.json(&server::LoginArgs {
|
||||||
|
username: self.username.clone(),
|
||||||
|
password: self.password.clone(),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json::<server::LoginOutput>()
|
||||||
|
.await?
|
||||||
|
.access_token
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let session_id = match self.session_id {
|
||||||
|
Some(s) => s,
|
||||||
|
None => {
|
||||||
|
// Get session ID from server
|
||||||
|
let r = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/list_sessions", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json::<server::ListSessionsOutput>()
|
||||||
|
.await?;
|
||||||
|
if r.session_ids.len() > 1 {
|
||||||
|
return Err(eyre!("user has more than one FROST session active, which is still not supported by this tool").into());
|
||||||
|
} else if r.session_ids.is_empty() {
|
||||||
|
return Err(eyre!("User has no current sessions active. The Coordinator should either specify your username, or manually share the session ID which you can specify with --session_id").into());
|
||||||
|
}
|
||||||
|
r.session_ids[0]
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.session_id = Some(session_id);
|
||||||
|
|
||||||
// Send Commitments to Server
|
// Send Commitments to Server
|
||||||
self.client
|
self.client
|
||||||
.post(format!("{}/send_commitments", self.host_port))
|
.post(format!("{}/send_commitments", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
.json(&server::SendCommitmentsArgs {
|
.json(&server::SendCommitmentsArgs {
|
||||||
session_id: self.session_id,
|
session_id,
|
||||||
identifier: identifier.into(),
|
identifier: identifier.into(),
|
||||||
commitments: vec![(&commitments).try_into()?],
|
commitments: vec![(&commitments).try_into()?],
|
||||||
})
|
})
|
||||||
|
@ -82,9 +127,8 @@ where
|
||||||
let r = self
|
let r = self
|
||||||
.client
|
.client
|
||||||
.post(format!("{}/get_signing_package", self.host_port))
|
.post(format!("{}/get_signing_package", self.host_port))
|
||||||
.json(&server::GetSigningPackageArgs {
|
.bearer_auth(&self.access_token)
|
||||||
session_id: self.session_id,
|
.json(&server::GetSigningPackageArgs { session_id })
|
||||||
})
|
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
if r.status() != 200 {
|
if r.status() != 200 {
|
||||||
|
@ -126,14 +170,22 @@ where
|
||||||
let _r = self
|
let _r = self
|
||||||
.client
|
.client
|
||||||
.post(format!("{}/send_signature_share", self.host_port))
|
.post(format!("{}/send_signature_share", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
.json(&server::SendSignatureShareArgs {
|
.json(&server::SendSignatureShareArgs {
|
||||||
identifier: identifier.into(),
|
identifier: identifier.into(),
|
||||||
session_id: self.session_id,
|
session_id: self.session_id.unwrap(),
|
||||||
signature_share: vec![signature_share.into()],
|
signature_share: vec![signature_share.into()],
|
||||||
})
|
})
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
let _r = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/logout", self.host_port))
|
||||||
|
.bearer_auth(&self.access_token)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,8 @@ async fn check_valid_round_1_inputs() {
|
||||||
port: 80,
|
port: 80,
|
||||||
session_id: "session-id".to_string(),
|
session_id: "session-id".to_string(),
|
||||||
http: false,
|
http: false,
|
||||||
|
username: "".to_string(),
|
||||||
|
password: "".to_string(),
|
||||||
};
|
};
|
||||||
let input = SECRET_SHARE_JSON;
|
let input = SECRET_SHARE_JSON;
|
||||||
let mut valid_input = input.as_bytes();
|
let mut valid_input = input.as_bytes();
|
||||||
|
|
|
@ -7,15 +7,19 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = "0.7.5"
|
axum = "0.7.5"
|
||||||
|
axum-extra = { version = "0.9.3", features = ["typed-header"] }
|
||||||
|
axum-macros = "0.4.1"
|
||||||
clap = { version = "4.5.13", features = ["derive"] }
|
clap = { version = "4.5.13", features = ["derive"] }
|
||||||
derivative = "2.2.0"
|
derivative = "2.2.0"
|
||||||
eyre = "0.6.11"
|
eyre = "0.6.11"
|
||||||
frost-core = { version = "2.0.0-rc.0", features = ["serde"] }
|
frost-core = { version = "2.0.0-rc.0", features = ["serde"] }
|
||||||
frost-rerandomized = { version = "2.0.0-rc.0", features = ["serde"] }
|
frost-rerandomized = { version = "2.0.0-rc.0", features = ["serde"] }
|
||||||
|
password-auth = "1.0.0"
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serdect = { version = "0.2.0" }
|
serdect = { version = "0.2.0" }
|
||||||
serde_json = "1.0.122"
|
serde_json = "1.0.122"
|
||||||
|
sqlx = { version = "0.7.3", features = ["sqlite", "time", "runtime-tokio", "uuid"] }
|
||||||
tokio = { version = "1.38", features = ["full"] }
|
tokio = { version = "1.38", features = ["full"] }
|
||||||
tower-http = { version = "0.5.2", features = ["trace"] }
|
tower-http = { version = "0.5.2", features = ["trace"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
|
@ -23,7 +27,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
uuid = { version = "1.10.0", features = ["v4", "fast-rng", "serde"] }
|
uuid = { version = "1.10.0", features = ["v4", "fast-rng", "serde"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
axum-test = "14.10.0"
|
axum-test = "15.2.0"
|
||||||
frost-ed25519 = { version = "2.0.0-rc.0", features = ["serde"] }
|
frost-ed25519 = { version = "2.0.0-rc.0", features = ["serde"] }
|
||||||
reddsa = { git = "https://github.com/ZcashFoundation/reddsa.git", rev = "4d8c4bb337231e6e89117334d7c61dada589a953", features = [
|
reddsa = { git = "https://github.com/ZcashFoundation/reddsa.git", rev = "4d8c4bb337231e6e89117334d7c61dada589a953", features = [
|
||||||
"frost",
|
"frost",
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
// generated by `sqlx migrate build-script`
|
||||||
|
fn main() {
|
||||||
|
// trigger recompilation when a new migration is added
|
||||||
|
println!("cargo:rerun-if-changed=migrations");
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
-- Add down migration script here
|
||||||
|
drop table if exists users;
|
||||||
|
drop table if exists access_tokens;
|
|
@ -0,0 +1,16 @@
|
||||||
|
-- Create users table.
|
||||||
|
create table if not exists users
|
||||||
|
(
|
||||||
|
id integer primary key not null,
|
||||||
|
username text not null unique,
|
||||||
|
password text not null,
|
||||||
|
pubkey blob not null
|
||||||
|
);
|
||||||
|
|
||||||
|
create table if not exists access_tokens
|
||||||
|
(
|
||||||
|
id integer primary key not null,
|
||||||
|
user_id integer not null,
|
||||||
|
access_token blob not null,
|
||||||
|
foreign key(user_id) references users(id) on delete cascade
|
||||||
|
);
|
|
@ -10,4 +10,8 @@ pub struct Args {
|
||||||
/// Port to bind to
|
/// Port to bind to
|
||||||
#[arg(short, long, default_value_t = 2744)]
|
#[arg(short, long, default_value_t = 2744)]
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
|
|
||||||
|
/// Database to use.
|
||||||
|
#[arg(short, long, default_value = "db.sqlite")]
|
||||||
|
pub database: String,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,32 +1,174 @@
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use axum::{extract::State, http::StatusCode, Json};
|
use axum::{extract::State, http::StatusCode, Json};
|
||||||
|
|
||||||
use eyre::eyre;
|
use eyre::eyre;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
state::{Session, SessionState, SharedState},
|
state::{Session, SessionState, SharedState},
|
||||||
types::*,
|
types::*,
|
||||||
|
user::{
|
||||||
|
add_access_token, authenticate_user, create_user, delete_user, get_user,
|
||||||
|
remove_access_token, User,
|
||||||
|
},
|
||||||
AppError,
|
AppError,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Implement the register API.
|
||||||
|
#[tracing::instrument(ret, err(Debug), skip(state,args), fields(args.username = %args.username))]
|
||||||
|
pub(crate) async fn register(
|
||||||
|
State(state): State<SharedState>,
|
||||||
|
Json(args): Json<RegisterArgs>,
|
||||||
|
) -> Result<Json<()>, AppError> {
|
||||||
|
let username = args.username.trim();
|
||||||
|
let password = args.password.trim();
|
||||||
|
|
||||||
|
if username.is_empty() || password.is_empty() {
|
||||||
|
return Err(AppError(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
eyre!("empty args").into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let db = {
|
||||||
|
let state_lock = state.read().unwrap();
|
||||||
|
state_lock.db.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
create_user(db, username, password, args.pubkey)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||||
|
|
||||||
|
Ok(Json(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implement the login API.
|
||||||
|
#[tracing::instrument(ret, err(Debug), skip(state,args), fields(args.username = %args.username))]
|
||||||
|
pub(crate) async fn login(
|
||||||
|
State(state): State<SharedState>,
|
||||||
|
Json(args): Json<LoginArgs>,
|
||||||
|
) -> Result<Json<LoginOutput>, AppError> {
|
||||||
|
// Check if the user sent the credentials
|
||||||
|
if args.username.is_empty() || args.password.is_empty() {
|
||||||
|
return Err(AppError(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
eyre!("empty args").into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let db = {
|
||||||
|
let state_lock = state.read().unwrap();
|
||||||
|
state_lock.db.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let user = authenticate_user(db.clone(), &args.username, &args.password)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||||
|
|
||||||
|
let user = match user {
|
||||||
|
Some(user) => user,
|
||||||
|
None => {
|
||||||
|
return Err(AppError(
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
eyre!("invalid user or password").into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let access_token = add_access_token(db.clone(), user.id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||||
|
|
||||||
|
let token = LoginOutput { access_token };
|
||||||
|
|
||||||
|
Ok(Json(token))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implement the logout API.
|
||||||
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
|
pub(crate) async fn logout(
|
||||||
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
|
) -> Result<Json<()>, AppError> {
|
||||||
|
let db = {
|
||||||
|
let state_lock = state.read().unwrap();
|
||||||
|
state_lock.db.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
remove_access_token(
|
||||||
|
db.clone(),
|
||||||
|
user.current_token
|
||||||
|
.expect("user is logged in so they must have a token"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||||
|
|
||||||
|
Ok(Json(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implement the unregister API.
|
||||||
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
|
pub(crate) async fn unregister(
|
||||||
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
|
) -> Result<Json<()>, AppError> {
|
||||||
|
let db = {
|
||||||
|
let state_lock = state.read().unwrap();
|
||||||
|
state_lock.db.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
delete_user(db, user.id)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||||
|
|
||||||
|
Ok(Json(()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Implement the create_new_session API.
|
/// Implement the create_new_session API.
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn create_new_session(
|
pub(crate) async fn create_new_session(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<CreateNewSessionArgs>,
|
Json(args): Json<CreateNewSessionArgs>,
|
||||||
) -> Result<Json<CreateNewSessionOutput>, AppError> {
|
) -> Result<Json<CreateNewSessionOutput>, AppError> {
|
||||||
tracing::info!("create_new_session");
|
|
||||||
if args.message_count == 0 {
|
if args.message_count == 0 {
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("invalid message_count"),
|
eyre!("invalid message_count").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
let db = {
|
||||||
|
let state_lock = state.read().unwrap();
|
||||||
|
state_lock.db.clone()
|
||||||
|
};
|
||||||
|
for username in &args.usernames {
|
||||||
|
if get_user(db.clone(), username)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?
|
||||||
|
.is_none()
|
||||||
|
{
|
||||||
|
return Err(AppError(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
eyre!("invalid user").into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
// Create new session object.
|
// Create new session object.
|
||||||
let id = Uuid::new_v4();
|
let id = Uuid::new_v4();
|
||||||
|
|
||||||
|
let mut state = state.write().unwrap();
|
||||||
|
|
||||||
|
// Save session ID in global state
|
||||||
|
for username in &args.usernames {
|
||||||
|
state
|
||||||
|
.sessions_by_username
|
||||||
|
.entry(username.to_string())
|
||||||
|
.or_default()
|
||||||
|
.insert(id);
|
||||||
|
}
|
||||||
|
// Create Session object
|
||||||
let session = Session {
|
let session = Session {
|
||||||
|
usernames: args.usernames,
|
||||||
num_signers: args.num_signers,
|
num_signers: args.num_signers,
|
||||||
message_count: args.message_count,
|
message_count: args.message_count,
|
||||||
state: SessionState::WaitingForCommitments {
|
state: SessionState::WaitingForCommitments {
|
||||||
|
@ -34,22 +176,41 @@ pub(crate) async fn create_new_session(
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
// Save session into global state.
|
// Save session into global state.
|
||||||
state.write().unwrap().sessions.insert(id, session);
|
state.sessions.insert(id, session);
|
||||||
|
|
||||||
let user = CreateNewSessionOutput { session_id: id };
|
let user = CreateNewSessionOutput { session_id: id };
|
||||||
Ok(Json(user))
|
Ok(Json(user))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Implement the create_new_session API.
|
||||||
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
|
pub(crate) async fn list_sessions(
|
||||||
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
|
) -> Result<Json<ListSessionsOutput>, AppError> {
|
||||||
|
let state = state.read().unwrap();
|
||||||
|
|
||||||
|
let session_ids = state
|
||||||
|
.sessions_by_username
|
||||||
|
.get(&user.username)
|
||||||
|
.map(|s| s.iter().cloned().collect())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
Ok(Json(ListSessionsOutput { session_ids }))
|
||||||
|
}
|
||||||
|
|
||||||
/// Implement the get_session_info API
|
/// Implement the get_session_info API
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn get_session_info(
|
pub(crate) async fn get_session_info(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<GetSessionInfoArgs>,
|
Json(args): Json<GetSessionInfoArgs>,
|
||||||
) -> Result<Json<GetSessionInfoOutput>, AppError> {
|
) -> Result<Json<GetSessionInfoOutput>, AppError> {
|
||||||
let state_lock = state.read().unwrap();
|
let state_lock = state.read().unwrap();
|
||||||
|
|
||||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
eyre!("session ID not found"),
|
eyre!("session ID not found").into(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
Ok(Json(GetSessionInfoOutput {
|
Ok(Json(GetSessionInfoOutput {
|
||||||
|
@ -60,9 +221,10 @@ pub(crate) async fn get_session_info(
|
||||||
|
|
||||||
/// Implement the send_commitments API
|
/// Implement the send_commitments API
|
||||||
// TODO: get identifier from channel rather from arguments
|
// TODO: get identifier from channel rather from arguments
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn send_commitments(
|
pub(crate) async fn send_commitments(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<SendCommitmentsArgs>,
|
Json(args): Json<SendCommitmentsArgs>,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
// Get the mutex lock to read and write from the state
|
// Get the mutex lock to read and write from the state
|
||||||
|
@ -73,7 +235,7 @@ pub(crate) async fn send_commitments(
|
||||||
.get_mut(&args.session_id)
|
.get_mut(&args.session_id)
|
||||||
.ok_or(AppError(
|
.ok_or(AppError(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
eyre!("session ID not found"),
|
eyre!("session ID not found").into(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
match &mut session.state {
|
match &mut session.state {
|
||||||
|
@ -81,7 +243,7 @@ pub(crate) async fn send_commitments(
|
||||||
if args.commitments.len() != session.message_count as usize {
|
if args.commitments.len() != session.message_count as usize {
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("wrong number of commitments"),
|
eyre!("wrong number of commitments").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
// Add commitment to map.
|
// Add commitment to map.
|
||||||
|
@ -89,6 +251,11 @@ pub(crate) async fn send_commitments(
|
||||||
// (it seems better to ignore overwrites, which could be caused by
|
// (it seems better to ignore overwrites, which could be caused by
|
||||||
// poor networking connectivity leading to retries)
|
// poor networking connectivity leading to retries)
|
||||||
commitments.insert(args.identifier, args.commitments);
|
commitments.insert(args.identifier, args.commitments);
|
||||||
|
tracing::debug!(
|
||||||
|
"added commitments, currently {}/{}",
|
||||||
|
commitments.len(),
|
||||||
|
session.num_signers
|
||||||
|
);
|
||||||
// If complete, advance to next state
|
// If complete, advance to next state
|
||||||
if commitments.len() == session.num_signers as usize {
|
if commitments.len() == session.num_signers as usize {
|
||||||
session.state = SessionState::CommitmentsReady {
|
session.state = SessionState::CommitmentsReady {
|
||||||
|
@ -99,7 +266,7 @@ pub(crate) async fn send_commitments(
|
||||||
_ => {
|
_ => {
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("incompatible session state"),
|
eyre!("incompatible session state").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,16 +274,17 @@ pub(crate) async fn send_commitments(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the get_commitments API
|
/// Implement the get_commitments API
|
||||||
// #[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn get_commitments(
|
pub(crate) async fn get_commitments(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<GetCommitmentsArgs>,
|
Json(args): Json<GetCommitmentsArgs>,
|
||||||
) -> Result<Json<GetCommitmentsOutput>, AppError> {
|
) -> Result<Json<GetCommitmentsOutput>, AppError> {
|
||||||
let state_lock = state.read().unwrap();
|
let state_lock = state.read().unwrap();
|
||||||
|
|
||||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
eyre!("session ID not found"),
|
eyre!("session ID not found").into(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
match &session.state {
|
match &session.state {
|
||||||
|
@ -135,15 +303,16 @@ pub(crate) async fn get_commitments(
|
||||||
})),
|
})),
|
||||||
_ => Err(AppError(
|
_ => Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("incompatible session state"),
|
eyre!("incompatible session state").into(),
|
||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the send_signing_package API
|
/// Implement the send_signing_package API
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn send_signing_package(
|
pub(crate) async fn send_signing_package(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<SendSigningPackageArgs>,
|
Json(args): Json<SendSigningPackageArgs>,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let mut state_lock = state.write().unwrap();
|
let mut state_lock = state.write().unwrap();
|
||||||
|
@ -153,7 +322,7 @@ pub(crate) async fn send_signing_package(
|
||||||
.get_mut(&args.session_id)
|
.get_mut(&args.session_id)
|
||||||
.ok_or(AppError(
|
.ok_or(AppError(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
eyre!("session ID not found"),
|
eyre!("session ID not found").into(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
match &mut session.state {
|
match &mut session.state {
|
||||||
|
@ -161,7 +330,7 @@ pub(crate) async fn send_signing_package(
|
||||||
if args.signing_package.len() != session.message_count as usize {
|
if args.signing_package.len() != session.message_count as usize {
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("wrong number of inputs"),
|
eyre!("wrong number of inputs").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if args.randomizer.len() != session.message_count as usize
|
if args.randomizer.len() != session.message_count as usize
|
||||||
|
@ -169,7 +338,7 @@ pub(crate) async fn send_signing_package(
|
||||||
{
|
{
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("wrong number of inputs"),
|
eyre!("wrong number of inputs").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
session.state = SessionState::WaitingForSignatureShares {
|
session.state = SessionState::WaitingForSignatureShares {
|
||||||
|
@ -183,7 +352,7 @@ pub(crate) async fn send_signing_package(
|
||||||
_ => {
|
_ => {
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("incompatible session state"),
|
eyre!("incompatible session state").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -191,16 +360,17 @@ pub(crate) async fn send_signing_package(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the get_signing_package API
|
/// Implement the get_signing_package API
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn get_signing_package(
|
pub(crate) async fn get_signing_package(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<GetSigningPackageArgs>,
|
Json(args): Json<GetSigningPackageArgs>,
|
||||||
) -> Result<Json<GetSigningPackageOutput>, AppError> {
|
) -> Result<Json<GetSigningPackageOutput>, AppError> {
|
||||||
let state_lock = state.read().unwrap();
|
let state_lock = state.read().unwrap();
|
||||||
|
|
||||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
eyre!("session ID not found"),
|
eyre!("session ID not found").into(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
match &session.state {
|
match &session.state {
|
||||||
|
@ -217,16 +387,17 @@ pub(crate) async fn get_signing_package(
|
||||||
})),
|
})),
|
||||||
_ => Err(AppError(
|
_ => Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("incompatible session state"),
|
eyre!("incompatible session state").into(),
|
||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the send_signature_share API
|
/// Implement the send_signature_share API
|
||||||
// TODO: get identifier from channel rather from arguments
|
// TODO: get identifier from channel rather from arguments
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn send_signature_share(
|
pub(crate) async fn send_signature_share(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<SendSignatureShareArgs>,
|
Json(args): Json<SendSignatureShareArgs>,
|
||||||
) -> Result<(), AppError> {
|
) -> Result<(), AppError> {
|
||||||
let mut state_lock = state.write().unwrap();
|
let mut state_lock = state.write().unwrap();
|
||||||
|
@ -236,7 +407,7 @@ pub(crate) async fn send_signature_share(
|
||||||
.get_mut(&args.session_id)
|
.get_mut(&args.session_id)
|
||||||
.ok_or(AppError(
|
.ok_or(AppError(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
eyre!("session ID not found"),
|
eyre!("session ID not found").into(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
match &mut session.state {
|
match &mut session.state {
|
||||||
|
@ -248,12 +419,15 @@ pub(crate) async fn send_signature_share(
|
||||||
aux_msg: _,
|
aux_msg: _,
|
||||||
} => {
|
} => {
|
||||||
if !identifiers.contains(&args.identifier) {
|
if !identifiers.contains(&args.identifier) {
|
||||||
return Err(AppError(StatusCode::NOT_FOUND, eyre!("invalid identifier")));
|
return Err(AppError(
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
eyre!("invalid identifier").into(),
|
||||||
|
));
|
||||||
}
|
}
|
||||||
if args.signature_share.len() != session.message_count as usize {
|
if args.signature_share.len() != session.message_count as usize {
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("wrong number of signature shares"),
|
eyre!("wrong number of signature shares").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
// Currently ignoring the possibility of overwriting previous values
|
// Currently ignoring the possibility of overwriting previous values
|
||||||
|
@ -270,7 +444,7 @@ pub(crate) async fn send_signature_share(
|
||||||
_ => {
|
_ => {
|
||||||
return Err(AppError(
|
return Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("incompatible session state"),
|
eyre!("incompatible session state").into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -278,16 +452,17 @@ pub(crate) async fn send_signature_share(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the get_signature_shares API
|
/// Implement the get_signature_shares API
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn get_signature_shares(
|
pub(crate) async fn get_signature_shares(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<GetSignatureSharesArgs>,
|
Json(args): Json<GetSignatureSharesArgs>,
|
||||||
) -> Result<Json<GetSignatureSharesOutput>, AppError> {
|
) -> Result<Json<GetSignatureSharesOutput>, AppError> {
|
||||||
let state_lock = state.read().unwrap();
|
let state_lock = state.read().unwrap();
|
||||||
|
|
||||||
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
eyre!("session ID not found"),
|
eyre!("session ID not found").into(),
|
||||||
))?;
|
))?;
|
||||||
|
|
||||||
match &session.state {
|
match &session.state {
|
||||||
|
@ -308,17 +483,34 @@ pub(crate) async fn get_signature_shares(
|
||||||
}
|
}
|
||||||
_ => Err(AppError(
|
_ => Err(AppError(
|
||||||
StatusCode::INTERNAL_SERVER_ERROR,
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
eyre!("incompatible session state"),
|
eyre!("incompatible session state").into(),
|
||||||
)),
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implement the close_session API.
|
/// Implement the close_session API.
|
||||||
#[tracing::instrument(ret, err(Debug))]
|
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
|
||||||
pub(crate) async fn close_session(
|
pub(crate) async fn close_session(
|
||||||
State(state): State<SharedState>,
|
State(state): State<SharedState>,
|
||||||
|
user: User,
|
||||||
Json(args): Json<CloseSessionArgs>,
|
Json(args): Json<CloseSessionArgs>,
|
||||||
) -> Result<Json<()>, AppError> {
|
) -> Result<Json<()>, AppError> {
|
||||||
state.write().unwrap().sessions.remove(&args.session_id);
|
let mut state = state.write().unwrap();
|
||||||
|
|
||||||
|
for username in state
|
||||||
|
.sessions
|
||||||
|
.get(&args.session_id)
|
||||||
|
.ok_or(AppError(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
eyre!("invalid session ID").into(),
|
||||||
|
))?
|
||||||
|
.usernames
|
||||||
|
.clone()
|
||||||
|
{
|
||||||
|
if let Some(v) = state.sessions_by_username.get_mut(&username) {
|
||||||
|
v.remove(&args.session_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.sessions.remove(&args.session_id);
|
||||||
Ok(Json(()))
|
Ok(Json(()))
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,9 @@ pub mod args;
|
||||||
mod functions;
|
mod functions;
|
||||||
mod state;
|
mod state;
|
||||||
mod types;
|
mod types;
|
||||||
|
mod user;
|
||||||
|
|
||||||
|
pub use state::{AppState, SharedState};
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
pub use types::*;
|
pub use types::*;
|
||||||
|
|
||||||
|
@ -16,11 +19,15 @@ use axum::{
|
||||||
/// Create the axum Router for the server.
|
/// Create the axum Router for the server.
|
||||||
/// Maps specific endpoints to handler functions.
|
/// Maps specific endpoints to handler functions.
|
||||||
// TODO: use methods of a single object instead of separate functions?
|
// TODO: use methods of a single object instead of separate functions?
|
||||||
pub fn router() -> Router {
|
pub fn router(shared_state: SharedState) -> Router {
|
||||||
// Shared state that is passed to each handler by axum
|
// Shared state that is passed to each handler by axum
|
||||||
let shared_state = state::SharedState::default();
|
|
||||||
Router::new()
|
Router::new()
|
||||||
|
.route("/register", post(functions::register))
|
||||||
|
.route("/login", post(functions::login))
|
||||||
|
.route("/logout", post(functions::logout))
|
||||||
|
.route("/unregister", post(functions::unregister))
|
||||||
.route("/create_new_session", post(functions::create_new_session))
|
.route("/create_new_session", post(functions::create_new_session))
|
||||||
|
.route("/list_sessions", post(functions::list_sessions))
|
||||||
.route("/get_session_info", post(functions::get_session_info))
|
.route("/get_session_info", post(functions::get_session_info))
|
||||||
.route("/send_commitments", post(functions::send_commitments))
|
.route("/send_commitments", post(functions::send_commitments))
|
||||||
.route("/get_commitments", post(functions::get_commitments))
|
.route("/get_commitments", post(functions::get_commitments))
|
||||||
|
@ -44,7 +51,8 @@ pub fn router() -> Router {
|
||||||
|
|
||||||
/// Run the server with the specified arguments.
|
/// Run the server with the specified arguments.
|
||||||
pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let app = router();
|
let shared_state = AppState::new(&args.database).await?;
|
||||||
|
let app = router(shared_state);
|
||||||
|
|
||||||
let addr = format!("{}:{}", args.ip, args.port);
|
let addr = format!("{}:{}", args.ip, args.port);
|
||||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
|
@ -55,7 +63,7 @@ pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
/// error happens during a API call, and a generic eyre::Report.
|
/// error happens during a API call, and a generic eyre::Report.
|
||||||
// TODO: create an enum with specific errors
|
// TODO: create an enum with specific errors
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AppError(StatusCode, eyre::Report);
|
pub struct AppError(StatusCode, Box<dyn std::error::Error>);
|
||||||
|
|
||||||
impl IntoResponse for AppError {
|
impl IntoResponse for AppError {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
|
|
|
@ -1,8 +1,13 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::{HashMap, HashSet},
|
collections::{HashMap, HashSet},
|
||||||
|
str::FromStr,
|
||||||
sync::{Arc, RwLock},
|
sync::{Arc, RwLock},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use sqlx::{
|
||||||
|
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
|
||||||
|
SqlitePool,
|
||||||
|
};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -67,6 +72,8 @@ impl Default for SessionState {
|
||||||
/// A particular signing session.
|
/// A particular signing session.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
|
/// The usernames of the participants
|
||||||
|
pub(crate) usernames: Vec<String>,
|
||||||
/// The number of signers in the session.
|
/// The number of signers in the session.
|
||||||
pub(crate) num_signers: u16,
|
pub(crate) num_signers: u16,
|
||||||
/// The set of identifiers for the session.
|
/// The set of identifiers for the session.
|
||||||
|
@ -78,10 +85,27 @@ pub struct Session {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The global state of the server.
|
/// The global state of the server.
|
||||||
#[derive(Default, Debug)]
|
#[derive(Debug)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
/// Mapping of signing sessions by UUID.
|
/// Mapping of signing sessions by UUID.
|
||||||
pub(crate) sessions: HashMap<Uuid, Session>,
|
pub(crate) sessions: HashMap<Uuid, Session>,
|
||||||
|
pub(crate) sessions_by_username: HashMap<String, HashSet<Uuid>>,
|
||||||
|
pub(crate) db: SqlitePool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
pub async fn new(database: &str) -> Result<SharedState, Box<dyn std::error::Error>> {
|
||||||
|
tracing::event!(tracing::Level::INFO, "opening database {}", database);
|
||||||
|
let options = SqliteConnectOptions::from_str(database)?.create_if_missing(true);
|
||||||
|
let db = SqlitePoolOptions::new().connect_with(options).await?;
|
||||||
|
sqlx::migrate!().run(&db).await?;
|
||||||
|
let state = Self {
|
||||||
|
sessions: Default::default(),
|
||||||
|
sessions_by_username: Default::default(),
|
||||||
|
db,
|
||||||
|
};
|
||||||
|
Ok(Arc::new(RwLock::new(state)))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Type alias for the global state under a reference-counted RW mutex,
|
/// Type alias for the global state under a reference-counted RW mutex,
|
||||||
|
|
|
@ -147,8 +147,31 @@ impl<C: frost_core::Ciphersuite> TryFrom<&SerializedSignatureShare>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct RegisterArgs {
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
#[serde(
|
||||||
|
serialize_with = "serdect::slice::serialize_hex_lower_or_bin",
|
||||||
|
deserialize_with = "serdect::slice::deserialize_hex_or_bin_vec"
|
||||||
|
)]
|
||||||
|
pub pubkey: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct LoginOutput {
|
||||||
|
pub access_token: Uuid,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct LoginArgs {
|
||||||
|
pub username: String,
|
||||||
|
pub password: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct CreateNewSessionArgs {
|
pub struct CreateNewSessionArgs {
|
||||||
|
pub usernames: Vec<String>,
|
||||||
pub num_signers: u16,
|
pub num_signers: u16,
|
||||||
pub message_count: u8,
|
pub message_count: u8,
|
||||||
}
|
}
|
||||||
|
@ -158,6 +181,11 @@ pub struct CreateNewSessionOutput {
|
||||||
pub session_id: Uuid,
|
pub session_id: Uuid,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ListSessionsOutput {
|
||||||
|
pub session_ids: Vec<Uuid>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct GetSessionInfoArgs {
|
pub struct GetSessionInfoArgs {
|
||||||
pub session_id: Uuid,
|
pub session_id: Uuid,
|
||||||
|
|
|
@ -0,0 +1,237 @@
|
||||||
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
async_trait,
|
||||||
|
extract::FromRequestParts,
|
||||||
|
http::{request::Parts, StatusCode},
|
||||||
|
RequestPartsExt,
|
||||||
|
};
|
||||||
|
use axum_extra::{
|
||||||
|
headers::{authorization::Bearer, Authorization},
|
||||||
|
TypedHeader,
|
||||||
|
};
|
||||||
|
use eyre::eyre;
|
||||||
|
use sqlx::{FromRow, SqlitePool};
|
||||||
|
use tokio::task;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{state::SharedState, AppError};
|
||||||
|
|
||||||
|
/// An User, as stored in the database.
|
||||||
|
#[derive(Debug, FromRow)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub struct User {
|
||||||
|
pub(crate) id: i64,
|
||||||
|
pub(crate) username: String,
|
||||||
|
pub(crate) password: String,
|
||||||
|
pub(crate) pubkey: Vec<u8>,
|
||||||
|
#[sqlx(skip)]
|
||||||
|
pub(crate) access_tokens: Vec<AccessToken>,
|
||||||
|
#[sqlx(skip)]
|
||||||
|
pub(crate) current_token: Option<Uuid>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, FromRow)]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub struct AccessToken {
|
||||||
|
pub(crate) id: i64,
|
||||||
|
pub(crate) user_id: i64,
|
||||||
|
pub(crate) access_token: Option<Uuid>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create user in the database.
|
||||||
|
///
|
||||||
|
/// The password is hashed and its hash is written in the DB.
|
||||||
|
pub(crate) async fn create_user(
|
||||||
|
db: SqlitePool,
|
||||||
|
username: &str,
|
||||||
|
password: &str,
|
||||||
|
pubkey: Vec<u8>,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
// TODO: enforce mininum password length
|
||||||
|
let password = password.to_owned();
|
||||||
|
let pwhash = task::spawn_blocking(|| password_auth::generate_hash(password)).await?;
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
insert into users (username, password, pubkey)
|
||||||
|
values (?, ?, ?)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(username)
|
||||||
|
.bind(pwhash)
|
||||||
|
.bind(pubkey)
|
||||||
|
.execute(&db)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get user from database, or None if it's not registered.
|
||||||
|
pub(crate) async fn get_user(
|
||||||
|
db: SqlitePool,
|
||||||
|
username: &str,
|
||||||
|
) -> Result<Option<User>, Box<dyn std::error::Error>> {
|
||||||
|
let user: Option<User> = sqlx::query_as("select * from users where username = ? ")
|
||||||
|
.bind(username)
|
||||||
|
.fetch_optional(&db)
|
||||||
|
.await?;
|
||||||
|
if let Some(mut user) = user {
|
||||||
|
let access_tokens: Vec<AccessToken> =
|
||||||
|
sqlx::query_as("select * from access_tokens where user_id = ?")
|
||||||
|
.bind(user.id)
|
||||||
|
.fetch_all(&db)
|
||||||
|
.await?;
|
||||||
|
user.access_tokens = access_tokens;
|
||||||
|
Ok(Some(user))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete an User from the database, given its database ID.
|
||||||
|
pub(crate) async fn delete_user(db: SqlitePool, id: i64) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
delete from users where id = ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(id)
|
||||||
|
.execute(&db)
|
||||||
|
.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Authenticate user registered in the database. Returns the User if
|
||||||
|
/// authentication is successful, or None if the username or password is wrong.
|
||||||
|
///
|
||||||
|
/// The given password is hashed and verified against the stored hash.
|
||||||
|
pub(crate) async fn authenticate_user(
|
||||||
|
db: SqlitePool,
|
||||||
|
username: &str,
|
||||||
|
password: &str,
|
||||||
|
) -> Result<Option<User>, Box<dyn std::error::Error>> {
|
||||||
|
let user: Option<User> = get_user(db, username).await?;
|
||||||
|
|
||||||
|
// Verifying the password is blocking and potentially slow, so we'll do so
|
||||||
|
// via `spawn_blocking`.
|
||||||
|
let password = password.to_owned();
|
||||||
|
let r: Result<_, password_auth::VerifyError> = task::spawn_blocking(|| {
|
||||||
|
// We're using password-based authentication--this works by comparing our form
|
||||||
|
// input with an argon2 password hash.
|
||||||
|
Ok(user.filter(|user| password_auth::verify_password(password, &user.password).is_ok()))
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
Ok(r?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Refreshes the user's access token, identified by its id in the database.
|
||||||
|
///
|
||||||
|
/// Generates a new token and overwrites the old one in the database, if any.
|
||||||
|
pub(crate) async fn add_access_token(
|
||||||
|
db: SqlitePool,
|
||||||
|
id: i64,
|
||||||
|
) -> Result<Uuid, Box<dyn std::error::Error>> {
|
||||||
|
let access_token = Uuid::new_v4();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
insert into access_tokens (user_id, access_token)
|
||||||
|
values (?, ?)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(id)
|
||||||
|
.bind(access_token)
|
||||||
|
.execute(&db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(access_token)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a user's access token.
|
||||||
|
pub(crate) async fn remove_access_token(
|
||||||
|
db: SqlitePool,
|
||||||
|
access_token: Uuid,
|
||||||
|
) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
delete from access_tokens where access_token = ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(access_token)
|
||||||
|
.execute(&db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the User for a given access token, or None if there is no match.
|
||||||
|
pub(crate) async fn get_user_for_access_token(
|
||||||
|
db: SqlitePool,
|
||||||
|
access_token: Uuid,
|
||||||
|
) -> Result<Option<User>, Box<dyn std::error::Error>> {
|
||||||
|
let user: Option<User> = sqlx::query_as(
|
||||||
|
r#"
|
||||||
|
select * from users inner join access_tokens on users.id = access_tokens.user_id where access_tokens.access_token = ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(access_token)
|
||||||
|
.fetch_optional(&db)
|
||||||
|
.await?;
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read a User from a request. This is used to authenticate users. If any axum
|
||||||
|
/// handler has an User argument, this will be called and the authentication
|
||||||
|
/// will be carried out.
|
||||||
|
#[async_trait]
|
||||||
|
impl FromRequestParts<SharedState> for User {
|
||||||
|
type Rejection = AppError;
|
||||||
|
|
||||||
|
#[tracing::instrument(ret, err(Debug), skip(parts, state))]
|
||||||
|
// Can be removed after this fix is released
|
||||||
|
// https://github.com/rust-lang/rust-clippy/issues/12281
|
||||||
|
#[allow(clippy::blocks_in_conditions)]
|
||||||
|
async fn from_request_parts(
|
||||||
|
parts: &mut Parts,
|
||||||
|
state: &SharedState,
|
||||||
|
) -> Result<Self, Self::Rejection> {
|
||||||
|
// Extract the token from the authorization header
|
||||||
|
let TypedHeader(Authorization(bearer)) = parts
|
||||||
|
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
AppError(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
eyre!("Bearer token missing").into(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
// Decode the user data
|
||||||
|
let access_token = Uuid::from_str(bearer.token()).map_err(|_| {
|
||||||
|
AppError(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
eyre!("invalid access token").into(),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let db = {
|
||||||
|
let state_lock = state.read().unwrap();
|
||||||
|
state_lock.db.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let user = get_user_for_access_token(db, access_token)
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
||||||
|
|
||||||
|
match user {
|
||||||
|
Some(mut user) => {
|
||||||
|
user.current_token = Some(access_token);
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
return Err(AppError(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
eyre!("user not found").into(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -2,7 +2,7 @@ use std::{collections::BTreeMap, time::Duration};
|
||||||
|
|
||||||
use axum_test::TestServer;
|
use axum_test::TestServer;
|
||||||
use rand::thread_rng;
|
use rand::thread_rng;
|
||||||
use server::{args::Args, router, SerializedSignatureShare, SerializedSigningPackage};
|
use server::{args::Args, router, AppState, SerializedSignatureShare, SerializedSigningPackage};
|
||||||
|
|
||||||
use frost_core as frost;
|
use frost_core as frost;
|
||||||
|
|
||||||
|
@ -47,14 +47,52 @@ async fn test_main_router<
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Instantiate test server using axum_test
|
// Instantiate test server using axum_test
|
||||||
let router = router();
|
let shared_state = AppState::new(":memory:").await?;
|
||||||
|
let router = router(shared_state);
|
||||||
let server = TestServer::new(router)?;
|
let server = TestServer::new(router)?;
|
||||||
|
|
||||||
|
// Create a dummy user. We make all requests with the same user since
|
||||||
|
// it currently it doesn't really matter who the user is, users are only
|
||||||
|
// used to share session IDs. This will likely change soon.
|
||||||
|
|
||||||
|
let res = server
|
||||||
|
.post("/register")
|
||||||
|
.json(&server::RegisterArgs {
|
||||||
|
username: "alice".to_string(),
|
||||||
|
password: "passw0rd".to_string(),
|
||||||
|
pubkey: vec![],
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
res.assert_status_ok();
|
||||||
|
|
||||||
|
let res = server
|
||||||
|
.post("/register")
|
||||||
|
.json(&server::RegisterArgs {
|
||||||
|
username: "bob".to_string(),
|
||||||
|
password: "passw0rd".to_string(),
|
||||||
|
pubkey: vec![],
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
res.assert_status_ok();
|
||||||
|
|
||||||
|
let res = server
|
||||||
|
.post("/login")
|
||||||
|
.json(&server::LoginArgs {
|
||||||
|
username: "alice".to_string(),
|
||||||
|
password: "passw0rd".to_string(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
res.assert_status_ok();
|
||||||
|
let r: server::LoginOutput = res.json();
|
||||||
|
let token = r.access_token;
|
||||||
|
|
||||||
// As the coordinator, create a new signing session with all participants,
|
// As the coordinator, create a new signing session with all participants,
|
||||||
// for 2 messages
|
// for 2 messages
|
||||||
let res = server
|
let res = server
|
||||||
.post("/create_new_session")
|
.post("/create_new_session")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::CreateNewSessionArgs {
|
.json(&server::CreateNewSessionArgs {
|
||||||
|
usernames: vec!["alice".to_string(), "bob".to_string()],
|
||||||
num_signers: 2,
|
num_signers: 2,
|
||||||
message_count: 2,
|
message_count: 2,
|
||||||
})
|
})
|
||||||
|
@ -75,6 +113,7 @@ async fn test_main_router<
|
||||||
// asking the server).
|
// asking the server).
|
||||||
let res = server
|
let res = server
|
||||||
.post("/get_session_info")
|
.post("/get_session_info")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::GetSessionInfoArgs { session_id })
|
.json(&server::GetSessionInfoArgs { session_id })
|
||||||
.await;
|
.await;
|
||||||
res.assert_status_ok();
|
res.assert_status_ok();
|
||||||
|
@ -96,6 +135,7 @@ async fn test_main_router<
|
||||||
// Send commitments to server
|
// Send commitments to server
|
||||||
let res = server
|
let res = server
|
||||||
.post("/send_commitments")
|
.post("/send_commitments")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::SendCommitmentsArgs {
|
.json(&server::SendCommitmentsArgs {
|
||||||
identifier: (*identifier).into(),
|
identifier: (*identifier).into(),
|
||||||
session_id,
|
session_id,
|
||||||
|
@ -110,6 +150,7 @@ async fn test_main_router<
|
||||||
// As the coordinator, get the commitments
|
// As the coordinator, get the commitments
|
||||||
let res = server
|
let res = server
|
||||||
.post("/get_commitments")
|
.post("/get_commitments")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::GetCommitmentsArgs { session_id })
|
.json(&server::GetCommitmentsArgs { session_id })
|
||||||
.await;
|
.await;
|
||||||
res.assert_status_ok();
|
res.assert_status_ok();
|
||||||
|
@ -146,6 +187,7 @@ async fn test_main_router<
|
||||||
// As the coordinator, send the SigningPackages to the server
|
// As the coordinator, send the SigningPackages to the server
|
||||||
let res = server
|
let res = server
|
||||||
.post("/send_signing_package")
|
.post("/send_signing_package")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::SendSigningPackageArgs {
|
.json(&server::SendSigningPackageArgs {
|
||||||
session_id,
|
session_id,
|
||||||
signing_package: signing_packages
|
signing_package: signing_packages
|
||||||
|
@ -173,6 +215,7 @@ async fn test_main_router<
|
||||||
// Get SigningPackages
|
// Get SigningPackages
|
||||||
let res = server
|
let res = server
|
||||||
.post("get_signing_package")
|
.post("get_signing_package")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::GetSigningPackageArgs { session_id })
|
.json(&server::GetSigningPackageArgs { session_id })
|
||||||
.await;
|
.await;
|
||||||
res.assert_status_ok();
|
res.assert_status_ok();
|
||||||
|
@ -210,6 +253,7 @@ async fn test_main_router<
|
||||||
// Send SignatureShares to the server
|
// Send SignatureShares to the server
|
||||||
let res = server
|
let res = server
|
||||||
.post("/send_signature_share")
|
.post("/send_signature_share")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::SendSignatureShareArgs {
|
.json(&server::SendSignatureShareArgs {
|
||||||
identifier: (*identifier).into(),
|
identifier: (*identifier).into(),
|
||||||
session_id,
|
session_id,
|
||||||
|
@ -225,6 +269,7 @@ async fn test_main_router<
|
||||||
// As the coordinator, get SignatureShares
|
// As the coordinator, get SignatureShares
|
||||||
let res = server
|
let res = server
|
||||||
.post("/get_signature_shares")
|
.post("/get_signature_shares")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::GetSignatureSharesArgs { session_id })
|
.json(&server::GetSignatureSharesArgs { session_id })
|
||||||
.await;
|
.await;
|
||||||
res.assert_status_ok();
|
res.assert_status_ok();
|
||||||
|
@ -265,6 +310,7 @@ async fn test_main_router<
|
||||||
// Close the session
|
// Close the session
|
||||||
let res = server
|
let res = server
|
||||||
.post("/close_session")
|
.post("/close_session")
|
||||||
|
.authorization_bearer(token)
|
||||||
.json(&server::CloseSessionArgs { session_id })
|
.json(&server::CloseSessionArgs { session_id })
|
||||||
.await;
|
.await;
|
||||||
res.assert_status_ok();
|
res.assert_status_ok();
|
||||||
|
@ -288,9 +334,12 @@ async fn test_main_router<
|
||||||
/// A better example on how to write client code.
|
/// A better example on how to write client code.
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
|
async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
// Spawn server for testing
|
// Spawn server for testing
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
server::run(&Args {
|
server::run(&Args {
|
||||||
|
database: ":memory:".to_string(),
|
||||||
ip: "127.0.0.1".to_string(),
|
ip: "127.0.0.1".to_string(),
|
||||||
port: 2744,
|
port: 2744,
|
||||||
})
|
})
|
||||||
|
@ -302,19 +351,67 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// TODO: this could possibly be not enough, use some retry logic instead
|
// TODO: this could possibly be not enough, use some retry logic instead
|
||||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||||
|
|
||||||
// Call create_new_session
|
// Create a client to make requests
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::new();
|
||||||
|
|
||||||
|
// Call register to create users
|
||||||
let r = client
|
let r = client
|
||||||
.post("http://127.0.0.1:2744/create_new_session")
|
.post("http://127.0.0.1:2744/register")
|
||||||
.json(&server::CreateNewSessionArgs {
|
.json(&server::RegisterArgs {
|
||||||
num_signers: 2,
|
username: "alice".to_string(),
|
||||||
message_count: 1,
|
password: "passw0rd".to_string(),
|
||||||
|
pubkey: vec![],
|
||||||
})
|
})
|
||||||
.send()
|
.send()
|
||||||
.await?
|
|
||||||
.json::<server::CreateNewSessionOutput>()
|
|
||||||
.await?;
|
.await?;
|
||||||
println!("{}", r.session_id);
|
if r.status() != reqwest::StatusCode::OK {
|
||||||
|
panic!("{}", r.text().await?)
|
||||||
|
}
|
||||||
|
let r = client
|
||||||
|
.post("http://127.0.0.1:2744/register")
|
||||||
|
.json(&server::RegisterArgs {
|
||||||
|
username: "bob".to_string(),
|
||||||
|
password: "passw0rd".to_string(),
|
||||||
|
pubkey: vec![],
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
if r.status() != reqwest::StatusCode::OK {
|
||||||
|
panic!("{}", r.text().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call login to authenticate
|
||||||
|
let r = client
|
||||||
|
.post("http://127.0.0.1:2744/login")
|
||||||
|
.json(&server::LoginArgs {
|
||||||
|
username: "alice".to_string(),
|
||||||
|
password: "passw0rd".to_string(),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
if r.status() != reqwest::StatusCode::OK {
|
||||||
|
panic!("{}", r.text().await?)
|
||||||
|
}
|
||||||
|
let r = r.json::<server::LoginOutput>().await?;
|
||||||
|
let access_token = r.access_token;
|
||||||
|
|
||||||
|
// Call create_new_session
|
||||||
|
let r = client
|
||||||
|
.post("http://127.0.0.1:2744/create_new_session")
|
||||||
|
.bearer_auth(access_token)
|
||||||
|
.json(&server::CreateNewSessionArgs {
|
||||||
|
usernames: vec!["alice".to_string(), "bob".to_string()],
|
||||||
|
message_count: 1,
|
||||||
|
num_signers: 2,
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
if r.status() != reqwest::StatusCode::OK {
|
||||||
|
panic!("{}", r.text().await?)
|
||||||
|
}
|
||||||
|
let r = r.json::<server::CreateNewSessionOutput>().await?;
|
||||||
|
let session_id = r.session_id;
|
||||||
|
println!("Session ID: {}", session_id);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue