Change for cuda verify integration

This commit is contained in:
Stephen Akridge 2018-03-26 21:07:11 -07:00
parent bc6d6b20fa
commit f4466c8c0a
13 changed files with 400 additions and 109 deletions

View File

@ -41,6 +41,7 @@ codecov = { repository = "solana-labs/solana", branch = "master", service = "git
[features] [features]
unstable = [] unstable = []
ipv6 = [] ipv6 = []
cuda = []
[dependencies] [dependencies]
rayon = "1.0.0" rayon = "1.0.0"
@ -54,5 +55,7 @@ untrusted = "0.5.1"
bincode = "1.0.0" bincode = "1.0.0"
chrono = { version = "0.4.0", features = ["serde"] } chrono = { version = "0.4.0", features = ["serde"] }
log = "^0.4.1" log = "^0.4.1"
env_logger = "^0.4.1"
matches = "^0.1.6" matches = "^0.1.6"
byteorder = "^1.2.1" byteorder = "^1.2.1"
libc = "^0.2.1"

12
build.rs Normal file
View File

@ -0,0 +1,12 @@
use std::env;
fn main() {
if !env::var("CARGO_FEATURE_CUDA").is_err() {
println!("cargo:rustc-link-search=native=.");
println!("cargo:rustc-link-lib=static=cuda_verify_ed25519");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
}

View File

@ -3,6 +3,8 @@
//! on behalf of the caller, and a private low-level API for when they have //! on behalf of the caller, and a private low-level API for when they have
//! already been signed and verified. //! already been signed and verified.
extern crate libc;
use chrono::prelude::*; use chrono::prelude::*;
use event::Event; use event::Event;
use hash::Hash; use hash::Hash;
@ -104,19 +106,19 @@ impl Accountant {
/// Process a Transaction that has already been verified. /// Process a Transaction that has already been verified.
pub fn process_verified_transaction(&self, tr: &Transaction) -> Result<()> { pub fn process_verified_transaction(&self, tr: &Transaction) -> Result<()> {
if self.get_balance(&tr.from).unwrap_or(0) < tr.tokens { if self.get_balance(&tr.from).unwrap_or(0) < tr.data.tokens {
return Err(AccountingError::InsufficientFunds); return Err(AccountingError::InsufficientFunds);
} }
if !self.reserve_signature_with_last_id(&tr.sig, &tr.last_id) { if !self.reserve_signature_with_last_id(&tr.sig, &tr.data.last_id) {
return Err(AccountingError::InvalidTransferSignature); return Err(AccountingError::InvalidTransferSignature);
} }
if let Some(x) = self.balances.read().unwrap().get(&tr.from) { if let Some(x) = self.balances.read().unwrap().get(&tr.from) {
*x.write().unwrap() -= tr.tokens; *x.write().unwrap() -= tr.data.tokens;
} }
let mut plan = tr.plan.clone(); let mut plan = tr.data.plan.clone();
plan.apply_witness(&Witness::Timestamp(*self.last_time.read().unwrap())); plan.apply_witness(&Witness::Timestamp(*self.last_time.read().unwrap()));
if let Some(ref payment) = plan.final_payment() { if let Some(ref payment) = plan.final_payment() {

View File

@ -6,24 +6,27 @@ use accountant::Accountant;
use bincode::{deserialize, serialize}; use bincode::{deserialize, serialize};
use entry::Entry; use entry::Entry;
use event::Event; use event::Event;
use ecdsa;
use hash::Hash; use hash::Hash;
use historian::Historian; use historian::Historian;
use packet;
use packet::SharedPackets;
use rayon::prelude::*; use rayon::prelude::*;
use recorder::Signal; use recorder::Signal;
use result::Result; use result::Result;
use serde_json; use serde_json;
use signature::PublicKey; use signature::PublicKey;
use std::cmp::max;
use std::collections::VecDeque;
use std::io::Write; use std::io::Write;
use std::net::{SocketAddr, UdpSocket}; use std::net::{SocketAddr, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{channel, SendError}; use std::sync::mpsc::{channel, Receiver, SendError, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{spawn, JoinHandle}; use std::thread::{spawn, JoinHandle};
use std::time::Duration; use std::time::Duration;
use streamer; use streamer;
use packet;
use std::sync::{Arc, Mutex};
use transaction::Transaction; use transaction::Transaction;
use std::collections::VecDeque;
pub struct AccountantSkel<W: Write + Send + 'static> { pub struct AccountantSkel<W: Write + Send + 'static> {
acc: Accountant, acc: Accountant,
@ -44,14 +47,14 @@ impl Request {
/// Verify the request is valid. /// Verify the request is valid.
pub fn verify(&self) -> bool { pub fn verify(&self) -> bool {
match *self { match *self {
Request::Transaction(ref tr) => tr.verify(), Request::Transaction(ref tr) => tr.verify_plan(),
_ => true, _ => true,
} }
} }
} }
/// Parallel verfication of a batch of requests. /// Parallel verfication of a batch of requests.
fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> { pub fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> {
reqs.into_par_iter().filter({ |x| x.0.verify() }).collect() reqs.into_par_iter().filter({ |x| x.0.verify() }).collect()
} }
@ -84,16 +87,20 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
} }
/// Process Request items sent by clients. /// Process Request items sent by clients.
pub fn log_verified_request(&mut self, msg: Request) -> Option<Response> { pub fn log_verified_request(&mut self, msg: Request, verify: u8) -> Option<Response> {
match msg { match msg {
Request::Transaction(_) if verify == 0 => {
trace!("Transaction failed sigverify");
None
}
Request::Transaction(tr) => { Request::Transaction(tr) => {
if let Err(err) = self.acc.process_verified_transaction(&tr) { if let Err(err) = self.acc.process_verified_transaction(&tr) {
eprintln!("Transaction error: {:?}", err); trace!("Transaction error: {:?}", err);
} else if let Err(SendError(_)) = self.historian } else if let Err(SendError(_)) = self.historian
.sender .sender
.send(Signal::Event(Event::Transaction(tr))) .send(Signal::Event(Event::Transaction(tr.clone())))
{ {
eprintln!("Channel send error"); error!("Channel send error");
} }
None None
} }
@ -105,46 +112,87 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
} }
} }
fn verifier(
recvr: &streamer::PacketReceiver,
sendr: &Sender<(Vec<SharedPackets>, Vec<Vec<u8>>)>,
) -> Result<()> {
let timer = Duration::new(1, 0);
let msgs = recvr.recv_timeout(timer)?;
trace!("got msgs");
let mut v = Vec::new();
v.push(msgs);
while let Ok(more) = recvr.try_recv() {
trace!("got more msgs");
v.push(more);
}
info!("batch {}", v.len());
let chunk = max(1, (v.len() + 3) / 4);
let chunks: Vec<_> = v.chunks(chunk).collect();
let rvs: Vec<_> = chunks
.into_par_iter()
.map(|x| ecdsa::ed25519_verify(&x.to_vec()))
.collect();
for (v, r) in v.chunks(chunk).zip(rvs) {
sendr.send((v.to_vec(), r))?;
}
Ok(())
}
pub fn deserialize_packets(p: &packet::Packets) -> Vec<Option<(Request, SocketAddr)>> {
// TODO: deserealize in parallel
let mut r = vec![];
for x in &p.packets {
let rsp_addr = x.meta.addr();
let sz = x.meta.size;
if let Ok(req) = deserialize(&x.data[0..sz]) {
r.push(Some((req, rsp_addr)));
} else {
r.push(None);
}
}
r
}
fn process( fn process(
obj: &Arc<Mutex<AccountantSkel<W>>>, obj: &Arc<Mutex<AccountantSkel<W>>>,
packet_receiver: &streamer::PacketReceiver, verified_receiver: &Receiver<(Vec<SharedPackets>, Vec<Vec<u8>>)>,
blob_sender: &streamer::BlobSender, blob_sender: &streamer::BlobSender,
packet_recycler: &packet::PacketRecycler, packet_recycler: &packet::PacketRecycler,
blob_recycler: &packet::BlobRecycler, blob_recycler: &packet::BlobRecycler,
) -> Result<()> { ) -> Result<()> {
let timer = Duration::new(1, 0); let timer = Duration::new(1, 0);
let msgs = packet_receiver.recv_timeout(timer)?; let (mms, vvs) = verified_receiver.recv_timeout(timer)?;
let msgs_ = msgs.clone(); for (msgs, vers) in mms.into_iter().zip(vvs.into_iter()) {
let mut rsps = VecDeque::new(); let msgs_ = msgs.clone();
{ let mut rsps = VecDeque::new();
let mut reqs = vec![]; {
for packet in &msgs.read().unwrap().packets { let reqs = Self::deserialize_packets(&((*msgs).read().unwrap()));
let rsp_addr = packet.meta.addr(); for (data, v) in reqs.into_iter().zip(vers.into_iter()) {
let sz = packet.meta.size; if let Some((req, rsp_addr)) = data {
let req = deserialize(&packet.data[0..sz])?; if !req.verify() {
reqs.push((req, rsp_addr)); continue;
} }
let reqs = filter_valid_requests(reqs); if let Some(resp) = obj.lock().unwrap().log_verified_request(req, v) {
for (req, rsp_addr) in reqs { let blob = blob_recycler.allocate();
if let Some(resp) = obj.lock().unwrap().log_verified_request(req) { {
let blob = blob_recycler.allocate(); let mut b = blob.write().unwrap();
{ let v = serialize(&resp)?;
let mut b = blob.write().unwrap(); let len = v.len();
let v = serialize(&resp)?; b.data[..len].copy_from_slice(&v);
let len = v.len(); b.meta.size = len;
b.data[..len].copy_from_slice(&v); b.meta.set_addr(&rsp_addr);
b.meta.size = len; }
b.meta.set_addr(&rsp_addr); rsps.push_back(blob);
}
} }
rsps.push_back(blob);
} }
} }
if !rsps.is_empty() {
//don't wake up the other side if there is nothing
blob_sender.send(rsps)?;
}
packet_recycler.recycle(msgs_);
} }
if !rsps.is_empty() {
//don't wake up the other side if there is nothing
blob_sender.send(rsps)?;
}
packet_recycler.recycle(msgs_);
Ok(()) Ok(())
} }
@ -169,11 +217,21 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
let (blob_sender, blob_receiver) = channel(); let (blob_sender, blob_receiver) = channel();
let t_responder = let t_responder =
streamer::responder(write, exit.clone(), blob_recycler.clone(), blob_receiver); streamer::responder(write, exit.clone(), blob_recycler.clone(), blob_receiver);
let (verified_sender, verified_receiver) = channel();
let exit_ = exit.clone();
let t_verifier = spawn(move || loop {
let e = Self::verifier(&packet_receiver, &verified_sender);
if e.is_err() && exit_.load(Ordering::Relaxed) {
break;
}
});
let skel = obj.clone(); let skel = obj.clone();
let t_server = spawn(move || loop { let t_server = spawn(move || loop {
let e = AccountantSkel::process( let e = AccountantSkel::process(
&skel, &skel,
&packet_receiver, &verified_receiver,
&blob_sender, &blob_sender,
&packet_recycler, &packet_recycler,
&blob_recycler, &blob_recycler,
@ -182,6 +240,21 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
break; break;
} }
}); });
Ok(vec![t_receiver, t_responder, t_server]) Ok(vec![t_receiver, t_responder, t_server, t_verifier])
}
}
#[cfg(test)]
mod tests {
use accountant_skel::Request;
use bincode::serialize;
use ecdsa;
use transaction::{memfind, test_tx};
#[test]
fn test_layout() {
let tr = test_tx();
let tx = serialize(&tr).unwrap();
let packet = serialize(&Request::Transaction(tr)).unwrap();
assert_matches!(memfind(&packet, &tx), Some(ecdsa::TX_OFFSET));
} }
} }

View File

@ -1,3 +1,4 @@
extern crate env_logger;
extern crate serde_json; extern crate serde_json;
extern crate solana; extern crate solana;
@ -11,6 +12,7 @@ use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
fn main() { fn main() {
env_logger::init().unwrap();
let addr = "127.0.0.1:8000"; let addr = "127.0.0.1:8000";
let stdin = io::stdin(); let stdin = io::stdin();
let mut entries = stdin let mut entries = stdin
@ -27,7 +29,7 @@ fn main() {
// transfer to oneself. // transfer to oneself.
let entry1: Entry = entries.next().unwrap(); let entry1: Entry = entries.next().unwrap();
let deposit = if let Event::Transaction(ref tr) = entry1.events[0] { let deposit = if let Event::Transaction(ref tr) = entry1.events[0] {
tr.plan.final_payment() tr.data.plan.final_payment()
} else { } else {
None None
}; };

147
src/ecdsa.rs Normal file
View File

@ -0,0 +1,147 @@
// Cuda-only imports
#[cfg(feature = "cuda")]
use packet::PACKET_DATA_SIZE;
#[cfg(feature = "cuda")]
use std::mem::size_of;
// Non-cuda imports
#[cfg(not(feature = "cuda"))]
use rayon::prelude::*;
#[cfg(not(feature = "cuda"))]
use untrusted;
#[cfg(not(feature = "cuda"))]
use ring::signature;
// Shared imports
use packet::{Packet, SharedPackets};
pub const TX_OFFSET: usize = 4;
pub const SIGNED_DATA_OFFSET: usize = 112;
pub const SIG_OFFSET: usize = 8;
pub const PUB_KEY_OFFSET: usize = 80;
pub const SIG_SIZE: usize = 64;
pub const PUB_KEY_SIZE: usize = 32;
#[cfg(feature = "cuda")]
#[repr(C)]
struct Elems {
elems: *const Packet,
num: u32,
}
#[cfg(feature = "cuda")]
#[link(name = "cuda_verify_ed25519")]
extern "C" {
fn ed25519_verify_many(
vecs: *const Elems,
num: u32, //number of vecs
message_size: u32, //size of each element inside the elems field of the vec
public_key_offset: u32,
signature_offset: u32,
signed_message_offset: u32,
signed_message_len_offset: u32,
out: *mut u8, //combined length of all the items in vecs
) -> u32;
}
#[cfg(not(feature = "cuda"))]
fn verify_packet(packet: &Packet) -> u8 {
let msg_start = TX_OFFSET + SIGNED_DATA_OFFSET;
let sig_start = TX_OFFSET + SIG_OFFSET;
let sig_end = sig_start + SIG_SIZE;
let pub_key_start = TX_OFFSET + PUB_KEY_OFFSET;
let pub_key_end = pub_key_start + PUB_KEY_SIZE;
if packet.meta.size > msg_start {
let msg_end = packet.meta.size;
return if signature::verify(
&signature::ED25519,
untrusted::Input::from(&packet.data[pub_key_start..pub_key_end]),
untrusted::Input::from(&packet.data[msg_start..msg_end]),
untrusted::Input::from(&packet.data[sig_start..sig_end]),
).is_ok()
{
1
} else {
0
};
} else {
return 0;
}
}
#[cfg(not(feature = "cuda"))]
pub fn ed25519_verify(batches: &Vec<SharedPackets>) -> Vec<Vec<u8>> {
let mut locks = Vec::new();
let mut rvs = Vec::new();
for packets in batches {
locks.push(packets.read().unwrap());
}
for p in locks {
let mut v = Vec::new();
v.resize(p.packets.len(), 0);
v = p.packets.par_iter().map(|x| verify_packet(x)).collect();
rvs.push(v);
}
rvs
}
#[cfg(feature = "cuda")]
pub fn ed25519_verify(batches: &Vec<SharedPackets>) -> Vec<Vec<u8>> {
let mut out = Vec::new();
let mut elems = Vec::new();
let mut locks = Vec::new();
let mut rvs = Vec::new();
for packets in batches {
locks.push(packets.read().unwrap());
}
let mut num = 0;
for p in locks {
elems.push(Elems {
elems: p.packets.as_ptr(),
num: p.packets.len() as u32,
});
let mut v = Vec::new();
v.resize(p.packets.len(), 0);
rvs.push(v);
num += p.packets.len();
}
out.resize(num, 0);
trace!("Starting verify num packets: {}", num);
trace!("elem len: {}", elems.len() as u32);
trace!("packet sizeof: {}", size_of::<Packet>() as u32);
trace!("pub key: {}", (TX_OFFSET + PUB_KEY_OFFSET) as u32);
trace!("sig offset: {}", (TX_OFFSET + SIG_OFFSET) as u32);
trace!("sign data: {}", (TX_OFFSET + SIGNED_DATA_OFFSET) as u32);
trace!("len offset: {}", PACKET_DATA_SIZE as u32);
unsafe {
let res = ed25519_verify_many(
elems.as_ptr(),
elems.len() as u32,
size_of::<Packet>() as u32,
(TX_OFFSET + PUB_KEY_OFFSET) as u32,
(TX_OFFSET + SIG_OFFSET) as u32,
(TX_OFFSET + SIGNED_DATA_OFFSET) as u32,
PACKET_DATA_SIZE as u32,
out.as_mut_ptr(),
);
if res != 0 {
trace!("RETURN!!!: {}", res);
}
}
trace!("done verify");
let mut num = 0;
for vs in rvs.iter_mut() {
for mut v in vs.iter_mut() {
*v = out[num];
if *v != 0 {
trace!("VERIFIED PACKET!!!!!");
}
num += 1;
}
}
rvs
}

View File

@ -47,7 +47,7 @@ impl Event {
/// spending plan is valid. /// spending plan is valid.
pub fn verify(&self) -> bool { pub fn verify(&self) -> bool {
match *self { match *self {
Event::Transaction(ref tr) => tr.verify(), Event::Transaction(ref tr) => tr.verify_sig(),
Event::Signature { from, tx_sig, sig } => sig.verify(&from, &tx_sig), Event::Signature { from, tx_sig, sig } => sig.verify(&from, &tx_sig),
Event::Timestamp { from, dt, sig } => sig.verify(&from, &serialize(&dt).unwrap()), Event::Timestamp { from, dt, sig } => sig.verify(&from, &serialize(&dt).unwrap()),
} }

View File

@ -4,13 +4,14 @@ pub mod accountant_skel;
pub mod accountant_stub; pub mod accountant_stub;
pub mod entry; pub mod entry;
pub mod event; pub mod event;
pub mod ecdsa;
pub mod hash; pub mod hash;
pub mod historian;
pub mod ledger; pub mod ledger;
pub mod mint; pub mod mint;
pub mod packet;
pub mod plan; pub mod plan;
pub mod recorder; pub mod recorder;
pub mod historian;
pub mod packet;
pub mod result; pub mod result;
pub mod signature; pub mod signature;
pub mod streamer; pub mod streamer;
@ -19,6 +20,7 @@ extern crate bincode;
extern crate byteorder; extern crate byteorder;
extern crate chrono; extern crate chrono;
extern crate generic_array; extern crate generic_array;
extern crate libc;
#[macro_use] #[macro_use]
extern crate log; extern crate log;
extern crate rayon; extern crate rayon;

View File

@ -68,7 +68,7 @@ mod tests {
fn test_create_events() { fn test_create_events() {
let mut events = Mint::new(100).create_events().into_iter(); let mut events = Mint::new(100).create_events().into_iter();
if let Event::Transaction(tr) = events.next().unwrap() { if let Event::Transaction(tr) = events.next().unwrap() {
if let Plan::Pay(payment) = tr.plan { if let Plan::Pay(payment) = tr.data.plan {
assert_eq!(tr.from, payment.to); assert_eq!(tr.from, payment.to);
} }
} }

View File

@ -1,10 +1,10 @@
use std::sync::{Arc, Mutex, RwLock}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use result::{Error, Result};
use std::collections::VecDeque;
use std::fmt; use std::fmt;
use std::io; use std::io;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
use std::collections::VecDeque; use std::sync::{Arc, Mutex, RwLock};
use result::{Error, Result};
pub type SharedPackets = Arc<RwLock<Packets>>; pub type SharedPackets = Arc<RwLock<Packets>>;
pub type SharedBlob = Arc<RwLock<Blob>>; pub type SharedBlob = Arc<RwLock<Blob>>;
@ -13,10 +13,11 @@ pub type BlobRecycler = Recycler<Blob>;
const NUM_PACKETS: usize = 1024 * 8; const NUM_PACKETS: usize = 1024 * 8;
const BLOB_SIZE: usize = 64 * 1024; const BLOB_SIZE: usize = 64 * 1024;
pub const PACKET_SIZE: usize = 256; pub const PACKET_DATA_SIZE: usize = 256;
pub const NUM_BLOBS: usize = (NUM_PACKETS * PACKET_SIZE) / BLOB_SIZE; pub const NUM_BLOBS: usize = (NUM_PACKETS * PACKET_DATA_SIZE) / BLOB_SIZE;
#[derive(Clone, Default)] #[derive(Clone, Default)]
#[repr(C)]
pub struct Meta { pub struct Meta {
pub size: usize, pub size: usize,
pub addr: [u16; 8], pub addr: [u16; 8],
@ -25,8 +26,9 @@ pub struct Meta {
} }
#[derive(Clone)] #[derive(Clone)]
#[repr(C)]
pub struct Packet { pub struct Packet {
pub data: [u8; PACKET_SIZE], pub data: [u8; PACKET_DATA_SIZE],
pub meta: Meta, pub meta: Meta,
} }
@ -44,7 +46,7 @@ impl fmt::Debug for Packet {
impl Default for Packet { impl Default for Packet {
fn default() -> Packet { fn default() -> Packet {
Packet { Packet {
data: [0u8; PACKET_SIZE], data: [0u8; PACKET_DATA_SIZE],
meta: Meta::default(), meta: Meta::default(),
} }
} }
@ -279,11 +281,11 @@ impl Blob {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use std::net::UdpSocket;
use std::io::Write;
use std::io;
use std::collections::VecDeque;
use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets}; use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets};
use std::collections::VecDeque;
use std::io;
use std::io::Write;
use std::net::UdpSocket;
#[test] #[test]
pub fn packet_recycler_test() { pub fn packet_recycler_test() {
let r = PacketRecycler::default(); let r = PacketRecycler::default();

View File

@ -35,6 +35,7 @@ pub struct Payment {
pub to: PublicKey, pub to: PublicKey,
} }
#[repr(C)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub enum Plan { pub enum Plan {
Pay(Payment), Pay(Payment),

View File

@ -1,12 +1,12 @@
use packet::{Blob, BlobRecycler, PacketRecycler, SharedBlob, SharedPackets, NUM_BLOBS};
use result::Result;
use std::collections::VecDeque;
use std::net::UdpSocket;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc; use std::sync::mpsc;
use std::time::Duration;
use std::net::UdpSocket;
use std::thread::{spawn, JoinHandle}; use std::thread::{spawn, JoinHandle};
use std::collections::VecDeque; use std::time::Duration;
use result::Result;
use packet::{Blob, BlobRecycler, PacketRecycler, SharedBlob, SharedPackets, NUM_BLOBS};
pub type PacketReceiver = mpsc::Receiver<SharedPackets>; pub type PacketReceiver = mpsc::Receiver<SharedPackets>;
pub type PacketSender = mpsc::Sender<SharedPackets>; pub type PacketSender = mpsc::Sender<SharedPackets>;
@ -67,7 +67,7 @@ pub fn responder(
r: BlobReceiver, r: BlobReceiver,
) -> JoinHandle<()> { ) -> JoinHandle<()> {
spawn(move || loop { spawn(move || loop {
if recv_send(&sock, &recycler, &r).is_err() || exit.load(Ordering::Relaxed) { if recv_send(&sock, &recycler, &r).is_err() && exit.load(Ordering::Relaxed) {
break; break;
} }
}) })
@ -141,16 +141,16 @@ pub fn window(
mod bench { mod bench {
extern crate test; extern crate test;
use self::test::Bencher; use self::test::Bencher;
use packet::{Packet, PacketRecycler, PACKET_DATA_SIZE};
use result::Result; use result::Result;
use std::net::{SocketAddr, UdpSocket}; use std::net::{SocketAddr, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::thread::sleep; use std::thread::sleep;
use std::thread::{spawn, JoinHandle}; use std::thread::{spawn, JoinHandle};
use std::time::Duration; use std::time::Duration;
use std::time::SystemTime; use std::time::SystemTime;
use std::sync::mpsc::channel;
use std::sync::atomic::{AtomicBool, Ordering};
use packet::{Packet, PacketRecycler, PACKET_SIZE};
use streamer::{receiver, PacketReceiver}; use streamer::{receiver, PacketReceiver};
fn producer( fn producer(
@ -163,7 +163,7 @@ mod bench {
let msgs_ = msgs.clone(); let msgs_ = msgs.clone();
msgs.write().unwrap().packets.resize(10, Packet::default()); msgs.write().unwrap().packets.resize(10, Packet::default());
for w in msgs.write().unwrap().packets.iter_mut() { for w in msgs.write().unwrap().packets.iter_mut() {
w.meta.size = PACKET_SIZE; w.meta.size = PACKET_DATA_SIZE;
w.meta.set_addr(&addr); w.meta.set_addr(&addr);
} }
spawn(move || loop { spawn(move || loop {
@ -241,15 +241,15 @@ mod bench {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets, PACKET_DATA_SIZE};
use std::collections::VecDeque;
use std::io;
use std::io::Write;
use std::net::UdpSocket; use std::net::UdpSocket;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::channel; use std::sync::mpsc::channel;
use std::io::Write;
use std::io;
use std::collections::VecDeque;
use std::time::Duration; use std::time::Duration;
use std::sync::Arc;
use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets, PACKET_SIZE};
use streamer::{receiver, responder, window, BlobReceiver, PacketReceiver}; use streamer::{receiver, responder, window, BlobReceiver, PacketReceiver};
fn get_msgs(r: PacketReceiver, num: &mut usize) { fn get_msgs(r: PacketReceiver, num: &mut usize) {
@ -288,7 +288,7 @@ mod test {
let b_ = b.clone(); let b_ = b.clone();
let mut w = b.write().unwrap(); let mut w = b.write().unwrap();
w.data[0] = i as u8; w.data[0] = i as u8;
w.meta.size = PACKET_SIZE; w.meta.size = PACKET_DATA_SIZE;
w.meta.set_addr(&addr); w.meta.set_addr(&addr);
msgs.push_back(b_); msgs.push_back(b_);
} }
@ -338,7 +338,7 @@ mod test {
let mut w = b.write().unwrap(); let mut w = b.write().unwrap();
w.set_index(i).unwrap(); w.set_index(i).unwrap();
assert_eq!(i, w.get_index().unwrap()); assert_eq!(i, w.get_index().unwrap());
w.meta.size = PACKET_SIZE; w.meta.size = PACKET_DATA_SIZE;
w.meta.set_addr(&addr); w.meta.set_addr(&addr);
msgs.push_back(b_); msgs.push_back(b_);
} }

View File

@ -8,12 +8,17 @@ use rayon::prelude::*;
use signature::{KeyPair, KeyPairUtil, PublicKey, Signature, SignatureUtil}; use signature::{KeyPair, KeyPairUtil, PublicKey, Signature, SignatureUtil};
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct Transaction { pub struct TransactionData {
pub from: PublicKey,
pub plan: Plan,
pub tokens: i64, pub tokens: i64,
pub last_id: Hash, pub last_id: Hash,
pub plan: Plan,
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct Transaction {
pub sig: Signature, pub sig: Signature,
pub from: PublicKey,
pub data: TransactionData,
} }
impl Transaction { impl Transaction {
@ -22,11 +27,13 @@ impl Transaction {
let from = from_keypair.pubkey(); let from = from_keypair.pubkey();
let plan = Plan::Pay(Payment { tokens, to }); let plan = Plan::Pay(Payment { tokens, to });
let mut tr = Transaction { let mut tr = Transaction {
from,
plan,
tokens,
last_id,
sig: Signature::default(), sig: Signature::default(),
data: TransactionData {
plan,
tokens,
last_id,
},
from: from,
}; };
tr.sign(from_keypair); tr.sign(from_keypair);
tr tr
@ -46,10 +53,12 @@ impl Transaction {
(Condition::Signature(from), Payment { tokens, to: from }), (Condition::Signature(from), Payment { tokens, to: from }),
); );
let mut tr = Transaction { let mut tr = Transaction {
from, data: TransactionData {
plan, plan,
tokens, tokens,
last_id, last_id,
},
from: from,
sig: Signature::default(), sig: Signature::default(),
}; };
tr.sign(from_keypair); tr.sign(from_keypair);
@ -57,7 +66,7 @@ impl Transaction {
} }
fn get_sign_data(&self) -> Vec<u8> { fn get_sign_data(&self) -> Vec<u8> {
serialize(&(&self.plan, &self.tokens, &self.last_id)).unwrap() serialize(&(&self.data)).unwrap()
} }
/// Sign this transaction. /// Sign this transaction.
@ -66,20 +75,45 @@ impl Transaction {
self.sig = Signature::clone_from_slice(keypair.sign(&sign_data).as_ref()); self.sig = Signature::clone_from_slice(keypair.sign(&sign_data).as_ref());
} }
/// Verify this transaction's signature and its spending plan. pub fn verify_sig(&self) -> bool {
pub fn verify(&self) -> bool { self.sig.verify(&self.from, &self.get_sign_data())
self.sig.verify(&self.from, &self.get_sign_data()) && self.plan.verify(self.tokens)
} }
pub fn verify_plan(&self) -> bool {
self.data.plan.verify(self.data.tokens)
}
}
#[cfg(test)]
pub fn test_tx() -> Transaction {
let keypair1 = KeyPair::new();
let pubkey1 = keypair1.pubkey();
let zero = Hash::default();
let mut tr = Transaction::new(&keypair1, pubkey1, 42, zero);
tr.sign(&keypair1);
return tr;
}
#[cfg(test)]
pub fn memfind<A: Eq>(a: &[A], b: &[A]) -> Option<usize> {
assert!(a.len() >= b.len());
let end = a.len() - b.len() + 1;
for i in 0..end {
if a[i..i + b.len()] == b[..] {
return Some(i);
}
}
None
} }
/// Verify a batch of signatures. /// Verify a batch of signatures.
pub fn verify_signatures(transactions: &[Transaction]) -> bool { pub fn verify_signatures(transactions: &[Transaction]) -> bool {
transactions.par_iter().all(|tr| tr.verify()) transactions.par_iter().all(|tr| tr.verify_sig())
} }
/// Verify a batch of spending plans. /// Verify a batch of spending plans.
pub fn verify_plans(transactions: &[Transaction]) -> bool { pub fn verify_plans(transactions: &[Transaction]) -> bool {
transactions.par_iter().all(|tr| tr.plan.verify(tr.tokens)) transactions.par_iter().all(|tr| tr.verify_plan())
} }
/// Verify a batch of transactions. /// Verify a batch of transactions.
@ -91,13 +125,14 @@ pub fn verify_transactions(transactions: &[Transaction]) -> bool {
mod tests { mod tests {
use super::*; use super::*;
use bincode::{deserialize, serialize}; use bincode::{deserialize, serialize};
use ecdsa;
#[test] #[test]
fn test_claim() { fn test_claim() {
let keypair = KeyPair::new(); let keypair = KeyPair::new();
let zero = Hash::default(); let zero = Hash::default();
let tr0 = Transaction::new(&keypair, keypair.pubkey(), 42, zero); let tr0 = Transaction::new(&keypair, keypair.pubkey(), 42, zero);
assert!(tr0.verify()); assert!(tr0.verify_plan());
} }
#[test] #[test]
@ -107,7 +142,7 @@ mod tests {
let keypair1 = KeyPair::new(); let keypair1 = KeyPair::new();
let pubkey1 = keypair1.pubkey(); let pubkey1 = keypair1.pubkey();
let tr0 = Transaction::new(&keypair0, pubkey1, 42, zero); let tr0 = Transaction::new(&keypair0, pubkey1, 42, zero);
assert!(tr0.verify()); assert!(tr0.verify_plan());
} }
#[test] #[test]
@ -117,10 +152,12 @@ mod tests {
to: Default::default(), to: Default::default(),
}); });
let claim0 = Transaction { let claim0 = Transaction {
data: TransactionData {
plan,
tokens: 0,
last_id: Default::default(),
},
from: Default::default(), from: Default::default(),
plan,
tokens: 0,
last_id: Default::default(),
sig: Default::default(), sig: Default::default(),
}; };
let buf = serialize(&claim0).unwrap(); let buf = serialize(&claim0).unwrap();
@ -135,8 +172,8 @@ mod tests {
let pubkey = keypair.pubkey(); let pubkey = keypair.pubkey();
let mut tr = Transaction::new(&keypair, pubkey, 42, zero); let mut tr = Transaction::new(&keypair, pubkey, 42, zero);
tr.sign(&keypair); tr.sign(&keypair);
tr.tokens = 1_000_000; // <-- attack! tr.data.tokens = 1_000_000; // <-- attack!
assert!(!tr.verify()); assert!(!tr.verify_plan());
} }
#[test] #[test]
@ -148,10 +185,20 @@ mod tests {
let zero = Hash::default(); let zero = Hash::default();
let mut tr = Transaction::new(&keypair0, pubkey1, 42, zero); let mut tr = Transaction::new(&keypair0, pubkey1, 42, zero);
tr.sign(&keypair0); tr.sign(&keypair0);
if let Plan::Pay(ref mut payment) = tr.plan { if let Plan::Pay(ref mut payment) = tr.data.plan {
payment.to = thief_keypair.pubkey(); // <-- attack! payment.to = thief_keypair.pubkey(); // <-- attack!
}; };
assert!(!tr.verify()); assert!(tr.verify_plan());
assert!(!tr.verify_sig());
}
#[test]
fn test_layout() {
let tr = test_tx();
let sign_data = tr.get_sign_data();
let tx = serialize(&tr).unwrap();
assert_matches!(memfind(&tx, &sign_data), Some(ecdsa::SIGNED_DATA_OFFSET));
assert_matches!(memfind(&tx, &tr.sig), Some(ecdsa::SIG_OFFSET));
assert_matches!(memfind(&tx, &tr.from), Some(ecdsa::PUB_KEY_OFFSET));
} }
#[test] #[test]
@ -160,16 +207,16 @@ mod tests {
let keypair1 = KeyPair::new(); let keypair1 = KeyPair::new();
let zero = Hash::default(); let zero = Hash::default();
let mut tr = Transaction::new(&keypair0, keypair1.pubkey(), 1, zero); let mut tr = Transaction::new(&keypair0, keypair1.pubkey(), 1, zero);
if let Plan::Pay(ref mut payment) = tr.plan { if let Plan::Pay(ref mut payment) = tr.data.plan {
payment.tokens = 2; // <-- attack! payment.tokens = 2; // <-- attack!
} }
assert!(!tr.verify()); assert!(!tr.verify_plan());
// Also, ensure all branchs of the plan spend all tokens // Also, ensure all branchs of the plan spend all tokens
if let Plan::Pay(ref mut payment) = tr.plan { if let Plan::Pay(ref mut payment) = tr.data.plan {
payment.tokens = 0; // <-- whoops! payment.tokens = 0; // <-- whoops!
} }
assert!(!tr.verify()); assert!(!tr.verify_plan());
} }
#[test] #[test]