diff --git a/Cargo.lock b/Cargo.lock index 4719e024fa..b4baa51496 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -932,6 +932,15 @@ dependencies = [ "rayon", ] +[[package]] +name = "derivation-path" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "193388a8c8c75a490b604ff61775e236541b8975e98e5ca1f6ea97d122b7e2db" +dependencies = [ + "failure", +] + [[package]] name = "derivative" version = "2.1.1" @@ -5071,6 +5080,7 @@ dependencies = [ "byteorder", "chrono", "curve25519-dalek 2.1.0", + "derivation-path", "digest 0.9.0", "ed25519-dalek", "generic-array 0.14.3", diff --git a/programs/bpf/Cargo.lock b/programs/bpf/Cargo.lock index 9cb8b0005a..6f14cbe75e 100644 --- a/programs/bpf/Cargo.lock +++ b/programs/bpf/Cargo.lock @@ -676,6 +676,15 @@ dependencies = [ "rayon", ] +[[package]] +name = "derivation-path" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "193388a8c8c75a490b604ff61775e236541b8975e98e5ca1f6ea97d122b7e2db" +dependencies = [ + "failure", +] + [[package]] name = "derivative" version = "2.1.3" @@ -845,6 +854,27 @@ dependencies = [ "termcolor", ] +[[package]] +name = "failure" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d32e9bd16cc02eae7db7ef620b392808b89f6a5e16bb3497d159c6b92a0f4f86" +dependencies = [ + "failure_derive", +] + +[[package]] +name = "failure_derive" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa4da3c766cd7a0db8242e326e9e4e081edd567072893ed320008189715366a4" +dependencies = [ + "proc-macro2 1.0.24", + "quote 1.0.6", + "syn 1.0.67", + "synstructure", +] + [[package]] name = "fake-simd" version = "0.1.2" @@ -3416,6 +3446,7 @@ dependencies = [ "bv", "byteorder 1.3.4", "chrono", + "derivation-path", "digest 0.9.0", "ed25519-dalek", "generic-array 0.14.3", diff --git a/remote-wallet/src/ledger.rs b/remote-wallet/src/ledger.rs index 546eaad128..741345a179 100644 --- a/remote-wallet/src/ledger.rs +++ b/remote-wallet/src/ledger.rs @@ -34,8 +34,6 @@ const MAX_CHUNK_SIZE: usize = 255; const APDU_SUCCESS_CODE: usize = 0x9000; -const SOL_DERIVATION_PATH_BE: [u8; 8] = [0x80, 0, 0, 44, 0x80, 0, 0x01, 0xF5]; // 44'/501', Solana - /// Ledger vendor ID const LEDGER_VID: u16 = 0x2c97; /// Ledger product IDs: Nano S and Nano X @@ -513,20 +511,16 @@ pub fn is_valid_ledger(vendor_id: u16, product_id: u16) -> bool { /// Build the derivation path byte array from a DerivationPath selection fn extend_and_serialize(derivation_path: &DerivationPath) -> Vec { - let byte = if derivation_path.change.is_some() { + let byte = if derivation_path.change().is_some() { 4 - } else if derivation_path.account.is_some() { + } else if derivation_path.account().is_some() { 3 } else { 2 }; let mut concat_derivation = vec![byte]; - concat_derivation.extend_from_slice(&SOL_DERIVATION_PATH_BE); - if let Some(account) = &derivation_path.account { - concat_derivation.extend_from_slice(&account.as_u32().to_be_bytes()); - if let Some(change) = &derivation_path.change { - concat_derivation.extend_from_slice(&change.as_u32().to_be_bytes()); - } + for index in derivation_path.path() { + concat_derivation.extend_from_slice(&index.to_bits().to_be_bytes()); } concat_derivation } diff --git a/remote-wallet/src/remote_wallet.rs b/remote-wallet/src/remote_wallet.rs index 9e471d42cc..75bcc02482 100644 --- a/remote-wallet/src/remote_wallet.rs +++ b/remote-wallet/src/remote_wallet.rs @@ -6,7 +6,7 @@ use { log::*, parking_lot::{Mutex, RwLock}, solana_sdk::{ - derivation_path::{DerivationPath, DerivationPathComponent, DerivationPathError}, + derivation_path::{DerivationPath, DerivationPathError}, pubkey::Pubkey, signature::{Signature, SignerError}, }, @@ -288,26 +288,10 @@ impl RemoteWalletInfo { if let Some(mut pair) = query_pairs.next() { if pair.0 == "key" { let key_path = pair.1.to_mut(); - let _key_path = key_path.clone(); if key_path.ends_with('/') { key_path.pop(); } - let mut parts = key_path.split('/'); - if let Some(account) = parts.next() { - derivation_path.account = - Some(DerivationPathComponent::from_str(account)?); - } - if let Some(change) = parts.next() { - derivation_path.change = - Some(DerivationPathComponent::from_str(change)?); - } - if parts.next().is_some() { - return Err(DerivationPathError::InvalidDerivationPath(format!( - "key path `{}` too deep, only / supported", - _key_path - )) - .into()); - } + derivation_path = DerivationPath::from_key_str(key_path)?; } else { return Err(DerivationPathError::InvalidDerivationPath(format!( "invalid query string `{}={}`, only `key` supported", @@ -378,13 +362,7 @@ mod tests { pubkey, error: None, })); - assert_eq!( - derivation_path, - DerivationPath { - account: Some(1.into()), - change: Some(2.into()), - } - ); + assert_eq!(derivation_path, DerivationPath::new_bip44(Some(1), Some(2))); let (wallet_info, derivation_path) = RemoteWalletInfo::parse_path(format!("usb://ledger/{:?}?key=1'/2'", pubkey)).unwrap(); assert!(wallet_info.matches(&RemoteWalletInfo { @@ -395,13 +373,7 @@ mod tests { pubkey, error: None, })); - assert_eq!( - derivation_path, - DerivationPath { - account: Some(1.into()), - change: Some(2.into()), - } - ); + assert_eq!(derivation_path, DerivationPath::new_bip44(Some(1), Some(2))); let (wallet_info, derivation_path) = RemoteWalletInfo::parse_path(format!("usb://ledger/{:?}?key=1\'/2\'", pubkey)).unwrap(); assert!(wallet_info.matches(&RemoteWalletInfo { @@ -412,13 +384,7 @@ mod tests { pubkey, error: None, })); - assert_eq!( - derivation_path, - DerivationPath { - account: Some(1.into()), - change: Some(2.into()), - } - ); + assert_eq!(derivation_path, DerivationPath::new_bip44(Some(1), Some(2))); let (wallet_info, derivation_path) = RemoteWalletInfo::parse_path(format!("usb://ledger/{:?}?key=1/2/", pubkey)).unwrap(); assert!(wallet_info.matches(&RemoteWalletInfo { @@ -429,13 +395,7 @@ mod tests { pubkey, error: None, })); - assert_eq!( - derivation_path, - DerivationPath { - account: Some(1.into()), - change: Some(2.into()), - } - ); + assert_eq!(derivation_path, DerivationPath::new_bip44(Some(1), Some(2))); let (wallet_info, derivation_path) = RemoteWalletInfo::parse_path(format!("usb://ledger/{:?}?key=1/", pubkey)).unwrap(); assert!(wallet_info.matches(&RemoteWalletInfo { @@ -446,13 +406,7 @@ mod tests { pubkey, error: None, })); - assert_eq!( - derivation_path, - DerivationPath { - account: Some(1.into()), - change: None, - } - ); + assert_eq!(derivation_path, DerivationPath::new_bip44(Some(1), None)); // Test that wallet id need not be complete for key derivation to work let (wallet_info, derivation_path) = @@ -465,13 +419,7 @@ mod tests { pubkey: Pubkey::default(), error: None, })); - assert_eq!( - derivation_path, - DerivationPath { - account: Some(1.into()), - change: None, - } - ); + assert_eq!(derivation_path, DerivationPath::new_bip44(Some(1), None)); let (wallet_info, derivation_path) = RemoteWalletInfo::parse_path("usb://ledger/?key=1/2".to_string()).unwrap(); assert!(wallet_info.matches(&RemoteWalletInfo { @@ -482,13 +430,7 @@ mod tests { pubkey: Pubkey::default(), error: None, })); - assert_eq!( - derivation_path, - DerivationPath { - account: Some(1.into()), - change: Some(2.into()), - } - ); + assert_eq!(derivation_path, DerivationPath::new_bip44(Some(1), Some(2))); // Failure cases assert!( diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index 3b356f8d66..756e02b380 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -43,6 +43,7 @@ bv = { version = "0.11.1", features = ["serde"] } byteorder = { version = "1.3.4", optional = true } chrono = { version = "0.4", optional = true } curve25519-dalek = { version = "2.1.0", optional = true } +derivation-path = { version = "0.1.3", default-features = false } generic-array = { version = "0.14.3", default-features = false, features = ["serde", "more_lengths"], optional = true } hex = "0.4.2" hmac = "0.10.1" diff --git a/sdk/src/derivation_path.rs b/sdk/src/derivation_path.rs index faef3ac622..c6ff35b87f 100644 --- a/sdk/src/derivation_path.rs +++ b/sdk/src/derivation_path.rs @@ -1,8 +1,13 @@ use { + core::{iter::IntoIterator, slice::Iter}, + derivation_path::{ChildIndex, DerivationPath as DerivationPathInner}, std::{fmt, str::FromStr}, thiserror::Error, }; +const ACCOUNT_INDEX: usize = 2; +const CHANGE_INDEX: usize = 3; + /// Derivation path error. #[derive(Error, Debug, Clone)] pub enum DerivationPathError { @@ -10,85 +15,84 @@ pub enum DerivationPathError { InvalidDerivationPath(String), } -#[derive(Clone, Default, PartialEq)] -pub struct DerivationPathComponent(u32); +#[derive(PartialEq)] +pub struct DerivationPath(DerivationPathInner); -impl DerivationPathComponent { - pub const HARDENED_BIT: u32 = 1 << 31; - - pub fn as_u32(&self) -> u32 { - self.0 - } -} - -impl From for DerivationPathComponent { - fn from(n: u32) -> Self { - Self(n | Self::HARDENED_BIT) - } -} - -impl FromStr for DerivationPathComponent { - type Err = DerivationPathError; - - fn from_str(s: &str) -> Result { - let index_str = if let Some(stripped) = s.strip_suffix('\'') { - stripped - } else { - s - }; - index_str.parse::().map(|ki| ki.into()).map_err(|_| { - DerivationPathError::InvalidDerivationPath(format!( - "failed to parse path component: {:?}", - s - )) - }) - } -} - -impl std::fmt::Display for DerivationPathComponent { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - let hardened = if (self.0 & Self::HARDENED_BIT) == 0 { - "" - } else { - "'" - }; - let index = self.0 & !Self::HARDENED_BIT; - write!(fmt, "{}{}", index, hardened) - } -} - -impl std::fmt::Debug for DerivationPathComponent { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - std::fmt::Display::fmt(self, fmt) - } -} - -#[derive(Default, PartialEq, Clone)] -pub struct DerivationPath { - pub account: Option, - pub change: Option, -} - -impl fmt::Debug for DerivationPath { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let account = if let Some(account) = &self.account { - format!("/{:?}", account) - } else { - "".to_string() - }; - let change = if let Some(change) = &self.change { - format!("/{:?}", change) - } else { - "".to_string() - }; - write!(f, "m/44'/501'{}{}", account, change) +impl Default for DerivationPath { + fn default() -> Self { + Self::new_bip44(None, None) } } impl DerivationPath { + fn new>>(path: P) -> Self { + Self(DerivationPathInner::new(path)) + } + + pub fn from_key_str(path: &str) -> Result { + Self::from_key_str_with_coin(path, Solana) + } + + fn from_key_str_with_coin(path: &str, coin: T) -> Result { + let path = format!("m/{}", path); + let extend = DerivationPathInner::from_str(&path) + .map_err(|err| DerivationPathError::InvalidDerivationPath(err.to_string()))?; + let mut extend = extend.into_iter(); + let account = extend.next().map(|index| index.to_u32()); + let change = extend.next().map(|index| index.to_u32()); + if extend.next().is_some() { + return Err(DerivationPathError::InvalidDerivationPath(format!( + "key path `{}` too deep, only / supported", + path + ))); + } + Ok(Self::new_bip44_with_coin(coin, account, change)) + } + + fn _from_absolute_path_str(path: &str) -> Result { + let inner = DerivationPath::_from_absolute_path_insecure_str(path)? + .into_iter() + .map(|c| ChildIndex::Hardened(c.to_u32())) + .collect::>(); + Ok(Self(DerivationPathInner::new(inner))) + } + + fn _from_absolute_path_insecure_str(path: &str) -> Result { + Ok(Self(DerivationPathInner::from_str(&path).map_err( + |err| DerivationPathError::InvalidDerivationPath(err.to_string()), + )?)) + } + + pub fn new_bip44(account: Option, change: Option) -> Self { + Self::new_bip44_with_coin(Solana, account, change) + } + + fn new_bip44_with_coin(coin: T, account: Option, change: Option) -> Self { + let mut indexes = coin.base_indexes(); + if let Some(account) = account { + indexes.push(ChildIndex::Hardened(account)); + if let Some(change) = change { + indexes.push(ChildIndex::Hardened(change)); + } + } + Self::new(indexes) + } + + pub fn account(&self) -> Option<&ChildIndex> { + self.0.path().get(ACCOUNT_INDEX) + } + + pub fn change(&self) -> Option<&ChildIndex> { + self.0.path().get(CHANGE_INDEX) + } + + pub fn path(&self) -> &[ChildIndex] { + self.0.path() + } + pub fn get_query(&self) -> String { - if let Some(account) = &self.account { - if let Some(change) = &self.change { + if let Some(account) = &self.account() { + if let Some(change) = &self.change() { format!("?key={}/{}", account, change) } else { format!("?key={}", account) @@ -99,65 +103,166 @@ impl DerivationPath { } } +impl fmt::Debug for DerivationPath { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "m")?; + for index in self.0.path() { + write!(f, "/{}", index)?; + } + Ok(()) + } +} + +impl<'a> IntoIterator for &'a DerivationPath { + type IntoIter = Iter<'a, ChildIndex>; + type Item = &'a ChildIndex; + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +trait Bip44 { + const PURPOSE: u32 = 44; + const COIN: u32; + + fn base_indexes(&self) -> Vec { + vec![ + ChildIndex::Hardened(Self::PURPOSE), + ChildIndex::Hardened(Self::COIN), + ] + } +} + +struct Solana; + +impl Bip44 for Solana { + const COIN: u32 = 501; +} + #[cfg(test)] mod tests { use super::*; + struct TestCoin; + impl Bip44 for TestCoin { + const COIN: u32 = 999; + } + + #[test] + fn test_from_key_str() { + let s = "1/2"; + assert_eq!( + DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2)) + ); + let s = "1'/2'"; + assert_eq!( + DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2)) + ); + let s = "1\'/2\'"; + assert_eq!( + DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2)) + ); + let s = "1"; + assert_eq!( + DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None) + ); + let s = "1'"; + assert_eq!( + DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None) + ); + let s = "1\'"; + assert_eq!( + DerivationPath::from_key_str_with_coin(s, TestCoin).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None) + ); + + assert!(DerivationPath::from_key_str_with_coin("1/2/3", TestCoin).is_err()); + assert!(DerivationPath::from_key_str_with_coin("other", TestCoin).is_err()); + assert!(DerivationPath::from_key_str_with_coin("1o", TestCoin).is_err()); + } + + #[test] + fn test_from_absolute_path_str() { + let s = "m/44/501"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::default() + ); + let s = "m/44'/501'"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::default() + ); + let s = "m/44'/501'/1/2"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::new_bip44(Some(1), Some(2)) + ); + let s = "m/44'/501'/1'/2'"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::new_bip44(Some(1), Some(2)) + ); + + // Test non-Solana Bip44 + let s = "m/44'/999'/1/2"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2)) + ); + let s = "m/44'/999'/1'/2'"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2)) + ); + + // Test non-bip44 paths + let s = "m/501'/0'/0/0"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::new(vec![ + ChildIndex::Hardened(501), + ChildIndex::Hardened(0), + ChildIndex::Hardened(0), + ChildIndex::Hardened(0), + ]) + ); + let s = "m/501'/0'/0'/0'"; + assert_eq!( + DerivationPath::_from_absolute_path_str(s).unwrap(), + DerivationPath::new(vec![ + ChildIndex::Hardened(501), + ChildIndex::Hardened(0), + ChildIndex::Hardened(0), + ChildIndex::Hardened(0), + ]) + ); + } + #[test] fn test_get_query() { - let derivation_path = DerivationPath { - account: None, - change: None, - }; + let derivation_path = DerivationPath::new_bip44_with_coin(TestCoin, None, None); assert_eq!(derivation_path.get_query(), "".to_string()); - let derivation_path = DerivationPath { - account: Some(1.into()), - change: None, - }; - assert_eq!( - derivation_path.get_query(), - format!("?key={}", DerivationPathComponent::from(1)) - ); - let derivation_path = DerivationPath { - account: Some(1.into()), - change: Some(2.into()), - }; - assert_eq!( - derivation_path.get_query(), - format!( - "?key={}/{}", - DerivationPathComponent::from(1), - DerivationPathComponent::from(2) - ) - ); + let derivation_path = DerivationPath::new_bip44_with_coin(TestCoin, Some(1), None); + assert_eq!(derivation_path.get_query(), "?key=1'".to_string()); + let derivation_path = DerivationPath::new_bip44_with_coin(TestCoin, Some(1), Some(2)); + assert_eq!(derivation_path.get_query(), "?key=1'/2'".to_string()); } #[test] fn test_derivation_path_debug() { - let mut path = DerivationPath::default(); + let path = DerivationPath::default(); assert_eq!(format!("{:?}", path), "m/44'/501'".to_string()); - path.account = Some(1.into()); + let path = DerivationPath::new_bip44(Some(1), None); assert_eq!(format!("{:?}", path), "m/44'/501'/1'".to_string()); - path.change = Some(2.into()); + let path = DerivationPath::new_bip44(Some(1), Some(2)); assert_eq!(format!("{:?}", path), "m/44'/501'/1'/2'".to_string()); } - - #[test] - fn test_derivation_path_component() { - let f = DerivationPathComponent::from(1); - assert_eq!(f.as_u32(), 1 | DerivationPathComponent::HARDENED_BIT); - - let fs = DerivationPathComponent::from_str("1").unwrap(); - assert_eq!(fs, f); - - let fs = DerivationPathComponent::from_str("1'").unwrap(); - assert_eq!(fs, f); - - assert!(DerivationPathComponent::from_str("-1").is_err()); - - assert_eq!(format!("{}", f), "1'".to_string()); - assert_eq!(format!("{:?}", f), "1'".to_string()); - } }