This commit is contained in:
Hanh 2022-08-18 11:42:06 +08:00
parent f15e8a73dd
commit ba1283f8ef
5 changed files with 412 additions and 9 deletions

View File

@ -87,6 +87,10 @@ node-bindgen = { version = "4.0", optional = true }
rustacuda = { version = "0.1.3", optional = true } rustacuda = { version = "0.1.3", optional = true }
rustacuda_core = { version = "0.1.2", optional = true } rustacuda_core = { version = "0.1.2", optional = true }
ash = { version = "0.37", features = [ "linked" ], optional = true }
metal = { version = "0.24", optional = true }
objc = { version = "0.2", features = [ "objc_exception" ], optional = true }
block = { version = "0.1.6", optional = true }
[features] [features]
ledger = ["ledger-apdu", "hmac", "ed25519-bip32", "ledger-transport-hid"] ledger = ["ledger-apdu", "hmac", "ed25519-bip32", "ledger-transport-hid"]
@ -95,6 +99,8 @@ dart_ffi = ["allo-isolate", "once_cell", "android_logger"]
rpc = ["rocket", "dotenv"] rpc = ["rocket", "dotenv"]
nodejs = ["node-bindgen"] nodejs = ["node-bindgen"]
cuda = ["rustacuda", "rustacuda_core"] cuda = ["rustacuda", "rustacuda_core"]
vulkan = ["ash"]
apple_metal = ["metal", "objc", "block"]
# librustzcash synced to 35023ed8ca2fb1061e78fd740b640d4eefcc5edd # librustzcash synced to 35023ed8ca2fb1061e78fd740b640d4eefcc5edd

View File

@ -33,7 +33,10 @@ use crate::{advance_tree, has_cuda};
use crate::gpu::trial_decrypt; use crate::gpu::trial_decrypt;
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
use crate::gpu::cuda::{CUDA_CONTEXT, CudaProcessor}; use crate::gpu::cuda::{CUDA_CONTEXT, CudaProcessor};
#[cfg(feature = "vulkan")]
use crate::gpu::vulkan::VulkanProcessor; use crate::gpu::vulkan::VulkanProcessor;
#[cfg(feature = "apple_metal")]
use crate::gpu::metal::MetalProcessor;
pub static DOWNLOADED_BYTES: AtomicUsize = AtomicUsize::new(0); pub static DOWNLOADED_BYTES: AtomicUsize = AtomicUsize::new(0);
pub static TRIAL_DECRYPTIONS: AtomicUsize = AtomicUsize::new(0); pub static TRIAL_DECRYPTIONS: AtomicUsize = AtomicUsize::new(0);
@ -422,6 +425,9 @@ impl DecryptNode {
#[cfg(feature = "vulkan")] #[cfg(feature = "vulkan")]
return self.vulkan_decrypt_blocks(network, blocks); return self.vulkan_decrypt_blocks(network, blocks);
#[cfg(feature = "apple_metal")]
return self.metal_decrypt_blocks(network, blocks);
#[allow(unreachable_code)] #[allow(unreachable_code)]
self.decrypt_blocks_soft(network, blocks) self.decrypt_blocks_soft(network, blocks)
} }
@ -469,6 +475,19 @@ impl DecryptNode {
let processor = VulkanProcessor::setup_decrypt(network, blocks, Path::new("")).unwrap(); let processor = VulkanProcessor::setup_decrypt(network, blocks, Path::new("")).unwrap();
trial_decrypt(processor, self.vks.iter()).unwrap() trial_decrypt(processor, self.vks.iter()).unwrap()
} }
#[cfg(feature = "apple_metal")]
pub fn metal_decrypt_blocks(
&self,
network: &Network,
blocks: Vec<CompactBlock>,
) -> Vec<DecryptedBlock> {
if blocks.is_empty() {
return vec![];
}
let processor = MetalProcessor::setup_decrypt(network, blocks).unwrap();
trial_decrypt(processor, self.vks.iter()).unwrap()
}
} }
#[allow(dead_code)] #[allow(dead_code)]

View File

