diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 51437b2..990bf0a 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -14,7 +14,7 @@ honggfuzz_fuzz = ["honggfuzz"] [dependencies] honggfuzz = { version = "0.5", optional = true } afl = { version = "0.3", optional = true } -bitcoin = { path = "..", features = ["fuzztarget"] } +bitcoin = { path = "..", features = ["fuzztarget", "serde-decimal"] } # Prevent this from interfering with workspaces [workspace] @@ -35,3 +35,11 @@ path = "fuzz_targets/deserialize_transaction.rs" [[bin]] name = "deserialize_address" path = "fuzz_targets/deserialize_address.rs" + +[[bin]] +name = "deserialize_decimal" +path = "fuzz_targets/deserialize_decimal.rs" + +[[bin]] +name = "deserialize_udecimal" +path = "fuzz_targets/deserialize_udecimal.rs" diff --git a/fuzz/fuzz_targets/deserialize_decimal.rs b/fuzz/fuzz_targets/deserialize_decimal.rs new file mode 100644 index 0000000..002bb76 --- /dev/null +++ b/fuzz/fuzz_targets/deserialize_decimal.rs @@ -0,0 +1,61 @@ +extern crate bitcoin; +use std::str::FromStr; +fn do_test(data: &[u8]) { + let data_str = String::from_utf8_lossy(data); + let dec = match bitcoin::util::decimal::Decimal::from_str(&data_str) { + Ok(dec) => dec, + Err(_) => return, + }; + let dec_roundtrip = match bitcoin::util::decimal::Decimal::from_str(&dec.to_string()) { + Ok(dec) => dec, + Err(_) => return, + }; + assert_eq!(dec, dec_roundtrip); +} + +#[cfg(feature = "afl")] +extern crate afl; +#[cfg(feature = "afl")] +fn main() { + afl::read_stdio_bytes(|data| { + do_test(&data); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + do_test(data); + }); + } +} + +#[cfg(test)] +mod tests { + fn extend_vec_from_hex(hex: &str, out: &mut Vec) { + let mut b = 0; + for (idx, c) in hex.as_bytes().iter().enumerate() { + b <<= 4; + match *c { + b'A'...b'F' => b |= c - b'A' + 10, + b'a'...b'f' => b |= c - b'a' + 10, + b'0'...b'9' => b |= c - b'0', + _ => panic!("Bad hex"), + } + if (idx & 1) == 1 { + out.push(b); + b = 0; + } + } + } + + #[test] + fn duplicate_crash() { + let mut a = Vec::new(); + extend_vec_from_hex("00000000", &mut a); + super::do_test(&a); + } +} diff --git a/fuzz/fuzz_targets/deserialize_udecimal.rs b/fuzz/fuzz_targets/deserialize_udecimal.rs new file mode 100644 index 0000000..558b1e3 --- /dev/null +++ b/fuzz/fuzz_targets/deserialize_udecimal.rs @@ -0,0 +1,61 @@ +extern crate bitcoin; +use std::str::FromStr; +fn do_test(data: &[u8]) { + let data_str = String::from_utf8_lossy(data); + let dec = match bitcoin::util::decimal::UDecimal::from_str(&data_str) { + Ok(dec) => dec, + Err(_) => return, + }; + let dec_roundtrip = match bitcoin::util::decimal::UDecimal::from_str(&dec.to_string()) { + Ok(dec) => dec, + Err(_) => return, + }; + assert_eq!(dec, dec_roundtrip); +} + +#[cfg(feature = "afl")] +extern crate afl; +#[cfg(feature = "afl")] +fn main() { + afl::read_stdio_bytes(|data| { + do_test(&data); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + do_test(data); + }); + } +} + +#[cfg(test)] +mod tests { + fn extend_vec_from_hex(hex: &str, out: &mut Vec) { + let mut b = 0; + for (idx, c) in hex.as_bytes().iter().enumerate() { + b <<= 4; + match *c { + b'A'...b'F' => b |= c - b'A' + 10, + b'a'...b'f' => b |= c - b'a' + 10, + b'0'...b'9' => b |= c - b'0', + _ => panic!("Bad hex"), + } + if (idx & 1) == 1 { + out.push(b); + b = 0; + } + } + } + + #[test] + fn duplicate_crash() { + let mut a = Vec::new(); + extend_vec_from_hex("00000000", &mut a); + super::do_test(&a); + } +} diff --git a/src/util/decimal.rs b/src/util/decimal.rs index 9b839c6..ab308eb 100644 --- a/src/util/decimal.rs +++ b/src/util/decimal.rs @@ -22,9 +22,11 @@ //! use std::{fmt, ops}; +#[cfg(feature = "serde-decimal")] use std::error; +#[cfg(feature = "serde-decimal")] use std::str::FromStr; #[cfg(feature = "serde-decimal")] use serde; -#[cfg(feature = "serde-decimal")] use strason::Json; +#[cfg(feature = "serde-decimal")] use strason::{self, Json}; /// A fixed-point decimal type #[derive(Copy, Clone, Debug, Eq, Ord)] @@ -61,7 +63,11 @@ impl fmt::Display for Decimal { let ten = 10i64.pow(self.exponent as u32); let int_part = self.mantissa / ten; let dec_part = (self.mantissa % ten).abs(); - write!(f, "{}.{:02$}", int_part, dec_part, self.exponent) + if int_part == 0 && self.mantissa < 0 { + write!(f, "-{}.{:02$}", int_part, dec_part, self.exponent) + } else { + write!(f, "{}.{:02$}", int_part, dec_part, self.exponent) + } } } @@ -126,6 +132,60 @@ impl Decimal { /// Returns whether or not the number is nonnegative #[inline] pub fn nonnegative(&self) -> bool { self.mantissa >= 0 } + + // Converts a JSON number to a Decimal previously parsed by strason + #[cfg(feature = "serde-decimal")] + fn parse_decimal(s: &str) -> Result { + // We know this will be a well-formed Json number, so we can + // be pretty lax about parsing + let mut negative = false; + let mut past_dec = false; + let mut exponent = 0; + let mut mantissa = 0i64; + + for b in s.as_bytes() { + match *b { + b'-' => { negative = true; } + b'0'...b'9' => { + match 10i64.checked_mul(mantissa) { + None => return Err(ParseDecimalError::TooBig), + Some(n) => { + match n.checked_add((b - b'0') as i64) { + None => return Err(ParseDecimalError::TooBig), + Some(n) => mantissa = n, + } + } + } + if past_dec { + exponent += 1; + if exponent > 18 { + return Err(ParseDecimalError::TooBig); + } + } + } + b'.' => { past_dec = true; } + _ => { /* whitespace or something, just ignore it */ } + } + } + if negative { mantissa *= -1; } + Ok(Decimal { + mantissa: mantissa, + exponent: exponent, + }) + } +} + +#[cfg(feature = "serde-decimal")] +impl FromStr for Decimal { + type Err = ParseDecimalError; + + /// Parses a `Decimal` from the given amount string. + fn from_str(s: &str) -> Result { + Json::from_str(s)? + .num() + .ok_or(ParseDecimalError::NotANumber) + .and_then(Decimal::parse_decimal) + } } #[cfg(feature = "serde-decimal")] @@ -140,35 +200,12 @@ impl<'de> serde::Deserialize<'de> for Decimal { where D: serde::Deserializer<'de>, { - let json = Json::deserialize(deserializer)?; - match json.num() { - Some(s) => { - // We know this will be a well-formed Json number, so we can - // be pretty lax about parsing - let mut negative = false; - let mut past_dec = false; - let mut exponent = 0; - let mut mantissa = 0i64; + use serde::de; - for b in s.as_bytes() { - match *b { - b'-' => { negative = true; } - b'0'...b'9' => { - mantissa = 10 * mantissa + (b - b'0') as i64; - if past_dec { exponent += 1; } - } - b'.' => { past_dec = true; } - _ => { /* whitespace or something, just ignore it */ } - } - } - if negative { mantissa *= -1; } - Ok(Decimal { - mantissa: mantissa, - exponent: exponent, - }) - } - None => Err(serde::de::Error::custom("expected decimal, got non-numeric")) - } + Json::deserialize(deserializer)? + .num() + .ok_or(de::Error::custom("expected decimal, got non-numeric")) + .and_then(|s| Decimal::parse_decimal(s).map_err(de::Error::custom)) } } @@ -260,6 +297,57 @@ impl UDecimal { self.mantissa * 10u64.pow((exponent - self.exponent) as u32) } } + + // Converts a JSON number to a Decimal previously parsed by strason + #[cfg(feature = "serde-decimal")] + fn parse_udecimal(s: &str) -> Result { + // We know this will be a well-formed Json number, so we can + // be pretty lax about parsing + let mut past_dec = false; + let mut exponent = 0; + let mut mantissa = 0u64; + + for b in s.as_bytes() { + match *b { + b'0'...b'9' => { + match 10u64.checked_mul(mantissa) { + None => return Err(ParseDecimalError::TooBig), + Some(n) => { + match n.checked_add((b - b'0') as u64) { + None => return Err(ParseDecimalError::TooBig), + Some(n) => mantissa = n, + } + } + } + if past_dec { + exponent += 1; + if exponent > 18 { + return Err(ParseDecimalError::TooBig); + } + } + } + b'.' => { past_dec = true; } + _ => { /* whitespace or something, just ignore it */ } + } + } + Ok(UDecimal { + mantissa: mantissa, + exponent: exponent, + }) + } +} + +#[cfg(feature = "serde-decimal")] +impl FromStr for UDecimal { + type Err = ParseDecimalError; + + /// Parses a `UDecimal` from the given amount string. + fn from_str(s: &str) -> Result { + Json::from_str(s)? + .num() + .ok_or(ParseDecimalError::NotANumber) + .and_then(UDecimal::parse_udecimal) + } } #[cfg(feature = "serde-decimal")] @@ -274,32 +362,12 @@ impl<'de> serde::Deserialize<'de> for UDecimal { where D: serde::Deserializer<'de>, { - let json = Json::deserialize(deserializer)?; - match json.num() { - Some(s) => { - // We know this will be a well-formed Json number, so we can - // be pretty lax about parsing - let mut past_dec = false; - let mut exponent = 0; - let mut mantissa = 0u64; + use serde::de; - for b in s.as_bytes() { - match *b { - b'0'...b'9' => { - mantissa = 10 * mantissa + (b - b'0') as u64; - if past_dec { exponent += 1; } - } - b'.' => { past_dec = true; } - _ => { /* whitespace or something, just ignore it */ } - } - } - Ok(UDecimal { - mantissa: mantissa, - exponent: exponent, - }) - } - None => Err(serde::de::Error::custom("expected decimal, got non-numeric")) - } + Json::deserialize(deserializer)? + .num() + .ok_or(de::Error::custom("expected decimal, got non-numeric")) + .and_then(|s| UDecimal::parse_udecimal(s).map_err(de::Error::custom)) } } @@ -321,6 +389,55 @@ impl serde::Serialize for UDecimal { } } +/// Errors that occur during `Decimal`/`UDecimal` parsing. +#[cfg(feature = "serde-decimal")] +#[derive(Debug)] +pub enum ParseDecimalError { + /// An error ocurred while parsing the JSON number. + Json(strason::Error), + /// Not a number. + NotANumber, + /// The number is too big to fit in a `Decimal` or `UDecimal`. + TooBig, +} + +#[cfg(feature = "serde-decimal")] +#[doc(hidden)] +impl From for ParseDecimalError { + fn from(e: strason::Error) -> ParseDecimalError { + ParseDecimalError::Json(e) + } +} + +#[cfg(feature = "serde-decimal")] +impl fmt::Display for ParseDecimalError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match *self { + ParseDecimalError::Json(ref e) => fmt::Display::fmt(e, fmt), + ParseDecimalError::NotANumber => fmt.write_str("not a valid JSON number"), + ParseDecimalError::TooBig => fmt.write_str("number is too big"), + } + } +} + +#[cfg(feature = "serde-decimal")] +impl error::Error for ParseDecimalError { + fn description(&self) -> &str { + match *self { + ParseDecimalError::Json(ref e) => e.description(), + ParseDecimalError::NotANumber => "not a valid JSON number", + ParseDecimalError::TooBig => "number is too big", + } + } + + fn cause(&self) -> Option<&error::Error> { + match *self { + ParseDecimalError::Json(ref e) => Some(e), + _ => None, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -468,4 +585,20 @@ mod tests { let dec: UDecimal = json.into_deserialize().unwrap(); assert_eq!(dec, UDecimal::new(98000, 7)); } + + #[test] + #[cfg(feature = "serde-decimal")] + fn parse_decimal_udecimal() { + let dec = "0.00980000".parse::().unwrap(); + assert_eq!(dec, Decimal::new(980000, 8)); + + let dec = "0.00980000".parse::().unwrap(); + assert_eq!(dec, UDecimal::new(980000, 8)); + + let dec = "0.00980".parse::().unwrap(); + assert_eq!(dec, Decimal::new(98000, 7)); + + let dec = "0.00980".parse::().unwrap(); + assert_eq!(dec, UDecimal::new(98000, 7)); + } }