Adding memcmp filter for accounts
This commit is contained in:
parent
d21031afbf
commit
b2db9d6622
|
@ -157,6 +157,7 @@ mod tests {
|
|||
.subscribe(vec![Filter::Account(AccountFilter {
|
||||
owner: Some(Pubkey::default()),
|
||||
accounts: None,
|
||||
filter: None,
|
||||
})])
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
@ -38,12 +38,33 @@ impl Filter {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum MemcmpFilterData {
|
||||
Bytes(Vec<u8>),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct MemcmpFilter {
|
||||
pub offset: u64,
|
||||
pub data: MemcmpFilterData,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum AccountFilterType {
|
||||
Datasize(u64),
|
||||
Memcmp(MemcmpFilter),
|
||||
}
|
||||
|
||||
// setting owner to 11111111111111111111111111111111 will subscribe to all the accounts
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
|
||||
#[repr(C)]
|
||||
pub struct AccountFilter {
|
||||
pub owner: Option<Pubkey>,
|
||||
pub accounts: Option<HashSet<Pubkey>>,
|
||||
pub filter: Option<AccountFilterType>,
|
||||
}
|
||||
|
||||
impl AccountFilter {
|
||||
|
@ -52,6 +73,30 @@ impl AccountFilter {
|
|||
if let Some(owner) = self.owner {
|
||||
// check if filter subscribes to all the accounts
|
||||
if owner == Pubkey::default() || owner == account.owner {
|
||||
// to do move the filtering somewhere else because here we need to decode the account data
|
||||
// but cannot be avoided for now, this will lag the client is abusing this filter
|
||||
// lagged clients will be dropped
|
||||
if let Some(filter) = &self.filter {
|
||||
match filter {
|
||||
AccountFilterType::Datasize(data_length) => {
|
||||
return account.data_length == *data_length
|
||||
}
|
||||
AccountFilterType::Memcmp(memcmp) => {
|
||||
if memcmp.offset > account.data_length {
|
||||
return false;
|
||||
}
|
||||
let solana_account = account.solana_account();
|
||||
let offset = memcmp.offset as usize;
|
||||
let MemcmpFilterData::Bytes(bytes) = &memcmp.data;
|
||||
|
||||
if solana_account.data[offset..].len() < bytes.len() {
|
||||
return false;
|
||||
}
|
||||
return solana_account.data[offset..offset + bytes.len()]
|
||||
== bytes[..];
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -62,3 +107,117 @@ impl AccountFilter {
|
|||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use solana_sdk::{account::Account as SolanaAccount, pubkey::Pubkey};
|
||||
|
||||
use crate::{
|
||||
filters::{AccountFilter, AccountFilterType, MemcmpFilter},
|
||||
message::Message,
|
||||
types::{account::Account, slot_identifier::SlotIdentifier},
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_accounts_filter() {
|
||||
let owner = Pubkey::new_unique();
|
||||
|
||||
let solana_account_1 = SolanaAccount {
|
||||
lamports: 1,
|
||||
data: vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
owner,
|
||||
executable: false,
|
||||
rent_epoch: 100,
|
||||
};
|
||||
let solana_account_2 = SolanaAccount {
|
||||
lamports: 2,
|
||||
data: vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
||||
owner,
|
||||
executable: false,
|
||||
rent_epoch: 100,
|
||||
};
|
||||
let solana_account_3 = SolanaAccount {
|
||||
lamports: 3,
|
||||
data: vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
|
||||
owner: Pubkey::new_unique(),
|
||||
executable: false,
|
||||
rent_epoch: 100,
|
||||
};
|
||||
|
||||
let msg_1 = Message::AccountMsg(Account::new(
|
||||
Pubkey::new_unique(),
|
||||
solana_account_1.clone(),
|
||||
crate::compression::CompressionType::Lz4Fast(8),
|
||||
SlotIdentifier { slot: 0 },
|
||||
0,
|
||||
));
|
||||
|
||||
let msg_2 = Message::AccountMsg(Account::new(
|
||||
Pubkey::new_unique(),
|
||||
solana_account_2.clone(),
|
||||
crate::compression::CompressionType::Lz4Fast(8),
|
||||
SlotIdentifier { slot: 0 },
|
||||
0,
|
||||
));
|
||||
|
||||
let msg_3 = Message::AccountMsg(Account::new(
|
||||
Pubkey::new_unique(),
|
||||
solana_account_3.clone(),
|
||||
crate::compression::CompressionType::Lz4Fast(8),
|
||||
SlotIdentifier { slot: 0 },
|
||||
0,
|
||||
));
|
||||
|
||||
let f1 = AccountFilter {
|
||||
owner: Some(owner),
|
||||
accounts: None,
|
||||
filter: None,
|
||||
};
|
||||
|
||||
assert_eq!(f1.allows(&msg_1), true);
|
||||
assert_eq!(f1.allows(&msg_2), true);
|
||||
assert_eq!(f1.allows(&msg_3), false);
|
||||
|
||||
let f2 = AccountFilter {
|
||||
owner: Some(owner),
|
||||
accounts: None,
|
||||
filter: Some(AccountFilterType::Datasize(9)),
|
||||
};
|
||||
assert_eq!(f2.allows(&msg_1), false);
|
||||
assert_eq!(f2.allows(&msg_2), false);
|
||||
assert_eq!(f2.allows(&msg_3), false);
|
||||
|
||||
let f3 = AccountFilter {
|
||||
owner: Some(owner),
|
||||
accounts: None,
|
||||
filter: Some(AccountFilterType::Datasize(10)),
|
||||
};
|
||||
assert_eq!(f3.allows(&msg_1), true);
|
||||
assert_eq!(f3.allows(&msg_2), true);
|
||||
assert_eq!(f3.allows(&msg_3), false);
|
||||
|
||||
let f4: AccountFilter = AccountFilter {
|
||||
owner: Some(owner),
|
||||
accounts: None,
|
||||
filter: Some(AccountFilterType::Memcmp(MemcmpFilter {
|
||||
offset: 2,
|
||||
data: crate::filters::MemcmpFilterData::Bytes(vec![3, 4, 5]),
|
||||
})),
|
||||
};
|
||||
assert_eq!(f4.allows(&msg_1), true);
|
||||
assert_eq!(f4.allows(&msg_2), false);
|
||||
assert_eq!(f4.allows(&msg_3), false);
|
||||
|
||||
let f5: AccountFilter = AccountFilter {
|
||||
owner: Some(owner),
|
||||
accounts: None,
|
||||
filter: Some(AccountFilterType::Memcmp(MemcmpFilter {
|
||||
offset: 2,
|
||||
data: crate::filters::MemcmpFilterData::Bytes(vec![13, 14, 15]),
|
||||
})),
|
||||
};
|
||||
assert_eq!(f5.allows(&msg_1), false);
|
||||
assert_eq!(f5.allows(&msg_2), true);
|
||||
assert_eq!(f5.allows(&msg_3), false);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use quinn::{Connection, Endpoint};
|
||||
use quinn::{Connection, Endpoint, VarInt};
|
||||
use std::sync::Arc;
|
||||
use std::{collections::VecDeque, time::Duration};
|
||||
use tokio::sync::Semaphore;
|
||||
|
@ -216,7 +216,7 @@ impl ConnectionManager {
|
|||
let id = connection_data.id;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let permit_result = semaphore.clone().try_acquire_owned();
|
||||
let permit_result = semaphore.try_acquire_owned();
|
||||
|
||||
let _permit = match permit_result {
|
||||
Ok(permit) => permit,
|
||||
|
@ -227,10 +227,8 @@ impl ConnectionManager {
|
|||
id,
|
||||
message_type
|
||||
);
|
||||
semaphore
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("Should aquire the permit")
|
||||
connection.close(VarInt::from_u32(0), b"laggy client");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ pub struct Account {
|
|||
pub write_version: u64,
|
||||
pub data: Vec<u8>,
|
||||
pub compression_type: CompressionType,
|
||||
pub data_length: u64,
|
||||
}
|
||||
|
||||
impl Account {
|
||||
|
@ -25,6 +26,7 @@ impl Account {
|
|||
write_version: 0,
|
||||
data: vec![178; data_size],
|
||||
compression_type: CompressionType::None,
|
||||
data_length: data_size as u64,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -36,6 +38,7 @@ impl Account {
|
|||
write_version: u64,
|
||||
) -> Self {
|
||||
let binary = bincode::serialize(&solana_account).expect("account should be serializable");
|
||||
let data_length = solana_account.data.len() as u64;
|
||||
|
||||
let data = match compression_type {
|
||||
CompressionType::None => binary,
|
||||
|
@ -61,6 +64,18 @@ impl Account {
|
|||
write_version,
|
||||
data,
|
||||
compression_type,
|
||||
data_length,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn solana_account(&self) -> SolanaAccount {
|
||||
match self.compression_type {
|
||||
CompressionType::None => bincode::deserialize(&self.data).expect("Should deserialize"),
|
||||
CompressionType::Lz4(_) | CompressionType::Lz4Fast(_) => {
|
||||
let uncompressed =
|
||||
lz4::block::decompress(&self.data, None).expect("should uncompress");
|
||||
bincode::deserialize(&uncompressed).expect("Should deserialize")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -120,6 +120,7 @@ async fn main() {
|
|||
Filter::Account(AccountFilter {
|
||||
owner: Some(Pubkey::default()),
|
||||
accounts: None,
|
||||
filter: None,
|
||||
}),
|
||||
Filter::Slot,
|
||||
Filter::BlockMeta,
|
||||
|
|
Loading…
Reference in New Issue