Cuda hardware acceleration

This commit is contained in:
Hanh 2022-08-02 22:34:15 +08:00
parent 1eadcce883
commit 979795a82c
8 changed files with 307 additions and 6 deletions

View File

@ -84,12 +84,16 @@ dotenv = { version = "0.15.0", optional = true }
node-bindgen = { version = "4.0", optional = true }
rustacuda = { version = "0.1.3", optional = true }
rustacuda_core = { version = "0.1.2", optional = true }
[features]
ledger = ["ledger-apdu", "hmac", "ed25519-bip32", "ledger-transport-hid"]
ledger_sapling = ["ledger"]
dart_ffi = ["allo-isolate", "once_cell", "android_logger"]
rpc = ["rocket", "dotenv"]
nodejs = ["node-bindgen"]
cuda = ["rustacuda", "rustacuda_core"]
# librustzcash synced to 35023ed8ca2fb1061e78fd740b640d4eefcc5edd
@ -128,3 +132,4 @@ cbindgen = "0.19.0"
[dev-dependencies]
criterion = "0.3.4"

View File

@ -6,6 +6,8 @@ use jubjub::{AffinePoint, ExtendedPoint};
use rayon::prelude::IntoParallelIterator;
use rayon::prelude::*;
use zcash_primitives::sapling::Node;
#[cfg(feature = "cuda")]
use crate::cuda::CUDA_PROCESSOR;
#[inline(always)]
fn batch_node_combine1(depth: usize, left: &Node, right: &Node) -> ExtendedPoint {
@ -165,7 +167,7 @@ impl CTreeBuilder {
}
}
fn combine_level(commitments: &mut [Node], offset: Option<Node>, n: usize, depth: usize) -> usize {
fn combine_level_soft(commitments: &mut [Node], offset: Option<Node>, n: usize, depth: usize) -> usize {
assert_eq!(n % 2, 0);
let nn = n / 2;
@ -203,6 +205,36 @@ fn combine_level(commitments: &mut [Node], offset: Option<Node>, n: usize, depth
nn
}
#[cfg(feature = "cuda")]
fn combine_level_cuda(commitments: &mut [Node], offset: Option<Node>, n: usize, depth: usize) -> usize {
assert_eq!(n % 2, 0);
if n == 0 { return 0; }
let mut hasher = CUDA_PROCESSOR.lock().unwrap();
if let Some(hasher) = hasher.as_mut() {
let nn = n / 2;
let hashes: Vec<_> = (0..n)
.map(|i| CTreeBuilder::get(commitments, i, &offset).repr)
.collect();
let new_hashes = hasher.batch_hash_cuda(depth as u8, &hashes).unwrap();
for i in 0..nn {
commitments[i] = Node::new(new_hashes[i]);
}
nn
}
else {
combine_level_soft(commitments, offset, n, depth)
}
}
fn combine_level(commitments: &mut [Node], offset: Option<Node>, n: usize, depth: usize) -> usize {
#[cfg(feature = "cuda")]
return combine_level_cuda(commitments, offset, n, depth);
#[allow(unreachable_code)]
combine_level_soft(commitments, offset, n, depth)
}
struct WitnessBuilder {
witness: Witness,
p: usize,

View File

@ -26,6 +26,8 @@ use zcash_primitives::sapling::note_encryption::SaplingDomain;
use zcash_primitives::sapling::{Node, Note, PaymentAddress};
use zcash_primitives::transaction::components::sapling::CompactOutputDescription;
use zcash_primitives::zip32::ExtendedFullViewingKey;
#[cfg(feature = "cuda")]
use crate::cuda::CUDA_PROCESSOR;
pub async fn get_latest_height(
client: &mut CompactTxStreamerClient<Channel>,
@ -324,7 +326,7 @@ fn decrypt_notes<'a, N: Parameters>(
count_outputs += 1;
}
} else {
log::info!("Spam Filter tx {}", hex::encode(&vtx.hash));
// log::info!("Spam Filter tx {}", hex::encode(&vtx.hash));
count_outputs += vtx.outputs.len() as u32;
}
}
@ -376,6 +378,18 @@ impl DecryptNode {
&self,
network: &Network,
blocks: &'a [CompactBlock],
) -> Vec<DecryptedBlock<'a>> {
#[cfg(feature = "cuda")]
return self.cuda_decrypt_blocks(network, blocks);
#[allow(unreachable_code)]
self.decrypt_blocks_soft(network, blocks)
}
pub fn decrypt_blocks_soft<'a>(
&self,
network: &Network,
blocks: &'a [CompactBlock],
) -> Vec<DecryptedBlock<'a>> {
let vks: Vec<_> = self.vks.iter().collect();
let mut decrypted_blocks: Vec<DecryptedBlock> = blocks
@ -385,6 +399,23 @@ impl DecryptNode {
decrypted_blocks.sort_by(|a, b| a.height.cmp(&b.height));
decrypted_blocks
}
#[cfg(feature = "cuda")]
pub fn cuda_decrypt_blocks<'a>(
&self,
network: &Network,
blocks: &'a [CompactBlock],
) -> Vec<DecryptedBlock<'a>> {
let mut cuda_processor = CUDA_PROCESSOR.lock().unwrap();
if let Some(cuda_processor) = cuda_processor.as_mut() {
let mut decrypted_blocks = vec![];
for (account, vk) in self.vks.iter() {
decrypted_blocks.extend(cuda_processor.trial_decrypt(network, *account, &vk.fvk, blocks).unwrap());
}
return decrypted_blocks;
}
self.decrypt_blocks_soft(network, blocks)
}
}
#[allow(dead_code)]

