patch(hermes): improve ws reliability
- Add max message size for incoming messages - Add sent message rate limit and ip whitelisting
This commit is contained in:
parent
5fdc0d2545
commit
1a64d58834
|
@ -1798,6 +1798,24 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "governor"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "821239e5672ff23e2a7060901fa622950bbd80b649cdaadd78d1c1767ed14eb4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"dashmap",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"no-std-compat",
|
||||
"nonzero_ext",
|
||||
"parking_lot 0.12.1",
|
||||
"quanta",
|
||||
"rand 0.8.5",
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.20"
|
||||
|
@ -1858,7 +1876,7 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
|||
|
||||
[[package]]
|
||||
name = "hermes"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
|
@ -1873,13 +1891,16 @@ dependencies = [
|
|||
"derive_more",
|
||||
"env_logger 0.10.0",
|
||||
"futures",
|
||||
"governor",
|
||||
"hex",
|
||||
"humantime",
|
||||
"ipnet",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"libp2p",
|
||||
"log",
|
||||
"mock_instant",
|
||||
"nonzero_ext",
|
||||
"prometheus-client",
|
||||
"pyth-sdk",
|
||||
"pythnet-sdk",
|
||||
|
@ -3021,6 +3042,15 @@ dependencies = [
|
|||
"linked-hash-map",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mach2"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6d0d1830bcd151a6fc4aea1369af235b36c1528fe976b8ff678683c9995eade8"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "match_cfg"
|
||||
version = "0.1.0"
|
||||
|
@ -3317,6 +3347,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "no-std-compat"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
|
||||
|
||||
[[package]]
|
||||
name = "nohash-hasher"
|
||||
version = "0.2.0"
|
||||
|
@ -3333,6 +3369,12 @@ dependencies = [
|
|||
"minimal-lexical",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nonzero_ext"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21"
|
||||
|
||||
[[package]]
|
||||
name = "nu-ansi-term"
|
||||
version = "0.46.0"
|
||||
|
@ -4052,6 +4094,22 @@ dependencies = [
|
|||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quanta"
|
||||
version = "0.11.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a17e662a7a8291a865152364c20c7abc5e60486ab2001e8ec10b24862de0b9ab"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
"libc",
|
||||
"mach2",
|
||||
"once_cell",
|
||||
"raw-cpuid",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quick-error"
|
||||
version = "1.2.3"
|
||||
|
@ -4220,6 +4278,15 @@ dependencies = [
|
|||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "raw-cpuid"
|
||||
version = "10.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.7.0"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "hermes"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
|
||||
edition = "2021"
|
||||
|
||||
|
@ -19,10 +19,13 @@ env_logger = { version = "0.10.0" }
|
|||
futures = { version = "0.3.28" }
|
||||
hex = { version = "0.4.3", features = ["serde"] }
|
||||
humantime = { version = "2.1.0" }
|
||||
ipnet = { version = "2.8.0" }
|
||||
governor = { version = "0.6.0" }
|
||||
lazy_static = { version = "1.4.0" }
|
||||
libc = { version = "0.2.140" }
|
||||
log = { version = "0.4.17" }
|
||||
mock_instant = { version = "0.3.1", features = ["sync"] }
|
||||
nonzero_ext = { version = "0.3.0" }
|
||||
prometheus-client = { version = "0.21.1" }
|
||||
pyth-sdk = { version = "0.8.0" }
|
||||
pythnet-sdk = { path = "../pythnet/pythnet_sdk/", version = "2.0.0", features = ["strum"] }
|
||||
|
|
|
@ -11,10 +11,14 @@ use {
|
|||
routing::get,
|
||||
Router,
|
||||
},
|
||||
ipnet::IpNet,
|
||||
serde_qs::axum::QsQueryConfig,
|
||||
std::sync::{
|
||||
atomic::Ordering,
|
||||
Arc,
|
||||
std::{
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::Ordering,
|
||||
Arc,
|
||||
},
|
||||
},
|
||||
tokio::{
|
||||
signal,
|
||||
|
@ -36,10 +40,10 @@ pub struct ApiState {
|
|||
}
|
||||
|
||||
impl ApiState {
|
||||
pub fn new(state: Arc<State>) -> Self {
|
||||
pub fn new(state: Arc<State>, ws_whitelist: Vec<IpNet>) -> Self {
|
||||
Self {
|
||||
state,
|
||||
ws: Arc::new(ws::WsState::new()),
|
||||
ws: Arc::new(ws::WsState::new(ws_whitelist)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,7 +88,7 @@ pub async fn run(
|
|||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
let state = ApiState::new(state);
|
||||
let state = ApiState::new(state, opts.rpc.ws_whitelist);
|
||||
|
||||
// Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
|
||||
// `with_state` method which replaces `Body` with `State` in the type signature.
|
||||
|
@ -131,7 +135,7 @@ pub async fn run(
|
|||
// Binds the axum's server to the configured address and port. This is a blocking call and will
|
||||
// not return until the server is shutdown.
|
||||
axum::Server::try_bind(&opts.rpc.addr)?
|
||||
.serve(app.into_make_service())
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.with_graceful_shutdown(async {
|
||||
// Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler
|
||||
// should also have triggered so we will let that one print the shutdown warning.
|
||||
|
|
|
@ -21,6 +21,7 @@ use {
|
|||
WebSocket,
|
||||
WebSocketUpgrade,
|
||||
},
|
||||
ConnectInfo,
|
||||
State as AxumState,
|
||||
},
|
||||
response::IntoResponse,
|
||||
|
@ -35,6 +36,13 @@ use {
|
|||
SinkExt,
|
||||
StreamExt,
|
||||
},
|
||||
governor::{
|
||||
DefaultKeyedRateLimiter,
|
||||
Quota,
|
||||
RateLimiter,
|
||||
},
|
||||
ipnet::IpNet,
|
||||
nonzero_ext::nonzero,
|
||||
pyth_sdk::PriceIdentifier,
|
||||
serde::{
|
||||
Deserialize,
|
||||
|
@ -42,6 +50,11 @@ use {
|
|||
},
|
||||
std::{
|
||||
collections::HashMap,
|
||||
net::{
|
||||
IpAddr,
|
||||
SocketAddr,
|
||||
},
|
||||
num::NonZeroU32,
|
||||
sync::{
|
||||
atomic::{
|
||||
AtomicUsize,
|
||||
|
@ -54,8 +67,13 @@ use {
|
|||
tokio::sync::mpsc,
|
||||
};
|
||||
|
||||
pub const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
|
||||
pub const NOTIFICATIONS_CHAN_LEN: usize = 1000;
|
||||
const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
|
||||
const NOTIFICATIONS_CHAN_LEN: usize = 1000;
|
||||
const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
|
||||
|
||||
/// The maximum number of bytes that can be sent per second per IP address.
|
||||
/// If the limit is exceeded, the connection is closed.
|
||||
const BYTES_LIMIT_PER_IP_PER_SECOND: u32 = 256 * 1024; // 256 KiB
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PriceFeedClientConfig {
|
||||
|
@ -65,15 +83,21 @@ pub struct PriceFeedClientConfig {
|
|||
}
|
||||
|
||||
pub struct WsState {
|
||||
pub subscriber_counter: AtomicUsize,
|
||||
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
|
||||
pub subscriber_counter: AtomicUsize,
|
||||
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
|
||||
pub bytes_limit_whitelist: Vec<IpNet>,
|
||||
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
|
||||
}
|
||||
|
||||
impl WsState {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(whitelist: Vec<IpNet>) -> Self {
|
||||
Self {
|
||||
subscriber_counter: AtomicUsize::new(0),
|
||||
subscribers: DashMap::new(),
|
||||
subscriber_counter: AtomicUsize::new(0),
|
||||
subscribers: DashMap::new(),
|
||||
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
|
||||
BYTES_LIMIT_PER_IP_PER_SECOND
|
||||
))),
|
||||
bytes_limit_whitelist: whitelist,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -118,20 +142,29 @@ enum ServerResponseMessage {
|
|||
pub async fn ws_route_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
AxumState(state): AxumState<super::ApiState>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
) -> impl IntoResponse {
|
||||
ws.on_upgrade(|socket| websocket_handler(socket, state))
|
||||
ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
|
||||
.on_upgrade(move |socket| websocket_handler(socket, state, addr))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(stream, state))]
|
||||
async fn websocket_handler(stream: WebSocket, state: super::ApiState) {
|
||||
#[tracing::instrument(skip(stream, state, addr))]
|
||||
async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
|
||||
let ws_state = state.ws.clone();
|
||||
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
|
||||
tracing::debug!(id, "New Websocket Connection");
|
||||
tracing::debug!(id, %addr, "New Websocket Connection");
|
||||
|
||||
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
|
||||
let (sender, receiver) = stream.split();
|
||||
let mut subscriber =
|
||||
Subscriber::new(id, state.state.clone(), notify_receiver, receiver, sender);
|
||||
let mut subscriber = Subscriber::new(
|
||||
id,
|
||||
addr.ip(),
|
||||
state.state.clone(),
|
||||
state.ws.clone(),
|
||||
notify_receiver,
|
||||
receiver,
|
||||
sender,
|
||||
);
|
||||
|
||||
ws_state.subscribers.insert(id, notify_sender);
|
||||
subscriber.run().await;
|
||||
|
@ -143,8 +176,10 @@ pub type SubscriberId = usize;
|
|||
/// It listens to the store for updates and sends them to the client.
|
||||
pub struct Subscriber {
|
||||
id: SubscriberId,
|
||||
ip_addr: IpAddr,
|
||||
closed: bool,
|
||||
store: Arc<State>,
|
||||
ws_state: Arc<WsState>,
|
||||
notify_receiver: mpsc::Receiver<AggregationEvent>,
|
||||
receiver: SplitStream<WebSocket>,
|
||||
sender: SplitSink<WebSocket, Message>,
|
||||
|
@ -156,15 +191,19 @@ pub struct Subscriber {
|
|||
impl Subscriber {
|
||||
pub fn new(
|
||||
id: SubscriberId,
|
||||
ip_addr: IpAddr,
|
||||
store: Arc<State>,
|
||||
ws_state: Arc<WsState>,
|
||||
notify_receiver: mpsc::Receiver<AggregationEvent>,
|
||||
receiver: SplitStream<WebSocket>,
|
||||
sender: SplitSink<WebSocket, Message>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
ip_addr,
|
||||
closed: false,
|
||||
store,
|
||||
ws_state,
|
||||
notify_receiver,
|
||||
receiver,
|
||||
sender,
|
||||
|
@ -243,19 +282,45 @@ impl Subscriber {
|
|||
}
|
||||
}
|
||||
|
||||
let message = serde_json::to_string(&ServerMessage::PriceUpdate {
|
||||
price_feed: RpcPriceFeed::from_price_feed_update(
|
||||
update,
|
||||
config.verbose,
|
||||
config.binary,
|
||||
),
|
||||
})?;
|
||||
|
||||
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
|
||||
if !self
|
||||
.ws_state
|
||||
.bytes_limit_whitelist
|
||||
.contains(&self.ip_addr.into())
|
||||
&& self.ws_state.rate_limiter.check_key_n(
|
||||
&self.ip_addr,
|
||||
NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
|
||||
) != Ok(Ok(()))
|
||||
{
|
||||
tracing::info!(
|
||||
self.id,
|
||||
ip = %self.ip_addr,
|
||||
"Rate limit exceeded. Closing connection.",
|
||||
);
|
||||
self.sender
|
||||
.send(
|
||||
serde_json::to_string(&ServerResponseMessage::Err {
|
||||
error: "Rate limit exceeded".to_string(),
|
||||
})?
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
self.sender.close().await?;
|
||||
self.closed = true;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// `sender.feed` buffers a message to the client but does not flush it, so we can send
|
||||
// multiple messages and flush them all at once.
|
||||
self.sender
|
||||
.feed(Message::Text(serde_json::to_string(
|
||||
&ServerMessage::PriceUpdate {
|
||||
price_feed: RpcPriceFeed::from_price_feed_update(
|
||||
update,
|
||||
config.verbose,
|
||||
config.binary,
|
||||
),
|
||||
},
|
||||
)?))
|
||||
.await?;
|
||||
self.sender.feed(message.into()).await?;
|
||||
}
|
||||
|
||||
self.sender.flush().await?;
|
||||
|
@ -394,4 +459,7 @@ pub async fn notify_updates(ws_state: Arc<WsState>, event: AggregationEvent) {
|
|||
ws_state.subscribers.remove(&id);
|
||||
}
|
||||
});
|
||||
|
||||
// Clean the bytes limiting dictionary
|
||||
ws_state.rate_limiter.retain_recent();
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use {
|
||||
clap::Args,
|
||||
ipnet::IpNet,
|
||||
std::net::SocketAddr,
|
||||
};
|
||||
|
||||
|
@ -14,4 +15,10 @@ pub struct Options {
|
|||
#[arg(default_value = DEFAULT_RPC_ADDR)]
|
||||
#[arg(env = "RPC_ADDR")]
|
||||
pub addr: SocketAddr,
|
||||
|
||||
/// Whitelisted websocket ip network addresses (separated by comma).
|
||||
#[arg(long = "rpc-ws-whitelist")]
|
||||
#[arg(value_delimiter = ',')]
|
||||
#[arg(env = "RPC_WS_WHITELIST")]
|
||||
pub ws_whitelist: Vec<IpNet>,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue