Merge pull request #257 from stevenroose/bip32-increment

bip32: Add increment method to DerivationPath and ChildNumber
This commit is contained in:
Andrew Poelstra 2019-05-04 21:25:21 +00:00 committed by GitHub
commit b471a12487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 101 additions and 2 deletions

View File

@ -137,6 +137,13 @@ impl ChildNumber {
}
}
/// Returns the child number that is a single increment from this one.
pub fn increment(self) -> Result<ChildNumber, Error> {
match self {
ChildNumber::Normal{ index: idx } => ChildNumber::from_normal_idx(idx+1),
ChildNumber::Hardened{ index: idx } => ChildNumber::from_hardened_idx(idx+1),
}
}
}
impl From<u32> for ChildNumber {
@ -260,20 +267,69 @@ impl FromStr for DerivationPath {
}
}
/// An iterator over children of a [DerivationPath].
///
/// It is returned by the methods [DerivationPath::children_since],
/// [DerivationPath::normal_children] and [DerivationPath::hardened_children].
pub struct DerivationPathIterator<'a> {
base: &'a DerivationPath,
next_child: Option<ChildNumber>,
}
impl<'a> DerivationPathIterator<'a> {
/// Start a new [DerivationPathIterator] at the given child.
pub fn start_from(path: &'a DerivationPath, start: ChildNumber) -> DerivationPathIterator<'a> {
DerivationPathIterator {
base: path,
next_child: Some(start),
}
}
}
impl<'a> Iterator for DerivationPathIterator<'a> {
type Item = DerivationPath;
fn next(&mut self) -> Option<Self::Item> {
if self.next_child.is_none() {
return None;
}
let ret = self.next_child.unwrap();
self.next_child = ret.increment().ok();
Some(self.base.child(ret))
}
}
impl DerivationPath {
/// Create a new DerivationPath that is a child of this one.
/// Create a new [DerivationPath] that is a child of this one.
pub fn child(&self, cn: ChildNumber) -> DerivationPath {
let mut path = self.0.clone();
path.push(cn);
DerivationPath(path)
}
/// Convert into a DerivationPath that is a child of this one.
/// Convert into a [DerivationPath] that is a child of this one.
pub fn into_child(self, cn: ChildNumber) -> DerivationPath {
let mut path = self.0;
path.push(cn);
DerivationPath(path)
}
/// Get an [Iterator] over the children of this [DerivationPath]
/// starting with the given [ChildNumber].
pub fn children_from(&self, cn: ChildNumber) -> DerivationPathIterator {
DerivationPathIterator::start_from(&self, cn)
}
/// Get an [Iterator] over the unhardened children of this [DerivationPath].
pub fn normal_children(&self) -> DerivationPathIterator {
DerivationPathIterator::start_from(&self, ChildNumber::Normal{ index: 0 })
}
/// Get an [Iterator] over the hardened children of this [DerivationPath].
pub fn hardened_children(&self) -> DerivationPathIterator {
DerivationPathIterator::start_from(&self, ChildNumber::Hardened{ index: 0 })
}
}
impl fmt::Display for DerivationPath {
@ -799,6 +855,49 @@ mod tests {
assert_eq!(Ok(pk), decoded_pk);
}
#[test]
fn test_increment() {
let idx = 9345497; // randomly generated, I promise
let cn = ChildNumber::from_normal_idx(idx).unwrap();
assert_eq!(cn.increment().ok(), Some(ChildNumber::from_normal_idx(idx+1).unwrap()));
let cn = ChildNumber::from_hardened_idx(idx).unwrap();
assert_eq!(cn.increment().ok(), Some(ChildNumber::from_hardened_idx(idx+1).unwrap()));
let max = (1<<31)-1;
let cn = ChildNumber::from_normal_idx(max).unwrap();
assert_eq!(cn.increment().err(), Some(Error::InvalidChildNumber(1<<31)));
let cn = ChildNumber::from_hardened_idx(max).unwrap();
assert_eq!(cn.increment().err(), Some(Error::InvalidChildNumber(1<<31)));
let cn = ChildNumber::from_normal_idx(350).unwrap();
let path = DerivationPath::from_str("m/42'").unwrap();
let mut iter = path.children_from(cn);
assert_eq!(iter.next(), Some("m/42'/350".parse().unwrap()));
assert_eq!(iter.next(), Some("m/42'/351".parse().unwrap()));
let path = DerivationPath::from_str("m/42'/350'").unwrap();
let mut iter = path.normal_children();
assert_eq!(iter.next(), Some("m/42'/350'/0".parse().unwrap()));
assert_eq!(iter.next(), Some("m/42'/350'/1".parse().unwrap()));
let path = DerivationPath::from_str("m/42'/350'").unwrap();
let mut iter = path.hardened_children();
assert_eq!(iter.next(), Some("m/42'/350'/0'".parse().unwrap()));
assert_eq!(iter.next(), Some("m/42'/350'/1'".parse().unwrap()));
let cn = ChildNumber::from_hardened_idx(42350).unwrap();
let path = DerivationPath::from_str("m/42'").unwrap();
let mut iter = path.children_from(cn);
assert_eq!(iter.next(), Some("m/42'/42350'".parse().unwrap()));
assert_eq!(iter.next(), Some("m/42'/42351'".parse().unwrap()));
let cn = ChildNumber::from_hardened_idx(max).unwrap();
let path = DerivationPath::from_str("m/42'").unwrap();
let mut iter = path.children_from(cn);
assert!(iter.next().is_some());
assert!(iter.next().is_none());
}
#[test]
fn test_vector_1() {
let secp = Secp256k1::new();