9
src/cuda.rs Normal file
View File

@ -0,0 +1,9 @@
use std::sync::Mutex;
use lazy_static::lazy_static;
mod processor;
use processor::CudaProcessor;
lazy_static! {
pub static ref CUDA_PROCESSOR: Mutex<Option<CudaProcessor>> = Mutex::new(CudaProcessor::new().ok());
}

221
src/cuda/processor.rs Normal file
View File

@ -0,0 +1,221 @@
use std::convert::TryInto;
use std::ffi::CString;
use jubjub::{Fq, ExtendedPoint};
use rustacuda::launch;
use rustacuda::prelude::*;
use group::Curve;
use ff::BatchInverter;
use crate::{Hash, GENERATORS_EXP};
use anyhow::Result;
use rustacuda::context::CurrentContext;
use crate::lw_rpc::CompactBlock;
use zcash_note_encryption::Domain;
use zcash_primitives::consensus::{BlockHeight, Network};
use zcash_primitives::sapling::note_encryption::SaplingDomain;
use zcash_primitives::sapling::SaplingIvk;
use zcash_primitives::zip32::ExtendedFullViewingKey;
use crate::chain::{DecryptedBlock, DecryptedNote, Nf};
const THREADS_PER_BLOCK: usize = 256usize;
const BUFFER_SIZE: usize = 128usize;
pub struct CudaProcessor {
device: Device,
context: Context,
hash_module: Module,
trial_decrypt_module: Module,
stream: Stream,
generators: DeviceBuffer<u8>,
}
unsafe impl Send for CudaProcessor {}
impl CudaProcessor {
pub fn new() -> Result<Self> {
rustacuda::init(rustacuda::CudaFlags::empty())?;
let device = Device::get_device(0)?;
let context = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO, device)?;
let ptx = CString::new(include_str!("../cuda/hash.ptx"))?;
let hash_module = Module::load_from_string(&ptx)?;
let ptx = CString::new(include_str!("../cuda/trial_decrypt.ptx"))?;
let trial_decrypt_module = Module::load_from_string(&ptx)?;
let stream = Stream::new(StreamFlags::DEFAULT, None)?;
log::info!("Prepare Generators");
let generators_len = GENERATORS_EXP.len();
let mut gens = vec![0u8; generators_len * 128];
for i in 0..generators_len {
GENERATORS_EXP[i].copy_to_slice(&mut gens[i * 128..(i + 1) * 128]);
}
let mut generators = DeviceBuffer::from_slice(&gens)?;
Ok(CudaProcessor {
device,
context,
hash_module,
trial_decrypt_module,
stream,
generators,
})
}
pub fn batch_hash_cuda(&mut self, depth: u8, data: &[Hash]) -> Result<Vec<Hash>> {
CurrentContext::set_current(&self.context)?;
log::info!("cuda - pedersen hash");
let n = data.len() / 2;
let mut in_data = DeviceBuffer::from_slice(data)?;
let mut out_data = unsafe { DeviceBuffer::<u8>::zeroed(n * 32 * 2)? };
unsafe {
// Launch the kernel again using the `function` form:
let function_name = CString::new("pedersen_hash")?;
let hash = self.hash_module.get_function(&function_name)?;
let blocks = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let stream = &self.stream;
let result = launch!(hash<<<(blocks as u32, 1, 1), (THREADS_PER_BLOCK as u32, 1, 1), 1024, stream>>>(
n,
depth,
self.generators.as_device_ptr(),
in_data.as_device_ptr(),
out_data.as_device_ptr()
));
result?;
}
self.stream.synchronize()?;
let mut res = vec![0u8; n * 32 * 2];
out_data.copy_to(&mut res)?;
let mut p = vec![];
let mut q: Vec<AffinePoint> = vec![AffinePoint::default(); n];
for i in 0..n {
let b = i * 64;
let u = Fq::from_bytes(&res[b..b + 32].try_into().unwrap()).unwrap();
let z = Fq::from_bytes(&res[b + 32..b + 64].try_into().unwrap()).unwrap();
q[i].u = z;
p.push(u);
}
BatchInverter::invert_with_internal_scratch(&mut q, |q| &mut q.u, |q| &mut q.v);
let mut out = vec![];
for i in 0..n {
let hash: Hash = (p[i] * &q[i].u).to_bytes();
// println!("{} {} {} {}", i, hex::encode(&data[i * 2]), hex::encode(&data[i * 2 + 1]), hex::encode(&hash));
out.push(hash);
}
Ok(out)
}
pub fn trial_decrypt<'a>(&mut self, network: &Network,
account: u32,
fvk: &ExtendedFullViewingKey,
blocks: &'a [CompactBlock]) -> Result<Vec<DecryptedBlock<'a>>> {
CurrentContext::set_current(&self.context)?;
log::info!("cuda - trial decrypt");
let ivk = fvk.fvk.vk.ivk();
let mut ivk_fr = ivk.0;
ivk_fr = ivk_fr.double(); // multiply by cofactor
ivk_fr = ivk_fr.double();
ivk_fr = ivk_fr.double();
let n = blocks.iter().map(|b|
b.vtx.iter().map(|tx| tx.outputs.len()).sum::<usize>()
).sum::<usize>();
let block_count = (n + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
let mut data_buffer = vec![0u8; n*BUFFER_SIZE];
let mut i = 0;
for b in blocks.iter() {
for tx in b.vtx.iter() {
for co in tx.outputs.iter() {
data_buffer[i*BUFFER_SIZE..i*BUFFER_SIZE+32].copy_from_slice(&co.epk);
data_buffer[i*BUFFER_SIZE+64..i*BUFFER_SIZE+116].copy_from_slice(&co.ciphertext);
i += 1;
}
}
}
let mut ivk_device_buffer = DeviceBuffer::from_slice(&ivk_fr.to_bytes())?;
let mut data_device_buffer = DeviceBuffer::from_slice(&data_buffer)?;
unsafe {
// Launch the kernel again using the `function` form:
let function_name = CString::new("trial_decrypt_full").unwrap();
let trial_decrypt_full = self.trial_decrypt_module.get_function(&function_name).unwrap();
let stream = &self.stream;
let result = launch!(trial_decrypt_full<<<(block_count as u32, 1, 1), (THREADS_PER_BLOCK as u32, 1, 1), 0, stream>>>(
n,
ivk_device_buffer.as_device_ptr(),
data_device_buffer.as_device_ptr()
));
result?;
}
self.stream.synchronize()?;
data_device_buffer.copy_to(&mut data_buffer)?;
let mut i = 0;
let mut decrypted_blocks = vec![];
for b in blocks.iter() {
let mut decrypted_notes = vec![];
let mut spends = vec![];
let mut count_outputs: u32 = 0;
let domain = SaplingDomain::for_height(*network, BlockHeight::from_u32(b.height as u32));
for (tx_index, tx) in b.vtx.iter().enumerate() {
for cs in tx.spends.iter() {
let mut nf = [0u8; 32];
nf.copy_from_slice(&cs.nf);
spends.push(Nf(nf));
}
for (output_index, co) in tx.outputs.iter().enumerate() {
let plaintext = &data_buffer[i*BUFFER_SIZE+64..i*BUFFER_SIZE+116];
if let Some((note, pa)) = domain.parse_note_plaintext_without_memo_ivk(&ivk, plaintext) {
let cmu = note.cmu().to_bytes();
if &cmu == co.cmu.as_slice() {
decrypted_notes.push(DecryptedNote {
account,
ivk: fvk.clone(),
note,
pa,
position_in_block: count_outputs as usize,
viewonly: false,
height: b.height as u32,
txid: tx.hash.clone(),
tx_index,
output_index
});
}
}
count_outputs += 1;
i += 1;
}
}
decrypted_blocks.push(DecryptedBlock {
height: b.height as u32,
notes: decrypted_notes,
count_outputs,
spends,
compact_block: b,
elapsed: 0
});
}
Ok(decrypted_blocks)
}
}
#[derive(Default, Clone)]
struct AffinePoint {
u: Fq,
v: Fq,
}

