Add user registration (#252)

* add user registration

* addint support to coordinator and participant

* Update participant/src/comms/http.rs

Co-authored-by: Pili Guerra <mpguerra@users.noreply.github.com>

---------

Co-authored-by: Pili Guerra <mpguerra@users.noreply.github.com>
This commit is contained in:
Conrado Gouvea 2024-08-14 18:48:59 -03:00 committed by GitHub
parent 4503c10790
commit 4c6e860d69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1594 additions and 77 deletions

788
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -17,6 +17,21 @@ pub struct Args {
#[arg(long, default_value_t = false)]
pub http: bool,
/// The username to use in HTTP mode.
#[arg(short = 'u', long, default_value = "")]
pub username: String,
/// The password to use in HTTP mode. If specified, it will be read from the
/// environment variable with the given name.
#[arg(short = 'w', long, default_value = "")]
pub password: String,
/// The comma-separated usernames of the signers to use in HTTP mode.
/// If HTTP mode is enabled and this is empty, then the session ID
/// will be printed and will have to be shared manually.
#[arg(short = 'S', long, value_delimiter = ',')]
pub signers: Vec<String>,
/// The number of participants. If 0, will prompt for a value.
#[arg(short = 'n', long, default_value_t = 0)]
pub num_signers: u16,

View File

@ -22,7 +22,7 @@ pub async fn cli<C: RandomizedCiphersuite + 'static>(
let mut comms: Box<dyn Comms<C>> = if args.cli {
Box::new(CLIComms::new())
} else if args.http {
Box::new(HTTPComms::new(args))
Box::new(HTTPComms::new(args)?)
} else {
Box::new(SocketComms::new(args))
};

View File

@ -16,10 +16,12 @@ use frost::{
use std::{
collections::BTreeMap,
env,
error::Error,
io::{BufRead, Write},
marker::PhantomData,
time::Duration,
vec,
};
use super::Comms;
@ -29,18 +31,27 @@ pub struct HTTPComms<C: Ciphersuite> {
client: reqwest::Client,
host_port: String,
session_id: Option<Uuid>,
username: String,
password: String,
access_token: String,
signers: Vec<String>,
_phantom: PhantomData<C>,
}
impl<C: Ciphersuite> HTTPComms<C> {
pub fn new(args: &Args) -> Self {
pub fn new(args: &Args) -> Result<Self, Box<dyn Error>> {
let client = reqwest::Client::new();
Self {
let password = env::var(&args.password).map_err(|_| eyre!("The password argument must specify the name of a environment variable containing the password"))?;
Ok(Self {
client,
host_port: format!("http://{}:{}", args.ip, args.port),
session_id: None,
username: args.username.clone(),
password,
access_token: String::new(),
signers: args.signers.clone(),
_phantom: Default::default(),
}
})
}
}
@ -53,10 +64,26 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
_pub_key_package: &PublicKeyPackage<C>,
num_signers: u16,
) -> Result<BTreeMap<Identifier<C>, SigningCommitments<C>>, Box<dyn Error>> {
self.access_token = self
.client
.post(format!("{}/login", self.host_port))
.json(&server::LoginArgs {
username: self.username.clone(),
password: self.password.clone(),
})
.send()
.await?
.json::<server::LoginOutput>()
.await?
.access_token
.to_string();
let r = self
.client
.post(format!("{}/create_new_session", self.host_port))
.bearer_auth(&self.access_token)
.json(&server::CreateNewSessionArgs {
usernames: self.signers.clone(),
num_signers,
message_count: 1,
})
@ -65,10 +92,12 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
.json::<server::CreateNewSessionOutput>()
.await?;
eprintln!(
"Send the following session ID to participants: {}",
r.session_id
);
if self.signers.is_empty() {
eprintln!(
"Send the following session ID to participants: {}",
r.session_id
);
}
self.session_id = Some(r.session_id);
eprint!("Waiting for participants to send their commitments...");
@ -76,6 +105,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
let r = self
.client
.post(format!("{}/get_commitments", self.host_port))
.bearer_auth(&self.access_token)
.json(&server::GetCommitmentsArgs {
session_id: r.session_id,
})
@ -114,6 +144,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
let _r = self
.client
.post(format!("{}/send_signing_package", self.host_port))
.bearer_auth(&self.access_token)
.json(&server::SendSigningPackageArgs {
aux_msg: Default::default(),
session_id: self.session_id.unwrap(),
@ -131,6 +162,7 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
let r = self
.client
.post(format!("{}/get_signature_shares", self.host_port))
.bearer_auth(&self.access_token)
.json(&server::GetSignatureSharesArgs {
session_id: self.session_id.unwrap(),
})
@ -145,6 +177,23 @@ impl<C: Ciphersuite + 'static> Comms<C> for HTTPComms<C> {
};
eprintln!();
let _r = self
.client
.post(format!("{}/close_session", self.host_port))
.bearer_auth(&self.access_token)
.json(&server::CloseSessionArgs {
session_id: self.session_id.unwrap(),
})
.send()
.await?;
let _r = self
.client
.post(format!("{}/logout", self.host_port))
.bearer_auth(&self.access_token)
.send()
.await?;
let signature_shares = r
.signature_shares
.first()

View File

@ -48,7 +48,9 @@ async fn read_commitments<C: Ciphersuite>(
let pub_key_package: PublicKeyPackage<C> = serde_json::from_str(&out)?;
let num_of_participants = if args.num_signers == 0 {
let num_of_participants = if !args.signers.is_empty() {
args.signers.len() as u16
} else if args.num_signers == 0 {
writeln!(logger, "The number of participants: ")?;
let mut participants = String::new();

View File

@ -17,6 +17,15 @@ pub struct Args {
#[arg(long, default_value_t = false)]
pub http: bool,
/// The username to use in HTTP mode.
#[arg(short = 'u', long, default_value = "")]
pub username: String,
/// The password to use in HTTP mode. If specified, it will be read from the
/// environment variable with the given name.
#[arg(short = 'w', long, default_value = "")]
pub password: String,
/// Public key package to use. Can be a file with a JSON-encoded
/// package, or "". If the file does not exist or if "" is specified,
/// then it will be read from standard input.

View File

@ -20,7 +20,7 @@ pub async fn cli<C: RandomizedCiphersuite + 'static>(
let mut comms: Box<dyn Comms<C>> = if args.cli {
Box::new(CLIComms::new())
} else if args.http {
Box::new(HTTPComms::new(args))
Box::new(HTTPComms::new(args)?)
} else {
Box::new(SocketComms::new(args))
};

View File

@ -10,6 +10,7 @@ use frost::{round1::SigningCommitments, round2::SignatureShare, Identifier};
use super::Comms;
use std::env;
use std::io::{BufRead, Write};
use std::error::Error;
@ -22,7 +23,10 @@ use crate::args::Args;
pub struct HTTPComms<C: Ciphersuite> {
client: reqwest::Client,
host_port: String,
session_id: Uuid,
session_id: Option<Uuid>,
username: String,
password: String,
access_token: String,
_phantom: PhantomData<C>,
}
@ -33,14 +37,18 @@ impl<C> HTTPComms<C>
where
C: Ciphersuite,
{
pub fn new(args: &Args) -> Self {
pub fn new(args: &Args) -> Result<Self, Box<dyn Error>> {
let client = reqwest::Client::new();
Self {
let password = env::var(&args.password).map_err(|_| eyre!("The password argument must specify the name of a environment variable containing the password"))?;
Ok(Self {
client,
host_port: format!("http://{}:{}", args.ip, args.port),
session_id: Uuid::parse_str(&args.session_id).expect("invalid session id"),
session_id: Uuid::parse_str(&args.session_id).ok(),
username: args.username.clone(),
password,
access_token: String::new(),
_phantom: Default::default(),
}
})
}
}
@ -63,11 +71,48 @@ where
),
Box<dyn Error>,
> {
self.access_token = self
.client
.post(format!("{}/login", self.host_port))
.json(&server::LoginArgs {
username: self.username.clone(),
password: self.password.clone(),
})
.send()
.await?
.json::<server::LoginOutput>()
.await?
.access_token
.to_string();
let session_id = match self.session_id {
Some(s) => s,
None => {
// Get session ID from server
let r = self
.client
.post(format!("{}/list_sessions", self.host_port))
.bearer_auth(&self.access_token)
.send()
.await?
.json::<server::ListSessionsOutput>()
.await?;
if r.session_ids.len() > 1 {
return Err(eyre!("user has more than one FROST session active, which is still not supported by this tool").into());
} else if r.session_ids.is_empty() {
return Err(eyre!("User has no current sessions active. The Coordinator should either specify your username, or manually share the session ID which you can specify with --session_id").into());
}
r.session_ids[0]
}
};
self.session_id = Some(session_id);
// Send Commitments to Server
self.client
.post(format!("{}/send_commitments", self.host_port))
.bearer_auth(&self.access_token)
.json(&server::SendCommitmentsArgs {
session_id: self.session_id,
session_id,
identifier: identifier.into(),
commitments: vec![(&commitments).try_into()?],
})
@ -82,9 +127,8 @@ where
let r = self
.client
.post(format!("{}/get_signing_package", self.host_port))
.json(&server::GetSigningPackageArgs {
session_id: self.session_id,
})
.bearer_auth(&self.access_token)
.json(&server::GetSigningPackageArgs { session_id })
.send()
.await?;
if r.status() != 200 {
@ -126,14 +170,22 @@ where
let _r = self
.client
.post(format!("{}/send_signature_share", self.host_port))
.bearer_auth(&self.access_token)
.json(&server::SendSignatureShareArgs {
identifier: identifier.into(),
session_id: self.session_id,
session_id: self.session_id.unwrap(),
signature_share: vec![signature_share.into()],
})
.send()
.await?;
let _r = self
.client
.post(format!("{}/logout", self.host_port))
.bearer_auth(&self.access_token)
.send()
.await?;
Ok(())
}
}

View File

@ -46,6 +46,8 @@ async fn check_valid_round_1_inputs() {
port: 80,
session_id: "session-id".to_string(),
http: false,
username: "".to_string(),
password: "".to_string(),
};
let input = SECRET_SHARE_JSON;
let mut valid_input = input.as_bytes();

View File

@ -7,15 +7,19 @@ edition = "2021"
[dependencies]
axum = "0.7.5"
axum-extra = { version = "0.9.3", features = ["typed-header"] }
axum-macros = "0.4.1"
clap = { version = "4.5.13", features = ["derive"] }
derivative = "2.2.0"
eyre = "0.6.11"
frost-core = { version = "2.0.0-rc.0", features = ["serde"] }
frost-rerandomized = { version = "2.0.0-rc.0", features = ["serde"] }
password-auth = "1.0.0"
rand = "0.8"
serde = { version = "1.0", features = ["derive"] }
serdect = { version = "0.2.0" }
serde_json = "1.0.122"
sqlx = { version = "0.7.3", features = ["sqlite", "time", "runtime-tokio", "uuid"] }
tokio = { version = "1.38", features = ["full"] }
tower-http = { version = "0.5.2", features = ["trace"] }
tracing = "0.1"
@ -23,7 +27,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.10.0", features = ["v4", "fast-rng", "serde"] }
[dev-dependencies]
axum-test = "14.10.0"
axum-test = "15.2.0"
frost-ed25519 = { version = "2.0.0-rc.0", features = ["serde"] }
reddsa = { git = "https://github.com/ZcashFoundation/reddsa.git", rev = "4d8c4bb337231e6e89117334d7c61dada589a953", features = [
"frost",

5
server/build.rs Normal file
View File

@ -0,0 +1,5 @@
// generated by `sqlx migrate build-script`
fn main() {
// trigger recompilation when a new migration is added
println!("cargo:rerun-if-changed=migrations");
}

View File

@ -0,0 +1,3 @@
-- Add down migration script here
drop table if exists users;
drop table if exists access_tokens;

View File

@ -0,0 +1,16 @@
-- Create users table.
create table if not exists users
(
id integer primary key not null,
username text not null unique,
password text not null,
pubkey blob not null
);
create table if not exists access_tokens
(
id integer primary key not null,
user_id integer not null,
access_token blob not null,
foreign key(user_id) references users(id) on delete cascade
);

View File

@ -10,4 +10,8 @@ pub struct Args {
/// Port to bind to
#[arg(short, long, default_value_t = 2744)]
pub port: u16,
/// Database to use.
#[arg(short, long, default_value = "db.sqlite")]
pub database: String,
}

View File

@ -1,32 +1,174 @@
use std::collections::HashSet;
use axum::{extract::State, http::StatusCode, Json};
use eyre::eyre;
use uuid::Uuid;
use crate::{
state::{Session, SessionState, SharedState},
types::*,
user::{
add_access_token, authenticate_user, create_user, delete_user, get_user,
remove_access_token, User,
},
AppError,
};
/// Implement the register API.
#[tracing::instrument(ret, err(Debug), skip(state,args), fields(args.username = %args.username))]
pub(crate) async fn register(
State(state): State<SharedState>,
Json(args): Json<RegisterArgs>,
) -> Result<Json<()>, AppError> {
let username = args.username.trim();
let password = args.password.trim();
if username.is_empty() || password.is_empty() {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("empty args").into(),
));
}
let db = {
let state_lock = state.read().unwrap();
state_lock.db.clone()
};
create_user(db, username, password, args.pubkey)
.await
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
Ok(Json(()))
}
/// Implement the login API.
#[tracing::instrument(ret, err(Debug), skip(state,args), fields(args.username = %args.username))]
pub(crate) async fn login(
State(state): State<SharedState>,
Json(args): Json<LoginArgs>,
) -> Result<Json<LoginOutput>, AppError> {
// Check if the user sent the credentials
if args.username.is_empty() || args.password.is_empty() {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("empty args").into(),
));
}
let db = {
let state_lock = state.read().unwrap();
state_lock.db.clone()
};
let user = authenticate_user(db.clone(), &args.username, &args.password)
.await
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
let user = match user {
Some(user) => user,
None => {
return Err(AppError(
StatusCode::UNAUTHORIZED,
eyre!("invalid user or password").into(),
))
}
};
let access_token = add_access_token(db.clone(), user.id)
.await
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
let token = LoginOutput { access_token };
Ok(Json(token))
}
/// Implement the logout API.
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn logout(
State(state): State<SharedState>,
user: User,
) -> Result<Json<()>, AppError> {
let db = {
let state_lock = state.read().unwrap();
state_lock.db.clone()
};
remove_access_token(
db.clone(),
user.current_token
.expect("user is logged in so they must have a token"),
)
.await
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
Ok(Json(()))
}
/// Implement the unregister API.
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn unregister(
State(state): State<SharedState>,
user: User,
) -> Result<Json<()>, AppError> {
let db = {
let state_lock = state.read().unwrap();
state_lock.db.clone()
};
delete_user(db, user.id)
.await
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
Ok(Json(()))
}
/// Implement the create_new_session API.
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn create_new_session(
State(state): State<SharedState>,
user: User,
Json(args): Json<CreateNewSessionArgs>,
) -> Result<Json<CreateNewSessionOutput>, AppError> {
tracing::info!("create_new_session");
if args.message_count == 0 {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid message_count"),
eyre!("invalid message_count").into(),
));
}
let db = {
let state_lock = state.read().unwrap();
state_lock.db.clone()
};
for username in &args.usernames {
if get_user(db.clone(), username)
.await
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?
.is_none()
{
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid user").into(),
));
}
}
// Create new session object.
let id = Uuid::new_v4();
let mut state = state.write().unwrap();
// Save session ID in global state
for username in &args.usernames {
state
.sessions_by_username
.entry(username.to_string())
.or_default()
.insert(id);
}
// Create Session object
let session = Session {
usernames: args.usernames,
num_signers: args.num_signers,
message_count: args.message_count,
state: SessionState::WaitingForCommitments {
@ -34,22 +176,41 @@ pub(crate) async fn create_new_session(
},
};
// Save session into global state.
state.write().unwrap().sessions.insert(id, session);
state.sessions.insert(id, session);
let user = CreateNewSessionOutput { session_id: id };
Ok(Json(user))
}
/// Implement the create_new_session API.
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn list_sessions(
State(state): State<SharedState>,
user: User,
) -> Result<Json<ListSessionsOutput>, AppError> {
let state = state.read().unwrap();
let session_ids = state
.sessions_by_username
.get(&user.username)
.map(|s| s.iter().cloned().collect())
.unwrap_or_default();
Ok(Json(ListSessionsOutput { session_ids }))
}
/// Implement the get_session_info API
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn get_session_info(
State(state): State<SharedState>,
user: User,
Json(args): Json<GetSessionInfoArgs>,
) -> Result<Json<GetSessionInfoOutput>, AppError> {
let state_lock = state.read().unwrap();
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found"),
eyre!("session ID not found").into(),
))?;
Ok(Json(GetSessionInfoOutput {
@ -60,9 +221,10 @@ pub(crate) async fn get_session_info(
/// Implement the send_commitments API
// TODO: get identifier from channel rather from arguments
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn send_commitments(
State(state): State<SharedState>,
user: User,
Json(args): Json<SendCommitmentsArgs>,
) -> Result<(), AppError> {
// Get the mutex lock to read and write from the state
@ -73,7 +235,7 @@ pub(crate) async fn send_commitments(
.get_mut(&args.session_id)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found"),
eyre!("session ID not found").into(),
))?;
match &mut session.state {
@ -81,7 +243,7 @@ pub(crate) async fn send_commitments(
if args.commitments.len() != session.message_count as usize {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("wrong number of commitments"),
eyre!("wrong number of commitments").into(),
));
}
// Add commitment to map.
@ -89,6 +251,11 @@ pub(crate) async fn send_commitments(
// (it seems better to ignore overwrites, which could be caused by
// poor networking connectivity leading to retries)
commitments.insert(args.identifier, args.commitments);
tracing::debug!(
"added commitments, currently {}/{}",
commitments.len(),
session.num_signers
);
// If complete, advance to next state
if commitments.len() == session.num_signers as usize {
session.state = SessionState::CommitmentsReady {
@ -99,7 +266,7 @@ pub(crate) async fn send_commitments(
_ => {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("incompatible session state"),
eyre!("incompatible session state").into(),
));
}
}
@ -107,16 +274,17 @@ pub(crate) async fn send_commitments(
}
/// Implement the get_commitments API
// #[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn get_commitments(
State(state): State<SharedState>,
user: User,
Json(args): Json<GetCommitmentsArgs>,
) -> Result<Json<GetCommitmentsOutput>, AppError> {
let state_lock = state.read().unwrap();
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found"),
eyre!("session ID not found").into(),
))?;
match &session.state {
@ -135,15 +303,16 @@ pub(crate) async fn get_commitments(
})),
_ => Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("incompatible session state"),
eyre!("incompatible session state").into(),
)),
}
}
/// Implement the send_signing_package API
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn send_signing_package(
State(state): State<SharedState>,
user: User,
Json(args): Json<SendSigningPackageArgs>,
) -> Result<(), AppError> {
let mut state_lock = state.write().unwrap();
@ -153,7 +322,7 @@ pub(crate) async fn send_signing_package(
.get_mut(&args.session_id)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found"),
eyre!("session ID not found").into(),
))?;
match &mut session.state {
@ -161,7 +330,7 @@ pub(crate) async fn send_signing_package(
if args.signing_package.len() != session.message_count as usize {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("wrong number of inputs"),
eyre!("wrong number of inputs").into(),
));
}
if args.randomizer.len() != session.message_count as usize
@ -169,7 +338,7 @@ pub(crate) async fn send_signing_package(
{
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("wrong number of inputs"),
eyre!("wrong number of inputs").into(),
));
}
session.state = SessionState::WaitingForSignatureShares {
@ -183,7 +352,7 @@ pub(crate) async fn send_signing_package(
_ => {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("incompatible session state"),
eyre!("incompatible session state").into(),
));
}
}
@ -191,16 +360,17 @@ pub(crate) async fn send_signing_package(
}
/// Implement the get_signing_package API
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn get_signing_package(
State(state): State<SharedState>,
user: User,
Json(args): Json<GetSigningPackageArgs>,
) -> Result<Json<GetSigningPackageOutput>, AppError> {
let state_lock = state.read().unwrap();
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found"),
eyre!("session ID not found").into(),
))?;
match &session.state {
@ -217,16 +387,17 @@ pub(crate) async fn get_signing_package(
})),
_ => Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("incompatible session state"),
eyre!("incompatible session state").into(),
)),
}
}
/// Implement the send_signature_share API
// TODO: get identifier from channel rather from arguments
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn send_signature_share(
State(state): State<SharedState>,
user: User,
Json(args): Json<SendSignatureShareArgs>,
) -> Result<(), AppError> {
let mut state_lock = state.write().unwrap();
@ -236,7 +407,7 @@ pub(crate) async fn send_signature_share(
.get_mut(&args.session_id)
.ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found"),
eyre!("session ID not found").into(),
))?;
match &mut session.state {
@ -248,12 +419,15 @@ pub(crate) async fn send_signature_share(
aux_msg: _,
} => {
if !identifiers.contains(&args.identifier) {
return Err(AppError(StatusCode::NOT_FOUND, eyre!("invalid identifier")));
return Err(AppError(
StatusCode::NOT_FOUND,
eyre!("invalid identifier").into(),
));
}
if args.signature_share.len() != session.message_count as usize {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("wrong number of signature shares"),
eyre!("wrong number of signature shares").into(),
));
}
// Currently ignoring the possibility of overwriting previous values
@ -270,7 +444,7 @@ pub(crate) async fn send_signature_share(
_ => {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("incompatible session state"),
eyre!("incompatible session state").into(),
));
}
}
@ -278,16 +452,17 @@ pub(crate) async fn send_signature_share(
}
/// Implement the get_signature_shares API
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn get_signature_shares(
State(state): State<SharedState>,
user: User,
Json(args): Json<GetSignatureSharesArgs>,
) -> Result<Json<GetSignatureSharesOutput>, AppError> {
let state_lock = state.read().unwrap();
let session = state_lock.sessions.get(&args.session_id).ok_or(AppError(
StatusCode::NOT_FOUND,
eyre!("session ID not found"),
eyre!("session ID not found").into(),
))?;
match &session.state {
@ -308,17 +483,34 @@ pub(crate) async fn get_signature_shares(
}
_ => Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("incompatible session state"),
eyre!("incompatible session state").into(),
)),
}
}
/// Implement the close_session API.
#[tracing::instrument(ret, err(Debug))]
#[tracing::instrument(ret, err(Debug), skip(state,user), fields(user.username = %user.username))]
pub(crate) async fn close_session(
State(state): State<SharedState>,
user: User,
Json(args): Json<CloseSessionArgs>,
) -> Result<Json<()>, AppError> {
state.write().unwrap().sessions.remove(&args.session_id);
let mut state = state.write().unwrap();
for username in state
.sessions
.get(&args.session_id)
.ok_or(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid session ID").into(),
))?
.usernames
.clone()
{
if let Some(v) = state.sessions_by_username.get_mut(&username) {
v.remove(&args.session_id);
}
}
state.sessions.remove(&args.session_id);
Ok(Json(()))
}