@ -14,6 +14,9 @@ pub mod cuda;
#[cfg(feature = "vulkan")] #[cfg(feature = "vulkan")]
pub mod vulkan; pub mod vulkan;
#[cfg(feature = "apple_metal")]
pub mod metal;
pub trait GPUProcessor { pub trait GPUProcessor {
fn decrypt_account(&mut self, ivk: &SaplingIvk) -> Result<()>; fn decrypt_account(&mut self, ivk: &SaplingIvk) -> Result<()>;
fn get_decrypted_blocks(self) -> Result<Vec<DecryptedBlock>>; fn get_decrypted_blocks(self) -> Result<Vec<DecryptedBlock>>;

384
src/gpu/metal.rs Normal file
View File

@ -0,0 +1,384 @@
use std::convert::TryInto;
use std::sync::Mutex;
use std::mem;
use std::ptr::slice_from_raw_parts;
use std::time::SystemTime;
use jubjub::Fq;
use metal::*;
use objc::rc::autoreleasepool;
use rand::rngs::OsRng;
use ff::Field;
use lazy_static::lazy_static;
use rand::{RngCore, SeedableRng};
use rand_chacha::ChaChaRng;
use zcash_client_backend::encoding::decode_extended_full_viewing_key;
use zcash_note_encryption::Domain;
use zcash_primitives::consensus::{BlockHeight, MainNetwork, Network, Parameters};
use zcash_primitives::sapling::note_encryption::SaplingDomain;
use zcash_primitives::sapling::SaplingIvk;
use crate::chain::DecryptedBlock;
use crate::CompactBlock;
use crate::gpu::{collect_nf, GPUProcessor};
lazy_static! {
pub static ref METAL_CONTEXT: Mutex<MetalContext> =
Mutex::new(MetalContext::new());
}
pub const N: usize = 200_000;
const WIDTH: u64 = 256;
#[derive(Clone)]
pub struct CompactOutput {
pub height: u32,
pub epk: [u8; 32],
pub cmu: [u8; 32],
pub ciphertext: [u8; 52],
}
#[repr(C)]
#[derive(Clone)]
struct Data {
key: [u8; 32],
epk: [u8; 32],
cipher: [u8; 64],
}
impl Default for Data {
fn default() -> Self {
Data {
key: [0; 32],
epk: [0; 32],
cipher: [0; 64],
}
}
}
pub struct MetalContext {
device: Device,
command_queue: CommandQueue,
kernel: Function,
ivk_buffer: Buffer,
data_buffer: Buffer,
}
unsafe impl Send for MetalContext {}
impl MetalContext {
pub fn new() -> Self {
let library_data = include_bytes!("./metal/main.metallib");
let device = Device::system_default().expect("no device found");
let command_queue = device.new_command_queue();
let library = device.new_library_with_data(&*library_data).unwrap();
let kernel = library.get_function("decrypt", None).unwrap();
let ivk_buffer = device.new_buffer(32, MTLResourceOptions::CPUCacheModeDefaultCache);
let data_buffer = device.new_buffer((N * MetalProcessor::buffer_stride()) as u64, MTLResourceOptions::CPUCacheModeDefaultCache);
MetalContext {
device,
command_queue,
kernel,
ivk_buffer,
data_buffer,
}
}
}
pub struct MetalProcessor {
network: Network,
decrypted_blocks: Vec<DecryptedBlock>,
encrypted_data: Vec<Data>,
decrypted_data: Vec<u8>,
n: usize,
}
impl MetalProcessor {
pub fn setup_decrypt(network: &Network, blocks: Vec<CompactBlock>) -> anyhow::Result<Self> {
log::info!("Metal::setup_decrypt");
let decrypted_blocks = collect_nf(blocks)?;
let mut encrypted_data: Vec<Data> = vec![];
for db in decrypted_blocks.iter() {
let b = &db.compact_block;
for tx in b.vtx.iter() {
for co in tx.outputs.iter() {
let mut cipher = [0u8; 64];
cipher[0..52].copy_from_slice(&co.ciphertext);
let data = Data {
key: [0u8; 32],
epk: co.clone().epk.try_into().unwrap(),
cipher,
};
encrypted_data.push(data);
}
}
}
let n = encrypted_data.len();
let mp = MetalProcessor {
network: network.clone(),
decrypted_blocks,
encrypted_data,
decrypted_data: vec![0u8; N * Self::buffer_stride()],
n
};
Ok(mp)
}
}
impl GPUProcessor for MetalProcessor {
fn decrypt_account(&mut self, ivk: &SaplingIvk) -> anyhow::Result<()> {
unsafe {
let mc = METAL_CONTEXT.lock().unwrap();
let mut ivk_fr = ivk.0;
ivk_fr = ivk_fr.double(); // multiply by cofactor
ivk_fr = ivk_fr.double();
ivk_fr = ivk_fr.double();
let ivk = ivk_fr.to_bytes();
mc.ivk_buffer.contents().copy_from(ivk.as_ptr().cast(), 32);
mc.data_buffer.contents().copy_from(self.encrypted_data.as_ptr().cast(), N * Self::buffer_stride());
let command_buffer = mc.command_queue.new_command_buffer();
let argument_encoder = mc.kernel.new_argument_encoder(0);
let arg_buffer = mc.device.new_buffer(
argument_encoder.encoded_length(),
MTLResourceOptions::empty(),
);
argument_encoder.set_argument_buffer(&arg_buffer, 0);
argument_encoder.set_buffer(0, &mc.ivk_buffer, 0);
argument_encoder.set_buffer(1, &mc.data_buffer, 0);
let encoder = command_buffer.new_compute_command_encoder();
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&mc.kernel));
let pipeline_state = mc.device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&arg_buffer), 0);
encoder.set_buffer(1, Some(&mc.data_buffer), 0);
encoder.use_resource(&mc.ivk_buffer, MTLResourceUsage::Read);
encoder.use_resource(&mc.data_buffer, MTLResourceUsage::Read | MTLResourceUsage::Write);
let width = WIDTH.into();
let thread_group_count = MTLSize {
width: N as u64 / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
unsafe {
let results = mc.data_buffer.contents() as *mut u8;
let res = std::slice::from_raw_parts::<u8>(results.cast(), N * Self::buffer_stride());
self.decrypted_data.copy_from_slice(&res);
}
Ok(())
}
}
fn get_decrypted_blocks(self) -> anyhow::Result<Vec<DecryptedBlock>> {
Ok(self.decrypted_blocks)
}
fn network(&self) -> Network {
self.network
}
fn borrow_buffers(&mut self) -> (&[u8], &mut [DecryptedBlock]) {
(&self.decrypted_data, &mut self.decrypted_blocks)
}
fn buffer_stride() -> usize {
mem::size_of::<Data>()
}
}
const TEST_FVK: &str = "zxviews1q0kl7tavzyqqpq8efe0vpgzwc37zj0zr9j2quurncpsy74tdvh9c3racve9yfv6gkssvekw4sz6ueenvup6whupguzkg5rgp0kma37r4uxz9472w4zwra4jv6fm5dc2cevfpjsxdgndagslmgdwudhv4stklzfeszrlcnsqxyr2qt8tsf4yxs3he4rzllcly7xagfmnlycvvnvhhr9l9j6ad693rkueqys9f7mkc7aacxwp3tfc9hpvlckxnj4nwu6jef2x98jefhcgmpkrmn";
const ZECPAGES_FVK: &str = "zxviews1q0duytgcqqqqpqre26wkl45gvwwwd706xw608hucmvfalr759ejwf7qshjf5r9aa7323zulvz6plhttp5mltqcgs9t039cx2d09mgq05ts63n8u35hyv6h9nc9ctqqtue2u7cer2mqegunuulq2luhq3ywjcz35yyljewa4mgkgjzyfwh6fr6jd0dzd44ghk0nxdv2hnv4j5nxfwv24rwdmgllhe0p8568sgqt9ckt02v2kxf5ahtql6s0ltjpkckw8gtymxtxuu9gcr0swvz";
pub fn test_co() -> CompactOutput {
let mut cmu = hex::decode("263a4c43290ce7d644c0a3ab694bb4710a4c3b20a528e2297ac1d360b017f704").unwrap();
cmu.reverse(); // epk was is given in MSB
let mut epk = hex::decode("d8360fc851709bb8d53e1f7ad2bab2c28c70d2c3c570af6620599f078ab37e02").unwrap();
epk.reverse();
let ciphertext = hex::decode("c9c2479a4c936b25c4848a15fc5debad377f0305f7e744cfb550bc09da12922669b6a4d82d2c8d56d9c804682bae459474467aad").unwrap();
CompactOutput {
height: 500_000,
epk: epk.try_into().unwrap(),
cmu: cmu.try_into().unwrap(),
ciphertext: ciphertext.try_into().unwrap(),
}
}
fn main() {
env_logger::init();
let library_data = include_bytes!("./metal/main.metallib");
let mut rng = ChaChaRng::from_seed([0; 32]);
autoreleasepool(|| {
let device = Device::system_default().expect("no device found");
let command_queue = device.new_command_queue();
let library = device.new_library_with_data(&*library_data).unwrap();
let kernel = library.get_function("decrypt", None).unwrap();
let fvk = decode_extended_full_viewing_key(Network::MainNetwork.hrp_sapling_extended_full_viewing_key(),
ZECPAGES_FVK).unwrap().unwrap();
let ivk = fvk.fvk.vk.ivk();
let mut ivk_fr = ivk.0;
ivk_fr = ivk_fr.double(); // multiply by cofactor (8)
ivk_fr = ivk_fr.double();
ivk_fr = ivk_fr.double();
let ivk8 = ivk_fr.to_bytes();
println!("ivk8: {}", hex::encode(&ivk8));
// let ivk8 = hex::decode("40c075fe695bf7135f70dc098fca6fab6a26774f8a070472579d00309386be1b").unwrap();
// let mut ivk8 = [0u8; 32];
// ivk8[0] = 1;
// let x = Fq::random(&mut rng);
// ivk8.copy_from_slice(&x.to_bytes());
// let mut test_data = vec![Data::default(); n];
// for i in 0..n {
// test_data[i].epk.copy_from_slice(&epk);
// test_data[i].cipher[0..52].copy_from_slice(&ciphertext);
// }
let mut test_data: Vec<Data> = vec![];
let notes = vec![test_co()];
for n in notes.iter() {
let mut cipher = [0u8; 64];
cipher[0..52].copy_from_slice(&n.ciphertext);
let data = Data {
key: [0u8; 32],
epk: n.epk,
cipher,
};
test_data.push(data);
}
let n = notes.len();
let ivk_buffer = device.new_buffer_with_data(
unsafe { mem::transmute(ivk8.as_ptr()) },
32u64,
MTLResourceOptions::CPUCacheModeDefaultCache,
);
let data_buffer = {
device.new_buffer_with_data(
unsafe { mem::transmute(test_data.as_ptr()) },
(test_data.len() * mem::size_of::<Data>()) as u64,
MTLResourceOptions::CPUCacheModeDefaultCache,
)
};
let ptr = data_buffer.contents() as *mut u8;
unsafe {
let res: &[Data] = std::slice::from_raw_parts::<Data>(ptr.cast(), 1).try_into().unwrap();
println!("Before {}", hex::encode(&res[0].epk));
}
let command_buffer = command_queue.new_command_buffer();
let argument_encoder = kernel.new_argument_encoder(0);
let arg_buffer = device.new_buffer(
argument_encoder.encoded_length(),
MTLResourceOptions::empty(),
);
argument_encoder.set_argument_buffer(&arg_buffer, 0);
argument_encoder.set_buffer(0, &ivk_buffer, 0);
argument_encoder.set_buffer(1, &data_buffer, 0);
let encoder = command_buffer.new_compute_command_encoder();
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&kernel));
let pipeline_state = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
encoder.set_compute_pipeline_state(&pipeline_state);
encoder.set_buffer(0, Some(&arg_buffer), 0);
encoder.set_buffer(1, Some(&data_buffer), 0);
encoder.use_resource(&ivk_buffer, MTLResourceUsage::Read);
encoder.use_resource(&data_buffer, MTLResourceUsage::Read | MTLResourceUsage::Write);
let width = 256;
let thread_group_count = MTLSize {
width: (test_data.len() as u64 + width - 1) / width,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
log::info!("Start - n = {}, n_groups = {}", n, thread_group_count.width);
let stopwatch = SystemTime::now();
command_buffer.commit();
command_buffer.wait_until_completed();
log::info!("Finish - {}", stopwatch.elapsed().unwrap().as_millis());
let ptr = data_buffer.contents() as *mut u8;
unsafe {
let res: &[Data] = std::slice::from_raw_parts::<Data>(ptr.cast(), n).try_into().unwrap();
let mut count = 0;
for i in 0..n {
let d = &res[i];
// let product = Fq::from_bytes(&ivk8).unwrap() - Fq::from_bytes(&d.epk).unwrap();
// let x = Fq::from_bytes(&d.key).unwrap();
// println!("{} {} {}", hex::encode(&d.epk), hex::encode(product.to_bytes()), hex::encode(&d.key));
// assert_eq!(product, Fq::from_bytes(&d.key).unwrap());
// println!("{}", hex::encode(&d.cipher));
let pt = &d.cipher;
let domain = SaplingDomain::for_height(MainNetwork, BlockHeight::from_u32(500_000));
if let Some((note, pa)) = domain.parse_note_plaintext_without_memo_ivk(&ivk, pt) {
if note.cmu().to_bytes() == notes[i].cmu.as_slice() {
// log::info!("{:?}", note);
// println!("{:?}", encode_payment_address(NETWORK.hrp_sapling_payment_address(), &pa));
count += 1;
}
}
}
log::info!("COUNT = {}", count);
}
});
}

View File

@ -57,15 +57,6 @@ async fn main() -> anyhow::Result<()> {
env_logger::init(); env_logger::init();
let _ = dotenv::dotenv(); let _ = dotenv::dotenv();
let server = get_best_server(&[
"https://lwdv3.zecwallet.co:443".to_string(),
"https://zuul.free2z.cash:9067".to_string(),
"https://mainnet.lightwalletd.com:9067".to_string(),
])
.await
.unwrap();
log::info!("Best server = {}", server);
let rocket = rocket::build(); let rocket = rocket::build();
let figment = rocket.figment(); let figment = rocket.figment();
let zec: HashMap<String, String> = figment.extract_inner("zec")?; let zec: HashMap<String, String> = figment.extract_inner("zec")?;