diff --git a/remote-wallet/src/ledger.rs b/remote-wallet/src/ledger.rs index 131877bd35..19343de1b8 100644 --- a/remote-wallet/src/ledger.rs +++ b/remote-wallet/src/ledger.rs @@ -1,16 +1,16 @@ -use crate::{ - ledger_error::LedgerError, - remote_wallet::{ - DerivationPath, RemoteWallet, RemoteWalletError, RemoteWalletInfo, RemoteWalletManager, +use { + crate::{ + ledger_error::LedgerError, + remote_wallet::{RemoteWallet, RemoteWalletError, RemoteWalletInfo, RemoteWalletManager}, }, + console::Emoji, + dialoguer::{theme::ColorfulTheme, Select}, + log::*, + num_traits::FromPrimitive, + semver::Version as FirmwareVersion, + solana_sdk::{derivation_path::DerivationPath, pubkey::Pubkey, signature::Signature}, + std::{cmp::min, fmt, sync::Arc}, }; -use console::Emoji; -use dialoguer::{theme::ColorfulTheme, Select}; -use log::*; -use num_traits::FromPrimitive; -use semver::Version as FirmwareVersion; -use solana_sdk::{pubkey::Pubkey, signature::Signature}; -use std::{cmp::min, fmt, sync::Arc}; static CHECK_MARK: Emoji = Emoji("✅ ", ""); diff --git a/remote-wallet/src/ledger_error.rs b/remote-wallet/src/ledger_error.rs index 6e5b6e2296..a434f776b3 100644 --- a/remote-wallet/src/ledger_error.rs +++ b/remote-wallet/src/ledger_error.rs @@ -1,5 +1,4 @@ -use num_derive::FromPrimitive; -use thiserror::Error; +use {num_derive::FromPrimitive, thiserror::Error}; #[derive(Error, Debug, Clone, FromPrimitive, PartialEq)] pub enum LedgerError { diff --git a/remote-wallet/src/remote_keypair.rs b/remote-wallet/src/remote_keypair.rs index 55baa6ae80..9028b7a6ce 100644 --- a/remote-wallet/src/remote_keypair.rs +++ b/remote-wallet/src/remote_keypair.rs @@ -1,13 +1,16 @@ -use crate::{ - ledger::get_ledger_from_info, - remote_wallet::{ - DerivationPath, RemoteWallet, RemoteWalletError, RemoteWalletInfo, RemoteWalletManager, - RemoteWalletType, +use { + crate::{ + ledger::get_ledger_from_info, + remote_wallet::{ + RemoteWallet, RemoteWalletError, RemoteWalletInfo, RemoteWalletManager, + RemoteWalletType, + }, + }, + solana_sdk::{ + derivation_path::DerivationPath, + pubkey::Pubkey, + signature::{Signature, Signer, SignerError}, }, -}; -use solana_sdk::{ - pubkey::Pubkey, - signature::{Signature, Signer, SignerError}, }; pub struct RemoteKeypair { diff --git a/remote-wallet/src/remote_wallet.rs b/remote-wallet/src/remote_wallet.rs index 1488d00fd0..9e471d42cc 100644 --- a/remote-wallet/src/remote_wallet.rs +++ b/remote-wallet/src/remote_wallet.rs @@ -1,21 +1,23 @@ -use crate::{ - ledger::{is_valid_ledger, LedgerWallet}, - ledger_error::LedgerError, +use { + crate::{ + ledger::{is_valid_ledger, LedgerWallet}, + ledger_error::LedgerError, + }, + log::*, + parking_lot::{Mutex, RwLock}, + solana_sdk::{ + derivation_path::{DerivationPath, DerivationPathComponent, DerivationPathError}, + pubkey::Pubkey, + signature::{Signature, SignerError}, + }, + std::{ + str::FromStr, + sync::Arc, + time::{Duration, Instant}, + }, + thiserror::Error, + url::Url, }; -use log::*; -use parking_lot::{Mutex, RwLock}; -use solana_sdk::{ - pubkey::Pubkey, - signature::{Signature, SignerError}, -}; -use std::{ - fmt, - str::FromStr, - sync::Arc, - time::{Duration, Instant}, -}; -use thiserror::Error; -use url::Url; const HID_GLOBAL_USAGE_PAGE: u16 = 0xFF00; const HID_USB_DEVICE_CLASS: u8 = 0; @@ -32,8 +34,8 @@ pub enum RemoteWalletError { #[error("device with non-supported product ID or vendor ID was detected")] InvalidDevice, - #[error("invalid derivation path: {0}")] - InvalidDerivationPath(String), + #[error(transparent)] + DerivationPathError(#[from] DerivationPathError), #[error("invalid input: {0}")] InvalidInput(String), @@ -251,13 +253,17 @@ pub struct RemoteWalletInfo { impl RemoteWalletInfo { pub fn parse_path(path: String) -> Result<(Self, DerivationPath), RemoteWalletError> { let wallet_path = Url::parse(&path).map_err(|e| { - RemoteWalletError::InvalidDerivationPath(format!("parse error: {:?}", e)) + Into::::into(DerivationPathError::InvalidDerivationPath(format!( + "parse error: {:?}", + e + ))) })?; if wallet_path.host_str().is_none() { - return Err(RemoteWalletError::InvalidDerivationPath( + return Err(DerivationPathError::InvalidDerivationPath( "missing remote wallet type".to_string(), - )); + ) + .into()); } let mut wallet_info = RemoteWalletInfo { @@ -268,9 +274,8 @@ impl RemoteWalletInfo { if let Some(wallet_id) = wallet_path.path_segments().map(|c| c.collect::>()) { if !wallet_id[0].is_empty() { wallet_info.pubkey = Pubkey::from_str(wallet_id[0]).map_err(|e| { - RemoteWalletError::InvalidDerivationPath(format!( - "pubkey from_str error: {:?}", - e + Into::::into(DerivationPathError::InvalidDerivationPath( + format!("pubkey from_str error: {:?}", e), )) })?; } @@ -297,22 +302,25 @@ impl RemoteWalletInfo { Some(DerivationPathComponent::from_str(change)?); } if parts.next().is_some() { - return Err(RemoteWalletError::InvalidDerivationPath(format!( + return Err(DerivationPathError::InvalidDerivationPath(format!( "key path `{}` too deep, only / supported", _key_path - ))); + )) + .into()); } } else { - return Err(RemoteWalletError::InvalidDerivationPath(format!( + return Err(DerivationPathError::InvalidDerivationPath(format!( "invalid query string `{}={}`, only `key` supported", pair.0, pair.1 - ))); + )) + .into()); } } if query_pairs.next().is_some() { - return Err(RemoteWalletError::InvalidDerivationPath( + return Err(DerivationPathError::InvalidDerivationPath( "invalid query string, extra fields not supported".to_string(), - )); + ) + .into()); } } } @@ -331,96 +339,6 @@ impl RemoteWalletInfo { } } -#[derive(Clone, Default, PartialEq)] -pub struct DerivationPathComponent(u32); - -impl DerivationPathComponent { - pub const HARDENED_BIT: u32 = 1 << 31; - - pub fn as_u32(&self) -> u32 { - self.0 - } -} - -impl From for DerivationPathComponent { - fn from(n: u32) -> Self { - Self(n | Self::HARDENED_BIT) - } -} - -impl FromStr for DerivationPathComponent { - type Err = RemoteWalletError; - - fn from_str(s: &str) -> Result { - let index_str = if let Some(stripped) = s.strip_suffix('\'') { - eprintln!("all path components are promoted to hardened representation"); - stripped - } else { - s - }; - index_str.parse::().map(|ki| ki.into()).map_err(|_| { - RemoteWalletError::InvalidDerivationPath(format!( - "failed to parse path component: {:?}", - s - )) - }) - } -} - -impl std::fmt::Display for DerivationPathComponent { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - let hardened = if (self.0 & Self::HARDENED_BIT) == 0 { - "" - } else { - "'" - }; - let index = self.0 & !Self::HARDENED_BIT; - write!(fmt, "{}{}", index, hardened) - } -} - -impl std::fmt::Debug for DerivationPathComponent { - fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { - std::fmt::Display::fmt(self, fmt) - } -} - -#[derive(Default, PartialEq, Clone)] -pub struct DerivationPath { - pub account: Option, - pub change: Option, -} - -impl fmt::Debug for DerivationPath { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let account = if let Some(account) = &self.account { - format!("/{:?}", account) - } else { - "".to_string() - }; - let change = if let Some(change) = &self.change { - format!("/{:?}", change) - } else { - "".to_string() - }; - write!(f, "m/44'/501'{}{}", account, change) - } -} - -impl DerivationPath { - pub fn get_query(&self) -> String { - if let Some(account) = &self.account { - if let Some(change) = &self.change { - format!("?key={}/{}", account, change) - } else { - format!("?key={}", account) - } - } else { - "".to_string() - } - } -} - /// Helper to determine if a device is a valid HID pub fn is_valid_hid_device(usage_page: u16, interface_number: i32) -> bool { usage_page == HID_GLOBAL_USAGE_PAGE || interface_number == HID_USB_DEVICE_CLASS as i32 @@ -636,62 +554,4 @@ mod tests { format!("usb://ledger/{}", pubkey_str) ); } - - #[test] - fn test_get_query() { - let derivation_path = DerivationPath { - account: None, - change: None, - }; - assert_eq!(derivation_path.get_query(), "".to_string()); - let derivation_path = DerivationPath { - account: Some(1.into()), - change: None, - }; - assert_eq!( - derivation_path.get_query(), - format!("?key={}", DerivationPathComponent::from(1)) - ); - let derivation_path = DerivationPath { - account: Some(1.into()), - change: Some(2.into()), - }; - assert_eq!( - derivation_path.get_query(), - format!( - "?key={}/{}", - DerivationPathComponent::from(1), - DerivationPathComponent::from(2) - ) - ); - } - - #[test] - fn test_derivation_path_debug() { - let mut path = DerivationPath::default(); - assert_eq!(format!("{:?}", path), "m/44'/501'".to_string()); - - path.account = Some(1.into()); - assert_eq!(format!("{:?}", path), "m/44'/501'/1'".to_string()); - - path.change = Some(2.into()); - assert_eq!(format!("{:?}", path), "m/44'/501'/1'/2'".to_string()); - } - - #[test] - fn test_derivation_path_component() { - let f = DerivationPathComponent::from(1); - assert_eq!(f.as_u32(), 1 | DerivationPathComponent::HARDENED_BIT); - - let fs = DerivationPathComponent::from_str("1").unwrap(); - assert_eq!(fs, f); - - let fs = DerivationPathComponent::from_str("1'").unwrap(); - assert_eq!(fs, f); - - assert!(DerivationPathComponent::from_str("-1").is_err()); - - assert_eq!(format!("{}", f), "1'".to_string()); - assert_eq!(format!("{:?}", f), "1'".to_string()); - } } diff --git a/sdk/src/derivation_path.rs b/sdk/src/derivation_path.rs new file mode 100644 index 0000000000..faef3ac622 --- /dev/null +++ b/sdk/src/derivation_path.rs @@ -0,0 +1,163 @@ +use { + std::{fmt, str::FromStr}, + thiserror::Error, +}; + +/// Derivation path error. +#[derive(Error, Debug, Clone)] +pub enum DerivationPathError { + #[error("invalid derivation path: {0}")] + InvalidDerivationPath(String), +} + +#[derive(Clone, Default, PartialEq)] +pub struct DerivationPathComponent(u32); + +impl DerivationPathComponent { + pub const HARDENED_BIT: u32 = 1 << 31; + + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +impl From for DerivationPathComponent { + fn from(n: u32) -> Self { + Self(n | Self::HARDENED_BIT) + } +} + +impl FromStr for DerivationPathComponent { + type Err = DerivationPathError; + + fn from_str(s: &str) -> Result { + let index_str = if let Some(stripped) = s.strip_suffix('\'') { + stripped + } else { + s + }; + index_str.parse::().map(|ki| ki.into()).map_err(|_| { + DerivationPathError::InvalidDerivationPath(format!( + "failed to parse path component: {:?}", + s + )) + }) + } +} + +impl std::fmt::Display for DerivationPathComponent { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + let hardened = if (self.0 & Self::HARDENED_BIT) == 0 { + "" + } else { + "'" + }; + let index = self.0 & !Self::HARDENED_BIT; + write!(fmt, "{}{}", index, hardened) + } +} + +impl std::fmt::Debug for DerivationPathComponent { + fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { + std::fmt::Display::fmt(self, fmt) + } +} + +#[derive(Default, PartialEq, Clone)] +pub struct DerivationPath { + pub account: Option, + pub change: Option, +} + +impl fmt::Debug for DerivationPath { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let account = if let Some(account) = &self.account { + format!("/{:?}", account) + } else { + "".to_string() + }; + let change = if let Some(change) = &self.change { + format!("/{:?}", change) + } else { + "".to_string() + }; + write!(f, "m/44'/501'{}{}", account, change) + } +} + +impl DerivationPath { + pub fn get_query(&self) -> String { + if let Some(account) = &self.account { + if let Some(change) = &self.change { + format!("?key={}/{}", account, change) + } else { + format!("?key={}", account) + } + } else { + "".to_string() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_query() { + let derivation_path = DerivationPath { + account: None, + change: None, + }; + assert_eq!(derivation_path.get_query(), "".to_string()); + let derivation_path = DerivationPath { + account: Some(1.into()), + change: None, + }; + assert_eq!( + derivation_path.get_query(), + format!("?key={}", DerivationPathComponent::from(1)) + ); + let derivation_path = DerivationPath { + account: Some(1.into()), + change: Some(2.into()), + }; + assert_eq!( + derivation_path.get_query(), + format!( + "?key={}/{}", + DerivationPathComponent::from(1), + DerivationPathComponent::from(2) + ) + ); + } + + #[test] + fn test_derivation_path_debug() { + let mut path = DerivationPath::default(); + assert_eq!(format!("{:?}", path), "m/44'/501'".to_string()); + + path.account = Some(1.into()); + assert_eq!(format!("{:?}", path), "m/44'/501'/1'".to_string()); + + path.change = Some(2.into()); + assert_eq!(format!("{:?}", path), "m/44'/501'/1'/2'".to_string()); + } + + #[test] + fn test_derivation_path_component() { + let f = DerivationPathComponent::from(1); + assert_eq!(f.as_u32(), 1 | DerivationPathComponent::HARDENED_BIT); + + let fs = DerivationPathComponent::from_str("1").unwrap(); + assert_eq!(fs, f); + + let fs = DerivationPathComponent::from_str("1'").unwrap(); + assert_eq!(fs, f); + + assert!(DerivationPathComponent::from_str("-1").is_err()); + + assert_eq!(format!("{}", f), "1'".to_string()); + assert_eq!(format!("{:?}", f), "1'".to_string()); + } +} diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index f9979e824d..64f89324ee 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -13,6 +13,7 @@ pub mod arithmetic; pub mod builtins; pub mod client; pub mod commitment_config; +pub mod derivation_path; pub mod deserialize_utils; pub mod entrypoint; pub mod entrypoint_deprecated;