View File

@ -8,7 +8,7 @@ use zcash_params::GENERATORS;
use zcash_primitives::constants::PEDERSEN_HASH_CHUNKS_PER_GENERATOR;
lazy_static! {
static ref GENERATORS_EXP: Vec<ExtendedNielsPoint> = read_generators_bin();
pub static ref GENERATORS_EXP: Vec<ExtendedNielsPoint> = read_generators_bin();
}
fn read_generators_bin() -> Vec<ExtendedNielsPoint> {
@ -47,7 +47,7 @@ macro_rules! accumulate_scalar {
};
}
type Hash = [u8; 32];
pub type Hash = [u8; 32];
pub fn pedersen_hash(depth: u8, left: &Hash, right: &Hash) -> Hash {
let p = pedersen_hash_inner(depth, left, right);

View File

@ -73,7 +73,7 @@ pub use crate::coinconfig::{
pub use crate::commitment::{CTree, Witness};
pub use crate::db::{AccountRec, DbAdapter, TxRec};
pub use crate::fountain::{put_drop, FountainCodes, RaptorQDrops};
pub use crate::hash::pedersen_hash;
pub use crate::hash::{Hash, pedersen_hash, GENERATORS_EXP};
pub use crate::key::{generate_random_enc_key, KeyHelpers};
pub use crate::lw_rpc::compact_tx_streamer_client::CompactTxStreamerClient;
pub use crate::lw_rpc::*;
@ -94,3 +94,6 @@ pub use crate::ledger::sweep_ledger;
#[cfg(feature = "nodejs")]
pub mod nodejs;
#[cfg(feature = "cuda")]
mod cuda;

View File

@ -136,7 +136,7 @@ pub fn list_accounts() -> Result<Json<Vec<AccountRec>>, Error> {
#[post("/sync?<offset>")]
pub async fn sync(offset: Option<u32>) -> Result<(), Error> {
let c = CoinConfig::get_active();
warp_api_ffi::api::sync::coin_sync(c.coin, true, offset.unwrap_or(0), u32::MAX, |_| {}, &SYNC_CANCELED)
warp_api_ffi::api::sync::coin_sync(c.coin, true, offset.unwrap_or(0), 50, |_| {}, &SYNC_CANCELED)
.await?;
Ok(())
}