feat(hermes): use ip from request headers for ratelimiting
This commit is contained in:
parent
b158f28c58
commit
d11216f309
|
@ -481,7 +481,7 @@ dependencies = [
|
|||
"async-trait",
|
||||
"axum-core",
|
||||
"axum-macros",
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.4",
|
||||
"bitflags 1.3.2",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
|
@ -566,9 +566,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
|||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.2"
|
||||
version = "0.21.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
|
||||
checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
|
@ -1898,13 +1898,13 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
|
|||
|
||||
[[package]]
|
||||
name = "hermes"
|
||||
version = "0.3.2"
|
||||
version = "0.3.3"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"axum-macros",
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.4",
|
||||
"borsh 0.10.3",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
|
@ -4503,7 +4503,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55"
|
||||
dependencies = [
|
||||
"async-compression",
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.4",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
|
@ -4744,7 +4744,7 @@ version = "1.0.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -6277,7 +6277,7 @@ dependencies = [
|
|||
"async-stream",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.4",
|
||||
"bytes",
|
||||
"h2",
|
||||
"http",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "hermes"
|
||||
version = "0.3.2"
|
||||
version = "0.3.3"
|
||||
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
|
||||
edition = "2021"
|
||||
|
||||
|
|
|
@ -13,12 +13,9 @@ use {
|
|||
},
|
||||
ipnet::IpNet,
|
||||
serde_qs::axum::QsQueryConfig,
|
||||
std::{
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::Ordering,
|
||||
Arc,
|
||||
},
|
||||
std::sync::{
|
||||
atomic::Ordering,
|
||||
Arc,
|
||||
},
|
||||
tokio::{
|
||||
signal,
|
||||
|
@ -40,10 +37,14 @@ pub struct ApiState {
|
|||
}
|
||||
|
||||
impl ApiState {
|
||||
pub fn new(state: Arc<State>, ws_whitelist: Vec<IpNet>) -> Self {
|
||||
pub fn new(
|
||||
state: Arc<State>,
|
||||
ws_whitelist: Vec<IpNet>,
|
||||
requester_ip_header_name: String,
|
||||
) -> Self {
|
||||
Self {
|
||||
state,
|
||||
ws: Arc::new(ws::WsState::new(ws_whitelist)),
|
||||
ws: Arc::new(ws::WsState::new(ws_whitelist, requester_ip_header_name)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -88,7 +89,11 @@ pub async fn run(
|
|||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
let state = ApiState::new(state, opts.rpc.ws_whitelist);
|
||||
let state = ApiState::new(
|
||||
state,
|
||||
opts.rpc.ws_whitelist,
|
||||
opts.rpc.requester_ip_header_name,
|
||||
);
|
||||
|
||||
// Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
|
||||
// `with_state` method which replaces `Body` with `State` in the type signature.
|
||||
|
@ -135,7 +140,7 @@ pub async fn run(
|
|||
// Binds the axum's server to the configured address and port. This is a blocking call and will
|
||||
// not return until the server is shutdown.
|
||||
axum::Server::try_bind(&opts.rpc.addr)?
|
||||
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
|
||||
.serve(app.into_make_service())
|
||||
.with_graceful_shutdown(async {
|
||||
// Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler
|
||||
// should also have triggered so we will let that one print the shutdown warning.
|
||||
|
|
|
@ -21,9 +21,9 @@ use {
|
|||
WebSocket,
|
||||
WebSocketUpgrade,
|
||||
},
|
||||
ConnectInfo,
|
||||
State as AxumState,
|
||||
},
|
||||
http::HeaderMap,
|
||||
response::IntoResponse,
|
||||
},
|
||||
dashmap::DashMap,
|
||||
|
@ -50,10 +50,7 @@ use {
|
|||
},
|
||||
std::{
|
||||
collections::HashMap,
|
||||
net::{
|
||||
IpAddr,
|
||||
SocketAddr,
|
||||
},
|
||||
net::IpAddr,
|
||||
num::NonZeroU32,
|
||||
sync::{
|
||||
atomic::{
|
||||
|
@ -83,21 +80,23 @@ pub struct PriceFeedClientConfig {
|
|||
}
|
||||
|
||||
pub struct WsState {
|
||||
pub subscriber_counter: AtomicUsize,
|
||||
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
|
||||
pub bytes_limit_whitelist: Vec<IpNet>,
|
||||
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
|
||||
pub subscriber_counter: AtomicUsize,
|
||||
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
|
||||
pub bytes_limit_whitelist: Vec<IpNet>,
|
||||
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
|
||||
pub requester_ip_header_name: String,
|
||||
}
|
||||
|
||||
impl WsState {
|
||||
pub fn new(whitelist: Vec<IpNet>) -> Self {
|
||||
pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self {
|
||||
Self {
|
||||
subscriber_counter: AtomicUsize::new(0),
|
||||
subscribers: DashMap::new(),
|
||||
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
|
||||
subscriber_counter: AtomicUsize::new(0),
|
||||
subscribers: DashMap::new(),
|
||||
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
|
||||
BYTES_LIMIT_PER_IP_PER_SECOND
|
||||
))),
|
||||
bytes_limit_whitelist: whitelist,
|
||||
requester_ip_header_name,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -142,23 +141,33 @@ enum ServerResponseMessage {
|
|||
pub async fn ws_route_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
AxumState(state): AxumState<super::ApiState>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
let requester_ip = headers
|
||||
.get(state.ws.requester_ip_header_name.as_str())
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple
|
||||
.and_then(|value| value.parse().ok());
|
||||
|
||||
ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
|
||||
.on_upgrade(move |socket| websocket_handler(socket, state, addr))
|
||||
.on_upgrade(move |socket| websocket_handler(socket, state, requester_ip))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(stream, state, addr))]
|
||||
async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
|
||||
#[tracing::instrument(skip(stream, state, subscriber_ip))]
|
||||
async fn websocket_handler(
|
||||
stream: WebSocket,
|
||||
state: super::ApiState,
|
||||
subscriber_ip: Option<IpAddr>,
|
||||
) {
|
||||
let ws_state = state.ws.clone();
|
||||
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
|
||||
tracing::debug!(id, %addr, "New Websocket Connection");
|
||||
tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
|
||||
|
||||
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
|
||||
let (sender, receiver) = stream.split();
|
||||
let mut subscriber = Subscriber::new(
|
||||
id,
|
||||
addr.ip(),
|
||||
subscriber_ip,
|
||||
state.state.clone(),
|
||||
state.ws.clone(),
|
||||
notify_receiver,
|
||||
|
@ -176,7 +185,7 @@ pub type SubscriberId = usize;
|
|||
/// It listens to the store for updates and sends them to the client.
|
||||
pub struct Subscriber {
|
||||
id: SubscriberId,
|
||||
ip_addr: IpAddr,
|
||||
ip_addr: Option<IpAddr>,
|
||||
closed: bool,
|
||||
store: Arc<State>,
|
||||
ws_state: Arc<WsState>,
|
||||
|
@ -191,7 +200,7 @@ pub struct Subscriber {
|
|||
impl Subscriber {
|
||||
pub fn new(
|
||||
id: SubscriberId,
|
||||
ip_addr: IpAddr,
|
||||
ip_addr: Option<IpAddr>,
|
||||
store: Arc<State>,
|
||||
ws_state: Arc<WsState>,
|
||||
notify_receiver: mpsc::Receiver<AggregationEvent>,
|
||||
|
@ -291,32 +300,36 @@ impl Subscriber {
|
|||
})?;
|
||||
|
||||
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
|
||||
if !self
|
||||
.ws_state
|
||||
.bytes_limit_whitelist
|
||||
.iter()
|
||||
.any(|ip_net| ip_net.contains(&self.ip_addr))
|
||||
&& self.ws_state.rate_limiter.check_key_n(
|
||||
&self.ip_addr,
|
||||
NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
|
||||
) != Ok(Ok(()))
|
||||
{
|
||||
tracing::info!(
|
||||
self.id,
|
||||
ip = %self.ip_addr,
|
||||
"Rate limit exceeded. Closing connection.",
|
||||
);
|
||||
self.sender
|
||||
.send(
|
||||
serde_json::to_string(&ServerResponseMessage::Err {
|
||||
error: "Rate limit exceeded".to_string(),
|
||||
})?
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
self.sender.close().await?;
|
||||
self.closed = true;
|
||||
return Ok(());
|
||||
// If the ip address is None no rate limiting is applied.
|
||||
if let Some(ip_addr) = self.ip_addr {
|
||||
if !self
|
||||
.ws_state
|
||||
.bytes_limit_whitelist
|
||||
.iter()
|
||||
.any(|ip_net| ip_net.contains(&ip_addr))
|
||||
&& self.ws_state.rate_limiter.check_key_n(
|
||||
&ip_addr,
|
||||
NonZeroU32::new(message.len().try_into()?)
|
||||
.ok_or(anyhow!("Empty message"))?,
|
||||
) != Ok(Ok(()))
|
||||
{
|
||||
tracing::info!(
|
||||
self.id,
|
||||
ip = %ip_addr,
|
||||
"Rate limit exceeded. Closing connection.",
|
||||
);
|
||||
self.sender
|
||||
.send(
|
||||
serde_json::to_string(&ServerResponseMessage::Err {
|
||||
error: "Rate limit exceeded".to_string(),
|
||||
})?
|
||||
.into(),
|
||||
)
|
||||
.await?;
|
||||
self.sender.close().await?;
|
||||
self.closed = true;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// `sender.feed` buffers a message to the client but does not flush it, so we can send
|
||||
|
|
|
@ -5,6 +5,7 @@ use {
|
|||
};
|
||||
|
||||
const DEFAULT_RPC_ADDR: &str = "127.0.0.1:33999";
|
||||
const DEFAULT_RPC_REQUESTER_IP_HEADER_NAME: &str = "X-Forwarded-For";
|
||||
|
||||
#[derive(Args, Clone, Debug)]
|
||||
#[command(next_help_heading = "RPC Options")]
|
||||
|
@ -21,4 +22,10 @@ pub struct Options {
|
|||
#[arg(value_delimiter = ',')]
|
||||
#[arg(env = "RPC_WS_WHITELIST")]
|
||||
pub ws_whitelist: Vec<IpNet>,
|
||||
|
||||
/// Header name (case insensitive) to fetch requester IP from.
|
||||
#[arg(long = "rpc-requester-ip-header-name")]
|
||||
#[arg(default_value = DEFAULT_RPC_REQUESTER_IP_HEADER_NAME)]
|
||||
#[arg(env = "RPC_REQUESTER_IP_HEADER_NAME")]
|
||||
pub requester_ip_header_name: String,
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue