fix(message-buffer): address PR feedback from 779, update tests (#784)

* fix(message-buffer): address PR feedback from 779, update tests

Add new unit test for reading with cursor, update test setup

* refactor(message-buffer): refactor test methods, add big endian doc on MessageBuffer
This commit is contained in:
swimricky 2023-04-26 14:16:50 -07:00 committed by GitHub
parent d16594ca6b
commit 04576df743
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 160 additions and 103 deletions

View File

@ -891,6 +891,7 @@ version = "0.1.0"
dependencies = [
"anchor-lang",
"bytemuck",
"byteorder",
]
[[package]]

View File

@ -19,3 +19,6 @@ default = []
anchor-lang = { version = "0.27.0" }
# needed for the new #[account(zero_copy)] in anchor 0.27.0
bytemuck = { version = "1.4.0", features = ["derive", "min_const_generics"]}
[dev-dependencies]
byteorder = "1.4.3"

View File

@ -75,6 +75,9 @@ pub fn create_buffer<'info>(
*message_buffer = MessageBuffer::new(bump);
}
loader.exit(&crate::ID)?;
} else {
// FIXME: change this to be emit!(Event)
msg!("Buffer account already initialized");
}
Ok(())

View File

@ -1,6 +1,5 @@
use {
crate::{
instructions::verify_message_buffer,
state::*,
MessageBufferError,
MESSAGE,
@ -23,7 +22,7 @@ pub fn delete_buffer<'info>(
.whitelist
.is_allowed_program_auth(&allowed_program_auth)?;
verify_message_buffer(message_buffer_account_info)?;
MessageBuffer::check_discriminator(message_buffer_account_info)?;
let expected_key = Pubkey::create_program_address(
&[

View File

@ -1,13 +1,6 @@
use {
crate::{
state::MessageBuffer,
MessageBufferError,
},
anchor_lang::{
prelude::*,
system_program,
Discriminator,
},
use anchor_lang::{
prelude::*,
system_program,
};
pub use {
create_buffer::*,
@ -23,34 +16,12 @@ mod put_all;
mod resize_buffer;
// String constants for deriving PDAs.
//
// An authorized program's message buffer will have PDA seeds [authorized_program_pda, MESSAGE, base_account_key],
// where authorized_program_pda is the
// where authorized_program_pda is the where `allowed_program_auth`
// is the whitelisted pubkey who authorized this call.
pub const MESSAGE: &str = "message";
pub const FUND: &str = "fund";
pub fn is_uninitialized_account(ai: &AccountInfo) -> bool {
ai.data_is_empty() && ai.owner == &system_program::ID
}
/// Verify message buffer account is initialized and has the correct discriminator.
///
/// Note: manually checking because using anchor's `AccountLoader.load()`
/// will panic since the `AccountInfo.data_len()` will not match the
/// size of the `MessageBuffer` since the `MessageBuffer` struct does not
/// include the messages.
pub fn verify_message_buffer(message_buffer_account_info: &AccountInfo) -> Result<()> {
if is_uninitialized_account(message_buffer_account_info) {
return err!(MessageBufferError::MessageBufferUninitialized);
}
let data = message_buffer_account_info.try_borrow_data()?;
if data.len() < MessageBuffer::discriminator().len() {
return Err(ErrorCode::AccountDiscriminatorNotFound.into());
}
let disc_bytes = &data[0..8];
if disc_bytes != &MessageBuffer::discriminator() {
return Err(ErrorCode::AccountDiscriminatorMismatch.into());
}
Ok(())
}

View File

@ -1,6 +1,5 @@
use {
crate::{
instructions::verify_message_buffer,
state::*,
MessageBufferError,
},
@ -20,7 +19,7 @@ pub fn put_all<'info>(
.first()
.ok_or(MessageBufferError::MessageBufferNotProvided)?;
verify_message_buffer(message_buffer_account_info)?;
MessageBuffer::check_discriminator(message_buffer_account_info)?;
let account_data = &mut message_buffer_account_info.try_borrow_mut_data()?;
let header_end_index = mem::size_of::<MessageBuffer>() + 8;

View File

@ -1,6 +1,5 @@
use {
crate::{
instructions::verify_message_buffer,
state::*,
MessageBufferError,
MESSAGE,
@ -30,7 +29,7 @@ pub fn resize_buffer<'info>(
ctx.accounts
.whitelist
.is_allowed_program_auth(&allowed_program_auth)?;
verify_message_buffer(message_buffer_account_info)?;
MessageBuffer::check_discriminator(message_buffer_account_info)?;
require_gte!(
target_size,
@ -62,7 +61,10 @@ pub fn resize_buffer<'info>(
MessageBufferError::InvalidPDA
);
if target_size_delta > 0 {
// allow for delta == 0 in case Rent requirements have changed
// and additional lamports need to be transferred.
// the realloc step will be a no-op in this case.
if target_size_delta >= 0 {
let target_rent = Rent::get()?.minimum_balance(target_size);
if message_buffer_account_info.lamports() < target_rent {
system_program::transfer(

View File

@ -1,9 +1,13 @@
use {
crate::{
accumulator_input_seeds,
instructions,
MessageBufferError,
},
anchor_lang::prelude::*,
anchor_lang::{
prelude::*,
Discriminator,
},
};
/// A MessageBuffer will have the following structure
@ -22,6 +26,13 @@ use {
/// A `MessageBuffer` AccountInfo.data will look like:
/// [ <discrimintator>, <buffer_header>, <messages> ]
/// (0..8) (8..header_len) (header_len...accountInfo.data.len)
///
///<br>
///
/// NOTE: The defined fields are read as *Little Endian*. The actual messages
/// are read as *Big Endian*. The MessageBuffer fields are only ever read
/// by the Pythnet validator & Hermes so don't need to be in Big Endian
/// for cross-platform compatibility.
#[account(zero_copy)]
#[derive(InitSpace, Debug)]
pub struct MessageBuffer {
@ -121,6 +132,28 @@ impl MessageBuffer {
require_keys_eq!(expected_key, key);
Ok(())
}
/// Verify message buffer account is initialized and has the correct discriminator.
///
/// Note: manually checking because using anchor's `AccountLoader.load()`
/// will panic since the `AccountInfo.data_len()` will not match the
/// size of the `MessageBuffer` since the `MessageBuffer` struct does not
/// include the messages.
pub fn check_discriminator(message_buffer_account_info: &AccountInfo) -> Result<()> {
if instructions::is_uninitialized_account(message_buffer_account_info) {
return err!(MessageBufferError::MessageBufferUninitialized);
}
let data = message_buffer_account_info.try_borrow_data()?;
if data.len() < MessageBuffer::discriminator().len() {
return Err(ErrorCode::AccountDiscriminatorNotFound.into());
}
let disc_bytes = &data[0..8];
if disc_bytes != &MessageBuffer::discriminator() {
return Err(ErrorCode::AccountDiscriminatorMismatch.into());
}
Ok(())
}
}
#[cfg(test)]
@ -154,6 +187,23 @@ mod test {
sighash
}
fn generate_message_buffer_bytes(_data_bytes: &Vec<Vec<u8>>) -> Vec<u8> {
let message_buffer = &mut MessageBuffer::new(0);
let header_len = message_buffer.header_len as usize;
let account_info_data = &mut vec![];
let discriminator = &mut sighash("accounts", "MessageBuffer");
let destination = &mut vec![0u8; 10_240 - header_len];
account_info_data.write_all(discriminator).unwrap();
account_info_data
.write_all(bytes_of_mut(message_buffer))
.unwrap();
account_info_data.write_all(destination).unwrap();
account_info_data.to_vec()
}
#[test]
fn test_sizes_and_alignments() {
@ -169,25 +219,19 @@ mod test {
let data = vec![vec![12, 34], vec![56, 78, 90]];
let data_bytes: Vec<Vec<u8>> = data.into_iter().map(data_bytes).collect();
let message_buffer = &mut MessageBuffer::new(0);
let header_len = message_buffer.header_len as usize;
let message_buffer_bytes = bytes_of_mut(message_buffer);
// assuming account_info.data.len() == 10KB
let messages = &mut vec![0u8; 10_240 - header_len];
let account_info_data = &mut generate_message_buffer_bytes(&data_bytes);
let account_info_data = &mut vec![];
let discriminator = &mut sighash("accounts", "MessageBuffer");
account_info_data.write_all(discriminator).unwrap();
account_info_data.write_all(message_buffer_bytes).unwrap();
account_info_data
.write_all(messages.as_mut_slice())
.unwrap();
let header_len = MessageBuffer::HEADER_LEN as usize;
let _account_data_len = account_info_data.len();
let destination = &mut account_info_data[(message_buffer.header_len as usize)..];
let (header_bytes, body_bytes) = account_info_data.split_at_mut(header_len);
let message_buffer: &mut MessageBuffer = bytemuck::from_bytes_mut(&mut header_bytes[8..]);
let (num_msgs, num_bytes) = message_buffer.put_all_in_buffer(destination, &data_bytes);
let (num_msgs, num_bytes) = message_buffer.put_all_in_buffer(body_bytes, &data_bytes);
let message_buffer: &MessageBuffer =
bytemuck::from_bytes(&account_info_data.as_slice()[8..header_len]);
assert_eq!(num_msgs, 2);
assert_eq!(num_bytes, 5);
@ -197,26 +241,22 @@ mod test {
assert_eq!(message_buffer.end_offsets[1], 5);
// let account_data = bytes_of(accumulator_input);
// // The header_len field represents the size of all data prior to the message bytes.
// // This includes the account discriminator, which is not part of the header struct.
// // Subtract the size of the discriminator (8 bytes) to compensate
// let header_len = message_buffer.header_len as usize - 8;
let header_len = message_buffer.header_len as usize;
let iter = message_buffer.end_offsets.iter().take_while(|x| **x != 0);
let mut start = header_len;
let mut data_iter = data_bytes.iter();
let read_data = &mut vec![];
for offset in iter {
let end_offset = header_len + *offset as usize;
let message_buffer_data = &account_info_data[start..end_offset];
let expected_data = data_iter.next().unwrap();
assert_eq!(message_buffer_data, expected_data.as_slice());
read_data.push(message_buffer_data);
start = end_offset;
}
println!("read_data: {:?}", read_data);
assert_eq!(read_data.len(), num_msgs);
for d in read_data.iter() {
let expected_data = data_iter.next().unwrap();
assert_eq!(d, &expected_data.as_slice());
}
}
#[test]
@ -225,25 +265,18 @@ mod test {
let data_bytes: Vec<Vec<u8>> = data.into_iter().map(data_bytes).collect();
let message_buffer = &mut MessageBuffer::new(0);
let header_len = message_buffer.header_len as usize;
let message_buffer_bytes = bytes_of_mut(message_buffer);
// assuming account_info.data.len() == 10KB
let messages = &mut vec![0u8; 10_240 - header_len];
let account_info_data = &mut generate_message_buffer_bytes(&data_bytes);
let account_info_data = &mut vec![];
let discriminator = &mut sighash("accounts", "MessageBuffer");
account_info_data.write_all(discriminator).unwrap();
account_info_data.write_all(message_buffer_bytes).unwrap();
account_info_data
.write_all(messages.as_mut_slice())
.unwrap();
let header_len = MessageBuffer::HEADER_LEN as usize;
let _account_data_len = account_info_data.len();
let (header_bytes, body_bytes) = account_info_data.split_at_mut(header_len);
let message_buffer: &mut MessageBuffer = bytemuck::from_bytes_mut(&mut header_bytes[8..]);
let destination = &mut account_info_data[(message_buffer.header_len as usize)..];
let (num_msgs, num_bytes) = message_buffer.put_all_in_buffer(body_bytes, &data_bytes);
let (num_msgs, num_bytes) = message_buffer.put_all_in_buffer(destination, &data_bytes);
let message_buffer: &MessageBuffer =
bytemuck::from_bytes(&account_info_data.as_slice()[8..header_len]);
assert_eq!(num_msgs, 2);
assert_eq!(
@ -265,7 +298,8 @@ mod test {
assert_eq!(message_buffer.end_offsets[2], 0);
}
//
#[test]
fn test_put_all_long_vec() {
let data = vec![
@ -277,29 +311,21 @@ mod test {
];
let data_bytes: Vec<Vec<u8>> = data.into_iter().map(data_bytes).collect();
// let message_buffer = &mut MessageBufferTemp::new(0);
// let (num_msgs, num_bytes) = message_buffer.put_all(&data_bytes);
let message_buffer = &mut MessageBuffer::new(0);
let header_len = message_buffer.header_len as usize;
let account_info_data = &mut generate_message_buffer_bytes(&data_bytes);
let message_buffer_bytes = bytes_of_mut(message_buffer);
// assuming account_info.data.len() == 10KB
let messages = &mut vec![0u8; 10_240 - header_len];
let header_len = MessageBuffer::HEADER_LEN as usize;
let account_info_data = &mut vec![];
let discriminator = &mut sighash("accounts", "MessageBuffer");
account_info_data.write_all(discriminator).unwrap();
account_info_data.write_all(message_buffer_bytes).unwrap();
account_info_data
.write_all(messages.as_mut_slice())
.unwrap();
let _account_data_len = account_info_data.len();
let (header_bytes, body_bytes) = account_info_data.split_at_mut(header_len);
let message_buffer: &mut MessageBuffer = bytemuck::from_bytes_mut(&mut header_bytes[8..]);
let destination = &mut account_info_data[(message_buffer.header_len as usize)..];
let (num_msgs, num_bytes) = message_buffer.put_all_in_buffer(body_bytes, &data_bytes);
let message_buffer: &MessageBuffer =
bytemuck::from_bytes(&account_info_data.as_slice()[8..header_len]);
let (num_msgs, num_bytes) = message_buffer.put_all_in_buffer(destination, &data_bytes);
assert_eq!(num_msgs, 3);
assert_eq!(
@ -325,4 +351,57 @@ mod test {
assert_eq!(message_buffer.end_offsets[3], 0);
assert_eq!(message_buffer.end_offsets[4], 0);
}
#[test]
pub fn test_cursor_read() {
use byteorder::{
LittleEndian,
ReadBytesExt,
};
let data = vec![vec![12, 34], vec![56, 78, 90]];
let data_bytes: Vec<Vec<u8>> = data.into_iter().map(data_bytes).collect();
let account_info_data = &mut generate_message_buffer_bytes(&data_bytes);
let header_len = MessageBuffer::HEADER_LEN as usize;
let (header_bytes, body_bytes) = account_info_data.split_at_mut(header_len);
let message_buffer: &mut MessageBuffer = bytemuck::from_bytes_mut(&mut header_bytes[8..]);
let (num_msgs, num_bytes) = message_buffer.put_all_in_buffer(body_bytes, &data_bytes);
assert_eq!(num_msgs, 2);
assert_eq!(num_bytes, 5);
let message_buffer: &MessageBuffer =
bytemuck::from_bytes(&account_info_data.as_slice()[8..header_len]);
assert_eq!(message_buffer.end_offsets[0], 2);
assert_eq!(message_buffer.end_offsets[1], 5);
let mut cursor = std::io::Cursor::new(&account_info_data[10..]);
let header_len = cursor.read_u16::<LittleEndian>().unwrap();
println!("header_len: {}", header_len);
let mut current_msg_start = header_len;
let mut end_offset = cursor.read_u16::<LittleEndian>().unwrap();
let mut data_iter = data_bytes.iter();
println!("init header_end: {}", end_offset);
let read_data = &mut vec![];
while end_offset != 0 {
let current_msg_end = header_len + end_offset;
let accumulator_input_data =
&account_info_data[current_msg_start as usize..current_msg_end as usize];
end_offset = cursor.read_u16::<LittleEndian>().unwrap();
current_msg_start = current_msg_end;
read_data.push(accumulator_input_data);
}
println!("read_data: {:?}", read_data);
for d in read_data.iter() {
let expected_data = data_iter.next().unwrap();
assert_eq!(d, &expected_data.as_slice());
}
assert_eq!(read_data.len(), 2);
}
}