From f4466c8c0a5a30b67a4804b4cab778092cacdd5b Mon Sep 17 00:00:00 2001 From: Stephen Akridge Date: Mon, 26 Mar 2018 21:07:11 -0700 Subject: [PATCH] Change for cuda verify integration --- Cargo.toml | 3 + build.rs | 12 ++++ src/accountant.rs | 10 +-- src/accountant_skel.rs | 155 ++++++++++++++++++++++++++++++----------- src/bin/testnode.rs | 4 +- src/ecdsa.rs | 147 ++++++++++++++++++++++++++++++++++++++ src/event.rs | 2 +- src/lib.rs | 6 +- src/mint.rs | 2 +- src/packet.rs | 26 +++---- src/plan.rs | 1 + src/streamer.rs | 34 ++++----- src/transaction.rs | 107 ++++++++++++++++++++-------- 13 files changed, 400 insertions(+), 109 deletions(-) create mode 100644 build.rs create mode 100644 src/ecdsa.rs diff --git a/Cargo.toml b/Cargo.toml index 1e914dc35d..b73553896d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,7 @@ codecov = { repository = "solana-labs/solana", branch = "master", service = "git [features] unstable = [] ipv6 = [] +cuda = [] [dependencies] rayon = "1.0.0" @@ -54,5 +55,7 @@ untrusted = "0.5.1" bincode = "1.0.0" chrono = { version = "0.4.0", features = ["serde"] } log = "^0.4.1" +env_logger = "^0.4.1" matches = "^0.1.6" byteorder = "^1.2.1" +libc = "^0.2.1" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000000..b8d72fe8ed --- /dev/null +++ b/build.rs @@ -0,0 +1,12 @@ +use std::env; + +fn main() { + if !env::var("CARGO_FEATURE_CUDA").is_err() { + println!("cargo:rustc-link-search=native=."); + println!("cargo:rustc-link-lib=static=cuda_verify_ed25519"); + println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64"); + println!("cargo:rustc-link-lib=dylib=cudart"); + println!("cargo:rustc-link-lib=dylib=cuda"); + println!("cargo:rustc-link-lib=dylib=cudadevrt"); + } +} diff --git a/src/accountant.rs b/src/accountant.rs index efe80df37e..7f8a4de0d3 100644 --- a/src/accountant.rs +++ b/src/accountant.rs @@ -3,6 +3,8 @@ //! on behalf of the caller, and a private low-level API for when they have //! already been signed and verified. +extern crate libc; + use chrono::prelude::*; use event::Event; use hash::Hash; @@ -104,19 +106,19 @@ impl Accountant { /// Process a Transaction that has already been verified. pub fn process_verified_transaction(&self, tr: &Transaction) -> Result<()> { - if self.get_balance(&tr.from).unwrap_or(0) < tr.tokens { + if self.get_balance(&tr.from).unwrap_or(0) < tr.data.tokens { return Err(AccountingError::InsufficientFunds); } - if !self.reserve_signature_with_last_id(&tr.sig, &tr.last_id) { + if !self.reserve_signature_with_last_id(&tr.sig, &tr.data.last_id) { return Err(AccountingError::InvalidTransferSignature); } if let Some(x) = self.balances.read().unwrap().get(&tr.from) { - *x.write().unwrap() -= tr.tokens; + *x.write().unwrap() -= tr.data.tokens; } - let mut plan = tr.plan.clone(); + let mut plan = tr.data.plan.clone(); plan.apply_witness(&Witness::Timestamp(*self.last_time.read().unwrap())); if let Some(ref payment) = plan.final_payment() { diff --git a/src/accountant_skel.rs b/src/accountant_skel.rs index bba0ab7407..d827b2d5ad 100644 --- a/src/accountant_skel.rs +++ b/src/accountant_skel.rs @@ -6,24 +6,27 @@ use accountant::Accountant; use bincode::{deserialize, serialize}; use entry::Entry; use event::Event; +use ecdsa; use hash::Hash; use historian::Historian; +use packet; +use packet::SharedPackets; use rayon::prelude::*; use recorder::Signal; use result::Result; use serde_json; use signature::PublicKey; +use std::cmp::max; +use std::collections::VecDeque; use std::io::Write; use std::net::{SocketAddr, UdpSocket}; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::mpsc::{channel, SendError}; +use std::sync::mpsc::{channel, Receiver, SendError, Sender}; +use std::sync::{Arc, Mutex}; use std::thread::{spawn, JoinHandle}; use std::time::Duration; use streamer; -use packet; -use std::sync::{Arc, Mutex}; use transaction::Transaction; -use std::collections::VecDeque; pub struct AccountantSkel { acc: Accountant, @@ -44,14 +47,14 @@ impl Request { /// Verify the request is valid. pub fn verify(&self) -> bool { match *self { - Request::Transaction(ref tr) => tr.verify(), + Request::Transaction(ref tr) => tr.verify_plan(), _ => true, } } } /// Parallel verfication of a batch of requests. -fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> { +pub fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> { reqs.into_par_iter().filter({ |x| x.0.verify() }).collect() } @@ -84,16 +87,20 @@ impl AccountantSkel { } /// Process Request items sent by clients. - pub fn log_verified_request(&mut self, msg: Request) -> Option { + pub fn log_verified_request(&mut self, msg: Request, verify: u8) -> Option { match msg { + Request::Transaction(_) if verify == 0 => { + trace!("Transaction failed sigverify"); + None + } Request::Transaction(tr) => { if let Err(err) = self.acc.process_verified_transaction(&tr) { - eprintln!("Transaction error: {:?}", err); + trace!("Transaction error: {:?}", err); } else if let Err(SendError(_)) = self.historian .sender - .send(Signal::Event(Event::Transaction(tr))) + .send(Signal::Event(Event::Transaction(tr.clone()))) { - eprintln!("Channel send error"); + error!("Channel send error"); } None } @@ -105,46 +112,87 @@ impl AccountantSkel { } } + fn verifier( + recvr: &streamer::PacketReceiver, + sendr: &Sender<(Vec, Vec>)>, + ) -> Result<()> { + let timer = Duration::new(1, 0); + let msgs = recvr.recv_timeout(timer)?; + trace!("got msgs"); + let mut v = Vec::new(); + v.push(msgs); + while let Ok(more) = recvr.try_recv() { + trace!("got more msgs"); + v.push(more); + } + info!("batch {}", v.len()); + let chunk = max(1, (v.len() + 3) / 4); + let chunks: Vec<_> = v.chunks(chunk).collect(); + let rvs: Vec<_> = chunks + .into_par_iter() + .map(|x| ecdsa::ed25519_verify(&x.to_vec())) + .collect(); + for (v, r) in v.chunks(chunk).zip(rvs) { + sendr.send((v.to_vec(), r))?; + } + Ok(()) + } + + pub fn deserialize_packets(p: &packet::Packets) -> Vec> { + // TODO: deserealize in parallel + let mut r = vec![]; + for x in &p.packets { + let rsp_addr = x.meta.addr(); + let sz = x.meta.size; + if let Ok(req) = deserialize(&x.data[0..sz]) { + r.push(Some((req, rsp_addr))); + } else { + r.push(None); + } + } + r + } + fn process( obj: &Arc>>, - packet_receiver: &streamer::PacketReceiver, + verified_receiver: &Receiver<(Vec, Vec>)>, blob_sender: &streamer::BlobSender, packet_recycler: &packet::PacketRecycler, blob_recycler: &packet::BlobRecycler, ) -> Result<()> { let timer = Duration::new(1, 0); - let msgs = packet_receiver.recv_timeout(timer)?; - let msgs_ = msgs.clone(); - let mut rsps = VecDeque::new(); - { - let mut reqs = vec![]; - for packet in &msgs.read().unwrap().packets { - let rsp_addr = packet.meta.addr(); - let sz = packet.meta.size; - let req = deserialize(&packet.data[0..sz])?; - reqs.push((req, rsp_addr)); - } - let reqs = filter_valid_requests(reqs); - for (req, rsp_addr) in reqs { - if let Some(resp) = obj.lock().unwrap().log_verified_request(req) { - let blob = blob_recycler.allocate(); - { - let mut b = blob.write().unwrap(); - let v = serialize(&resp)?; - let len = v.len(); - b.data[..len].copy_from_slice(&v); - b.meta.size = len; - b.meta.set_addr(&rsp_addr); + let (mms, vvs) = verified_receiver.recv_timeout(timer)?; + for (msgs, vers) in mms.into_iter().zip(vvs.into_iter()) { + let msgs_ = msgs.clone(); + let mut rsps = VecDeque::new(); + { + let reqs = Self::deserialize_packets(&((*msgs).read().unwrap())); + for (data, v) in reqs.into_iter().zip(vers.into_iter()) { + if let Some((req, rsp_addr)) = data { + if !req.verify() { + continue; + } + if let Some(resp) = obj.lock().unwrap().log_verified_request(req, v) { + let blob = blob_recycler.allocate(); + { + let mut b = blob.write().unwrap(); + let v = serialize(&resp)?; + let len = v.len(); + b.data[..len].copy_from_slice(&v); + b.meta.size = len; + b.meta.set_addr(&rsp_addr); + } + rsps.push_back(blob); + } } - rsps.push_back(blob); } } + if !rsps.is_empty() { + //don't wake up the other side if there is nothing + blob_sender.send(rsps)?; + } + packet_recycler.recycle(msgs_); } - if !rsps.is_empty() { - //don't wake up the other side if there is nothing - blob_sender.send(rsps)?; - } - packet_recycler.recycle(msgs_); Ok(()) } @@ -169,11 +217,21 @@ impl AccountantSkel { let (blob_sender, blob_receiver) = channel(); let t_responder = streamer::responder(write, exit.clone(), blob_recycler.clone(), blob_receiver); + let (verified_sender, verified_receiver) = channel(); + + let exit_ = exit.clone(); + let t_verifier = spawn(move || loop { + let e = Self::verifier(&packet_receiver, &verified_sender); + if e.is_err() && exit_.load(Ordering::Relaxed) { + break; + } + }); + let skel = obj.clone(); let t_server = spawn(move || loop { let e = AccountantSkel::process( &skel, - &packet_receiver, + &verified_receiver, &blob_sender, &packet_recycler, &blob_recycler, @@ -182,6 +240,21 @@ impl AccountantSkel { break; } }); - Ok(vec![t_receiver, t_responder, t_server]) + Ok(vec![t_receiver, t_responder, t_server, t_verifier]) + } +} + +#[cfg(test)] +mod tests { + use accountant_skel::Request; + use bincode::serialize; + use ecdsa; + use transaction::{memfind, test_tx}; + #[test] + fn test_layout() { + let tr = test_tx(); + let tx = serialize(&tr).unwrap(); + let packet = serialize(&Request::Transaction(tr)).unwrap(); + assert_matches!(memfind(&packet, &tx), Some(ecdsa::TX_OFFSET)); } } diff --git a/src/bin/testnode.rs b/src/bin/testnode.rs index bbce2d91ee..37410d3d85 100644 --- a/src/bin/testnode.rs +++ b/src/bin/testnode.rs @@ -1,3 +1,4 @@ +extern crate env_logger; extern crate serde_json; extern crate solana; @@ -11,6 +12,7 @@ use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex}; fn main() { + env_logger::init().unwrap(); let addr = "127.0.0.1:8000"; let stdin = io::stdin(); let mut entries = stdin @@ -27,7 +29,7 @@ fn main() { // transfer to oneself. let entry1: Entry = entries.next().unwrap(); let deposit = if let Event::Transaction(ref tr) = entry1.events[0] { - tr.plan.final_payment() + tr.data.plan.final_payment() } else { None }; diff --git a/src/ecdsa.rs b/src/ecdsa.rs new file mode 100644 index 0000000000..f3ede709b9 --- /dev/null +++ b/src/ecdsa.rs @@ -0,0 +1,147 @@ +// Cuda-only imports +#[cfg(feature = "cuda")] +use packet::PACKET_DATA_SIZE; +#[cfg(feature = "cuda")] +use std::mem::size_of; + +// Non-cuda imports +#[cfg(not(feature = "cuda"))] +use rayon::prelude::*; +#[cfg(not(feature = "cuda"))] +use untrusted; +#[cfg(not(feature = "cuda"))] +use ring::signature; + +// Shared imports +use packet::{Packet, SharedPackets}; + +pub const TX_OFFSET: usize = 4; +pub const SIGNED_DATA_OFFSET: usize = 112; +pub const SIG_OFFSET: usize = 8; +pub const PUB_KEY_OFFSET: usize = 80; + +pub const SIG_SIZE: usize = 64; +pub const PUB_KEY_SIZE: usize = 32; + +#[cfg(feature = "cuda")] +#[repr(C)] +struct Elems { + elems: *const Packet, + num: u32, +} + +#[cfg(feature = "cuda")] +#[link(name = "cuda_verify_ed25519")] +extern "C" { + fn ed25519_verify_many( + vecs: *const Elems, + num: u32, //number of vecs + message_size: u32, //size of each element inside the elems field of the vec + public_key_offset: u32, + signature_offset: u32, + signed_message_offset: u32, + signed_message_len_offset: u32, + out: *mut u8, //combined length of all the items in vecs + ) -> u32; +} + +#[cfg(not(feature = "cuda"))] +fn verify_packet(packet: &Packet) -> u8 { + let msg_start = TX_OFFSET + SIGNED_DATA_OFFSET; + let sig_start = TX_OFFSET + SIG_OFFSET; + let sig_end = sig_start + SIG_SIZE; + let pub_key_start = TX_OFFSET + PUB_KEY_OFFSET; + let pub_key_end = pub_key_start + PUB_KEY_SIZE; + + if packet.meta.size > msg_start { + let msg_end = packet.meta.size; + return if signature::verify( + &signature::ED25519, + untrusted::Input::from(&packet.data[pub_key_start..pub_key_end]), + untrusted::Input::from(&packet.data[msg_start..msg_end]), + untrusted::Input::from(&packet.data[sig_start..sig_end]), + ).is_ok() + { + 1 + } else { + 0 + }; + } else { + return 0; + } +} + +#[cfg(not(feature = "cuda"))] +pub fn ed25519_verify(batches: &Vec) -> Vec> { + let mut locks = Vec::new(); + let mut rvs = Vec::new(); + for packets in batches { + locks.push(packets.read().unwrap()); + } + + for p in locks { + let mut v = Vec::new(); + v.resize(p.packets.len(), 0); + v = p.packets.par_iter().map(|x| verify_packet(x)).collect(); + rvs.push(v); + } + rvs +} + +#[cfg(feature = "cuda")] +pub fn ed25519_verify(batches: &Vec) -> Vec> { + let mut out = Vec::new(); + let mut elems = Vec::new(); + let mut locks = Vec::new(); + let mut rvs = Vec::new(); + + for packets in batches { + locks.push(packets.read().unwrap()); + } + let mut num = 0; + for p in locks { + elems.push(Elems { + elems: p.packets.as_ptr(), + num: p.packets.len() as u32, + }); + let mut v = Vec::new(); + v.resize(p.packets.len(), 0); + rvs.push(v); + num += p.packets.len(); + } + out.resize(num, 0); + trace!("Starting verify num packets: {}", num); + trace!("elem len: {}", elems.len() as u32); + trace!("packet sizeof: {}", size_of::() as u32); + trace!("pub key: {}", (TX_OFFSET + PUB_KEY_OFFSET) as u32); + trace!("sig offset: {}", (TX_OFFSET + SIG_OFFSET) as u32); + trace!("sign data: {}", (TX_OFFSET + SIGNED_DATA_OFFSET) as u32); + trace!("len offset: {}", PACKET_DATA_SIZE as u32); + unsafe { + let res = ed25519_verify_many( + elems.as_ptr(), + elems.len() as u32, + size_of::() as u32, + (TX_OFFSET + PUB_KEY_OFFSET) as u32, + (TX_OFFSET + SIG_OFFSET) as u32, + (TX_OFFSET + SIGNED_DATA_OFFSET) as u32, + PACKET_DATA_SIZE as u32, + out.as_mut_ptr(), + ); + if res != 0 { + trace!("RETURN!!!: {}", res); + } + } + trace!("done verify"); + let mut num = 0; + for vs in rvs.iter_mut() { + for mut v in vs.iter_mut() { + *v = out[num]; + if *v != 0 { + trace!("VERIFIED PACKET!!!!!"); + } + num += 1; + } + } + rvs +} diff --git a/src/event.rs b/src/event.rs index dbd2a93412..fabd6d7735 100644 --- a/src/event.rs +++ b/src/event.rs @@ -47,7 +47,7 @@ impl Event { /// spending plan is valid. pub fn verify(&self) -> bool { match *self { - Event::Transaction(ref tr) => tr.verify(), + Event::Transaction(ref tr) => tr.verify_sig(), Event::Signature { from, tx_sig, sig } => sig.verify(&from, &tx_sig), Event::Timestamp { from, dt, sig } => sig.verify(&from, &serialize(&dt).unwrap()), } diff --git a/src/lib.rs b/src/lib.rs index a44a14be2b..1c8ed2ebad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,14 @@ pub mod accountant_skel; pub mod accountant_stub; pub mod entry; pub mod event; +pub mod ecdsa; pub mod hash; +pub mod historian; pub mod ledger; pub mod mint; +pub mod packet; pub mod plan; pub mod recorder; -pub mod historian; -pub mod packet; pub mod result; pub mod signature; pub mod streamer; @@ -19,6 +20,7 @@ extern crate bincode; extern crate byteorder; extern crate chrono; extern crate generic_array; +extern crate libc; #[macro_use] extern crate log; extern crate rayon; diff --git a/src/mint.rs b/src/mint.rs index 5b869c3152..2f385914de 100644 --- a/src/mint.rs +++ b/src/mint.rs @@ -68,7 +68,7 @@ mod tests { fn test_create_events() { let mut events = Mint::new(100).create_events().into_iter(); if let Event::Transaction(tr) = events.next().unwrap() { - if let Plan::Pay(payment) = tr.plan { + if let Plan::Pay(payment) = tr.data.plan { assert_eq!(tr.from, payment.to); } } diff --git a/src/packet.rs b/src/packet.rs index e6c91592d0..e75f0a820b 100644 --- a/src/packet.rs +++ b/src/packet.rs @@ -1,10 +1,10 @@ -use std::sync::{Arc, Mutex, RwLock}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use result::{Error, Result}; +use std::collections::VecDeque; use std::fmt; use std::io; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; -use std::collections::VecDeque; -use result::{Error, Result}; +use std::sync::{Arc, Mutex, RwLock}; pub type SharedPackets = Arc>; pub type SharedBlob = Arc>; @@ -13,10 +13,11 @@ pub type BlobRecycler = Recycler; const NUM_PACKETS: usize = 1024 * 8; const BLOB_SIZE: usize = 64 * 1024; -pub const PACKET_SIZE: usize = 256; -pub const NUM_BLOBS: usize = (NUM_PACKETS * PACKET_SIZE) / BLOB_SIZE; +pub const PACKET_DATA_SIZE: usize = 256; +pub const NUM_BLOBS: usize = (NUM_PACKETS * PACKET_DATA_SIZE) / BLOB_SIZE; #[derive(Clone, Default)] +#[repr(C)] pub struct Meta { pub size: usize, pub addr: [u16; 8], @@ -25,8 +26,9 @@ pub struct Meta { } #[derive(Clone)] +#[repr(C)] pub struct Packet { - pub data: [u8; PACKET_SIZE], + pub data: [u8; PACKET_DATA_SIZE], pub meta: Meta, } @@ -44,7 +46,7 @@ impl fmt::Debug for Packet { impl Default for Packet { fn default() -> Packet { Packet { - data: [0u8; PACKET_SIZE], + data: [0u8; PACKET_DATA_SIZE], meta: Meta::default(), } } @@ -279,11 +281,11 @@ impl Blob { #[cfg(test)] mod test { - use std::net::UdpSocket; - use std::io::Write; - use std::io; - use std::collections::VecDeque; use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets}; + use std::collections::VecDeque; + use std::io; + use std::io::Write; + use std::net::UdpSocket; #[test] pub fn packet_recycler_test() { let r = PacketRecycler::default(); diff --git a/src/plan.rs b/src/plan.rs index d1bc40745e..72aa41f1c8 100644 --- a/src/plan.rs +++ b/src/plan.rs @@ -35,6 +35,7 @@ pub struct Payment { pub to: PublicKey, } +#[repr(C)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] pub enum Plan { Pay(Payment), diff --git a/src/streamer.rs b/src/streamer.rs index 87581f32df..fad9f54070 100644 --- a/src/streamer.rs +++ b/src/streamer.rs @@ -1,12 +1,12 @@ +use packet::{Blob, BlobRecycler, PacketRecycler, SharedBlob, SharedPackets, NUM_BLOBS}; +use result::Result; +use std::collections::VecDeque; +use std::net::UdpSocket; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; -use std::time::Duration; -use std::net::UdpSocket; use std::thread::{spawn, JoinHandle}; -use std::collections::VecDeque; -use result::Result; -use packet::{Blob, BlobRecycler, PacketRecycler, SharedBlob, SharedPackets, NUM_BLOBS}; +use std::time::Duration; pub type PacketReceiver = mpsc::Receiver; pub type PacketSender = mpsc::Sender; @@ -67,7 +67,7 @@ pub fn responder( r: BlobReceiver, ) -> JoinHandle<()> { spawn(move || loop { - if recv_send(&sock, &recycler, &r).is_err() || exit.load(Ordering::Relaxed) { + if recv_send(&sock, &recycler, &r).is_err() && exit.load(Ordering::Relaxed) { break; } }) @@ -141,16 +141,16 @@ pub fn window( mod bench { extern crate test; use self::test::Bencher; + use packet::{Packet, PacketRecycler, PACKET_DATA_SIZE}; use result::Result; use std::net::{SocketAddr, UdpSocket}; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::mpsc::channel; use std::sync::{Arc, Mutex}; use std::thread::sleep; use std::thread::{spawn, JoinHandle}; use std::time::Duration; use std::time::SystemTime; - use std::sync::mpsc::channel; - use std::sync::atomic::{AtomicBool, Ordering}; - use packet::{Packet, PacketRecycler, PACKET_SIZE}; use streamer::{receiver, PacketReceiver}; fn producer( @@ -163,7 +163,7 @@ mod bench { let msgs_ = msgs.clone(); msgs.write().unwrap().packets.resize(10, Packet::default()); for w in msgs.write().unwrap().packets.iter_mut() { - w.meta.size = PACKET_SIZE; + w.meta.size = PACKET_DATA_SIZE; w.meta.set_addr(&addr); } spawn(move || loop { @@ -241,15 +241,15 @@ mod bench { #[cfg(test)] mod test { + use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets, PACKET_DATA_SIZE}; + use std::collections::VecDeque; + use std::io; + use std::io::Write; use std::net::UdpSocket; + use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc::channel; - use std::io::Write; - use std::io; - use std::collections::VecDeque; use std::time::Duration; - use std::sync::Arc; - use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets, PACKET_SIZE}; use streamer::{receiver, responder, window, BlobReceiver, PacketReceiver}; fn get_msgs(r: PacketReceiver, num: &mut usize) { @@ -288,7 +288,7 @@ mod test { let b_ = b.clone(); let mut w = b.write().unwrap(); w.data[0] = i as u8; - w.meta.size = PACKET_SIZE; + w.meta.size = PACKET_DATA_SIZE; w.meta.set_addr(&addr); msgs.push_back(b_); } @@ -338,7 +338,7 @@ mod test { let mut w = b.write().unwrap(); w.set_index(i).unwrap(); assert_eq!(i, w.get_index().unwrap()); - w.meta.size = PACKET_SIZE; + w.meta.size = PACKET_DATA_SIZE; w.meta.set_addr(&addr); msgs.push_back(b_); } diff --git a/src/transaction.rs b/src/transaction.rs index be43b88373..bf7127975c 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -8,12 +8,17 @@ use rayon::prelude::*; use signature::{KeyPair, KeyPairUtil, PublicKey, Signature, SignatureUtil}; #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] -pub struct Transaction { - pub from: PublicKey, - pub plan: Plan, +pub struct TransactionData { pub tokens: i64, pub last_id: Hash, + pub plan: Plan, +} + +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] +pub struct Transaction { pub sig: Signature, + pub from: PublicKey, + pub data: TransactionData, } impl Transaction { @@ -22,11 +27,13 @@ impl Transaction { let from = from_keypair.pubkey(); let plan = Plan::Pay(Payment { tokens, to }); let mut tr = Transaction { - from, - plan, - tokens, - last_id, sig: Signature::default(), + data: TransactionData { + plan, + tokens, + last_id, + }, + from: from, }; tr.sign(from_keypair); tr @@ -46,10 +53,12 @@ impl Transaction { (Condition::Signature(from), Payment { tokens, to: from }), ); let mut tr = Transaction { - from, - plan, - tokens, - last_id, + data: TransactionData { + plan, + tokens, + last_id, + }, + from: from, sig: Signature::default(), }; tr.sign(from_keypair); @@ -57,7 +66,7 @@ impl Transaction { } fn get_sign_data(&self) -> Vec { - serialize(&(&self.plan, &self.tokens, &self.last_id)).unwrap() + serialize(&(&self.data)).unwrap() } /// Sign this transaction. @@ -66,20 +75,45 @@ impl Transaction { self.sig = Signature::clone_from_slice(keypair.sign(&sign_data).as_ref()); } - /// Verify this transaction's signature and its spending plan. - pub fn verify(&self) -> bool { - self.sig.verify(&self.from, &self.get_sign_data()) && self.plan.verify(self.tokens) + pub fn verify_sig(&self) -> bool { + self.sig.verify(&self.from, &self.get_sign_data()) } + + pub fn verify_plan(&self) -> bool { + self.data.plan.verify(self.data.tokens) + } +} + +#[cfg(test)] +pub fn test_tx() -> Transaction { + let keypair1 = KeyPair::new(); + let pubkey1 = keypair1.pubkey(); + let zero = Hash::default(); + let mut tr = Transaction::new(&keypair1, pubkey1, 42, zero); + tr.sign(&keypair1); + return tr; +} + +#[cfg(test)] +pub fn memfind(a: &[A], b: &[A]) -> Option { + assert!(a.len() >= b.len()); + let end = a.len() - b.len() + 1; + for i in 0..end { + if a[i..i + b.len()] == b[..] { + return Some(i); + } + } + None } /// Verify a batch of signatures. pub fn verify_signatures(transactions: &[Transaction]) -> bool { - transactions.par_iter().all(|tr| tr.verify()) + transactions.par_iter().all(|tr| tr.verify_sig()) } /// Verify a batch of spending plans. pub fn verify_plans(transactions: &[Transaction]) -> bool { - transactions.par_iter().all(|tr| tr.plan.verify(tr.tokens)) + transactions.par_iter().all(|tr| tr.verify_plan()) } /// Verify a batch of transactions. @@ -91,13 +125,14 @@ pub fn verify_transactions(transactions: &[Transaction]) -> bool { mod tests { use super::*; use bincode::{deserialize, serialize}; + use ecdsa; #[test] fn test_claim() { let keypair = KeyPair::new(); let zero = Hash::default(); let tr0 = Transaction::new(&keypair, keypair.pubkey(), 42, zero); - assert!(tr0.verify()); + assert!(tr0.verify_plan()); } #[test] @@ -107,7 +142,7 @@ mod tests { let keypair1 = KeyPair::new(); let pubkey1 = keypair1.pubkey(); let tr0 = Transaction::new(&keypair0, pubkey1, 42, zero); - assert!(tr0.verify()); + assert!(tr0.verify_plan()); } #[test] @@ -117,10 +152,12 @@ mod tests { to: Default::default(), }); let claim0 = Transaction { + data: TransactionData { + plan, + tokens: 0, + last_id: Default::default(), + }, from: Default::default(), - plan, - tokens: 0, - last_id: Default::default(), sig: Default::default(), }; let buf = serialize(&claim0).unwrap(); @@ -135,8 +172,8 @@ mod tests { let pubkey = keypair.pubkey(); let mut tr = Transaction::new(&keypair, pubkey, 42, zero); tr.sign(&keypair); - tr.tokens = 1_000_000; // <-- attack! - assert!(!tr.verify()); + tr.data.tokens = 1_000_000; // <-- attack! + assert!(!tr.verify_plan()); } #[test] @@ -148,10 +185,20 @@ mod tests { let zero = Hash::default(); let mut tr = Transaction::new(&keypair0, pubkey1, 42, zero); tr.sign(&keypair0); - if let Plan::Pay(ref mut payment) = tr.plan { + if let Plan::Pay(ref mut payment) = tr.data.plan { payment.to = thief_keypair.pubkey(); // <-- attack! }; - assert!(!tr.verify()); + assert!(tr.verify_plan()); + assert!(!tr.verify_sig()); + } + #[test] + fn test_layout() { + let tr = test_tx(); + let sign_data = tr.get_sign_data(); + let tx = serialize(&tr).unwrap(); + assert_matches!(memfind(&tx, &sign_data), Some(ecdsa::SIGNED_DATA_OFFSET)); + assert_matches!(memfind(&tx, &tr.sig), Some(ecdsa::SIG_OFFSET)); + assert_matches!(memfind(&tx, &tr.from), Some(ecdsa::PUB_KEY_OFFSET)); } #[test] @@ -160,16 +207,16 @@ mod tests { let keypair1 = KeyPair::new(); let zero = Hash::default(); let mut tr = Transaction::new(&keypair0, keypair1.pubkey(), 1, zero); - if let Plan::Pay(ref mut payment) = tr.plan { + if let Plan::Pay(ref mut payment) = tr.data.plan { payment.tokens = 2; // <-- attack! } - assert!(!tr.verify()); + assert!(!tr.verify_plan()); // Also, ensure all branchs of the plan spend all tokens - if let Plan::Pay(ref mut payment) = tr.plan { + if let Plan::Pay(ref mut payment) = tr.data.plan { payment.tokens = 0; // <-- whoops! } - assert!(!tr.verify()); + assert!(!tr.verify_plan()); } #[test]