simplifies sigverify copy_return_values (#29495)

This commit is contained in:
behzad nouri 2023-01-05 19:45:52 +00:00 committed by GitHub
parent 0581fc2def
commit b71cb9d9c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 58 additions and 27 deletions

View File

@ -312,12 +312,9 @@ pub fn verify_shreds_gpu(
trace!("out buf {:?}", out);
// Each shred has exactly one signature.
let v_sig_lens: Vec<_> = batches
.iter()
.map(|batch| vec![1u32; batch.len()])
.collect();
let v_sig_lens = batches.iter().map(|batch| repeat(1u32).take(batch.len()));
let mut rvs: Vec<_> = batches.iter().map(|batch| vec![0u8; batch.len()]).collect();
sigverify::copy_return_values(&v_sig_lens, &out, &mut rvs);
sigverify::copy_return_values(v_sig_lens, &out, &mut rvs);
inc_new_counter_debug!("ed25519_shred_verify_gpu", out.len());
rvs

View File

@ -664,26 +664,18 @@ pub fn ed25519_verify_disabled(batches: &mut [PacketBatch]) {
inc_new_counter_debug!("ed25519_verify_disabled", packet_count);
}
pub fn copy_return_values(sig_lens: &[Vec<u32>], out: &PinnedVec<u8>, rvs: &mut [Vec<u8>]) {
let mut num = 0;
for (vs, sig_vs) in rvs.iter_mut().zip(sig_lens.iter()) {
for (v, sig_v) in vs.iter_mut().zip(sig_vs.iter()) {
if *sig_v == 0 {
*v = 0;
} else {
let mut vout = 1;
for _ in 0..*sig_v {
if 0 == out[num] {
vout = 0;
}
num = num.saturating_add(1);
}
*v = vout;
}
if *v != 0 {
trace!("VERIFIED PACKET!!!!!");
}
}
pub fn copy_return_values<I, T>(sig_lens: I, out: &PinnedVec<u8>, rvs: &mut [Vec<u8>])
where
I: IntoIterator<Item = T>,
T: IntoIterator<Item = u32>,
{
debug_assert!(rvs.iter().flatten().all(|&rv| rv == 0u8));
let mut offset = 0usize;
let rvs = rvs.iter_mut().flatten();
for (k, rv) in sig_lens.into_iter().flatten().zip(rvs) {
let out = out[offset..].iter().take(k as usize).all(|&x| x == 1u8);
*rv = u8::from(k != 0u32 && out);
offset = offset.saturating_add(k as usize);
}
}
@ -796,7 +788,7 @@ pub fn ed25519_verify(
}
}
trace!("done verify");
copy_return_values(&sig_lens, &out, &mut rvs);
copy_return_values(sig_lens, &out, &mut rvs);
mark_disabled(batches, &rvs);
inc_new_counter_debug!("ed25519_verify_gpu", valid_packet_count);
}
@ -820,7 +812,10 @@ mod tests {
signature::{Keypair, Signature, Signer},
transaction::Transaction,
},
std::sync::atomic::{AtomicU64, Ordering},
std::{
iter::repeat_with,
sync::atomic::{AtomicU64, Ordering},
},
};
const SIG_OFFSET: usize = 1;
@ -831,6 +826,45 @@ mod tests {
(0..end).find(|&i| a[i..i + b.len()] == b[..])
}
#[test]
fn test_copy_return_values() {
let mut rng = rand::thread_rng();
let sig_lens: Vec<Vec<u32>> = {
let size = rng.gen_range(0, 64);
repeat_with(|| {
let size = rng.gen_range(0, 16);
repeat_with(|| rng.gen_range(0, 5)).take(size).collect()
})
.take(size)
.collect()
};
let out: Vec<Vec<Vec<bool>>> = sig_lens
.iter()
.map(|sig_lens| {
sig_lens
.iter()
.map(|&size| repeat_with(|| rng.gen()).take(size as usize).collect())
.collect()
})
.collect();
let expected: Vec<Vec<u8>> = out
.iter()
.map(|out| {
out.iter()
.map(|out| u8::from(!out.is_empty() && out.iter().all(|&k| k)))
.collect()
})
.collect();
let out =
PinnedVec::<u8>::from_vec(out.into_iter().flatten().flatten().map(u8::from).collect());
let mut rvs: Vec<Vec<u8>> = sig_lens
.iter()
.map(|sig_lens| vec![0u8; sig_lens.len()])
.collect();
copy_return_values(sig_lens, &out, &mut rvs);
assert_eq!(rvs, expected);
}
#[test]
fn test_mark_disabled() {
let batch_size = 1;