feat(hermes): use ip from request headers for ratelimiting

This commit is contained in:
Ali Behjati 2023-10-12 18:59:06 +02:00
parent b158f28c58
commit d11216f309
5 changed files with 91 additions and 66 deletions

16
hermes/Cargo.lock generated
View File

@ -481,7 +481,7 @@ dependencies = [
"async-trait",
"axum-core",
"axum-macros",
"base64 0.21.2",
"base64 0.21.4",
"bitflags 1.3.2",
"bytes",
"futures-util",
@ -566,9 +566,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.21.2"
version = "0.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2"
[[package]]
name = "base64ct"
@ -1898,13 +1898,13 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermes"
version = "0.3.2"
version = "0.3.3"
dependencies = [
"anyhow",
"async-trait",
"axum",
"axum-macros",
"base64 0.21.2",
"base64 0.21.4",
"borsh 0.10.3",
"byteorder",
"chrono",
@ -4503,7 +4503,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55"
dependencies = [
"async-compression",
"base64 0.21.2",
"base64 0.21.4",
"bytes",
"encoding_rs",
"futures-core",
@ -4744,7 +4744,7 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
dependencies = [
"base64 0.21.2",
"base64 0.21.4",
]
[[package]]
@ -6277,7 +6277,7 @@ dependencies = [
"async-stream",
"async-trait",
"axum",
"base64 0.21.2",
"base64 0.21.4",
"bytes",
"h2",
"http",

View File

@ -1,6 +1,6 @@
[package]
name = "hermes"
version = "0.3.2"
version = "0.3.3"
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
edition = "2021"

View File

@ -13,12 +13,9 @@ use {
},
ipnet::IpNet,
serde_qs::axum::QsQueryConfig,
std::{
net::SocketAddr,
sync::{
atomic::Ordering,
Arc,
},
std::sync::{
atomic::Ordering,
Arc,
},
tokio::{
signal,
@ -40,10 +37,14 @@ pub struct ApiState {
}
impl ApiState {
pub fn new(state: Arc<State>, ws_whitelist: Vec<IpNet>) -> Self {
pub fn new(
state: Arc<State>,
ws_whitelist: Vec<IpNet>,
requester_ip_header_name: String,
) -> Self {
Self {
state,
ws: Arc::new(ws::WsState::new(ws_whitelist)),
ws: Arc::new(ws::WsState::new(ws_whitelist, requester_ip_header_name)),
}
}
}
@ -88,7 +89,11 @@ pub async fn run(
)]
struct ApiDoc;
let state = ApiState::new(state, opts.rpc.ws_whitelist);
let state = ApiState::new(
state,
opts.rpc.ws_whitelist,
opts.rpc.requester_ip_header_name,
);
// Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
// `with_state` method which replaces `Body` with `State` in the type signature.
@ -135,7 +140,7 @@ pub async fn run(
// Binds the axum's server to the configured address and port. This is a blocking call and will
// not return until the server is shutdown.
axum::Server::try_bind(&opts.rpc.addr)?
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
.serve(app.into_make_service())
.with_graceful_shutdown(async {
// Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler
// should also have triggered so we will let that one print the shutdown warning.

View File

@ -21,9 +21,9 @@ use {
WebSocket,
WebSocketUpgrade,
},
ConnectInfo,
State as AxumState,
},
http::HeaderMap,
response::IntoResponse,
},
dashmap::DashMap,
@ -50,10 +50,7 @@ use {
},
std::{
collections::HashMap,
net::{
IpAddr,
SocketAddr,
},
net::IpAddr,
num::NonZeroU32,
sync::{
atomic::{
@ -83,21 +80,23 @@ pub struct PriceFeedClientConfig {
}
pub struct WsState {
pub subscriber_counter: AtomicUsize,
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
pub bytes_limit_whitelist: Vec<IpNet>,
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
pub subscriber_counter: AtomicUsize,
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
pub bytes_limit_whitelist: Vec<IpNet>,
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
pub requester_ip_header_name: String,
}
impl WsState {
pub fn new(whitelist: Vec<IpNet>) -> Self {
pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self {
Self {
subscriber_counter: AtomicUsize::new(0),
subscribers: DashMap::new(),
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
subscriber_counter: AtomicUsize::new(0),
subscribers: DashMap::new(),
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
BYTES_LIMIT_PER_IP_PER_SECOND
))),
bytes_limit_whitelist: whitelist,
requester_ip_header_name,
}
}
}
@ -142,23 +141,33 @@ enum ServerResponseMessage {
pub async fn ws_route_handler(
ws: WebSocketUpgrade,
AxumState(state): AxumState<super::ApiState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
headers: HeaderMap,
) -> impl IntoResponse {
let requester_ip = headers
.get(state.ws.requester_ip_header_name.as_str())
.and_then(|value| value.to_str().ok())
.and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple
.and_then(|value| value.parse().ok());
ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
.on_upgrade(move |socket| websocket_handler(socket, state, addr))
.on_upgrade(move |socket| websocket_handler(socket, state, requester_ip))
}
#[tracing::instrument(skip(stream, state, addr))]
async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
#[tracing::instrument(skip(stream, state, subscriber_ip))]
async fn websocket_handler(
stream: WebSocket,
state: super::ApiState,
subscriber_ip: Option<IpAddr>,
) {
let ws_state = state.ws.clone();
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
tracing::debug!(id, %addr, "New Websocket Connection");
tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
let (sender, receiver) = stream.split();
let mut subscriber = Subscriber::new(
id,
addr.ip(),
subscriber_ip,
state.state.clone(),
state.ws.clone(),
notify_receiver,
@ -176,7 +185,7 @@ pub type SubscriberId = usize;
/// It listens to the store for updates and sends them to the client.
pub struct Subscriber {
id: SubscriberId,
ip_addr: IpAddr,
ip_addr: Option<IpAddr>,
closed: bool,
store: Arc<State>,
ws_state: Arc<WsState>,
@ -191,7 +200,7 @@ pub struct Subscriber {
impl Subscriber {
pub fn new(
id: SubscriberId,
ip_addr: IpAddr,
ip_addr: Option<IpAddr>,
store: Arc<State>,
ws_state: Arc<WsState>,
notify_receiver: mpsc::Receiver<AggregationEvent>,
@ -291,32 +300,36 @@ impl Subscriber {
})?;
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
if !self
.ws_state
.bytes_limit_whitelist
.iter()
.any(|ip_net| ip_net.contains(&self.ip_addr))
&& self.ws_state.rate_limiter.check_key_n(
&self.ip_addr,
NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
) != Ok(Ok(()))
{
tracing::info!(
self.id,
ip = %self.ip_addr,
"Rate limit exceeded. Closing connection.",
);
self.sender
.send(
serde_json::to_string(&ServerResponseMessage::Err {
error: "Rate limit exceeded".to_string(),
})?
.into(),
)
.await?;
self.sender.close().await?;
self.closed = true;
return Ok(());
// If the ip address is None no rate limiting is applied.
if let Some(ip_addr) = self.ip_addr {
if !self
.ws_state
.bytes_limit_whitelist
.iter()
.any(|ip_net| ip_net.contains(&ip_addr))
&& self.ws_state.rate_limiter.check_key_n(
&ip_addr,
NonZeroU32::new(message.len().try_into()?)
.ok_or(anyhow!("Empty message"))?,
) != Ok(Ok(()))
{
tracing::info!(
self.id,
ip = %ip_addr,
"Rate limit exceeded. Closing connection.",
);
self.sender
.send(
serde_json::to_string(&ServerResponseMessage::Err {
error: "Rate limit exceeded".to_string(),
})?
.into(),
)
.await?;
self.sender.close().await?;
self.closed = true;
return Ok(());
}
}
// `sender.feed` buffers a message to the client but does not flush it, so we can send

View File

@ -5,6 +5,7 @@ use {
};
const DEFAULT_RPC_ADDR: &str = "127.0.0.1:33999";
const DEFAULT_RPC_REQUESTER_IP_HEADER_NAME: &str = "X-Forwarded-For";
#[derive(Args, Clone, Debug)]
#[command(next_help_heading = "RPC Options")]
@ -21,4 +22,10 @@ pub struct Options {
#[arg(value_delimiter = ',')]
#[arg(env = "RPC_WS_WHITELIST")]
pub ws_whitelist: Vec<IpNet>,
/// Header name (case insensitive) to fetch requester IP from.
#[arg(long = "rpc-requester-ip-header-name")]
#[arg(default_value = DEFAULT_RPC_REQUESTER_IP_HEADER_NAME)]
#[arg(env = "RPC_REQUESTER_IP_HEADER_NAME")]
pub requester_ip_header_name: String,
}