From 5689cfc1584442d68ad84ce552773d9371b4ced7 Mon Sep 17 00:00:00 2001 From: Liam Gray Date: Tue, 26 Nov 2024 18:28:00 +0000 Subject: [PATCH] Example bodged rfc8949 implementation --- ciborium/src/lib.rs | 8 +- ciborium/src/ser/mod.rs | 299 ++++++++++++++++++++++++-------- ciborium/src/value/canonical.rs | 8 + ciborium/tests/canonical.rs | 26 +++ 4 files changed, 262 insertions(+), 79 deletions(-) diff --git a/ciborium/src/lib.rs b/ciborium/src/lib.rs index f143943..f0edf0c 100644 --- a/ciborium/src/lib.rs +++ b/ciborium/src/lib.rs @@ -99,12 +99,14 @@ pub mod value; // Re-export the [items recommended by serde](https://serde.rs/conventions.html). #[doc(inline)] -pub use crate::de::from_reader; +pub use crate::de::{from_reader, from_reader_with_buffer, Deserializer}; + #[doc(inline)] -pub use crate::de::from_reader_with_buffer; +pub use crate::ser::{into_writer, Serializer}; #[doc(inline)] -pub use crate::ser::into_writer; +#[cfg(feature = "std")] +pub use crate::ser::{into_writer_canonical, to_vec, to_vec_canonical}; #[doc(inline)] pub use crate::value::Value; diff --git a/ciborium/src/ser/mod.rs b/ciborium/src/ser/mod.rs index 03dd1da..dcdca1e 100644 --- a/ciborium/src/ser/mod.rs +++ b/ciborium/src/ser/mod.rs @@ -7,24 +7,57 @@ mod error; pub use error::Error; use alloc::string::ToString; - use ciborium_io::Write; use ciborium_ll::*; use serde::{ser, Serialize as _}; -struct Serializer(Encoder); +/// A serializer for CBOR. +pub struct Serializer { + encoder: Encoder, + + /// Whether to canonically sort map keys in output according to [RFC 8949]'s deterministic + /// encoding spec. + /// + /// [RFC 8949]: https://www.rfc-editor.org/rfc/rfc8949.html#name-deterministically-encoded-c + #[cfg(feature = "std")] + canonical: bool, +} + +impl Serializer { + /// Create a new CBOR serializer. + /// + /// `canonical` determines whether to canonically sort map keys in output according to + /// [RFC 8949]'s deterministic encoding spec. Requires the `std` feature to sort keys. + /// + /// [RFC 8949]: https://www.rfc-editor.org/rfc/rfc8949.html#name-deterministically-encoded-c + #[cfg(feature = "std")] + pub fn new(encoder: impl Into>, canonical: bool) -> Self { + Self { + encoder: encoder.into(), + canonical + } + } +} impl From for Serializer { #[inline] fn from(writer: W) -> Self { - Self(writer.into()) + Self { + encoder: writer.into(), + #[cfg(feature = "std")] + canonical: false, + } } } impl From> for Serializer { #[inline] fn from(writer: Encoder) -> Self { - Self(writer) + Self { + encoder: writer, + #[cfg(feature = "std")] + canonical: false, + } } } @@ -45,7 +78,7 @@ where #[inline] fn serialize_bool(self, v: bool) -> Result<(), Self::Error> { - Ok(self.0.push(match v { + Ok(self.encoder.push(match v { false => Header::Simple(simple::FALSE), true => Header::Simple(simple::TRUE), })?) @@ -68,7 +101,7 @@ where #[inline] fn serialize_i64(self, v: i64) -> Result<(), Self::Error> { - Ok(self.0.push(match v.is_negative() { + Ok(self.encoder.push(match v.is_negative() { false => Header::Positive(v as u64), true => Header::Negative(v as u64 ^ !0), })?) @@ -82,8 +115,8 @@ where }; match (tag, u64::try_from(raw)) { - (tag::BIGPOS, Ok(x)) => return Ok(self.0.push(Header::Positive(x))?), - (tag::BIGNEG, Ok(x)) => return Ok(self.0.push(Header::Negative(x))?), + (tag::BIGPOS, Ok(x)) => return Ok(self.encoder.push(Header::Positive(x))?), + (tag::BIGNEG, Ok(x)) => return Ok(self.encoder.push(Header::Negative(x))?), _ => {} } @@ -95,9 +128,9 @@ where slice = &slice[1..]; } - self.0.push(Header::Tag(tag))?; - self.0.push(Header::Bytes(Some(slice.len())))?; - Ok(self.0.write_all(slice)?) + self.encoder.push(Header::Tag(tag))?; + self.encoder.push(Header::Bytes(Some(slice.len())))?; + Ok(self.encoder.write_all(slice)?) } #[inline] @@ -117,7 +150,7 @@ where #[inline] fn serialize_u64(self, v: u64) -> Result<(), Self::Error> { - Ok(self.0.push(Header::Positive(v))?) + Ok(self.encoder.push(Header::Positive(v))?) } #[inline] @@ -134,9 +167,9 @@ where slice = &slice[1..]; } - self.0.push(Header::Tag(tag::BIGPOS))?; - self.0.push(Header::Bytes(Some(slice.len())))?; - Ok(self.0.write_all(slice)?) + self.encoder.push(Header::Tag(tag::BIGPOS))?; + self.encoder.push(Header::Bytes(Some(slice.len())))?; + Ok(self.encoder.write_all(slice)?) } #[inline] @@ -146,7 +179,7 @@ where #[inline] fn serialize_f64(self, v: f64) -> Result<(), Self::Error> { - Ok(self.0.push(Header::Float(v))?) + Ok(self.encoder.push(Header::Float(v))?) } #[inline] @@ -157,19 +190,19 @@ where #[inline] fn serialize_str(self, v: &str) -> Result<(), Self::Error> { let bytes = v.as_bytes(); - self.0.push(Header::Text(bytes.len().into()))?; - Ok(self.0.write_all(bytes)?) + self.encoder.push(Header::Text(bytes.len().into()))?; + Ok(self.encoder.write_all(bytes)?) } #[inline] fn serialize_bytes(self, v: &[u8]) -> Result<(), Self::Error> { - self.0.push(Header::Bytes(v.len().into()))?; - Ok(self.0.write_all(v)?) + self.encoder.push(Header::Bytes(v.len().into()))?; + Ok(self.encoder.write_all(v)?) } #[inline] fn serialize_none(self) -> Result<(), Self::Error> { - Ok(self.0.push(Header::Simple(simple::NULL))?) + Ok(self.encoder.push(Header::Simple(simple::NULL))?) } #[inline] @@ -215,7 +248,7 @@ where value: &U, ) -> Result<(), Self::Error> { if name != "@@TAG@@" || variant != "@@UNTAGGED@@" { - self.0.push(Header::Map(Some(1)))?; + self.encoder.push(Header::Map(Some(1)))?; self.serialize_str(variant)?; } @@ -224,12 +257,8 @@ where #[inline] fn serialize_seq(self, length: Option) -> Result { - self.0.push(Header::Array(length))?; - Ok(CollectionSerializer { - encoder: self, - ending: length.is_none(), - tag: false, - }) + self.encoder.push(Header::Array(length))?; + Ok(CollectionSerializer::new(self, length.is_none(), false)) } #[inline] @@ -255,33 +284,21 @@ where length: usize, ) -> Result { match (name, variant) { - ("@@TAG@@", "@@TAGGED@@") => Ok(CollectionSerializer { - encoder: self, - ending: false, - tag: true, - }), + ("@@TAG@@", "@@TAGGED@@") => Ok(CollectionSerializer::new(self, false, true)), _ => { - self.0.push(Header::Map(Some(1)))?; + self.encoder.push(Header::Map(Some(1)))?; self.serialize_str(variant)?; - self.0.push(Header::Array(Some(length)))?; - Ok(CollectionSerializer { - encoder: self, - ending: false, - tag: false, - }) + self.encoder.push(Header::Array(Some(length)))?; + Ok(CollectionSerializer::new(self, false, false)) } } } #[inline] fn serialize_map(self, length: Option) -> Result { - self.0.push(Header::Map(length))?; - Ok(CollectionSerializer { - encoder: self, - ending: length.is_none(), - tag: false, - }) + self.encoder.push(Header::Map(length))?; + Ok(CollectionSerializer::new(self, length.is_none(), false)) } #[inline] @@ -290,12 +307,8 @@ where _name: &'static str, length: usize, ) -> Result { - self.0.push(Header::Map(Some(length)))?; - Ok(CollectionSerializer { - encoder: self, - ending: false, - tag: false, - }) + self.encoder.push(Header::Map(Some(length)))?; + Ok(CollectionSerializer::new(self, false, false)) } #[inline] @@ -306,14 +319,10 @@ where variant: &'static str, length: usize, ) -> Result { - self.0.push(Header::Map(Some(1)))?; + self.encoder.push(Header::Map(Some(1)))?; self.serialize_str(variant)?; - self.0.push(Header::Map(Some(length)))?; - Ok(CollectionSerializer { - encoder: self, - ending: false, - tag: false, - }) + self.encoder.push(Header::Map(Some(length)))?; + Ok(CollectionSerializer::new(self, false, false)) } #[inline] @@ -327,7 +336,42 @@ macro_rules! end { #[inline] fn end(self) -> Result<(), Self::Error> { if self.ending { - self.encoder.0.push(Header::Break)?; + self.serializer.encoder.push(Header::Break)?; + } + + Ok(()) + } + }; +} + +/// `end_map` outputs cached keys if we're serializing in canonical mode +macro_rules! end_map { + () => { + #[inline] + fn end(self) -> Result<(), Self::Error> { + #[cfg(feature = "std")] + if self.serializer.canonical { + // keys get sorted in lexicographical byte order + let keys = self.cache_keys; + let values = self.cache_values; + + debug_assert_eq!( + keys.len(), values.len(), + "ciborium error: canonicalization failed, keys and values must have same length."); + + let mut pairs = std::collections::BTreeMap::new(); + for (key, value) in keys.iter().zip(values.iter()) { + pairs.insert(key, value); + } + + for (key, value) in pairs { + self.serializer.encoder.write_all(&key)?; + self.serializer.encoder.write_all(&value)?; + } + } + + if self.ending { + self.serializer.encoder.push(Header::Break)?; } Ok(()) @@ -335,10 +379,39 @@ macro_rules! end { }; } -struct CollectionSerializer<'a, W> { - encoder: &'a mut Serializer, +/// An internal struct for serializing collections. +/// +/// Not to be used externally, only exposed as part of the [Serializer] type. +#[doc(hidden)] +pub struct CollectionSerializer<'a, W> { + serializer: &'a mut Serializer, ending: bool, tag: bool, + + #[cfg(feature = "std")] + cache_keys: Vec>, + #[cfg(feature = "std")] + cache_values: Vec>, +} + +impl<'a, W> CollectionSerializer<'a, W> { + pub fn new(serializer: &'a mut Serializer, ending: bool, tag: bool) -> Self { + #[cfg(feature = "std")] + let capacity = match serializer.canonical { + true => 4, + false => 0, + }; + + Self { + serializer, + ending, + tag, + #[cfg(feature = "std")] + cache_keys: Vec::with_capacity(capacity), + #[cfg(feature = "std")] + cache_values: Vec::with_capacity(capacity), + } + } } impl<'a, W: Write> ser::SerializeSeq for CollectionSerializer<'a, W> @@ -353,7 +426,7 @@ where &mut self, value: &U, ) -> Result<(), Self::Error> { - value.serialize(&mut *self.encoder) + value.serialize(&mut *self.serializer) } end!(); @@ -371,7 +444,7 @@ where &mut self, value: &U, ) -> Result<(), Self::Error> { - value.serialize(&mut *self.encoder) + value.serialize(&mut *self.serializer) } end!(); @@ -389,7 +462,7 @@ where &mut self, value: &U, ) -> Result<(), Self::Error> { - value.serialize(&mut *self.encoder) + value.serialize(&mut *self.serializer) } end!(); @@ -408,12 +481,12 @@ where value: &U, ) -> Result<(), Self::Error> { if !self.tag { - return value.serialize(&mut *self.encoder); + return value.serialize(&mut *self.serializer); } self.tag = false; match value.serialize(crate::tag::Serializer) { - Ok(x) => Ok(self.encoder.0.push(Header::Tag(x))?), + Ok(x) => Ok(self.serializer.encoder.push(Header::Tag(x))?), _ => Err(Error::Value("expected tag".into())), } } @@ -430,7 +503,14 @@ where #[inline] fn serialize_key(&mut self, key: &U) -> Result<(), Self::Error> { - key.serialize(&mut *self.encoder) + #[cfg(feature = "std")] + if self.serializer.canonical { + let key_bytes = to_vec(key).map_err(|e| Error::Value(e.to_string()))?; + self.cache_keys.push(key_bytes); + return Ok(()); + } + + key.serialize(&mut *self.serializer) } #[inline] @@ -438,10 +518,17 @@ where &mut self, value: &U, ) -> Result<(), Self::Error> { - value.serialize(&mut *self.encoder) + #[cfg(feature = "std")] + if self.serializer.canonical { + let value_bytes = to_vec(value).map_err(|e| Error::Value(e.to_string()))?; + self.cache_values.push(value_bytes); + return Ok(()); + } + + value.serialize(&mut *self.serializer) } - end!(); + end_map!(); } impl<'a, W: Write> ser::SerializeStruct for CollectionSerializer<'a, W> @@ -457,12 +544,21 @@ where key: &'static str, value: &U, ) -> Result<(), Self::Error> { - key.serialize(&mut *self.encoder)?; - value.serialize(&mut *self.encoder)?; + #[cfg(feature = "std")] + if self.serializer.canonical { + let key_bytes = to_vec(key).map_err(|e| Error::Value(e.to_string()))?; + self.cache_keys.push(key_bytes); + let value_bytes = to_vec(value).map_err(|e| Error::Value(e.to_string()))?; + self.cache_values.push(value_bytes); + return Ok(()); + } + + key.serialize(&mut *self.serializer)?; + value.serialize(&mut *self.serializer)?; Ok(()) } - end!(); + end_map!(); } impl<'a, W: Write> ser::SerializeStructVariant for CollectionSerializer<'a, W> @@ -478,11 +574,44 @@ where key: &'static str, value: &U, ) -> Result<(), Self::Error> { - key.serialize(&mut *self.encoder)?; - value.serialize(&mut *self.encoder) + #[cfg(feature = "std")] + if self.serializer.canonical { + let key_bytes = to_vec(key).map_err(|e| Error::Value(e.to_string()))?; + self.cache_keys.push(key_bytes); + let value_bytes = to_vec(value).map_err(|e| Error::Value(e.to_string()))?; + self.cache_values.push(value_bytes); + return Ok(()); + } + + key.serialize(&mut *self.serializer)?; + value.serialize(&mut *self.serializer) } - end!(); + end_map!(); +} + +/// Serializes as CBOR into `Vec`. +#[cfg(feature = "std")] +#[inline] +pub fn to_vec(value: &T) -> Result, Error> { + let mut buffer = std::vec::Vec::with_capacity(1024); + let mut serializer = Serializer::new(&mut buffer, false); + value.serialize(&mut serializer)?; + Ok(buffer) +} + +/// Canonically serializes as CBOR into `Vec`. +/// +/// This will sort map keys in output according to [RFC 8949]'s deterministic encoding spec. +/// +/// [RFC 8949]: https://www.rfc-editor.org/rfc/rfc8949.html#name-deterministically-encoded-c +#[cfg(feature = "std")] +#[inline] +pub fn to_vec_canonical(value: &T) -> Result, Error> { + let mut buffer = std::vec::Vec::with_capacity(1024); + let mut serializer = Serializer::new(&mut buffer, true); + value.serialize(&mut serializer)?; + Ok(buffer) } /// Serializes as CBOR into a type with [`impl ciborium_io::Write`](ciborium_io::Write) @@ -497,3 +626,21 @@ where let mut encoder = Serializer::from(writer); value.serialize(&mut encoder) } + +/// Canonically serializes as CBOR into a type with [`impl ciborium_io::Write`](ciborium_io::Write) +/// +/// This will sort map keys in output according to [RFC 8949]'s deterministic encoding spec. +/// +/// [RFC 8949]: https://www.rfc-editor.org/rfc/rfc8949.html#name-deterministically-encoded-c +#[cfg(feature = "std")] +#[inline] +pub fn into_writer_canonical( + value: &T, + writer: W, +) -> Result<(), Error> +where + W::Error: core::fmt::Debug, +{ + let mut encoder = Serializer::new(writer, true); + value.serialize(&mut encoder) +} \ No newline at end of file diff --git a/ciborium/src/value/canonical.rs b/ciborium/src/value/canonical.rs index 072e1cf..3afc58b 100644 --- a/ciborium/src/value/canonical.rs +++ b/ciborium/src/value/canonical.rs @@ -59,6 +59,14 @@ pub fn cmp_value(v1: &Value, v2: &Value) -> Ordering { } } +pub fn cmp_value_rfc8949(v1: &Value, v2: &Value) -> Ordering { + let mut bytes1 = Vec::new(); + let _ = crate::ser::into_writer(v1, &mut bytes1); + let mut bytes2 = Vec::new(); + let _ = crate::ser::into_writer(v2, &mut bytes2); + bytes1.cmp(&bytes2) +} + /// A CBOR Value that impl Ord and Eq to allow sorting of values as defined in both /// RFC 7049 Section 3.9 (regarding key sorting) and RFC 8949 4.2.3 (as errata). /// diff --git a/ciborium/tests/canonical.rs b/ciborium/tests/canonical.rs index d4aa33c..15d1438 100644 --- a/ciborium/tests/canonical.rs +++ b/ciborium/tests/canonical.rs @@ -63,6 +63,32 @@ fn map() { ); } +/// Match [RFC 8949] deterministic ordering example. +/// +/// The RFC specifies lexicographic byte ordering of serialized keys. +/// +/// [RFC 8949]: https://www.rfc-editor.org/rfc/rfc8949.html#name-core-deterministic-encoding +#[test] +#[cfg(feature = "std")] +fn map_canonical() { + let mut map = BTreeMap::new(); + map.insert(cval!(false), val!(2)); + map.insert(cval!([-1]), val!(5)); + map.insert(cval!(-1), val!(1)); + map.insert(cval!(10), val!(0)); + map.insert(cval!(100), val!(3)); + map.insert(cval!([100]), val!(7)); + map.insert(cval!("z"), val!(4)); + map.insert(cval!("aa"), val!(6)); + + let bytes1 = ciborium::ser::to_vec_canonical(&map).unwrap(); + + assert_eq!( + hex::encode(&bytes1), + "a80a001864032001617a046261610681186407812005f402" + ); +} + #[test] fn negative_numbers() { let mut array: Vec = vec![