patch(hermes): improve ws reliability

- Add max message size for incoming messages
- Add sent message rate limit and ip whitelisting
This commit is contained in:
Ali Behjati 2023-10-04 22:05:44 +02:00
parent 5fdc0d2545
commit 1a64d58834
5 changed files with 182 additions and 33 deletions

69
hermes/Cargo.lock generated
View File

@ -1798,6 +1798,24 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "governor"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4"
dependencies = [
"cfg-if",
"dashmap",
"futures",
"futures-timer",
"no-std-compat",
"nonzero_ext",
"parking_lot 0.12.1",
"quanta",
"rand 0.8.5",
"smallvec",
]
[[package]]
name = "h2"
version = "0.3.20"
@ -1858,7 +1876,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermes"
version = "0.2.0"
version = "0.2.1"
dependencies = [
"anyhow",
"async-trait",
@ -1873,13 +1891,16 @@ dependencies = [
"derive_more",
"env_logger 0.10.0",
"futures",
"governor",
"hex",
"humantime",
"ipnet",
"lazy_static",
"libc",
"libp2p",
"log",
"mock_instant",
"nonzero_ext",
"prometheus-client",
"pyth-sdk",
"pythnet-sdk",
@ -3021,6 +3042,15 @@ dependencies = [
"linked-hash-map",
]
[[package]]
name = "mach2"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8"
dependencies = [
"libc",
]
[[package]]
name = "match_cfg"
version = "0.1.0"
@ -3317,6 +3347,12 @@ dependencies = [
"libc",
]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
[[package]]
name = "nohash-hasher"
version = "0.2.0"
@ -3333,6 +3369,12 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nonzero_ext"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@ -4052,6 +4094,22 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "quanta"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab"
dependencies = [
"crossbeam-utils",
"libc",
"mach2",
"once_cell",
"raw-cpuid",
"wasi 0.11.0+wasi-snapshot-preview1",
"web-sys",
"winapi",
]
[[package]]
name = "quick-error"
version = "1.2.3"
@ -4220,6 +4278,15 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "raw-cpuid"
version = "10.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "rayon"
version = "1.7.0"

View File

@ -1,6 +1,6 @@
[package]
name = "hermes"
version = "0.2.0"
version = "0.2.1"
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
edition = "2021"
@ -19,10 +19,13 @@ env_logger = { version = "0.10.0" }
futures = { version = "0.3.28" }
hex = { version = "0.4.3", features = ["serde"] }
humantime = { version = "2.1.0" }
ipnet = { version = "2.8.0" }
governor = { version = "0.6.0" }
lazy_static = { version = "1.4.0" }
libc = { version = "0.2.140" }
log = { version = "0.4.17" }
mock_instant = { version = "0.3.1", features = ["sync"] }
nonzero_ext = { version = "0.3.0" }
prometheus-client = { version = "0.21.1" }
pyth-sdk = { version = "0.8.0" }
pythnet-sdk = { path = "../pythnet/pythnet_sdk/", version = "2.0.0", features = ["strum"] }

View File

@ -11,11 +11,15 @@ use {
routing::get,
Router,
},
ipnet::IpNet,
serde_qs::axum::QsQueryConfig,
std::sync::{
std::{
net::SocketAddr,
sync::{
atomic::Ordering,
Arc,
},
},
tokio::{
signal,
sync::mpsc::Receiver,
@ -36,10 +40,10 @@ pub struct ApiState {
}
impl ApiState {
pub fn new(state: Arc<State>) -> Self {
pub fn new(state: Arc<State>, ws_whitelist: Vec<IpNet>) -> Self {
Self {
state,
ws: Arc::new(ws::WsState::new()),
ws: Arc::new(ws::WsState::new(ws_whitelist)),
}
}
}
@ -84,7 +88,7 @@ pub async fn run(
)]
struct ApiDoc;
let state = ApiState::new(state);
let state = ApiState::new(state, opts.rpc.ws_whitelist);
// Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
// `with_state` method which replaces `Body` with `State` in the type signature.
@ -131,7 +135,7 @@ pub async fn run(
// Binds the axum's server to the configured address and port. This is a blocking call and will
// not return until the server is shutdown.
axum::Server::try_bind(&opts.rpc.addr)?
.serve(app.into_make_service())
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.with_graceful_shutdown(async {
// Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler
// should also have triggered so we will let that one print the shutdown warning.

View File

@ -21,6 +21,7 @@ use {
WebSocket,
WebSocketUpgrade,
},
ConnectInfo,
State as AxumState,
},
response::IntoResponse,
@ -35,6 +36,13 @@ use {
SinkExt,
StreamExt,
},
governor::{
DefaultKeyedRateLimiter,
Quota,
RateLimiter,
},
ipnet::IpNet,
nonzero_ext::nonzero,
pyth_sdk::PriceIdentifier,
serde::{
Deserialize,
@ -42,6 +50,11 @@ use {
},
std::{
collections::HashMap,
net::{
IpAddr,
SocketAddr,
},
num::NonZeroU32,
sync::{
atomic::{
AtomicUsize,
@ -54,8 +67,13 @@ use {
tokio::sync::mpsc,
};
pub const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
pub const NOTIFICATIONS_CHAN_LEN: usize = 1000;
const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
const NOTIFICATIONS_CHAN_LEN: usize = 1000;
const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
/// The maximum number of bytes that can be sent per second per IP address.
/// If the limit is exceeded, the connection is closed.
const BYTES_LIMIT_PER_IP_PER_SECOND: u32 = 256 * 1024; // 256 KiB
#[derive(Clone)]
pub struct PriceFeedClientConfig {
@ -67,13 +85,19 @@ pub struct PriceFeedClientConfig {
pub struct WsState {
pub subscriber_counter: AtomicUsize,
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
pub bytes_limit_whitelist: Vec<IpNet>,
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
}
impl WsState {
pub fn new() -> Self {
pub fn new(whitelist: Vec<IpNet>) -> Self {
Self {
subscriber_counter: AtomicUsize::new(0),
subscribers: DashMap::new(),
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
BYTES_LIMIT_PER_IP_PER_SECOND
))),
bytes_limit_whitelist: whitelist,
}
}
}
@ -118,20 +142,29 @@ enum ServerResponseMessage {
pub async fn ws_route_handler(
ws: WebSocketUpgrade,
AxumState(state): AxumState<super::ApiState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
ws.on_upgrade(|socket| websocket_handler(socket, state))
ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
.on_upgrade(move |socket| websocket_handler(socket, state, addr))
}
#[tracing::instrument(skip(stream, state))]
async fn websocket_handler(stream: WebSocket, state: super::ApiState) {
#[tracing::instrument(skip(stream, state, addr))]
async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
let ws_state = state.ws.clone();
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
tracing::debug!(id, "New Websocket Connection");
tracing::debug!(id, %addr, "New Websocket Connection");
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
let (sender, receiver) = stream.split();
let mut subscriber =
Subscriber::new(id, state.state.clone(), notify_receiver, receiver, sender);
let mut subscriber = Subscriber::new(
id,
addr.ip(),
state.state.clone(),
state.ws.clone(),
notify_receiver,
receiver,
sender,
);
ws_state.subscribers.insert(id, notify_sender);
subscriber.run().await;
@ -143,8 +176,10 @@ pub type SubscriberId = usize;
/// It listens to the store for updates and sends them to the client.
pub struct Subscriber {
id: SubscriberId,
ip_addr: IpAddr,
closed: bool,
store: Arc<State>,
ws_state: Arc<WsState>,
notify_receiver: mpsc::Receiver<AggregationEvent>,
receiver: SplitStream<WebSocket>,
sender: SplitSink<WebSocket, Message>,
@ -156,15 +191,19 @@ pub struct Subscriber {
impl Subscriber {
pub fn new(
id: SubscriberId,
ip_addr: IpAddr,
store: Arc<State>,
ws_state: Arc<WsState>,
notify_receiver: mpsc::Receiver<AggregationEvent>,
receiver: SplitStream<WebSocket>,
sender: SplitSink<WebSocket, Message>,
) -> Self {
Self {
id,
ip_addr,
closed: false,
store,
ws_state,
notify_receiver,
receiver,
sender,
@ -243,19 +282,45 @@ impl Subscriber {
}
}
// `sender.feed` buffers a message to the client but does not flush it, so we can send
// multiple messages and flush them all at once.
self.sender
.feed(Message::Text(serde_json::to_string(
&ServerMessage::PriceUpdate {
let message = serde_json::to_string(&ServerMessage::PriceUpdate {
price_feed: RpcPriceFeed::from_price_feed_update(
update,
config.verbose,
config.binary,
),
},
)?))
})?;
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
if !self
.ws_state
.bytes_limit_whitelist
.contains(&self.ip_addr.into())
&& self.ws_state.rate_limiter.check_key_n(
&self.ip_addr,
NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
) != Ok(Ok(()))
{
tracing::info!(
self.id,
ip = %self.ip_addr,
"Rate limit exceeded. Closing connection.",
);
self.sender
.send(
serde_json::to_string(&ServerResponseMessage::Err {
error: "Rate limit exceeded".to_string(),
})?
.into(),
)
.await?;
self.sender.close().await?;
self.closed = true;
return Ok(());
}
// `sender.feed` buffers a message to the client but does not flush it, so we can send
// multiple messages and flush them all at once.
self.sender.feed(message.into()).await?;
}
self.sender.flush().await?;
@ -394,4 +459,7 @@ pub async fn notify_updates(ws_state: Arc<WsState>, event: AggregationEvent) {
ws_state.subscribers.remove(&id);
}
});
// Clean the bytes limiting dictionary
ws_state.rate_limiter.retain_recent();
}

View File

@ -1,5 +1,6 @@
use {
clap::Args,
ipnet::IpNet,
std::net::SocketAddr,
};
@ -14,4 +15,10 @@ pub struct Options {
#[arg(default_value = DEFAULT_RPC_ADDR)]
#[arg(env = "RPC_ADDR")]
pub addr: SocketAddr,
/// Whitelisted websocket ip network addresses (separated by comma).
#[arg(long = "rpc-ws-whitelist")]
#[arg(value_delimiter = ',')]
#[arg(env = "RPC_WS_WHITELIST")]
pub ws_whitelist: Vec<IpNet>,
}