diff --git a/client/src/client.rs b/client/src/client.rs index de2b0ef..0b3dbcd 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -157,6 +157,7 @@ mod tests { .subscribe(vec![Filter::Account(AccountFilter { owner: Some(Pubkey::default()), accounts: None, + filter: None, })]) .await .unwrap(); diff --git a/common/src/filters.rs b/common/src/filters.rs index 8447eab..42e7151 100644 --- a/common/src/filters.rs +++ b/common/src/filters.rs @@ -38,12 +38,33 @@ impl Filter { } } +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)] +#[serde(rename_all = "camelCase")] +pub enum MemcmpFilterData { + Bytes(Vec), +} + +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)] +#[serde(rename_all = "camelCase")] +pub struct MemcmpFilter { + pub offset: u64, + pub data: MemcmpFilterData, +} + +#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)] +#[serde(rename_all = "camelCase")] +pub enum AccountFilterType { + Datasize(u64), + Memcmp(MemcmpFilter), +} + // setting owner to 11111111111111111111111111111111 will subscribe to all the accounts #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] #[repr(C)] pub struct AccountFilter { pub owner: Option, pub accounts: Option>, + pub filter: Option, } impl AccountFilter { @@ -52,6 +73,30 @@ impl AccountFilter { if let Some(owner) = self.owner { // check if filter subscribes to all the accounts if owner == Pubkey::default() || owner == account.owner { + // to do move the filtering somewhere else because here we need to decode the account data + // but cannot be avoided for now, this will lag the client is abusing this filter + // lagged clients will be dropped + if let Some(filter) = &self.filter { + match filter { + AccountFilterType::Datasize(data_length) => { + return account.data_length == *data_length + } + AccountFilterType::Memcmp(memcmp) => { + if memcmp.offset > account.data_length { + return false; + } + let solana_account = account.solana_account(); + let offset = memcmp.offset as usize; + let MemcmpFilterData::Bytes(bytes) = &memcmp.data; + + if solana_account.data[offset..].len() < bytes.len() { + return false; + } + return solana_account.data[offset..offset + bytes.len()] + == bytes[..]; + } + } + } return true; } } @@ -62,3 +107,117 @@ impl AccountFilter { false } } + +#[cfg(test)] +mod tests { + use solana_sdk::{account::Account as SolanaAccount, pubkey::Pubkey}; + + use crate::{ + filters::{AccountFilter, AccountFilterType, MemcmpFilter}, + message::Message, + types::{account::Account, slot_identifier::SlotIdentifier}, + }; + + #[tokio::test] + async fn test_accounts_filter() { + let owner = Pubkey::new_unique(); + + let solana_account_1 = SolanaAccount { + lamports: 1, + data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + owner, + executable: false, + rent_epoch: 100, + }; + let solana_account_2 = SolanaAccount { + lamports: 2, + data: vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + owner, + executable: false, + rent_epoch: 100, + }; + let solana_account_3 = SolanaAccount { + lamports: 3, + data: vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20], + owner: Pubkey::new_unique(), + executable: false, + rent_epoch: 100, + }; + + let msg_1 = Message::AccountMsg(Account::new( + Pubkey::new_unique(), + solana_account_1.clone(), + crate::compression::CompressionType::Lz4Fast(8), + SlotIdentifier { slot: 0 }, + 0, + )); + + let msg_2 = Message::AccountMsg(Account::new( + Pubkey::new_unique(), + solana_account_2.clone(), + crate::compression::CompressionType::Lz4Fast(8), + SlotIdentifier { slot: 0 }, + 0, + )); + + let msg_3 = Message::AccountMsg(Account::new( + Pubkey::new_unique(), + solana_account_3.clone(), + crate::compression::CompressionType::Lz4Fast(8), + SlotIdentifier { slot: 0 }, + 0, + )); + + let f1 = AccountFilter { + owner: Some(owner), + accounts: None, + filter: None, + }; + + assert_eq!(f1.allows(&msg_1), true); + assert_eq!(f1.allows(&msg_2), true); + assert_eq!(f1.allows(&msg_3), false); + + let f2 = AccountFilter { + owner: Some(owner), + accounts: None, + filter: Some(AccountFilterType::Datasize(9)), + }; + assert_eq!(f2.allows(&msg_1), false); + assert_eq!(f2.allows(&msg_2), false); + assert_eq!(f2.allows(&msg_3), false); + + let f3 = AccountFilter { + owner: Some(owner), + accounts: None, + filter: Some(AccountFilterType::Datasize(10)), + }; + assert_eq!(f3.allows(&msg_1), true); + assert_eq!(f3.allows(&msg_2), true); + assert_eq!(f3.allows(&msg_3), false); + + let f4: AccountFilter = AccountFilter { + owner: Some(owner), + accounts: None, + filter: Some(AccountFilterType::Memcmp(MemcmpFilter { + offset: 2, + data: crate::filters::MemcmpFilterData::Bytes(vec![3, 4, 5]), + })), + }; + assert_eq!(f4.allows(&msg_1), true); + assert_eq!(f4.allows(&msg_2), false); + assert_eq!(f4.allows(&msg_3), false); + + let f5: AccountFilter = AccountFilter { + owner: Some(owner), + accounts: None, + filter: Some(AccountFilterType::Memcmp(MemcmpFilter { + offset: 2, + data: crate::filters::MemcmpFilterData::Bytes(vec![13, 14, 15]), + })), + }; + assert_eq!(f5.allows(&msg_1), false); + assert_eq!(f5.allows(&msg_2), true); + assert_eq!(f5.allows(&msg_3), false); + } +} diff --git a/common/src/quic/connection_manager.rs b/common/src/quic/connection_manager.rs index 689746d..f41476a 100644 --- a/common/src/quic/connection_manager.rs +++ b/common/src/quic/connection_manager.rs @@ -1,4 +1,4 @@ -use quinn::{Connection, Endpoint}; +use quinn::{Connection, Endpoint, VarInt}; use std::sync::Arc; use std::{collections::VecDeque, time::Duration}; use tokio::sync::Semaphore; @@ -216,7 +216,7 @@ impl ConnectionManager { let id = connection_data.id; tokio::spawn(async move { - let permit_result = semaphore.clone().try_acquire_owned(); + let permit_result = semaphore.try_acquire_owned(); let _permit = match permit_result { Ok(permit) => permit, @@ -227,10 +227,8 @@ impl ConnectionManager { id, message_type ); - semaphore - .acquire_owned() - .await - .expect("Should aquire the permit") + connection.close(VarInt::from_u32(0), b"laggy client"); + return; } }; diff --git a/common/src/types/account.rs b/common/src/types/account.rs index d289b56..ea9b6d3 100644 --- a/common/src/types/account.rs +++ b/common/src/types/account.rs @@ -14,6 +14,7 @@ pub struct Account { pub write_version: u64, pub data: Vec, pub compression_type: CompressionType, + pub data_length: u64, } impl Account { @@ -25,6 +26,7 @@ impl Account { write_version: 0, data: vec![178; data_size], compression_type: CompressionType::None, + data_length: data_size as u64, } } @@ -36,6 +38,7 @@ impl Account { write_version: u64, ) -> Self { let binary = bincode::serialize(&solana_account).expect("account should be serializable"); + let data_length = solana_account.data.len() as u64; let data = match compression_type { CompressionType::None => binary, @@ -61,6 +64,18 @@ impl Account { write_version, data, compression_type, + data_length, + } + } + + pub fn solana_account(&self) -> SolanaAccount { + match self.compression_type { + CompressionType::None => bincode::deserialize(&self.data).expect("Should deserialize"), + CompressionType::Lz4(_) | CompressionType::Lz4Fast(_) => { + let uncompressed = + lz4::block::decompress(&self.data, None).expect("should uncompress"); + bincode::deserialize(&uncompressed).expect("Should deserialize") + } } } } diff --git a/tester/src/main.rs b/tester/src/main.rs index 7560187..1a734d4 100644 --- a/tester/src/main.rs +++ b/tester/src/main.rs @@ -120,6 +120,7 @@ async fn main() { Filter::Account(AccountFilter { owner: Some(Pubkey::default()), accounts: None, + filter: None, }), Filter::Slot, Filter::BlockMeta,