diff --git a/src/lib.rs b/src/lib.rs index 7cf2626..38263b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,20 +10,33 @@ #[cfg(feature = "std")] extern crate std; +use core::mem; + use memuse::{self, DynamicUsage}; -use subtle::{Choice, ConditionallySelectable}; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; pub mod fingerprint; /// A type-safe wrapper for account identifiers. +/// +/// Accounts are 31-bit unsigned integers, and are always treated as hardened in +/// derivation paths. #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct AccountId(u32); memuse::impl_no_dynamic_usage!(AccountId); -impl From for AccountId { - fn from(id: u32) -> Self { - Self(id) +impl TryFrom for AccountId { + type Error = TryFromIntError; + + fn try_from(id: u32) -> Result { + // Account IDs are always hardened in derivation paths, so they are effectively at + // most 31 bits. + if id < (1 << 31) { + Ok(Self(id)) + } else { + Err(TryFromIntError(())) + } } } @@ -46,6 +59,19 @@ impl ConditionallySelectable for AccountId { } } +/// The error type returned when a checked integral type conversion fails. +#[derive(Clone, Copy, Debug)] +pub struct TryFromIntError(()); + +impl core::fmt::Display for TryFromIntError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "out of range integral type conversion attempted") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for TryFromIntError {} + // ZIP 32 structures /// A child index for a derived key. @@ -54,6 +80,12 @@ impl ConditionallySelectable for AccountId { #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct ChildIndex(u32); +impl ConstantTimeEq for ChildIndex { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl ChildIndex { /// Parses the given ZIP 32 child index. /// @@ -87,6 +119,12 @@ impl ChildIndex { #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct ChainCode([u8; 32]); +impl ConstantTimeEq for ChainCode { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl ChainCode { /// Constructs a `ChainCode` from the given array. pub fn new(c: [u8; 32]) -> Self { @@ -102,7 +140,7 @@ impl ChainCode { /// The index for a particular diversifier. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct DiversifierIndex(pub [u8; 11]); +pub struct DiversifierIndex([u8; 11]); impl Default for DiversifierIndex { fn default() -> Self { @@ -110,17 +148,24 @@ impl Default for DiversifierIndex { } } -impl From for DiversifierIndex { - fn from(i: u32) -> Self { - u64::from(i).into() - } +macro_rules! di_from { + ($n:ident) => { + impl From<$n> for DiversifierIndex { + fn from(j: $n) -> Self { + let mut j_bytes = [0; 11]; + j_bytes[..mem::size_of::<$n>()].copy_from_slice(&j.to_le_bytes()); + DiversifierIndex(j_bytes) + } + } + }; } +di_from!(u32); +di_from!(u64); +di_from!(usize); -impl From for DiversifierIndex { - fn from(i: u64) -> Self { - let mut result = DiversifierIndex([0; 11]); - result.0[..8].copy_from_slice(&i.to_le_bytes()); - result +impl From<[u8; 11]> for DiversifierIndex { + fn from(j_bytes: [u8; 11]) -> Self { + DiversifierIndex(j_bytes) } } @@ -140,6 +185,11 @@ impl DiversifierIndex { DiversifierIndex([0; 11]) } + /// Returns the raw bytes of the diversifier index. + pub fn as_bytes(&self) -> &[u8; 11] { + &self.0 + } + /// Increments this index, failing on overflow. pub fn increment(&mut self) -> Result<(), DiversifierIndexOverflowError> { for k in 0..11 {