Change for cuda verify integration

This commit is contained in:
Stephen Akridge 2018-03-26 21:07:11 -07:00
parent bc6d6b20fa
commit f4466c8c0a
13 changed files with 400 additions and 109 deletions

View File

@ -41,6 +41,7 @@ codecov = { repository = "solana-labs/solana", branch = "master", service = "git
[features]
unstable = []
ipv6 = []
cuda = []
[dependencies]
rayon = "1.0.0"
@ -54,5 +55,7 @@ untrusted = "0.5.1"
bincode = "1.0.0"
chrono = { version = "0.4.0", features = ["serde"] }
log = "^0.4.1"
env_logger = "^0.4.1"
matches = "^0.1.6"
byteorder = "^1.2.1"
libc = "^0.2.1"

12
build.rs Normal file
View File

@ -0,0 +1,12 @@
use std::env;
fn main() {
if !env::var("CARGO_FEATURE_CUDA").is_err() {
println!("cargo:rustc-link-search=native=.");
println!("cargo:rustc-link-lib=static=cuda_verify_ed25519");
println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=cuda");
println!("cargo:rustc-link-lib=dylib=cudadevrt");
}
}

View File

@ -3,6 +3,8 @@
//! on behalf of the caller, and a private low-level API for when they have
//! already been signed and verified.
extern crate libc;
use chrono::prelude::*;
use event::Event;
use hash::Hash;
@ -104,19 +106,19 @@ impl Accountant {
/// Process a Transaction that has already been verified.
pub fn process_verified_transaction(&self, tr: &Transaction) -> Result<()> {
if self.get_balance(&tr.from).unwrap_or(0) < tr.tokens {
if self.get_balance(&tr.from).unwrap_or(0) < tr.data.tokens {
return Err(AccountingError::InsufficientFunds);
}
if !self.reserve_signature_with_last_id(&tr.sig, &tr.last_id) {
if !self.reserve_signature_with_last_id(&tr.sig, &tr.data.last_id) {
return Err(AccountingError::InvalidTransferSignature);
}
if let Some(x) = self.balances.read().unwrap().get(&tr.from) {
*x.write().unwrap() -= tr.tokens;
*x.write().unwrap() -= tr.data.tokens;
}
let mut plan = tr.plan.clone();
let mut plan = tr.data.plan.clone();
plan.apply_witness(&Witness::Timestamp(*self.last_time.read().unwrap()));
if let Some(ref payment) = plan.final_payment() {

View File

@ -6,24 +6,27 @@ use accountant::Accountant;
use bincode::{deserialize, serialize};
use entry::Entry;
use event::Event;
use ecdsa;
use hash::Hash;
use historian::Historian;
use packet;
use packet::SharedPackets;
use rayon::prelude::*;
use recorder::Signal;
use result::Result;
use serde_json;
use signature::PublicKey;
use std::cmp::max;
use std::collections::VecDeque;
use std::io::Write;
use std::net::{SocketAddr, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{channel, SendError};
use std::sync::mpsc::{channel, Receiver, SendError, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{spawn, JoinHandle};
use std::time::Duration;
use streamer;
use packet;
use std::sync::{Arc, Mutex};
use transaction::Transaction;
use std::collections::VecDeque;
pub struct AccountantSkel<W: Write + Send + 'static> {
acc: Accountant,
@ -44,14 +47,14 @@ impl Request {
/// Verify the request is valid.
pub fn verify(&self) -> bool {
match *self {
Request::Transaction(ref tr) => tr.verify(),
Request::Transaction(ref tr) => tr.verify_plan(),
_ => true,
}
}
}
/// Parallel verfication of a batch of requests.
fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> {
pub fn filter_valid_requests(reqs: Vec<(Request, SocketAddr)>) -> Vec<(Request, SocketAddr)> {
reqs.into_par_iter().filter({ |x| x.0.verify() }).collect()
}
@ -84,16 +87,20 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
}
/// Process Request items sent by clients.
pub fn log_verified_request(&mut self, msg: Request) -> Option<Response> {
pub fn log_verified_request(&mut self, msg: Request, verify: u8) -> Option<Response> {
match msg {
Request::Transaction(_) if verify == 0 => {
trace!("Transaction failed sigverify");
None
}
Request::Transaction(tr) => {
if let Err(err) = self.acc.process_verified_transaction(&tr) {
eprintln!("Transaction error: {:?}", err);
trace!("Transaction error: {:?}", err);
} else if let Err(SendError(_)) = self.historian
.sender
.send(Signal::Event(Event::Transaction(tr)))
.send(Signal::Event(Event::Transaction(tr.clone())))
{
eprintln!("Channel send error");
error!("Channel send error");
}
None
}
@ -105,46 +112,87 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
}
}
fn verifier(
recvr: &streamer::PacketReceiver,
sendr: &Sender<(Vec<SharedPackets>, Vec<Vec<u8>>)>,
) -> Result<()> {
let timer = Duration::new(1, 0);
let msgs = recvr.recv_timeout(timer)?;
trace!("got msgs");
let mut v = Vec::new();
v.push(msgs);
while let Ok(more) = recvr.try_recv() {
trace!("got more msgs");
v.push(more);
}
info!("batch {}", v.len());
let chunk = max(1, (v.len() + 3) / 4);
let chunks: Vec<_> = v.chunks(chunk).collect();
let rvs: Vec<_> = chunks
.into_par_iter()
.map(|x| ecdsa::ed25519_verify(&x.to_vec()))
.collect();
for (v, r) in v.chunks(chunk).zip(rvs) {
sendr.send((v.to_vec(), r))?;
}
Ok(())
}
pub fn deserialize_packets(p: &packet::Packets) -> Vec<Option<(Request, SocketAddr)>> {
// TODO: deserealize in parallel
let mut r = vec![];
for x in &p.packets {
let rsp_addr = x.meta.addr();
let sz = x.meta.size;
if let Ok(req) = deserialize(&x.data[0..sz]) {
r.push(Some((req, rsp_addr)));
} else {
r.push(None);
}
}
r
}
fn process(
obj: &Arc<Mutex<AccountantSkel<W>>>,
packet_receiver: &streamer::PacketReceiver,
verified_receiver: &Receiver<(Vec<SharedPackets>, Vec<Vec<u8>>)>,
blob_sender: &streamer::BlobSender,
packet_recycler: &packet::PacketRecycler,
blob_recycler: &packet::BlobRecycler,
) -> Result<()> {
let timer = Duration::new(1, 0);
let msgs = packet_receiver.recv_timeout(timer)?;
let msgs_ = msgs.clone();
let mut rsps = VecDeque::new();
{
let mut reqs = vec![];
for packet in &msgs.read().unwrap().packets {
let rsp_addr = packet.meta.addr();
let sz = packet.meta.size;
let req = deserialize(&packet.data[0..sz])?;
reqs.push((req, rsp_addr));
}
let reqs = filter_valid_requests(reqs);
for (req, rsp_addr) in reqs {
if let Some(resp) = obj.lock().unwrap().log_verified_request(req) {
let blob = blob_recycler.allocate();
{
let mut b = blob.write().unwrap();
let v = serialize(&resp)?;
let len = v.len();
b.data[..len].copy_from_slice(&v);
b.meta.size = len;
b.meta.set_addr(&rsp_addr);
let (mms, vvs) = verified_receiver.recv_timeout(timer)?;
for (msgs, vers) in mms.into_iter().zip(vvs.into_iter()) {
let msgs_ = msgs.clone();
let mut rsps = VecDeque::new();
{
let reqs = Self::deserialize_packets(&((*msgs).read().unwrap()));
for (data, v) in reqs.into_iter().zip(vers.into_iter()) {
if let Some((req, rsp_addr)) = data {
if !req.verify() {
continue;
}
if let Some(resp) = obj.lock().unwrap().log_verified_request(req, v) {
let blob = blob_recycler.allocate();
{
let mut b = blob.write().unwrap();
let v = serialize(&resp)?;
let len = v.len();
b.data[..len].copy_from_slice(&v);
b.meta.size = len;
b.meta.set_addr(&rsp_addr);
}
rsps.push_back(blob);
}
}
rsps.push_back(blob);
}
}
if !rsps.is_empty() {
//don't wake up the other side if there is nothing
blob_sender.send(rsps)?;
}
packet_recycler.recycle(msgs_);
}
if !rsps.is_empty() {
//don't wake up the other side if there is nothing
blob_sender.send(rsps)?;
}
packet_recycler.recycle(msgs_);
Ok(())
}
@ -169,11 +217,21 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
let (blob_sender, blob_receiver) = channel();
let t_responder =
streamer::responder(write, exit.clone(), blob_recycler.clone(), blob_receiver);
let (verified_sender, verified_receiver) = channel();
let exit_ = exit.clone();
let t_verifier = spawn(move || loop {
let e = Self::verifier(&packet_receiver, &verified_sender);
if e.is_err() && exit_.load(Ordering::Relaxed) {
break;
}
});
let skel = obj.clone();
let t_server = spawn(move || loop {
let e = AccountantSkel::process(
&skel,
&packet_receiver,
&verified_receiver,
&blob_sender,
&packet_recycler,
&blob_recycler,
@ -182,6 +240,21 @@ impl<W: Write + Send + 'static> AccountantSkel<W> {
break;
}
});
Ok(vec![t_receiver, t_responder, t_server])
Ok(vec![t_receiver, t_responder, t_server, t_verifier])
}
}
#[cfg(test)]
mod tests {
use accountant_skel::Request;
use bincode::serialize;
use ecdsa;
use transaction::{memfind, test_tx};
#[test]
fn test_layout() {
let tr = test_tx();
let tx = serialize(&tr).unwrap();
let packet = serialize(&Request::Transaction(tr)).unwrap();
assert_matches!(memfind(&packet, &tx), Some(ecdsa::TX_OFFSET));
}
}

View File

@ -1,3 +1,4 @@
extern crate env_logger;
extern crate serde_json;
extern crate solana;
@ -11,6 +12,7 @@ use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
fn main() {
env_logger::init().unwrap();
let addr = "127.0.0.1:8000";
let stdin = io::stdin();
let mut entries = stdin
@ -27,7 +29,7 @@ fn main() {
// transfer to oneself.
let entry1: Entry = entries.next().unwrap();
let deposit = if let Event::Transaction(ref tr) = entry1.events[0] {
tr.plan.final_payment()
tr.data.plan.final_payment()
} else {
None
};

147
src/ecdsa.rs Normal file
View File

@ -0,0 +1,147 @@
// Cuda-only imports
#[cfg(feature = "cuda")]
use packet::PACKET_DATA_SIZE;
#[cfg(feature = "cuda")]
use std::mem::size_of;
// Non-cuda imports
#[cfg(not(feature = "cuda"))]
use rayon::prelude::*;
#[cfg(not(feature = "cuda"))]
use untrusted;
#[cfg(not(feature = "cuda"))]
use ring::signature;
// Shared imports
use packet::{Packet, SharedPackets};
pub const TX_OFFSET: usize = 4;
pub const SIGNED_DATA_OFFSET: usize = 112;
pub const SIG_OFFSET: usize = 8;
pub const PUB_KEY_OFFSET: usize = 80;
pub const SIG_SIZE: usize = 64;
pub const PUB_KEY_SIZE: usize = 32;
#[cfg(feature = "cuda")]
#[repr(C)]
struct Elems {
elems: *const Packet,
num: u32,
}
#[cfg(feature = "cuda")]
#[link(name = "cuda_verify_ed25519")]
extern "C" {
fn ed25519_verify_many(
vecs: *const Elems,
num: u32, //number of vecs
message_size: u32, //size of each element inside the elems field of the vec
public_key_offset: u32,
signature_offset: u32,
signed_message_offset: u32,
signed_message_len_offset: u32,
out: *mut u8, //combined length of all the items in vecs
) -> u32;
}
#[cfg(not(feature = "cuda"))]
fn verify_packet(packet: &Packet) -> u8 {
let msg_start = TX_OFFSET + SIGNED_DATA_OFFSET;
let sig_start = TX_OFFSET + SIG_OFFSET;
let sig_end = sig_start + SIG_SIZE;
let pub_key_start = TX_OFFSET + PUB_KEY_OFFSET;
let pub_key_end = pub_key_start + PUB_KEY_SIZE;
if packet.meta.size > msg_start {
let msg_end = packet.meta.size;
return if signature::verify(
&signature::ED25519,
untrusted::Input::from(&packet.data[pub_key_start..pub_key_end]),
untrusted::Input::from(&packet.data[msg_start..msg_end]),
untrusted::Input::from(&packet.data[sig_start..sig_end]),
).is_ok()
{
1
} else {
0
};
} else {
return 0;
}
}
#[cfg(not(feature = "cuda"))]
pub fn ed25519_verify(batches: &Vec<SharedPackets>) -> Vec<Vec<u8>> {
let mut locks = Vec::new();
let mut rvs = Vec::new();
for packets in batches {
locks.push(packets.read().unwrap());
}
for p in locks {
let mut v = Vec::new();
v.resize(p.packets.len(), 0);
v = p.packets.par_iter().map(|x| verify_packet(x)).collect();
rvs.push(v);
}
rvs
}
#[cfg(feature = "cuda")]
pub fn ed25519_verify(batches: &Vec<SharedPackets>) -> Vec<Vec<u8>> {
let mut out = Vec::new();
let mut elems = Vec::new();
let mut locks = Vec::new();
let mut rvs = Vec::new();
for packets in batches {
locks.push(packets.read().unwrap());
}
let mut num = 0;
for p in locks {
elems.push(Elems {
elems: p.packets.as_ptr(),
num: p.packets.len() as u32,
});
let mut v = Vec::new();
v.resize(p.packets.len(), 0);
rvs.push(v);
num += p.packets.len();
}
out.resize(num, 0);
trace!("Starting verify num packets: {}", num);
trace!("elem len: {}", elems.len() as u32);
trace!("packet sizeof: {}", size_of::<Packet>() as u32);
trace!("pub key: {}", (TX_OFFSET + PUB_KEY_OFFSET) as u32);
trace!("sig offset: {}", (TX_OFFSET + SIG_OFFSET) as u32);
trace!("sign data: {}", (TX_OFFSET + SIGNED_DATA_OFFSET) as u32);
trace!("len offset: {}", PACKET_DATA_SIZE as u32);
unsafe {
let res = ed25519_verify_many(
elems.as_ptr(),
elems.len() as u32,
size_of::<Packet>() as u32,
(TX_OFFSET + PUB_KEY_OFFSET) as u32,
(TX_OFFSET + SIG_OFFSET) as u32,
(TX_OFFSET + SIGNED_DATA_OFFSET) as u32,
PACKET_DATA_SIZE as u32,
out.as_mut_ptr(),
);
if res != 0 {
trace!("RETURN!!!: {}", res);
}
}
trace!("done verify");
let mut num = 0;
for vs in rvs.iter_mut() {
for mut v in vs.iter_mut() {
*v = out[num];
if *v != 0 {
trace!("VERIFIED PACKET!!!!!");
}
num += 1;
}
}
rvs
}

View File

@ -47,7 +47,7 @@ impl Event {
/// spending plan is valid.
pub fn verify(&self) -> bool {
match *self {
Event::Transaction(ref tr) => tr.verify(),
Event::Transaction(ref tr) => tr.verify_sig(),
Event::Signature { from, tx_sig, sig } => sig.verify(&from, &tx_sig),
Event::Timestamp { from, dt, sig } => sig.verify(&from, &serialize(&dt).unwrap()),
}

View File

@ -4,13 +4,14 @@ pub mod accountant_skel;
pub mod accountant_stub;
pub mod entry;
pub mod event;
pub mod ecdsa;
pub mod hash;
pub mod historian;
pub mod ledger;
pub mod mint;
pub mod packet;
pub mod plan;
pub mod recorder;
pub mod historian;
pub mod packet;
pub mod result;
pub mod signature;
pub mod streamer;
@ -19,6 +20,7 @@ extern crate bincode;
extern crate byteorder;
extern crate chrono;
extern crate generic_array;
extern crate libc;
#[macro_use]
extern crate log;
extern crate rayon;

View File

@ -68,7 +68,7 @@ mod tests {
fn test_create_events() {
let mut events = Mint::new(100).create_events().into_iter();
if let Event::Transaction(tr) = events.next().unwrap() {
if let Plan::Pay(payment) = tr.plan {
if let Plan::Pay(payment) = tr.data.plan {
assert_eq!(tr.from, payment.to);
}
}

View File

@ -1,10 +1,10 @@
use std::sync::{Arc, Mutex, RwLock};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use result::{Error, Result};
use std::collections::VecDeque;
use std::fmt;
use std::io;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
use std::collections::VecDeque;
use result::{Error, Result};
use std::sync::{Arc, Mutex, RwLock};
pub type SharedPackets = Arc<RwLock<Packets>>;
pub type SharedBlob = Arc<RwLock<Blob>>;
@ -13,10 +13,11 @@ pub type BlobRecycler = Recycler<Blob>;
const NUM_PACKETS: usize = 1024 * 8;
const BLOB_SIZE: usize = 64 * 1024;
pub const PACKET_SIZE: usize = 256;
pub const NUM_BLOBS: usize = (NUM_PACKETS * PACKET_SIZE) / BLOB_SIZE;
pub const PACKET_DATA_SIZE: usize = 256;
pub const NUM_BLOBS: usize = (NUM_PACKETS * PACKET_DATA_SIZE) / BLOB_SIZE;
#[derive(Clone, Default)]
#[repr(C)]
pub struct Meta {
pub size: usize,
pub addr: [u16; 8],
@ -25,8 +26,9 @@ pub struct Meta {
}
#[derive(Clone)]
#[repr(C)]
pub struct Packet {
pub data: [u8; PACKET_SIZE],
pub data: [u8; PACKET_DATA_SIZE],
pub meta: Meta,
}
@ -44,7 +46,7 @@ impl fmt::Debug for Packet {
impl Default for Packet {
fn default() -> Packet {
Packet {
data: [0u8; PACKET_SIZE],
data: [0u8; PACKET_DATA_SIZE],
meta: Meta::default(),
}
}
@ -279,11 +281,11 @@ impl Blob {
#[cfg(test)]
mod test {
use std::net::UdpSocket;
use std::io::Write;
use std::io;
use std::collections::VecDeque;
use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets};
use std::collections::VecDeque;
use std::io;
use std::io::Write;
use std::net::UdpSocket;
#[test]
pub fn packet_recycler_test() {
let r = PacketRecycler::default();

View File

@ -35,6 +35,7 @@ pub struct Payment {
pub to: PublicKey,
}
#[repr(C)]
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub enum Plan {
Pay(Payment),

View File

@ -1,12 +1,12 @@
use packet::{Blob, BlobRecycler, PacketRecycler, SharedBlob, SharedPackets, NUM_BLOBS};
use result::Result;
use std::collections::VecDeque;
use std::net::UdpSocket;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc;
use std::time::Duration;
use std::net::UdpSocket;
use std::thread::{spawn, JoinHandle};
use std::collections::VecDeque;
use result::Result;
use packet::{Blob, BlobRecycler, PacketRecycler, SharedBlob, SharedPackets, NUM_BLOBS};
use std::time::Duration;
pub type PacketReceiver = mpsc::Receiver<SharedPackets>;
pub type PacketSender = mpsc::Sender<SharedPackets>;
@ -67,7 +67,7 @@ pub fn responder(
r: BlobReceiver,
) -> JoinHandle<()> {
spawn(move || loop {
if recv_send(&sock, &recycler, &r).is_err() || exit.load(Ordering::Relaxed) {
if recv_send(&sock, &recycler, &r).is_err() && exit.load(Ordering::Relaxed) {
break;
}
})
@ -141,16 +141,16 @@ pub fn window(
mod bench {
extern crate test;
use self::test::Bencher;
use packet::{Packet, PacketRecycler, PACKET_DATA_SIZE};
use result::Result;
use std::net::{SocketAddr, UdpSocket};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex};
use std::thread::sleep;
use std::thread::{spawn, JoinHandle};
use std::time::Duration;
use std::time::SystemTime;
use std::sync::mpsc::channel;
use std::sync::atomic::{AtomicBool, Ordering};
use packet::{Packet, PacketRecycler, PACKET_SIZE};
use streamer::{receiver, PacketReceiver};
fn producer(
@ -163,7 +163,7 @@ mod bench {
let msgs_ = msgs.clone();
msgs.write().unwrap().packets.resize(10, Packet::default());
for w in msgs.write().unwrap().packets.iter_mut() {
w.meta.size = PACKET_SIZE;
w.meta.size = PACKET_DATA_SIZE;
w.meta.set_addr(&addr);
}
spawn(move || loop {
@ -241,15 +241,15 @@ mod bench {
#[cfg(test)]
mod test {
use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets, PACKET_DATA_SIZE};
use std::collections::VecDeque;
use std::io;
use std::io::Write;
use std::net::UdpSocket;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::channel;
use std::io::Write;
use std::io;
use std::collections::VecDeque;
use std::time::Duration;
use std::sync::Arc;
use packet::{Blob, BlobRecycler, Packet, PacketRecycler, Packets, PACKET_SIZE};
use streamer::{receiver, responder, window, BlobReceiver, PacketReceiver};
fn get_msgs(r: PacketReceiver, num: &mut usize) {
@ -288,7 +288,7 @@ mod test {
let b_ = b.clone();
let mut w = b.write().unwrap();
w.data[0] = i as u8;
w.meta.size = PACKET_SIZE;
w.meta.size = PACKET_DATA_SIZE;
w.meta.set_addr(&addr);
msgs.push_back(b_);
}
@ -338,7 +338,7 @@ mod test {
let mut w = b.write().unwrap();
w.set_index(i).unwrap();
assert_eq!(i, w.get_index().unwrap());
w.meta.size = PACKET_SIZE;
w.meta.size = PACKET_DATA_SIZE;
w.meta.set_addr(&addr);
msgs.push_back(b_);
}

View File

@ -8,12 +8,17 @@ use rayon::prelude::*;
use signature::{KeyPair, KeyPairUtil, PublicKey, Signature, SignatureUtil};
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct Transaction {
pub from: PublicKey,
pub plan: Plan,
pub struct TransactionData {
pub tokens: i64,
pub last_id: Hash,
pub plan: Plan,
}
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct Transaction {
pub sig: Signature,
pub from: PublicKey,
pub data: TransactionData,
}
impl Transaction {
@ -22,11 +27,13 @@ impl Transaction {
let from = from_keypair.pubkey();
let plan = Plan::Pay(Payment { tokens, to });
let mut tr = Transaction {
from,
plan,
tokens,
last_id,
sig: Signature::default(),
data: TransactionData {
plan,
tokens,
last_id,
},
from: from,
};
tr.sign(from_keypair);
tr
@ -46,10 +53,12 @@ impl Transaction {
(Condition::Signature(from), Payment { tokens, to: from }),
);
let mut tr = Transaction {
from,
plan,
tokens,
last_id,
data: TransactionData {
plan,
tokens,
last_id,
},
from: from,
sig: Signature::default(),
};
tr.sign(from_keypair);
@ -57,7 +66,7 @@ impl Transaction {
}
fn get_sign_data(&self) -> Vec<u8> {
serialize(&(&self.plan, &self.tokens, &self.last_id)).unwrap()
serialize(&(&self.data)).unwrap()
}
/// Sign this transaction.
@ -66,20 +75,45 @@ impl Transaction {
self.sig = Signature::clone_from_slice(keypair.sign(&sign_data).as_ref());
}
/// Verify this transaction's signature and its spending plan.
pub fn verify(&self) -> bool {
self.sig.verify(&self.from, &self.get_sign_data()) && self.plan.verify(self.tokens)
pub fn verify_sig(&self) -> bool {
self.sig.verify(&self.from, &self.get_sign_data())
}
pub fn verify_plan(&self) -> bool {
self.data.plan.verify(self.data.tokens)
}
}
#[cfg(test)]
pub fn test_tx() -> Transaction {
let keypair1 = KeyPair::new();
let pubkey1 = keypair1.pubkey();
let zero = Hash::default();
let mut tr = Transaction::new(&keypair1, pubkey1, 42, zero);
tr.sign(&keypair1);
return tr;
}
#[cfg(test)]
pub fn memfind<A: Eq>(a: &[A], b: &[A]) -> Option<usize> {
assert!(a.len() >= b.len());
let end = a.len() - b.len() + 1;
for i in 0..end {
if a[i..i + b.len()] == b[..] {
return Some(i);
}
}
None
}
/// Verify a batch of signatures.
pub fn verify_signatures(transactions: &[Transaction]) -> bool {
transactions.par_iter().all(|tr| tr.verify())
transactions.par_iter().all(|tr| tr.verify_sig())
}
/// Verify a batch of spending plans.
pub fn verify_plans(transactions: &[Transaction]) -> bool {
transactions.par_iter().all(|tr| tr.plan.verify(tr.tokens))
transactions.par_iter().all(|tr| tr.verify_plan())
}
/// Verify a batch of transactions.
@ -91,13 +125,14 @@ pub fn verify_transactions(transactions: &[Transaction]) -> bool {
mod tests {
use super::*;
use bincode::{deserialize, serialize};
use ecdsa;
#[test]
fn test_claim() {
let keypair = KeyPair::new();
let zero = Hash::default();
let tr0 = Transaction::new(&keypair, keypair.pubkey(), 42, zero);
assert!(tr0.verify());
assert!(tr0.verify_plan());
}
#[test]
@ -107,7 +142,7 @@ mod tests {
let keypair1 = KeyPair::new();
let pubkey1 = keypair1.pubkey();
let tr0 = Transaction::new(&keypair0, pubkey1, 42, zero);
assert!(tr0.verify());
assert!(tr0.verify_plan());
}
#[test]
@ -117,10 +152,12 @@ mod tests {
to: Default::default(),
});
let claim0 = Transaction {
data: TransactionData {
plan,
tokens: 0,
last_id: Default::default(),
},
from: Default::default(),
plan,
tokens: 0,
last_id: Default::default(),
sig: Default::default(),
};
let buf = serialize(&claim0).unwrap();
@ -135,8 +172,8 @@ mod tests {
let pubkey = keypair.pubkey();
let mut tr = Transaction::new(&keypair, pubkey, 42, zero);
tr.sign(&keypair);
tr.tokens = 1_000_000; // <-- attack!
assert!(!tr.verify());
tr.data.tokens = 1_000_000; // <-- attack!
assert!(!tr.verify_plan());
}
#[test]
@ -148,10 +185,20 @@ mod tests {
let zero = Hash::default();
let mut tr = Transaction::new(&keypair0, pubkey1, 42, zero);
tr.sign(&keypair0);
if let Plan::Pay(ref mut payment) = tr.plan {
if let Plan::Pay(ref mut payment) = tr.data.plan {
payment.to = thief_keypair.pubkey(); // <-- attack!
};
assert!(!tr.verify());
assert!(tr.verify_plan());
assert!(!tr.verify_sig());
}
#[test]
fn test_layout() {
let tr = test_tx();
let sign_data = tr.get_sign_data();
let tx = serialize(&tr).unwrap();
assert_matches!(memfind(&tx, &sign_data), Some(ecdsa::SIGNED_DATA_OFFSET));
assert_matches!(memfind(&tx, &tr.sig), Some(ecdsa::SIG_OFFSET));
assert_matches!(memfind(&tx, &tr.from), Some(ecdsa::PUB_KEY_OFFSET));
}
#[test]
@ -160,16 +207,16 @@ mod tests {
let keypair1 = KeyPair::new();
let zero = Hash::default();
let mut tr = Transaction::new(&keypair0, keypair1.pubkey(), 1, zero);
if let Plan::Pay(ref mut payment) = tr.plan {
if let Plan::Pay(ref mut payment) = tr.data.plan {
payment.tokens = 2; // <-- attack!
}
assert!(!tr.verify());
assert!(!tr.verify_plan());
// Also, ensure all branchs of the plan spend all tokens
if let Plan::Pay(ref mut payment) = tr.plan {
if let Plan::Pay(ref mut payment) = tr.data.plan {
payment.tokens = 0; // <-- whoops!
}
assert!(!tr.verify());
assert!(!tr.verify_plan());
}
#[test]