patch(hermes): improve ws reliability

- Add max message size for incoming messages
- Add sent message rate limit and ip whitelisting
This commit is contained in:
Ali Behjati 2023-10-04 22:05:44 +02:00
parent 5fdc0d2545
commit 1a64d58834
5 changed files with 182 additions and 33 deletions

69
hermes/Cargo.lock generated
View File

@ -1798,6 +1798,24 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "governor"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4"
dependencies = [
"cfg-if",
"dashmap",
"futures",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot 0.12.1",
"quanta",
"rand 0.8.5",
"smallvec",
]
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.3.20" version = "0.3.20"
@ -1858,7 +1876,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]] [[package]]
name = "hermes" name = "hermes"
version = "0.2.0" version = "0.2.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
@ -1873,13 +1891,16 @@ dependencies = [
"derive_more", "derive_more",
"env_logger 0.10.0", "env_logger 0.10.0",
"futures", "futures",
"governor",
"hex", "hex",
"humantime", "humantime",
"ipnet",
"lazy_static", "lazy_static",
"libc", "libc",
"libp2p", "libp2p",
"log", "log",
"mock_instant", "mock_instant",
"nonzero_ext",
"prometheus-client", "prometheus-client",
"pyth-sdk", "pyth-sdk",
"pythnet-sdk", "pythnet-sdk",
@ -3021,6 +3042,15 @@ dependencies = [
"linked-hash-map", "linked-hash-map",
] ]
[[package]]
name = "mach2"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "match_cfg" name = "match_cfg"
version = "0.1.0" version = "0.1.0"
@ -3317,6 +3347,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]] [[package]]
name = "nohash-hasher" name = "nohash-hasher"
version = "0.2.0" version = "0.2.0"
@ -3333,6 +3369,12 @@ dependencies = [
"minimal-lexical", "minimal-lexical",
] ]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]] [[package]]
name = "nu-ansi-term" name = "nu-ansi-term"
version = "0.46.0" version = "0.46.0"
@ -4052,6 +4094,22 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "quanta"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab"
dependencies = [
"crossbeam-utils",
"libc",
"mach2",
"once_cell",
"raw-cpuid",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]] [[package]]
name = "quick-error" name = "quick-error"
version = "1.2.3" version = "1.2.3"
@ -4220,6 +4278,15 @@ dependencies = [
"rand_core 0.6.4", "rand_core 0.6.4",
] ]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
dependencies = [
"bitflags 1.3.2",
]
[[package]] [[package]]
name = "rayon" name = "rayon"
version = "1.7.0" version = "1.7.0"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "hermes" name = "hermes"
version = "0.2.0" version = "0.2.1"
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle." description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
edition = "2021" edition = "2021"
@ -19,10 +19,13 @@ env_logger = { version = "0.10.0" }
futures = { version = "0.3.28" } futures = { version = "0.3.28" }
hex = { version = "0.4.3", features = ["serde"] } hex = { version = "0.4.3", features = ["serde"] }
humantime = { version = "2.1.0" } humantime = { version = "2.1.0" }
ipnet = { version = "2.8.0" }
governor = { version = "0.6.0" }
lazy_static = { version = "1.4.0" } lazy_static = { version = "1.4.0" }
libc = { version = "0.2.140" } libc = { version = "0.2.140" }
log = { version = "0.4.17" } log = { version = "0.4.17" }
mock_instant = { version = "0.3.1", features = ["sync"] } mock_instant = { version = "0.3.1", features = ["sync"] }
nonzero_ext = { version = "0.3.0" }
prometheus-client = { version = "0.21.1" } prometheus-client = { version = "0.21.1" }
pyth-sdk = { version = "0.8.0" } pyth-sdk = { version = "0.8.0" }
pythnet-sdk = { path = "../pythnet/pythnet_sdk/", version = "2.0.0", features = ["strum"] } pythnet-sdk = { path = "../pythnet/pythnet_sdk/", version = "2.0.0", features = ["strum"] }

View File

