dkg: use generics

This commit is contained in:
Conrado Gouvea 2024-06-21 16:17:07 -03:00 committed by natalie
parent d6955796c6
commit ab6924ac5f
8 changed files with 67 additions and 54 deletions

2
Cargo.lock generated
View File

@ -562,8 +562,10 @@ dependencies = [
name = "dkg"
version = "0.1.0"
dependencies = [
"clap",
"exitcode",
"eyre",
"frost-core",
"frost-ed25519",
"hex",
"itertools 0.13.0",

View File

@ -7,8 +7,10 @@ edition = "2021"
[dependencies]
eyre = "0.6.12"
frost-core = { version = "1.0.0", features = ["serde"] }
frost-ed25519 = { version = "1.0.0", features = ["serde"] }
reddsa = { git = "https://github.com/ZcashFoundation/reddsa.git", rev = "81c649c412e5b6ba56d491d2857f91fbd28adbc7", features = ["frost"] }
clap = { version = "4.5.7", features = ["derive"] }
hex = { version = "0.4", features = ["serde"] }
thiserror = "1.0"
rand = "0.8"

View File

@ -1,7 +1,4 @@
#[cfg(not(feature = "redpallas"))]
use frost_ed25519 as frost;
#[cfg(feature = "redpallas")]
use reddsa::frost::redpallas as frost;
use frost_core::{self as frost, Ciphersuite};
use rand::thread_rng;
use std::collections::BTreeMap;
@ -9,11 +6,11 @@ use std::io::{BufRead, Write};
use crate::inputs::{read_round1_package, read_round2_package, request_inputs};
pub fn cli(
pub fn cli<C: Ciphersuite + 'static>(
reader: &mut impl BufRead,
logger: &mut impl Write,
) -> Result<(), Box<dyn std::error::Error>> {
let config = request_inputs(reader, logger)?;
let config = request_inputs::<C>(reader, logger)?;
let rng = thread_rng();

View File

@ -1,7 +1,4 @@
#[cfg(not(feature = "redpallas"))]
use frost_ed25519 as frost;
#[cfg(feature = "redpallas")]
use reddsa::frost::redpallas as frost;
use frost_core::{self as frost, Ciphersuite};
use frost::{
keys::dkg::{round1, round2},
@ -13,13 +10,13 @@ use eyre::eyre;
use std::io::{BufRead, Write};
#[derive(Debug, PartialEq, Clone)]
pub struct Config {
pub struct Config<C: Ciphersuite> {
pub min_signers: u16,
pub max_signers: u16,
pub identifier: Identifier,
pub identifier: Identifier<C>,
}
fn validate_inputs(config: &Config) -> Result<(), Error> {
fn validate_inputs<C: Ciphersuite>(config: &Config<C>) -> Result<(), Error<C>> {
if config.min_signers < 2 {
return Err(Error::InvalidMinSigners);
}
@ -35,10 +32,10 @@ fn validate_inputs(config: &Config) -> Result<(), Error> {
Ok(())
}
pub fn request_inputs(
pub fn request_inputs<C: Ciphersuite + 'static>(
input: &mut impl BufRead,
logger: &mut dyn Write,
) -> Result<Config, Box<dyn std::error::Error>> {
) -> Result<Config<C>, Box<dyn std::error::Error>> {
writeln!(logger, "The minimum number of signers: (2 or more)")?;
let mut min = String::new();
@ -47,7 +44,7 @@ pub fn request_inputs(
let min_signers = min
.trim()
.parse::<u16>()
.map_err(|_| Error::InvalidMinSigners)?;
.map_err(|_| Error::<C>::InvalidMinSigners)?;
writeln!(logger, "The maximum number of signers:")?;
@ -56,7 +53,7 @@ pub fn request_inputs(
let max_signers = max
.trim()
.parse::<u16>()
.map_err(|_| Error::InvalidMaxSigners)?;
.map_err(|_| Error::<C>::InvalidMaxSigners)?;
writeln!(
logger,
@ -70,7 +67,7 @@ pub fn request_inputs(
let u16_identifier = identifier_input
.trim()
.parse::<u16>()
.map_err(|_| Error::MalformedIdentifier)?;
.map_err(|_| Error::<C>::MalformedIdentifier)?;
let identifier = u16_identifier.try_into()?;
let config = Config {
@ -84,22 +81,24 @@ pub fn request_inputs(
Ok(config)
}
pub fn read_identifier(input: &mut impl BufRead) -> Result<Identifier, Box<dyn std::error::Error>> {
pub fn read_identifier<C: Ciphersuite + 'static>(
input: &mut impl BufRead,
) -> Result<Identifier<C>, Box<dyn std::error::Error>> {
let mut identifier_input = String::new();
input.read_line(&mut identifier_input)?;
let bytes = hex::decode(identifier_input.trim())?;
let serialization = bytes.try_into().map_err(|_| eyre!("Invalid Identifier"))?;
let identifier = Identifier::deserialize(&serialization)?;
let identifier = Identifier::<C>::deserialize(&serialization)?;
Ok(identifier)
}
pub fn read_round1_package(
pub fn read_round1_package<C: Ciphersuite + 'static>(
input: &mut impl BufRead,
logger: &mut dyn Write,
) -> Result<(Identifier, round1::Package), Box<dyn std::error::Error>> {
) -> Result<(Identifier<C>, round1::Package<C>), Box<dyn std::error::Error>> {
writeln!(logger, "The sender's identifier (hex string):")?;
let identifier = read_identifier(input)?;
let identifier = read_identifier::<C>(input)?;
writeln!(logger, "Their JSON-encoded Round 1 Package:")?;
@ -110,13 +109,13 @@ pub fn read_round1_package(
Ok((identifier, round1_package))
}
pub fn read_round2_package(
pub fn read_round2_package<C: Ciphersuite + 'static>(
input: &mut impl BufRead,
logger: &mut dyn Write,
) -> Result<(Identifier, round2::Package), Box<dyn std::error::Error>> {
) -> Result<(Identifier<C>, round2::Package<C>), Box<dyn std::error::Error>> {
writeln!(logger, "The sender's identifier (hex string):")?;
let identifier = read_identifier(input)?;
let identifier = read_identifier::<C>(input)?;
writeln!(logger, "Their JSON-encoded Round 2 Package:")?;

View File

@ -1,2 +1,6 @@
#[cfg(test)]
mod tests;
pub mod args;
pub mod cli;
pub mod inputs;

View File

@ -1,17 +1,20 @@
mod cli;
mod inputs;
#[cfg(test)]
mod tests;
use std::io;
use cli::cli;
use clap::Parser;
use dkg::{args::Args, cli::cli};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
let mut reader = Box::new(io::stdin().lock());
let mut logger = io::stdout();
cli(&mut reader, &mut logger)?;
if args.ciphersuite == "ed25519" {
cli::<frost_ed25519::Ed25519Sha512>(&mut reader, &mut logger)?;
} else if args.ciphersuite == "redpallas" {
cli::<reddsa::frost::redpallas::PallasBlake2b512>(&mut reader, &mut logger)?;
}
Ok(())
}

View File

@ -1,5 +1,3 @@
#![cfg(not(feature = "redpallas"))]
use std::io::BufWriter;
use crate::inputs::{request_inputs, Config};
@ -8,7 +6,7 @@ use frost_ed25519 as frost;
#[test]
fn check_valid_input_for_signers() {
let config = Config {
let config = Config::<frost_ed25519::Ed25519Sha512> {
min_signers: 2,
max_signers: 3,
identifier: 1u16.try_into().unwrap(),
@ -25,7 +23,8 @@ fn check_valid_input_for_signers() {
fn return_error_if_min_participant_greater_than_max_participant() {
let mut invalid_input = "4\n3\n1\n".as_bytes();
let mut buf = BufWriter::new(Vec::new());
let expected = request_inputs(&mut invalid_input, &mut buf).unwrap_err();
let expected =
request_inputs::<frost_ed25519::Ed25519Sha512>(&mut invalid_input, &mut buf).unwrap_err();
assert_eq!(
*expected.downcast::<Error>().unwrap(),
@ -37,7 +36,8 @@ fn return_error_if_min_participant_greater_than_max_participant() {
fn return_error_if_min_participant_is_less_than_2() {
let mut invalid_input = "1\n3\n1\n".as_bytes();
let mut buf = BufWriter::new(Vec::new());
let expected = request_inputs(&mut invalid_input, &mut buf).unwrap_err();
let expected =
request_inputs::<frost_ed25519::Ed25519Sha512>(&mut invalid_input, &mut buf).unwrap_err();
assert_eq!(
*expected.downcast::<Error>().unwrap(),
@ -49,7 +49,8 @@ fn return_error_if_min_participant_is_less_than_2() {
fn return_error_if_max_participant_is_less_than_2() {
let mut invalid_input = "2\n1\n1\n".as_bytes();
let mut buf = BufWriter::new(Vec::new());
let expected = request_inputs(&mut invalid_input, &mut buf).unwrap_err();
let expected =
request_inputs::<frost_ed25519::Ed25519Sha512>(&mut invalid_input, &mut buf).unwrap_err();
assert_eq!(
*expected.downcast::<Error>().unwrap(),
@ -61,7 +62,8 @@ fn return_error_if_max_participant_is_less_than_2() {
fn return_error_if_invalid_min_signers_input() {
let mut invalid_input = "hello\n6\n1\n".as_bytes();
let mut buf = BufWriter::new(Vec::new());
let expected = request_inputs(&mut invalid_input, &mut buf).unwrap_err();
let expected =
request_inputs::<frost_ed25519::Ed25519Sha512>(&mut invalid_input, &mut buf).unwrap_err();
assert_eq!(
*expected.downcast::<Error>().unwrap(),
@ -73,7 +75,8 @@ fn return_error_if_invalid_min_signers_input() {
fn return_error_if_invalid_max_signers_input() {
let mut invalid_input = "4\nworld\n1\n".as_bytes();
let mut buf = BufWriter::new(Vec::new());
let expected = request_inputs(&mut invalid_input, &mut buf).unwrap_err();
let expected =
request_inputs::<frost_ed25519::Ed25519Sha512>(&mut invalid_input, &mut buf).unwrap_err();
assert_eq!(
*expected.downcast::<Error>().unwrap(),
@ -85,7 +88,8 @@ fn return_error_if_invalid_max_signers_input() {
fn return_malformed_identifier_error_if_identifier_invalid() {
let mut invalid_input = "4\n6\nasecret\n".as_bytes();
let mut buf = BufWriter::new(Vec::new());
let expected = request_inputs(&mut invalid_input, &mut buf).unwrap_err();
let expected =
request_inputs::<frost_ed25519::Ed25519Sha512>(&mut invalid_input, &mut buf).unwrap_err();
assert_eq!(
*expected.downcast::<Error>().unwrap(),

View File

@ -1,7 +1,4 @@
#[cfg(not(feature = "redpallas"))]
use frost_ed25519 as frost;
#[cfg(feature = "redpallas")]
use reddsa::frost::redpallas as frost;
use frost_core::{self as frost, Ciphersuite};
use dkg::cli::cli;
@ -32,8 +29,13 @@ fn read_line(mut reader: impl BufRead) -> Result<String, std::io::Error> {
// where in the function it's getting stuck and check if the test at that point
// is correct.
#[test]
#[allow(clippy::needless_range_loop)]
fn check_dkg() {
check_dkg_for_ciphersuite::<frost_ed25519::Ed25519Sha512>();
check_dkg_for_ciphersuite::<reddsa::frost::redpallas::PallasBlake2b512>();
}
#[allow(clippy::needless_range_loop)]
fn check_dkg_for_ciphersuite<C: Ciphersuite + 'static>() {
let mut input_writers = Vec::new();
let mut output_readers = Vec::new();
let mut join_handles = Vec::new();
@ -44,7 +46,7 @@ fn check_dkg() {
let (mut input_reader, input_writer) = pipe::pipe();
let (output_reader, mut output_writer) = pipe::pipe();
join_handles.push(thread::spawn(move || {
cli(&mut input_reader, &mut output_writer).unwrap()
cli::<C>(&mut input_reader, &mut output_writer).unwrap()
}));
input_writers.push(input_writer);
output_readers.push(output_reader);
@ -117,7 +119,7 @@ fn check_dkg() {
);
// Write j's identifier
let jid: Identifier = ((j + 1) as u16).try_into().unwrap();
let jid: Identifier<C> = ((j + 1) as u16).try_into().unwrap();
writeln!(&mut input_writers[i], "{}", hex::encode(jid.serialize())).unwrap();
assert_eq!(
@ -186,7 +188,7 @@ fn check_dkg() {
);
// Write j's identifier
let jid: Identifier = ((j + 1) as u16).try_into().unwrap();
let jid: Identifier<C> = ((j + 1) as u16).try_into().unwrap();
writeln!(&mut input_writers[i], "{}", hex::encode(jid.serialize())).unwrap();
assert_eq!(
@ -195,7 +197,7 @@ fn check_dkg() {
);
// Write j's package sent to i
let iid: Identifier = ((i + 1) as u16).try_into().unwrap();
let iid: Identifier<C> = ((i + 1) as u16).try_into().unwrap();
let iids = hex::encode(iid.serialize());
let s = round2_packages.get(&j).expect("j").get(&iids).expect("i");
write!(&mut input_writers[i], "{}", s).unwrap();
@ -215,7 +217,7 @@ fn check_dkg() {
// Read key package
let key_package_json = read_line(&mut output_readers[i]).unwrap();
let _key_package: KeyPackage = serde_json::from_str(&key_package_json).unwrap();
let _key_package: KeyPackage<C> = serde_json::from_str(&key_package_json).unwrap();
assert_eq!(read_line(&mut output_readers[i]).unwrap(), "\n");
assert_eq!(
@ -226,7 +228,7 @@ fn check_dkg() {
// Read public key package
let public_key_package_json = read_line(&mut output_readers[i]).unwrap();
let public_key_package: PublicKeyPackage =
let public_key_package: PublicKeyPackage<C> =
serde_json::from_str(&public_key_package_json).unwrap();
public_key_packages.insert(i, public_key_package);
}