simplifies sigverify copy_return_values (#29495)
This commit is contained in:
parent
0581fc2def
commit
b71cb9d9c7
|
@ -312,12 +312,9 @@ pub fn verify_shreds_gpu(
|
||||||
trace!("out buf {:?}", out);
|
trace!("out buf {:?}", out);
|
||||||
|
|
||||||
// Each shred has exactly one signature.
|
// Each shred has exactly one signature.
|
||||||
let v_sig_lens: Vec<_> = batches
|
let v_sig_lens = batches.iter().map(|batch| repeat(1u32).take(batch.len()));
|
||||||
.iter()
|
|
||||||
.map(|batch| vec![1u32; batch.len()])
|
|
||||||
.collect();
|
|
||||||
let mut rvs: Vec<_> = batches.iter().map(|batch| vec![0u8; batch.len()]).collect();
|
let mut rvs: Vec<_> = batches.iter().map(|batch| vec![0u8; batch.len()]).collect();
|
||||||
sigverify::copy_return_values(&v_sig_lens, &out, &mut rvs);
|
sigverify::copy_return_values(v_sig_lens, &out, &mut rvs);
|
||||||
|
|
||||||
inc_new_counter_debug!("ed25519_shred_verify_gpu", out.len());
|
inc_new_counter_debug!("ed25519_shred_verify_gpu", out.len());
|
||||||
rvs
|
rvs
|
||||||
|
|
|
@ -664,26 +664,18 @@ pub fn ed25519_verify_disabled(batches: &mut [PacketBatch]) {
|
||||||
inc_new_counter_debug!("ed25519_verify_disabled", packet_count);
|
inc_new_counter_debug!("ed25519_verify_disabled", packet_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn copy_return_values(sig_lens: &[Vec<u32>], out: &PinnedVec<u8>, rvs: &mut [Vec<u8>]) {
|
pub fn copy_return_values<I, T>(sig_lens: I, out: &PinnedVec<u8>, rvs: &mut [Vec<u8>])
|
||||||
let mut num = 0;
|
where
|
||||||
for (vs, sig_vs) in rvs.iter_mut().zip(sig_lens.iter()) {
|
I: IntoIterator<Item = T>,
|
||||||
for (v, sig_v) in vs.iter_mut().zip(sig_vs.iter()) {
|
T: IntoIterator<Item = u32>,
|
||||||
if *sig_v == 0 {
|
{
|
||||||
*v = 0;
|
debug_assert!(rvs.iter().flatten().all(|&rv| rv == 0u8));
|
||||||
} else {
|
let mut offset = 0usize;
|
||||||
let mut vout = 1;
|
let rvs = rvs.iter_mut().flatten();
|
||||||
for _ in 0..*sig_v {
|
for (k, rv) in sig_lens.into_iter().flatten().zip(rvs) {
|
||||||
if 0 == out[num] {
|
let out = out[offset..].iter().take(k as usize).all(|&x| x == 1u8);
|
||||||
vout = 0;
|
*rv = u8::from(k != 0u32 && out);
|
||||||
}
|
offset = offset.saturating_add(k as usize);
|
||||||
num = num.saturating_add(1);
|
|
||||||
}
|
|
||||||
*v = vout;
|
|
||||||
}
|
|
||||||
if *v != 0 {
|
|
||||||
trace!("VERIFIED PACKET!!!!!");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -796,7 +788,7 @@ pub fn ed25519_verify(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
trace!("done verify");
|
trace!("done verify");
|
||||||
copy_return_values(&sig_lens, &out, &mut rvs);
|
copy_return_values(sig_lens, &out, &mut rvs);
|
||||||
mark_disabled(batches, &rvs);
|
mark_disabled(batches, &rvs);
|
||||||
inc_new_counter_debug!("ed25519_verify_gpu", valid_packet_count);
|
inc_new_counter_debug!("ed25519_verify_gpu", valid_packet_count);
|
||||||
}
|
}
|
||||||
|
@ -820,7 +812,10 @@ mod tests {
|
||||||
signature::{Keypair, Signature, Signer},
|
signature::{Keypair, Signature, Signer},
|
||||||
transaction::Transaction,
|
transaction::Transaction,
|
||||||
},
|
},
|
||||||
std::sync::atomic::{AtomicU64, Ordering},
|
std::{
|
||||||
|
iter::repeat_with,
|
||||||
|
sync::atomic::{AtomicU64, Ordering},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const SIG_OFFSET: usize = 1;
|
const SIG_OFFSET: usize = 1;
|
||||||
|
@ -831,6 +826,45 @@ mod tests {
|
||||||
(0..end).find(|&i| a[i..i + b.len()] == b[..])
|
(0..end).find(|&i| a[i..i + b.len()] == b[..])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_copy_return_values() {
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let sig_lens: Vec<Vec<u32>> = {
|
||||||
|
let size = rng.gen_range(0, 64);
|
||||||
|
repeat_with(|| {
|
||||||
|
let size = rng.gen_range(0, 16);
|
||||||
|
repeat_with(|| rng.gen_range(0, 5)).take(size).collect()
|
||||||
|
})
|
||||||
|
.take(size)
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
let out: Vec<Vec<Vec<bool>>> = sig_lens
|
||||||
|
.iter()
|
||||||
|
.map(|sig_lens| {
|
||||||
|
sig_lens
|
||||||
|
.iter()
|
||||||
|
.map(|&size| repeat_with(|| rng.gen()).take(size as usize).collect())
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let expected: Vec<Vec<u8>> = out
|
||||||
|
.iter()
|
||||||
|
.map(|out| {
|
||||||
|
out.iter()
|
||||||
|
.map(|out| u8::from(!out.is_empty() && out.iter().all(|&k| k)))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let out =
|
||||||
|
PinnedVec::<u8>::from_vec(out.into_iter().flatten().flatten().map(u8::from).collect());
|
||||||
|
let mut rvs: Vec<Vec<u8>> = sig_lens
|
||||||
|
.iter()
|
||||||
|
.map(|sig_lens| vec![0u8; sig_lens.len()])
|
||||||
|
.collect();
|
||||||
|
copy_return_values(sig_lens, &out, &mut rvs);
|
||||||
|
assert_eq!(rvs, expected);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mark_disabled() {
|
fn test_mark_disabled() {
|
||||||
let batch_size = 1;
|
let batch_size = 1;
|
||||||
|
|
Loading…
Reference in New Issue