@ -11,11 +11,15 @@ use {
routing::get, routing::get,
Router, Router,
}, },
ipnet::IpNet,
serde_qs::axum::QsQueryConfig, serde_qs::axum::QsQueryConfig,
std::sync::{ std::{
net::SocketAddr,
sync::{
atomic::Ordering, atomic::Ordering,
Arc, Arc,
}, },
},
tokio::{ tokio::{
signal, signal,
sync::mpsc::Receiver, sync::mpsc::Receiver,
@ -36,10 +40,10 @@ pub struct ApiState {
} }
impl ApiState { impl ApiState {
pub fn new(state: Arc<State>) -> Self { pub fn new(state: Arc<State>, ws_whitelist: Vec<IpNet>) -> Self {
Self { Self {
state, state,
ws: Arc::new(ws::WsState::new()), ws: Arc::new(ws::WsState::new(ws_whitelist)),
} }
} }
} }
@ -84,7 +88,7 @@ pub async fn run(
)] )]
struct ApiDoc; struct ApiDoc;
let state = ApiState::new(state); let state = ApiState::new(state, opts.rpc.ws_whitelist);
// Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the // Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
// `with_state` method which replaces `Body` with `State` in the type signature. // `with_state` method which replaces `Body` with `State` in the type signature.
@ -131,7 +135,7 @@ pub async fn run(
// Binds the axum's server to the configured address and port. This is a blocking call and will // Binds the axum's server to the configured address and port. This is a blocking call and will
// not return until the server is shutdown. // not return until the server is shutdown.
axum::Server::try_bind(&opts.rpc.addr)? axum::Server::try_bind(&opts.rpc.addr)?
.serve(app.into_make_service()) .serve(app.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(async { .with_graceful_shutdown(async {
// Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler // Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler
// should also have triggered so we will let that one print the shutdown warning. // should also have triggered so we will let that one print the shutdown warning.

View File

@ -21,6 +21,7 @@ use {
WebSocket, WebSocket,
WebSocketUpgrade, WebSocketUpgrade,
}, },
ConnectInfo,
State as AxumState, State as AxumState,
}, },
response::IntoResponse, response::IntoResponse,
@ -35,6 +36,13 @@ use {
SinkExt, SinkExt,
StreamExt, StreamExt,
}, },
governor::{
DefaultKeyedRateLimiter,
Quota,
RateLimiter,
},
ipnet::IpNet,
nonzero_ext::nonzero,
pyth_sdk::PriceIdentifier, pyth_sdk::PriceIdentifier,
serde::{ serde::{
Deserialize, Deserialize,
@ -42,6 +50,11 @@ use {
}, },
std::{ std::{
collections::HashMap, collections::HashMap,
net::{
IpAddr,
SocketAddr,
},
num::NonZeroU32,
sync::{ sync::{
atomic::{ atomic::{
AtomicUsize, AtomicUsize,
@ -54,8 +67,13 @@ use {
tokio::sync::mpsc, tokio::sync::mpsc,
}; };
pub const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30); const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
pub const NOTIFICATIONS_CHAN_LEN: usize = 1000; const NOTIFICATIONS_CHAN_LEN: usize = 1000;
const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
/// The maximum number of bytes that can be sent per second per IP address.
/// If the limit is exceeded, the connection is closed.
const BYTES_LIMIT_PER_IP_PER_SECOND: u32 = 256 * 1024; // 256 KiB
#[derive(Clone)] #[derive(Clone)]
pub struct PriceFeedClientConfig { pub struct PriceFeedClientConfig {
@ -67,13 +85,19 @@ pub struct PriceFeedClientConfig {
pub struct WsState { pub struct WsState {
pub subscriber_counter: AtomicUsize, pub subscriber_counter: AtomicUsize,
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>, pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
pub bytes_limit_whitelist: Vec<IpNet>,
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
} }
impl WsState { impl WsState {
pub fn new() -> Self { pub fn new(whitelist: Vec<IpNet>) -> Self {
Self { Self {
subscriber_counter: AtomicUsize::new(0), subscriber_counter: AtomicUsize::new(0),
subscribers: DashMap::new(), subscribers: DashMap::new(),
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
BYTES_LIMIT_PER_IP_PER_SECOND
))),
bytes_limit_whitelist: whitelist,
} }
} }
} }
@ -118,20 +142,29 @@ enum ServerResponseMessage {
pub async fn ws_route_handler( pub async fn ws_route_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
AxumState(state): AxumState<super::ApiState>, AxumState(state): AxumState<super::ApiState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse { ) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket_handler(socket, state)) ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
.on_upgrade(move |socket| websocket_handler(socket, state, addr))
} }
#[tracing::instrument(skip(stream, state))] #[tracing::instrument(skip(stream, state, addr))]
async fn websocket_handler(stream: WebSocket, state: super::ApiState) { async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
let ws_state = state.ws.clone(); let ws_state = state.ws.clone();
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst); let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
tracing::debug!(id, "New Websocket Connection"); tracing::debug!(id, %addr, "New Websocket Connection");
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN); let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
let (sender, receiver) = stream.split(); let (sender, receiver) = stream.split();
let mut subscriber = let mut subscriber = Subscriber::new(
Subscriber::new(id, state.state.clone(), notify_receiver, receiver, sender); id,
addr.ip(),
state.state.clone(),
state.ws.clone(),
notify_receiver,
receiver,
sender,
);
ws_state.subscribers.insert(id, notify_sender); ws_state.subscribers.insert(id, notify_sender);
subscriber.run().await; subscriber.run().await;
@ -143,8 +176,10 @@ pub type SubscriberId = usize;
/// It listens to the store for updates and sends them to the client. /// It listens to the store for updates and sends them to the client.
pub struct Subscriber { pub struct Subscriber {
id: SubscriberId, id: SubscriberId,
ip_addr: IpAddr,
closed: bool, closed: bool,
store: Arc<State>, store: Arc<State>,
ws_state: Arc<WsState>,
notify_receiver: mpsc::Receiver<AggregationEvent>, notify_receiver: mpsc::Receiver<AggregationEvent>,
receiver: SplitStream<WebSocket>, receiver: SplitStream<WebSocket>,
sender: SplitSink<WebSocket, Message>, sender: SplitSink<WebSocket, Message>,
@ -156,15 +191,19 @@ pub struct Subscriber {
impl Subscriber { impl Subscriber {
pub fn new( pub fn new(
id: SubscriberId, id: SubscriberId,
ip_addr: IpAddr,
store: Arc<State>, store: Arc<State>,
ws_state: Arc<WsState>,
notify_receiver: mpsc::Receiver<AggregationEvent>, notify_receiver: mpsc::Receiver<AggregationEvent>,
receiver: SplitStream<WebSocket>, receiver: SplitStream<WebSocket>,
sender: SplitSink<WebSocket, Message>, sender: SplitSink<WebSocket, Message>,
) -> Self { ) -> Self {
Self { Self {
id, id,
ip_addr,
closed: false, closed: false,
store, store,
ws_state,
notify_receiver, notify_receiver,
receiver, receiver,
sender, sender,
@ -243,19 +282,45 @@ impl Subscriber {
} }
} }
// `sender.feed` buffers a message to the client but does not flush it, so we can send let message = serde_json::to_string(&ServerMessage::PriceUpdate {
// multiple messages and flush them all at once.
self.sender
.feed(Message::Text(serde_json::to_string(
&ServerMessage::PriceUpdate {
price_feed: RpcPriceFeed::from_price_feed_update( price_feed: RpcPriceFeed::from_price_feed_update(
update, update,
config.verbose, config.verbose,
config.binary, config.binary,
), ),
}, })?;
)?))
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
if !self
.ws_state
.bytes_limit_whitelist
.contains(&self.ip_addr.into())
&& self.ws_state.rate_limiter.check_key_n(
&self.ip_addr,
NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
) != Ok(Ok(()))
{
tracing::info!(
self.id,
ip = %self.ip_addr,
"Rate limit exceeded. Closing connection.",
);
self.sender
.send(
serde_json::to_string(&ServerResponseMessage::Err {
error: "Rate limit exceeded".to_string(),
})?
.into(),
)
.await?; .await?;
self.sender.close().await?;
self.closed = true;
return Ok(());
}
// `sender.feed` buffers a message to the client but does not flush it, so we can send
// multiple messages and flush them all at once.
self.sender.feed(message.into()).await?;
} }
self.sender.flush().await?; self.sender.flush().await?;
@ -394,4 +459,7 @@ pub async fn notify_updates(ws_state: Arc<WsState>, event: AggregationEvent) {
ws_state.subscribers.remove(&id); ws_state.subscribers.remove(&id);
} }
}); });
// Clean the bytes limiting dictionary
ws_state.rate_limiter.retain_recent();
} }

View File

@ -1,5 +1,6 @@
use { use {
clap::Args, clap::Args,
ipnet::IpNet,
std::net::SocketAddr, std::net::SocketAddr,
}; };
@ -14,4 +15,10 @@ pub struct Options {
#[arg(default_value = DEFAULT_RPC_ADDR)] #[arg(default_value = DEFAULT_RPC_ADDR)]
#[arg(env = "RPC_ADDR")] #[arg(env = "RPC_ADDR")]
pub addr: SocketAddr, pub addr: SocketAddr,
/// Whitelisted websocket ip network addresses (separated by comma).
#[arg(long = "rpc-ws-whitelist")]
#[arg(value_delimiter = ',')]
#[arg(env = "RPC_WS_WHITELIST")]
pub ws_whitelist: Vec<IpNet>,
} }