diff --git a/hermes/Cargo.lock b/hermes/Cargo.lock index 1b2d6517..93e0ca1f 100644 --- a/hermes/Cargo.lock +++ b/hermes/Cargo.lock @@ -481,7 +481,7 @@ dependencies = [ "async-trait", "axum-core", "axum-macros", - "base64 0.21.2", + "base64 0.21.4", "bitflags 1.3.2", "bytes", "futures-util", @@ -566,9 +566,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "base64ct" @@ -1898,13 +1898,13 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermes" -version = "0.3.2" +version = "0.3.3" dependencies = [ "anyhow", "async-trait", "axum", "axum-macros", - "base64 0.21.2", + "base64 0.21.4", "borsh 0.10.3", "byteorder", "chrono", @@ -4503,7 +4503,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" dependencies = [ "async-compression", - "base64 0.21.2", + "base64 0.21.4", "bytes", "encoding_rs", "futures-core", @@ -4744,7 +4744,7 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ - "base64 0.21.2", + "base64 0.21.4", ] [[package]] @@ -6277,7 +6277,7 @@ dependencies = [ "async-stream", "async-trait", "axum", - "base64 0.21.2", + "base64 0.21.4", "bytes", "h2", "http", diff --git a/hermes/Cargo.toml b/hermes/Cargo.toml index ff30faf4..5e7a7e4a 100644 --- a/hermes/Cargo.toml +++ b/hermes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "hermes" -version = "0.3.2" +version = "0.3.3" description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle." edition = "2021" diff --git a/hermes/src/api.rs b/hermes/src/api.rs index 380546e5..abab1796 100644 --- a/hermes/src/api.rs +++ b/hermes/src/api.rs @@ -13,12 +13,9 @@ use { }, ipnet::IpNet, serde_qs::axum::QsQueryConfig, - std::{ - net::SocketAddr, - sync::{ - atomic::Ordering, - Arc, - }, + std::sync::{ + atomic::Ordering, + Arc, }, tokio::{ signal, @@ -40,10 +37,14 @@ pub struct ApiState { } impl ApiState { - pub fn new(state: Arc, ws_whitelist: Vec) -> Self { + pub fn new( + state: Arc, + ws_whitelist: Vec, + requester_ip_header_name: String, + ) -> Self { Self { state, - ws: Arc::new(ws::WsState::new(ws_whitelist)), + ws: Arc::new(ws::WsState::new(ws_whitelist, requester_ip_header_name)), } } } @@ -88,7 +89,11 @@ pub async fn run( )] struct ApiDoc; - let state = ApiState::new(state, opts.rpc.ws_whitelist); + let state = ApiState::new( + state, + opts.rpc.ws_whitelist, + opts.rpc.requester_ip_header_name, + ); // Initialize Axum Router. Note the type here is a `Router` due to the use of the // `with_state` method which replaces `Body` with `State` in the type signature. @@ -135,7 +140,7 @@ pub async fn run( // Binds the axum's server to the configured address and port. This is a blocking call and will // not return until the server is shutdown. axum::Server::try_bind(&opts.rpc.addr)? - .serve(app.into_make_service_with_connect_info::()) + .serve(app.into_make_service()) .with_graceful_shutdown(async { // Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler // should also have triggered so we will let that one print the shutdown warning. diff --git a/hermes/src/api/ws.rs b/hermes/src/api/ws.rs index f903c30f..2326be4c 100644 --- a/hermes/src/api/ws.rs +++ b/hermes/src/api/ws.rs @@ -21,9 +21,9 @@ use { WebSocket, WebSocketUpgrade, }, - ConnectInfo, State as AxumState, }, + http::HeaderMap, response::IntoResponse, }, dashmap::DashMap, @@ -50,10 +50,7 @@ use { }, std::{ collections::HashMap, - net::{ - IpAddr, - SocketAddr, - }, + net::IpAddr, num::NonZeroU32, sync::{ atomic::{ @@ -83,21 +80,23 @@ pub struct PriceFeedClientConfig { } pub struct WsState { - pub subscriber_counter: AtomicUsize, - pub subscribers: DashMap>, - pub bytes_limit_whitelist: Vec, - pub rate_limiter: DefaultKeyedRateLimiter, + pub subscriber_counter: AtomicUsize, + pub subscribers: DashMap>, + pub bytes_limit_whitelist: Vec, + pub rate_limiter: DefaultKeyedRateLimiter, + pub requester_ip_header_name: String, } impl WsState { - pub fn new(whitelist: Vec) -> Self { + pub fn new(whitelist: Vec, requester_ip_header_name: String) -> Self { Self { - subscriber_counter: AtomicUsize::new(0), - subscribers: DashMap::new(), - rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!( + subscriber_counter: AtomicUsize::new(0), + subscribers: DashMap::new(), + rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!( BYTES_LIMIT_PER_IP_PER_SECOND ))), bytes_limit_whitelist: whitelist, + requester_ip_header_name, } } } @@ -142,23 +141,33 @@ enum ServerResponseMessage { pub async fn ws_route_handler( ws: WebSocketUpgrade, AxumState(state): AxumState, - ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, ) -> impl IntoResponse { + let requester_ip = headers + .get(state.ws.requester_ip_header_name.as_str()) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple + .and_then(|value| value.parse().ok()); + ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE) - .on_upgrade(move |socket| websocket_handler(socket, state, addr)) + .on_upgrade(move |socket| websocket_handler(socket, state, requester_ip)) } -#[tracing::instrument(skip(stream, state, addr))] -async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) { +#[tracing::instrument(skip(stream, state, subscriber_ip))] +async fn websocket_handler( + stream: WebSocket, + state: super::ApiState, + subscriber_ip: Option, +) { let ws_state = state.ws.clone(); let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst); - tracing::debug!(id, %addr, "New Websocket Connection"); + tracing::debug!(id, ?subscriber_ip, "New Websocket Connection"); let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN); let (sender, receiver) = stream.split(); let mut subscriber = Subscriber::new( id, - addr.ip(), + subscriber_ip, state.state.clone(), state.ws.clone(), notify_receiver, @@ -176,7 +185,7 @@ pub type SubscriberId = usize; /// It listens to the store for updates and sends them to the client. pub struct Subscriber { id: SubscriberId, - ip_addr: IpAddr, + ip_addr: Option, closed: bool, store: Arc, ws_state: Arc, @@ -191,7 +200,7 @@ pub struct Subscriber { impl Subscriber { pub fn new( id: SubscriberId, - ip_addr: IpAddr, + ip_addr: Option, store: Arc, ws_state: Arc, notify_receiver: mpsc::Receiver, @@ -291,32 +300,36 @@ impl Subscriber { })?; // Close the connection if rate limit is exceeded and the ip is not whitelisted. - if !self - .ws_state - .bytes_limit_whitelist - .iter() - .any(|ip_net| ip_net.contains(&self.ip_addr)) - && self.ws_state.rate_limiter.check_key_n( - &self.ip_addr, - NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?, - ) != Ok(Ok(())) - { - tracing::info!( - self.id, - ip = %self.ip_addr, - "Rate limit exceeded. Closing connection.", - ); - self.sender - .send( - serde_json::to_string(&ServerResponseMessage::Err { - error: "Rate limit exceeded".to_string(), - })? - .into(), - ) - .await?; - self.sender.close().await?; - self.closed = true; - return Ok(()); + // If the ip address is None no rate limiting is applied. + if let Some(ip_addr) = self.ip_addr { + if !self + .ws_state + .bytes_limit_whitelist + .iter() + .any(|ip_net| ip_net.contains(&ip_addr)) + && self.ws_state.rate_limiter.check_key_n( + &ip_addr, + NonZeroU32::new(message.len().try_into()?) + .ok_or(anyhow!("Empty message"))?, + ) != Ok(Ok(())) + { + tracing::info!( + self.id, + ip = %ip_addr, + "Rate limit exceeded. Closing connection.", + ); + self.sender + .send( + serde_json::to_string(&ServerResponseMessage::Err { + error: "Rate limit exceeded".to_string(), + })? + .into(), + ) + .await?; + self.sender.close().await?; + self.closed = true; + return Ok(()); + } } // `sender.feed` buffers a message to the client but does not flush it, so we can send diff --git a/hermes/src/config/rpc.rs b/hermes/src/config/rpc.rs index 7d03a3bb..8a23b85c 100644 --- a/hermes/src/config/rpc.rs +++ b/hermes/src/config/rpc.rs @@ -5,6 +5,7 @@ use { }; const DEFAULT_RPC_ADDR: &str = "127.0.0.1:33999"; +const DEFAULT_RPC_REQUESTER_IP_HEADER_NAME: &str = "X-Forwarded-For"; #[derive(Args, Clone, Debug)] #[command(next_help_heading = "RPC Options")] @@ -21,4 +22,10 @@ pub struct Options { #[arg(value_delimiter = ',')] #[arg(env = "RPC_WS_WHITELIST")] pub ws_whitelist: Vec, + + /// Header name (case insensitive) to fetch requester IP from. + #[arg(long = "rpc-requester-ip-header-name")] + #[arg(default_value = DEFAULT_RPC_REQUESTER_IP_HEADER_NAME)] + #[arg(env = "RPC_REQUESTER_IP_HEADER_NAME")] + pub requester_ip_header_name: String, }