From 842cab27391e2dc18f0d072a6967ca28abff3e11 Mon Sep 17 00:00:00 2001 From: Trent Nelson Date: Mon, 22 Jun 2020 11:10:11 -0600 Subject: [PATCH] Remote Wallet: Stricter derivation path component parsing (#10725) --- remote-wallet/src/ledger.rs | 12 +-- remote-wallet/src/remote_wallet.rs | 156 +++++++++++++++++++++++------ 2 files changed, 129 insertions(+), 39 deletions(-) diff --git a/remote-wallet/src/ledger.rs b/remote-wallet/src/ledger.rs index af38463fa6..8f512d22d0 100644 --- a/remote-wallet/src/ledger.rs +++ b/remote-wallet/src/ledger.rs @@ -22,8 +22,6 @@ const DEPRECATE_VERSION_BEFORE: FirmwareVersion = FirmwareVersion { build: Vec::new(), }; -const HARDENED_BIT: u32 = 1 << 31; - const APDU_TAG: u8 = 0x05; const APDU_CLA: u8 = 0xe0; const APDU_PAYLOAD_HEADER_LEN: usize = 7; @@ -475,12 +473,10 @@ fn extend_and_serialize(derivation_path: &DerivationPath) -> Vec { }; let mut concat_derivation = vec![byte]; concat_derivation.extend_from_slice(&SOL_DERIVATION_PATH_BE); - if let Some(account) = derivation_path.account { - let hardened_account = account | HARDENED_BIT; - concat_derivation.extend_from_slice(&hardened_account.to_be_bytes()); - if let Some(change) = derivation_path.change { - let hardened_change = change | HARDENED_BIT; - concat_derivation.extend_from_slice(&hardened_change.to_be_bytes()); + if let Some(account) = &derivation_path.account { + concat_derivation.extend_from_slice(&account.as_u32().to_be_bytes()); + if let Some(change) = &derivation_path.change { + concat_derivation.extend_from_slice(&change.as_u32().to_be_bytes()); } } concat_derivation diff --git a/remote-wallet/src/remote_wallet.rs b/remote-wallet/src/remote_wallet.rs index 4f92fd8cf4..11cbd7ac52 100644 --- a/remote-wallet/src/remote_wallet.rs +++ b/remote-wallet/src/remote_wallet.rs @@ -281,12 +281,14 @@ impl RemoteWalletInfo { key_path.pop(); } let mut parts = key_path.split('/'); - derivation_path.account = parts - .next() - .and_then(|account| account.replace("'", "").parse::().ok()); - derivation_path.change = parts - .next() - .and_then(|change| change.replace("'", "").parse::().ok()); + if let Some(account) = parts.next() { + derivation_path.account = + Some(DerivationPathComponent::from_str(account)?); + } + if let Some(change) = parts.next() { + derivation_path.change = + Some(DerivationPathComponent::from_str(change)?); + } if parts.next().is_some() { return Err(RemoteWalletError::InvalidDerivationPath(format!( "key path `{}` too deep, only / supported", @@ -322,21 +324,74 @@ impl RemoteWalletInfo { } } +#[derive(Clone, Default, PartialEq)] +pub struct DerivationPathComponent(u32); + +impl DerivationPathComponent { + pub const HARDENED_BIT: u32 = 1 << 31; + + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +impl From for DerivationPathComponent { + fn from(n: u32) -> Self { + Self(n | Self::HARDENED_BIT) + } +} + +impl FromStr for DerivationPathComponent { + type Err = RemoteWalletError; + + fn from_str(s: &str) -> Result { + // Replace str::splitn() with str::strip_suffix() once stabilized + let parts: Vec<_> = s.splitn(2, '\'').collect(); + if parts.len() == 2 { + eprintln!("all path components are promoted to hardened representation"); + } + parts[0].parse::().map(|ki| ki.into()).map_err(|_| { + RemoteWalletError::InvalidDerivationPath(format!( + "failed to parse path component: {:?}", + s + )) + }) + } +} + +impl std::fmt::Display for DerivationPathComponent { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + let hardened = if (self.0 & Self::HARDENED_BIT) == 0 { + "" + } else { + "'" + }; + let index = self.0 & !Self::HARDENED_BIT; + write!(fmt, "{}{}", index, hardened) + } +} + +impl std::fmt::Debug for DerivationPathComponent { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Display::fmt(self, fmt) + } +} + #[derive(Default, PartialEq, Clone)] pub struct DerivationPath { - pub account: Option, - pub change: Option, + pub account: Option, + pub change: Option, } impl fmt::Debug for DerivationPath { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let account = if let Some(account) = self.account { - format!("/{:?}'", account) + let account = if let Some(account) = &self.account { + format!("/{:?}", account) } else { "".to_string() }; - let change = if let Some(change) = self.change { - format!("/{:?}'", change) + let change = if let Some(change) = &self.change { + format!("/{:?}", change) } else { "".to_string() }; @@ -346,8 +401,8 @@ impl fmt::Debug for DerivationPath { impl DerivationPath { pub fn get_query(&self) -> String { - if let Some(account) = self.account { - if let Some(change) = self.change { + if let Some(account) = &self.account { + if let Some(change) = &self.change { format!("?key={}/{}", account, change) } else { format!("?key={}", account) @@ -399,8 +454,8 @@ mod tests { assert_eq!( derivation_path, DerivationPath { - account: Some(1), - change: Some(2), + account: Some(1.into()), + change: Some(2.into()), } ); let (wallet_info, derivation_path) = @@ -415,8 +470,8 @@ mod tests { assert_eq!( derivation_path, DerivationPath { - account: Some(1), - change: Some(2), + account: Some(1.into()), + change: Some(2.into()), } ); let (wallet_info, derivation_path) = @@ -431,8 +486,8 @@ mod tests { assert_eq!( derivation_path, DerivationPath { - account: Some(1), - change: Some(2), + account: Some(1.into()), + change: Some(2.into()), } ); let (wallet_info, derivation_path) = @@ -447,8 +502,8 @@ mod tests { assert_eq!( derivation_path, DerivationPath { - account: Some(1), - change: Some(2), + account: Some(1.into()), + change: Some(2.into()), } ); let (wallet_info, derivation_path) = @@ -463,7 +518,7 @@ mod tests { assert_eq!( derivation_path, DerivationPath { - account: Some(1), + account: Some(1.into()), change: None, } ); @@ -481,7 +536,7 @@ mod tests { assert_eq!( derivation_path, DerivationPath { - account: Some(1), + account: Some(1.into()), change: None, } ); @@ -497,8 +552,8 @@ mod tests { assert_eq!( derivation_path, DerivationPath { - account: Some(1), - change: Some(2), + account: Some(1.into()), + change: Some(2.into()), } ); @@ -569,14 +624,53 @@ mod tests { }; assert_eq!(derivation_path.get_query(), "".to_string()); let derivation_path = DerivationPath { - account: Some(1), + account: Some(1.into()), change: None, }; - assert_eq!(derivation_path.get_query(), "?key=1".to_string()); + assert_eq!( + derivation_path.get_query(), + format!("?key={}", DerivationPathComponent::from(1)) + ); let derivation_path = DerivationPath { - account: Some(1), - change: Some(2), + account: Some(1.into()), + change: Some(2.into()), }; - assert_eq!(derivation_path.get_query(), "?key=1/2".to_string()); + assert_eq!( + derivation_path.get_query(), + format!( + "?key={}/{}", + DerivationPathComponent::from(1), + DerivationPathComponent::from(2) + ) + ); + } + + #[test] + fn test_derivation_path_debug() { + let mut path = DerivationPath::default(); + assert_eq!(format!("{:?}", path), "m/44'/501'".to_string()); + + path.account = Some(1.into()); + assert_eq!(format!("{:?}", path), "m/44'/501'/1'".to_string()); + + path.change = Some(2.into()); + assert_eq!(format!("{:?}", path), "m/44'/501'/1'/2'".to_string()); + } + + #[test] + fn test_derivation_path_component() { + let f = DerivationPathComponent::from(1); + assert_eq!(f.as_u32(), 1 | DerivationPathComponent::HARDENED_BIT); + + let fs = DerivationPathComponent::from_str("1").unwrap(); + assert_eq!(fs, f); + + let fs = DerivationPathComponent::from_str("1'").unwrap(); + assert_eq!(fs, f); + + assert!(DerivationPathComponent::from_str("-1").is_err()); + + assert_eq!(format!("{}", f), "1'".to_string()); + assert_eq!(format!("{:?}", f), "1'".to_string()); } }