From 946536f42317c64b821bbaa03656d7a9e3311399 Mon Sep 17 00:00:00 2001 From: GroovieGermanikus Date: Tue, 8 Aug 2023 13:13:27 +0200 Subject: [PATCH] use TryFrom+TryInto for (de-)serialization if proxy_request_format --- core/src/proxy_request_format.rs | 38 ++++++++++++------- .../src/inbound/proxy_listener.rs | 2 +- .../src/proxy_request_format.rs | 37 ++++++++++-------- .../tests/proxy_request_format.rs | 14 +++++-- 4 files changed, 58 insertions(+), 33 deletions(-) diff --git a/core/src/proxy_request_format.rs b/core/src/proxy_request_format.rs index 8970becd..70f51c7b 100644 --- a/core/src/proxy_request_format.rs +++ b/core/src/proxy_request_format.rs @@ -46,20 +46,6 @@ impl TpuForwardingRequest { } } - pub fn serialize_wire_format(&self) -> Vec { - bincode::serialize(&self).expect("Expect to serialize transactions") - } - - pub fn deserialize_from_raw_request(raw_proxy_request: &[u8]) -> TpuForwardingRequest { - let request = bincode::deserialize::(raw_proxy_request) - .context("deserialize proxy request") - .unwrap(); - - assert_eq!(request.format_version, 2301); - - request - } - pub fn get_tpu_socket_addr(&self) -> SocketAddr { self.tpu_socket_addr } @@ -72,3 +58,27 @@ impl TpuForwardingRequest { self.transactions.clone() } } + +impl TryInto> for TpuForwardingRequest { + type Error = anyhow::Error; + + fn try_into(self) -> Result, Self::Error> { + bincode::serialize(&self) + .context("serialize proxy request") + .map_err(anyhow::Error::from) + } +} + +impl TryFrom<&[u8]> for TpuForwardingRequest { + type Error = anyhow::Error; + + fn try_from(value: &[u8]) -> Result { + let request = bincode::deserialize::(value) + .context("deserialize proxy request") + .map_err(anyhow::Error::from); + if let Ok(ref req) = request { + assert_eq!(req.format_version, FORMAT_VERSION1); + } + request + } +} diff --git a/quic-forward-proxy/src/inbound/proxy_listener.rs b/quic-forward-proxy/src/inbound/proxy_listener.rs index e458d7cd..e5966478 100644 --- a/quic-forward-proxy/src/inbound/proxy_listener.rs +++ b/quic-forward-proxy/src/inbound/proxy_listener.rs @@ -133,7 +133,7 @@ impl ProxyListener { let raw_request = recv_stream.read_to_end(10_000_000).await.unwrap(); let proxy_request = - TpuForwardingRequest::deserialize_from_raw_request(&raw_request); + TpuForwardingRequest::try_from(raw_request.as_slice()).unwrap(); trace!("proxy request details: {}", proxy_request); let _tpu_identity = proxy_request.get_identity_tpunode(); diff --git a/quic-forward-proxy/src/proxy_request_format.rs b/quic-forward-proxy/src/proxy_request_format.rs index 9b8b5761..ec4dfcd1 100644 --- a/quic-forward-proxy/src/proxy_request_format.rs +++ b/quic-forward-proxy/src/proxy_request_format.rs @@ -47,21 +47,6 @@ impl TpuForwardingRequest { } } - pub fn serialize_wire_format(&self) -> Vec { - bincode::serialize(&self).expect("Expect to serialize transactions") - } - - // TODO reame - pub fn deserialize_from_raw_request(raw_proxy_request: &[u8]) -> TpuForwardingRequest { - let request = bincode::deserialize::(raw_proxy_request) - .context("deserialize proxy request") - .unwrap(); - - assert_eq!(request.format_version, FORMAT_VERSION1); - - request - } - pub fn get_tpu_socket_addr(&self) -> SocketAddr { self.tpu_socket_addr } @@ -74,3 +59,25 @@ impl TpuForwardingRequest { self.transactions.clone() } } + +impl TryInto> for TpuForwardingRequest { + type Error = anyhow::Error; + + fn try_into(self) -> Result, Self::Error> { + bincode::serialize(&self).map_err(anyhow::Error::from) + } +} + +impl TryFrom<&[u8]> for TpuForwardingRequest { + type Error = anyhow::Error; + + fn try_from(value: &[u8]) -> Result { + let request = bincode::deserialize::(value) + .context("deserialize proxy request") + .map_err(anyhow::Error::from); + if let Ok(ref req) = request { + assert_eq!(req.format_version, FORMAT_VERSION1); + } + request + } +} diff --git a/quic-forward-proxy/tests/proxy_request_format.rs b/quic-forward-proxy/tests/proxy_request_format.rs index bb752be8..2256a6f5 100644 --- a/quic-forward-proxy/tests/proxy_request_format.rs +++ b/quic-forward-proxy/tests/proxy_request_format.rs @@ -15,17 +15,25 @@ fn roundtrip() { let tx = Transaction::new_with_payer(&[memo_ix], Some(&payer_pubkey)); - let wire_data = TpuForwardingRequest::new( + let wire_data: Vec = TpuForwardingRequest::new( "127.0.0.1:5454".parse().unwrap(), Pubkey::from_str("Bm8rtweCQ19ksNebrLY92H7x4bCaeDJSSmEeWqkdCeop").unwrap(), vec![tx.into()], ) - .serialize_wire_format(); + .try_into() + .unwrap(); println!("wire_data: {:02X?}", wire_data); - let request = TpuForwardingRequest::deserialize_from_raw_request(&wire_data); + let request = TpuForwardingRequest::try_from(wire_data.as_slice()).unwrap(); assert!(request.get_tpu_socket_addr().is_ipv4()); assert_eq!(request.get_transactions().len(), 1); } + +#[test] +fn deserialize_error() { + let value: &[u8] = &[1, 2, 3, 4]; + let result = TpuForwardingRequest::try_from(value); + assert_eq!(result.unwrap_err().to_string(), "deserialize proxy request"); +}