Adding memcmp filter for accounts

This commit is contained in:
godmodegalactus 2024-05-17 17:20:46 +02:00
parent d21031afbf
commit b2db9d6622
No known key found for this signature in database
GPG Key ID: 22DA4A30887FDA3C
5 changed files with 180 additions and 6 deletions

View File

@ -157,6 +157,7 @@ mod tests {
.subscribe(vec![Filter::Account(AccountFilter {
owner: Some(Pubkey::default()),
accounts: None,
filter: None,
})])
.await
.unwrap();

View File

@ -38,12 +38,33 @@ impl Filter {
}
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)]
#[serde(rename_all = "camelCase")]
pub enum MemcmpFilterData {
Bytes(Vec<u8>),
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)]
#[serde(rename_all = "camelCase")]
pub struct MemcmpFilter {
pub offset: u64,
pub data: MemcmpFilterData,
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)]
#[serde(rename_all = "camelCase")]
pub enum AccountFilterType {
Datasize(u64),
Memcmp(MemcmpFilter),
}
// setting owner to 11111111111111111111111111111111 will subscribe to all the accounts
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
#[repr(C)]
pub struct AccountFilter {
pub owner: Option<Pubkey>,
pub accounts: Option<HashSet<Pubkey>>,
pub filter: Option<AccountFilterType>,
}
impl AccountFilter {
@ -52,6 +73,30 @@ impl AccountFilter {
if let Some(owner) = self.owner {
// check if filter subscribes to all the accounts
if owner == Pubkey::default() || owner == account.owner {
// to do move the filtering somewhere else because here we need to decode the account data
// but cannot be avoided for now, this will lag the client is abusing this filter
// lagged clients will be dropped
if let Some(filter) = &self.filter {
match filter {
AccountFilterType::Datasize(data_length) => {
return account.data_length == *data_length
}
AccountFilterType::Memcmp(memcmp) => {
if memcmp.offset > account.data_length {
return false;
}
let solana_account = account.solana_account();
let offset = memcmp.offset as usize;
let MemcmpFilterData::Bytes(bytes) = &memcmp.data;
if solana_account.data[offset..].len() < bytes.len() {
return false;
}
return solana_account.data[offset..offset + bytes.len()]
== bytes[..];
}
}
}
return true;
}
}
@ -62,3 +107,117 @@ impl AccountFilter {
false
}
}
#[cfg(test)]
mod tests {
use solana_sdk::{account::Account as SolanaAccount, pubkey::Pubkey};
use crate::{
filters::{AccountFilter, AccountFilterType, MemcmpFilter},
message::Message,
types::{account::Account, slot_identifier::SlotIdentifier},
};
#[tokio::test]
async fn test_accounts_filter() {
let owner = Pubkey::new_unique();
let solana_account_1 = SolanaAccount {
lamports: 1,
data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
owner,
executable: false,
rent_epoch: 100,
};
let solana_account_2 = SolanaAccount {
lamports: 2,
data: vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
owner,
executable: false,
rent_epoch: 100,
};
let solana_account_3 = SolanaAccount {
lamports: 3,
data: vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
owner: Pubkey::new_unique(),
executable: false,
rent_epoch: 100,
};
let msg_1 = Message::AccountMsg(Account::new(
Pubkey::new_unique(),
solana_account_1.clone(),
crate::compression::CompressionType::Lz4Fast(8),
SlotIdentifier { slot: 0 },
0,
));
let msg_2 = Message::AccountMsg(Account::new(
Pubkey::new_unique(),
solana_account_2.clone(),
crate::compression::CompressionType::Lz4Fast(8),
SlotIdentifier { slot: 0 },
0,
));
let msg_3 = Message::AccountMsg(Account::new(
Pubkey::new_unique(),
solana_account_3.clone(),
crate::compression::CompressionType::Lz4Fast(8),
SlotIdentifier { slot: 0 },
0,
));
let f1 = AccountFilter {
owner: Some(owner),
accounts: None,
filter: None,
};
assert_eq!(f1.allows(&msg_1), true);
assert_eq!(f1.allows(&msg_2), true);
assert_eq!(f1.allows(&msg_3), false);
let f2 = AccountFilter {
owner: Some(owner),
accounts: None,
filter: Some(AccountFilterType::Datasize(9)),
};
assert_eq!(f2.allows(&msg_1), false);
assert_eq!(f2.allows(&msg_2), false);
assert_eq!(f2.allows(&msg_3), false);
let f3 = AccountFilter {
owner: Some(owner),
accounts: None,
filter: Some(AccountFilterType::Datasize(10)),
};
assert_eq!(f3.allows(&msg_1), true);
assert_eq!(f3.allows(&msg_2), true);
assert_eq!(f3.allows(&msg_3), false);
let f4: AccountFilter = AccountFilter {
owner: Some(owner),
accounts: None,
filter: Some(AccountFilterType::Memcmp(MemcmpFilter {
offset: 2,
data: crate::filters::MemcmpFilterData::Bytes(vec![3, 4, 5]),
})),
};
assert_eq!(f4.allows(&msg_1), true);
assert_eq!(f4.allows(&msg_2), false);
assert_eq!(f4.allows(&msg_3), false);
let f5: AccountFilter = AccountFilter {
owner: Some(owner),
accounts: None,
filter: Some(AccountFilterType::Memcmp(MemcmpFilter {
offset: 2,
data: crate::filters::MemcmpFilterData::Bytes(vec![13, 14, 15]),
})),
};
assert_eq!(f5.allows(&msg_1), false);
assert_eq!(f5.allows(&msg_2), true);
assert_eq!(f5.allows(&msg_3), false);
}
}

View File

@ -1,4 +1,4 @@
use quinn::{Connection, Endpoint};
use quinn::{Connection, Endpoint, VarInt};
use std::sync::Arc;
use std::{collections::VecDeque, time::Duration};
use tokio::sync::Semaphore;
@ -216,7 +216,7 @@ impl ConnectionManager {
let id = connection_data.id;
tokio::spawn(async move {
let permit_result = semaphore.clone().try_acquire_owned();
let permit_result = semaphore.try_acquire_owned();
let _permit = match permit_result {
Ok(permit) => permit,
@ -227,10 +227,8 @@ impl ConnectionManager {
id,
message_type
);
semaphore
.acquire_owned()
.await
.expect("Should aquire the permit")
connection.close(VarInt::from_u32(0), b"laggy client");
return;
}
};

View File

@ -14,6 +14,7 @@ pub struct Account {
pub write_version: u64,
pub data: Vec<u8>,
pub compression_type: CompressionType,
pub data_length: u64,
}
impl Account {
@ -25,6 +26,7 @@ impl Account {
write_version: 0,
data: vec![178; data_size],
compression_type: CompressionType::None,
data_length: data_size as u64,
}
}
@ -36,6 +38,7 @@ impl Account {
write_version: u64,
) -> Self {
let binary = bincode::serialize(&solana_account).expect("account should be serializable");
let data_length = solana_account.data.len() as u64;
let data = match compression_type {
CompressionType::None => binary,
@ -61,6 +64,18 @@ impl Account {
write_version,
data,
compression_type,
data_length,
}
}
pub fn solana_account(&self) -> SolanaAccount {
match self.compression_type {
CompressionType::None => bincode::deserialize(&self.data).expect("Should deserialize"),
CompressionType::Lz4(_) | CompressionType::Lz4Fast(_) => {
let uncompressed =
lz4::block::decompress(&self.data, None).expect("should uncompress");
bincode::deserialize(&uncompressed).expect("Should deserialize")
}
}
}
}

View File

@ -120,6 +120,7 @@ async fn main() {
Filter::Account(AccountFilter {
owner: Some(Pubkey::default()),
accounts: None,
filter: None,
}),
Filter::Slot,
Filter::BlockMeta,