View File

@ -2,6 +2,9 @@ pub mod args;
mod functions;
mod state;
mod types;
mod user;
pub use state::{AppState, SharedState};
use tower_http::trace::TraceLayer;
pub use types::*;
@ -16,11 +19,15 @@ use axum::{
/// Create the axum Router for the server.
/// Maps specific endpoints to handler functions.
// TODO: use methods of a single object instead of separate functions?
pub fn router() -> Router {
pub fn router(shared_state: SharedState) -> Router {
// Shared state that is passed to each handler by axum
let shared_state = state::SharedState::default();
Router::new()
.route("/register", post(functions::register))
.route("/login", post(functions::login))
.route("/logout", post(functions::logout))
.route("/unregister", post(functions::unregister))
.route("/create_new_session", post(functions::create_new_session))
.route("/list_sessions", post(functions::list_sessions))
.route("/get_session_info", post(functions::get_session_info))
.route("/send_commitments", post(functions::send_commitments))
.route("/get_commitments", post(functions::get_commitments))
@ -44,7 +51,8 @@ pub fn router() -> Router {
/// Run the server with the specified arguments.
pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
let app = router();
let shared_state = AppState::new(&args.database).await?;
let app = router(shared_state);
let addr = format!("{}:{}", args.ip, args.port);
let listener = tokio::net::TcpListener::bind(addr).await?;
@ -55,7 +63,7 @@ pub async fn run(args: &Args) -> Result<(), Box<dyn std::error::Error>> {
/// error happens during a API call, and a generic eyre::Report.
// TODO: create an enum with specific errors
#[derive(Debug)]
pub struct AppError(StatusCode, eyre::Report);
pub struct AppError(StatusCode, Box<dyn std::error::Error>);
impl IntoResponse for AppError {
fn into_response(self) -> Response {

View File

@ -1,8 +1,13 @@
use std::{
collections::{HashMap, HashSet},
str::FromStr,
sync::{Arc, RwLock},
};
use sqlx::{
sqlite::{SqliteConnectOptions, SqlitePoolOptions},
SqlitePool,
};
use uuid::Uuid;
use crate::{
@ -67,6 +72,8 @@ impl Default for SessionState {
/// A particular signing session.
#[derive(Debug)]
pub struct Session {
/// The usernames of the participants
pub(crate) usernames: Vec<String>,
/// The number of signers in the session.
pub(crate) num_signers: u16,
/// The set of identifiers for the session.
@ -78,10 +85,27 @@ pub struct Session {
}
/// The global state of the server.
#[derive(Default, Debug)]
#[derive(Debug)]
pub struct AppState {
/// Mapping of signing sessions by UUID.
pub(crate) sessions: HashMap<Uuid, Session>,
pub(crate) sessions_by_username: HashMap<String, HashSet<Uuid>>,
pub(crate) db: SqlitePool,
}
impl AppState {
pub async fn new(database: &str) -> Result<SharedState, Box<dyn std::error::Error>> {
tracing::event!(tracing::Level::INFO, "opening database {}", database);
let options = SqliteConnectOptions::from_str(database)?.create_if_missing(true);
let db = SqlitePoolOptions::new().connect_with(options).await?;
sqlx::migrate!().run(&db).await?;
let state = Self {
sessions: Default::default(),
sessions_by_username: Default::default(),
db,
};
Ok(Arc::new(RwLock::new(state)))
}
}
/// Type alias for the global state under a reference-counted RW mutex,

View File

@ -147,8 +147,31 @@ impl<C: frost_core::Ciphersuite> TryFrom<&SerializedSignatureShare>
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RegisterArgs {
pub username: String,
pub password: String,
#[serde(
serialize_with = "serdect::slice::serialize_hex_lower_or_bin",
deserialize_with = "serdect::slice::deserialize_hex_or_bin_vec"
)]
pub pubkey: Vec<u8>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginOutput {
pub access_token: Uuid,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginArgs {
pub username: String,
pub password: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CreateNewSessionArgs {
pub usernames: Vec<String>,
pub num_signers: u16,
pub message_count: u8,
}
@ -158,6 +181,11 @@ pub struct CreateNewSessionOutput {
pub session_id: Uuid,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ListSessionsOutput {
pub session_ids: Vec<Uuid>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GetSessionInfoArgs {
pub session_id: Uuid,

237
server/src/user.rs Normal file
View File

@ -0,0 +1,237 @@
use std::str::FromStr;
use axum::{
async_trait,
extract::FromRequestParts,
http::{request::Parts, StatusCode},
RequestPartsExt,
};
use axum_extra::{
headers::{authorization::Bearer, Authorization},
TypedHeader,
};
use eyre::eyre;
use sqlx::{FromRow, SqlitePool};
use tokio::task;
use uuid::Uuid;
use crate::{state::SharedState, AppError};
/// An User, as stored in the database.
#[derive(Debug, FromRow)]
#[allow(dead_code)]
pub struct User {
pub(crate) id: i64,
pub(crate) username: String,
pub(crate) password: String,
pub(crate) pubkey: Vec<u8>,
#[sqlx(skip)]
pub(crate) access_tokens: Vec<AccessToken>,
#[sqlx(skip)]
pub(crate) current_token: Option<Uuid>,
}
#[derive(Debug, FromRow)]
#[allow(dead_code)]
pub struct AccessToken {
pub(crate) id: i64,
pub(crate) user_id: i64,
pub(crate) access_token: Option<Uuid>,
}
/// Create user in the database.
///
/// The password is hashed and its hash is written in the DB.
pub(crate) async fn create_user(
db: SqlitePool,
username: &str,
password: &str,
pubkey: Vec<u8>,
) -> Result<(), Box<dyn std::error::Error>> {
// TODO: enforce mininum password length
let password = password.to_owned();
let pwhash = task::spawn_blocking(|| password_auth::generate_hash(password)).await?;
sqlx::query(
r#"
insert into users (username, password, pubkey)
values (?, ?, ?)
"#,
)
.bind(username)
.bind(pwhash)
.bind(pubkey)
.execute(&db)
.await?;
Ok(())
}
/// Get user from database, or None if it's not registered.
pub(crate) async fn get_user(
db: SqlitePool,
username: &str,
) -> Result<Option<User>, Box<dyn std::error::Error>> {
let user: Option<User> = sqlx::query_as("select * from users where username = ? ")
.bind(username)
.fetch_optional(&db)
.await?;
if let Some(mut user) = user {
let access_tokens: Vec<AccessToken> =
sqlx::query_as("select * from access_tokens where user_id = ?")
.bind(user.id)
.fetch_all(&db)
.await?;
user.access_tokens = access_tokens;
Ok(Some(user))
} else {
Ok(None)
}
}
/// Delete an User from the database, given its database ID.
pub(crate) async fn delete_user(db: SqlitePool, id: i64) -> Result<(), Box<dyn std::error::Error>> {
sqlx::query(
r#"
delete from users where id = ?
"#,
)
.bind(id)
.execute(&db)
.await?;
Ok(())
}
/// Authenticate user registered in the database. Returns the User if
/// authentication is successful, or None if the username or password is wrong.
///
/// The given password is hashed and verified against the stored hash.
pub(crate) async fn authenticate_user(
db: SqlitePool,
username: &str,
password: &str,
) -> Result<Option<User>, Box<dyn std::error::Error>> {
let user: Option<User> = get_user(db, username).await?;
// Verifying the password is blocking and potentially slow, so we'll do so
// via `spawn_blocking`.
let password = password.to_owned();
let r: Result<_, password_auth::VerifyError> = task::spawn_blocking(|| {
// We're using password-based authentication--this works by comparing our form
// input with an argon2 password hash.
Ok(user.filter(|user| password_auth::verify_password(password, &user.password).is_ok()))
})
.await?;
Ok(r?)
}
/// Refreshes the user's access token, identified by its id in the database.
///
/// Generates a new token and overwrites the old one in the database, if any.
pub(crate) async fn add_access_token(
db: SqlitePool,
id: i64,
) -> Result<Uuid, Box<dyn std::error::Error>> {
let access_token = Uuid::new_v4();
sqlx::query(
r#"
insert into access_tokens (user_id, access_token)
values (?, ?)
"#,
)
.bind(id)
.bind(access_token)
.execute(&db)
.await?;
Ok(access_token)
}
/// Remove a user's access token.
pub(crate) async fn remove_access_token(
db: SqlitePool,
access_token: Uuid,
) -> Result<(), Box<dyn std::error::Error>> {
sqlx::query(
r#"
delete from access_tokens where access_token = ?
"#,
)
.bind(access_token)
.execute(&db)
.await?;
Ok(())
}
/// Return the User for a given access token, or None if there is no match.
pub(crate) async fn get_user_for_access_token(
db: SqlitePool,
access_token: Uuid,
) -> Result<Option<User>, Box<dyn std::error::Error>> {
let user: Option<User> = sqlx::query_as(
r#"
select * from users inner join access_tokens on users.id = access_tokens.user_id where access_tokens.access_token = ?
"#,
)
.bind(access_token)
.fetch_optional(&db)
.await?;
Ok(user)
}
/// Read a User from a request. This is used to authenticate users. If any axum
/// handler has an User argument, this will be called and the authentication
/// will be carried out.
#[async_trait]
impl FromRequestParts<SharedState> for User {
type Rejection = AppError;
#[tracing::instrument(ret, err(Debug), skip(parts, state))]
// Can be removed after this fix is released
// https://github.com/rust-lang/rust-clippy/issues/12281
#[allow(clippy::blocks_in_conditions)]
async fn from_request_parts(
parts: &mut Parts,
state: &SharedState,
) -> Result<Self, Self::Rejection> {
// Extract the token from the authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("Bearer token missing").into(),
)
})?;
// Decode the user data
let access_token = Uuid::from_str(bearer.token()).map_err(|_| {
AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("invalid access token").into(),
)
})?;
let db = {
let state_lock = state.read().unwrap();
state_lock.db.clone()
};
let user = get_user_for_access_token(db, access_token)
.await
.map_err(|e| AppError(StatusCode::INTERNAL_SERVER_ERROR, e))?;
match user {
Some(mut user) => {
user.current_token = Some(access_token);
Ok(user)
}
None => {
return Err(AppError(
StatusCode::INTERNAL_SERVER_ERROR,
eyre!("user not found").into(),
))
}
}
}
}

View File

@ -2,7 +2,7 @@ use std::{collections::BTreeMap, time::Duration};
use axum_test::TestServer;
use rand::thread_rng;
use server::{args::Args, router, SerializedSignatureShare, SerializedSigningPackage};
use server::{args::Args, router, AppState, SerializedSignatureShare, SerializedSigningPackage};
use frost_core as frost;
@ -47,14 +47,52 @@ async fn test_main_router<
.collect();
// Instantiate test server using axum_test
let router = router();
let shared_state = AppState::new(":memory:").await?;
let router = router(shared_state);
let server = TestServer::new(router)?;
// Create a dummy user. We make all requests with the same user since
// it currently it doesn't really matter who the user is, users are only
// used to share session IDs. This will likely change soon.
let res = server
.post("/register")
.json(&server::RegisterArgs {
username: "alice".to_string(),
password: "passw0rd".to_string(),
pubkey: vec![],
})
.await;
res.assert_status_ok();
let res = server
.post("/register")
.json(&server::RegisterArgs {
username: "bob".to_string(),
password: "passw0rd".to_string(),
pubkey: vec![],
})
.await;
res.assert_status_ok();
let res = server
.post("/login")
.json(&server::LoginArgs {
username: "alice".to_string(),
password: "passw0rd".to_string(),
})
.await;
res.assert_status_ok();
let r: server::LoginOutput = res.json();
let token = r.access_token;
// As the coordinator, create a new signing session with all participants,
// for 2 messages
let res = server
.post("/create_new_session")
.authorization_bearer(token)
.json(&server::CreateNewSessionArgs {
usernames: vec!["alice".to_string(), "bob".to_string()],
num_signers: 2,
message_count: 2,
})
@ -75,6 +113,7 @@ async fn test_main_router<
// asking the server).
let res = server
.post("/get_session_info")
.authorization_bearer(token)
.json(&server::GetSessionInfoArgs { session_id })
.await;
res.assert_status_ok();
@ -96,6 +135,7 @@ async fn test_main_router<
// Send commitments to server
let res = server
.post("/send_commitments")
.authorization_bearer(token)
.json(&server::SendCommitmentsArgs {
identifier: (*identifier).into(),
session_id,
@ -110,6 +150,7 @@ async fn test_main_router<
// As the coordinator, get the commitments
let res = server
.post("/get_commitments")
.authorization_bearer(token)
.json(&server::GetCommitmentsArgs { session_id })
.await;
res.assert_status_ok();
@ -146,6 +187,7 @@ async fn test_main_router<
// As the coordinator, send the SigningPackages to the server
let res = server
.post("/send_signing_package")
.authorization_bearer(token)
.json(&server::SendSigningPackageArgs {
session_id,
signing_package: signing_packages
@ -173,6 +215,7 @@ async fn test_main_router<
// Get SigningPackages
let res = server
.post("get_signing_package")
.authorization_bearer(token)
.json(&server::GetSigningPackageArgs { session_id })
.await;
res.assert_status_ok();
@ -210,6 +253,7 @@ async fn test_main_router<
// Send SignatureShares to the server
let res = server
.post("/send_signature_share")
.authorization_bearer(token)
.json(&server::SendSignatureShareArgs {
identifier: (*identifier).into(),
session_id,
@ -225,6 +269,7 @@ async fn test_main_router<
// As the coordinator, get SignatureShares
let res = server
.post("/get_signature_shares")
.authorization_bearer(token)
.json(&server::GetSignatureSharesArgs { session_id })
.await;
res.assert_status_ok();
@ -265,6 +310,7 @@ async fn test_main_router<
// Close the session
let res = server
.post("/close_session")
.authorization_bearer(token)
.json(&server::CloseSessionArgs { session_id })
.await;
res.assert_status_ok();
@ -288,9 +334,12 @@ async fn test_main_router<
/// A better example on how to write client code.
#[tokio::test]
async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
// Spawn server for testing
tokio::spawn(async move {
server::run(&Args {
database: ":memory:".to_string(),
ip: "127.0.0.1".to_string(),
port: 2744,
})
@ -302,19 +351,67 @@ async fn test_http() -> Result<(), Box<dyn std::error::Error>> {
// TODO: this could possibly be not enough, use some retry logic instead
tokio::time::sleep(Duration::from_secs(2)).await;
// Call create_new_session
// Create a client to make requests
let client = reqwest::Client::new();
// Call register to create users
let r = client
.post("http://127.0.0.1:2744/create_new_session")
.json(&server::CreateNewSessionArgs {
num_signers: 2,
message_count: 1,
.post("http://127.0.0.1:2744/register")
.json(&server::RegisterArgs {
username: "alice".to_string(),
password: "passw0rd".to_string(),
pubkey: vec![],
})
.send()
.await?
.json::<server::CreateNewSessionOutput>()
.await?;
println!("{}", r.session_id);
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
}
let r = client
.post("http://127.0.0.1:2744/register")
.json(&server::RegisterArgs {
username: "bob".to_string(),
password: "passw0rd".to_string(),
pubkey: vec![],
})
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
}
// Call login to authenticate
let r = client
.post("http://127.0.0.1:2744/login")
.json(&server::LoginArgs {
username: "alice".to_string(),
password: "passw0rd".to_string(),
})
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
}
let r = r.json::<server::LoginOutput>().await?;
let access_token = r.access_token;
// Call create_new_session
let r = client
.post("http://127.0.0.1:2744/create_new_session")
.bearer_auth(access_token)
.json(&server::CreateNewSessionArgs {
usernames: vec!["alice".to_string(), "bob".to_string()],
message_count: 1,
num_signers: 2,
})
.send()
.await?;
if r.status() != reqwest::StatusCode::OK {
panic!("{}", r.text().await?)
}
let r = r.json::<server::CreateNewSessionOutput>().await?;
let session_id = r.session_id;
println!("Session ID: {}", session_id);
Ok(())
}