diff --git a/hermes/Cargo.toml b/hermes/Cargo.toml index de3c045268..2b97d06280 100644 --- a/hermes/Cargo.toml +++ b/hermes/Cargo.toml @@ -17,7 +17,7 @@ dashmap = { version = "5.4.0" } derive_more = { version = "0.99.17" } env_logger = { version = "0.10.0" } futures = { version = "0.3.28" } -hex = { version = "0.4.3" } +hex = { version = "0.4.3", features = ["serde"] } humantime = { version = "2.1.0" } lazy_static = { version = "1.4.0" } libc = { version = "0.2.140" } diff --git a/hermes/src/api/rest/get_vaa_ccip.rs b/hermes/src/api/rest/get_vaa_ccip.rs index c727972044..88c9550770 100644 --- a/hermes/src/api/rest/get_vaa_ccip.rs +++ b/hermes/src/api/rest/get_vaa_ccip.rs @@ -5,7 +5,6 @@ use { UnixTimestamp, }, api::rest::RestError, - impl_deserialize_for_hex_string_wrapper, }, anyhow::Result, axum::{ @@ -17,6 +16,10 @@ use { DerefMut, }, pyth_sdk::PriceIdentifier, + serde::{ + Deserialize, + Serialize, + }, serde_qs::axum::QsQuery, utoipa::{ IntoParams, @@ -24,9 +27,8 @@ use { }, }; -#[derive(Debug, Clone, Deref, DerefMut, ToSchema)] -pub struct GetVaaCcipInput([u8; 40]); -impl_deserialize_for_hex_string_wrapper!(GetVaaCcipInput, 40); +#[derive(Clone, Debug, Deref, DerefMut, Deserialize, Serialize, ToSchema)] +pub struct GetVaaCcipInput(#[serde(with = "crate::serde::hex")] [u8; 40]); #[derive(Debug, serde::Deserialize, IntoParams)] #[into_params(parameter_in=Query)] diff --git a/hermes/src/api/types.rs b/hermes/src/api/types.rs index ab2a5e3852..8555e97f4b 100644 --- a/hermes/src/api/types.rs +++ b/hermes/src/api/types.rs @@ -6,7 +6,6 @@ use { UnixTimestamp, }, doc_examples, - impl_deserialize_for_hex_string_wrapper, }, base64::{ engine::general_purpose::STANDARD as base64_standard_engine, @@ -21,6 +20,10 @@ use { DerefMut, }, pyth_sdk::PriceIdentifier, + serde::{ + Deserialize, + Serialize, + }, utoipa::ToSchema, wormhole_sdk::Chain, }; @@ -33,11 +36,9 @@ use { /// * e62df6c8b4a85fe1a67db44dc12de5db330f7ac66b72dc658afedf0f4a415b43 /// /// See https://pyth.network/developers/price-feed-ids for a list of all price feed ids. -#[derive(Debug, Clone, Deref, DerefMut, ToSchema)] +#[derive(Clone, Debug, Deref, DerefMut, Deserialize, Serialize, ToSchema)] #[schema(value_type=String, example=doc_examples::price_feed_id_example)] -pub struct PriceIdInput([u8; 32]); -// TODO: Use const generics instead of macro. -impl_deserialize_for_hex_string_wrapper!(PriceIdInput, 32); +pub struct PriceIdInput(#[serde(with = "crate::serde::hex")] [u8; 32]); impl From for PriceIdentifier { fn from(id: PriceIdInput) -> Self { diff --git a/hermes/src/macros.rs b/hermes/src/macros.rs deleted file mode 100644 index 760cd8a0c2..0000000000 --- a/hermes/src/macros.rs +++ /dev/null @@ -1,41 +0,0 @@ -#[macro_export] -/// A macro that generates Deserialize from string for a struct S that wraps [u8; N] where N is a -/// compile-time constant. This macro deserializes a string with or without leading 0x and supports -/// both lower case and upper case hex characters. -macro_rules! impl_deserialize_for_hex_string_wrapper { - ($struct_name:ident, $array_size:expr) => { - impl<'de> serde::Deserialize<'de> for $struct_name { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct HexVisitor; - - impl<'de> serde::de::Visitor<'de> for HexVisitor { - type Value = [u8; $array_size]; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "a hex string of length {}", $array_size * 2) - } - - fn visit_str(self, s: &str) -> Result - where - E: serde::de::Error, - { - let s = s.trim_start_matches("0x"); - let bytes = hex::decode(s) - .map_err(|_| E::invalid_value(serde::de::Unexpected::Str(s), &self))?; - if bytes.len() != $array_size { - return Err(E::invalid_length(bytes.len(), &self)); - } - let mut array = [0_u8; $array_size]; - array.copy_from_slice(&bytes); - Ok(array) - } - } - - deserializer.deserialize_str(HexVisitor).map($struct_name) - } - } - }; -} diff --git a/hermes/src/main.rs b/hermes/src/main.rs index 0006ba8dfa..6e2fa4b0ba 100644 --- a/hermes/src/main.rs +++ b/hermes/src/main.rs @@ -20,8 +20,8 @@ mod aggregate; mod api; mod config; mod doc_examples; -mod macros; mod network; +mod serde; mod state; mod wormhole; diff --git a/hermes/src/serde.rs b/hermes/src/serde.rs new file mode 100644 index 0000000000..2efc327f10 --- /dev/null +++ b/hermes/src/serde.rs @@ -0,0 +1,83 @@ +pub mod hex { + use { + hex::FromHex, + serde::{ + de::IntoDeserializer, + Deserialize, + Deserializer, + Serializer, + }, + }; + + pub fn serialize(b: &[u8; N], s: S) -> Result + where + S: Serializer, + { + s.serialize_str(hex::encode(b).as_str()) + } + + pub fn deserialize<'de, D, R>(d: D) -> Result + where + D: Deserializer<'de>, + R: FromHex, + ::Error: std::fmt::Display, + { + let s: String = Deserialize::deserialize(d)?; + let p = s.starts_with("0x") || s.starts_with("0X"); + let s = if p { &s[2..] } else { &s[..] }; + hex::serde::deserialize(s.into_deserializer()) + } + + #[cfg(test)] + mod tests { + use serde::Deserialize; + + #[derive(Debug, Deserialize, PartialEq)] + struct H(#[serde(with = "super")] [u8; 32]); + + #[test] + fn test_deserialize() { + let e = H([ + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, + 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, + 0x89, 0xab, 0xcd, 0xef, + ]); + + let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\""; + let u = "\"0x0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF\""; + assert_eq!(serde_json::from_str::(l).unwrap(), e); + assert_eq!(serde_json::from_str::(u).unwrap(), e); + + let l = "\"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef\""; + let u = "\"0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF\""; + assert_eq!(serde_json::from_str::(l).unwrap(), e); + assert_eq!(serde_json::from_str::(u).unwrap(), e); + } + + #[test] + fn test_deserialize_invalid_length() { + let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde\""; + let u = "\"0X0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDE\""; + assert!(serde_json::from_str::(l).is_err()); + assert!(serde_json::from_str::(u).is_err()); + + let l = "\"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcde\""; + let u = "\"0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDE\""; + assert!(serde_json::from_str::(l).is_err()); + assert!(serde_json::from_str::(u).is_err()); + } + + #[test] + fn test_deserialize_invalid_hex() { + let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\""; + let u = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\""; + assert!(serde_json::from_str::(l).is_err()); + assert!(serde_json::from_str::(u).is_err()); + + let l = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\""; + let u = "\"0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdeg\""; + assert!(serde_json::from_str::(l).is_err()); + assert!(serde_json::from_str::(u).is_err()); + } + } +}