diff --git a/ciborium/src/de/mod.rs b/ciborium/src/de/mod.rs index 18742e4..7280552 100644 --- a/ciborium/src/de/mod.rs +++ b/ciborium/src/de/mod.rs @@ -8,12 +8,11 @@ pub use error::Error; use alloc::{string::String, vec::Vec}; +use crate::{simple_type::SimpleTypeAccess, tag::TagAccess}; use ciborium_io::Read; use ciborium_ll::*; use serde::de::{self, value::BytesDeserializer, Deserializer as _}; -use crate::tag::TagAccess; - trait Expected { fn expected(self, kind: &'static str) -> E; } @@ -213,8 +212,15 @@ where Header::Simple(simple::FALSE) => self.deserialize_bool(visitor), Header::Simple(simple::TRUE) => self.deserialize_bool(visitor), Header::Simple(simple::NULL) => self.deserialize_option(visitor), - Header::Simple(simple::UNDEFINED) => self.deserialize_option(visitor), - h @ Header::Simple(..) => Err(h.expected("known simple value")), + Header::Simple(v @ simple::UNDEFINED) => { + visitor.visit_enum(SimpleTypeAccess::new(self, v)) + } + // Those have to be registered via Standard Actions or are reserved so we should error whenever we + // encounter one. This crate should be updated once new entries in this range are added + // in the IANA registry + h @ Header::Simple(0..=31) => Err(h.expected("known simple value")), + // However we should support arbitrary simple types + Header::Simple(v) => visitor.visit_enum(SimpleTypeAccess::new(self, v)), h @ Header::Break => Err(h.expected("non-break")), } @@ -604,6 +610,19 @@ where let access = TagAccess::new(me, tag); visitor.visit_enum(access) }); + } else if name == "@@ST@@" { + return match self.decoder.pull()? { + Header::Simple(v @ simple::UNDEFINED) => { + self.decoder.push(Header::Positive(v as u64)); + visitor.visit_enum(SimpleTypeAccess::new(self, v)) + } + h @ Header::Simple(0..=31) => Err(h.expected("known simple value")), + Header::Simple(v) => { + self.decoder.push(Header::Positive(v as u64)); + visitor.visit_enum(SimpleTypeAccess::new(self, v)) + } + h => Err(h.expected("known simple value")), + }; } loop { diff --git a/ciborium/src/lib.rs b/ciborium/src/lib.rs index f143943..d4d0735 100644 --- a/ciborium/src/lib.rs +++ b/ciborium/src/lib.rs @@ -94,6 +94,7 @@ extern crate alloc; pub mod de; pub mod ser; +pub mod simple_type; pub mod tag; pub mod value; diff --git a/ciborium/src/ser/mod.rs b/ciborium/src/ser/mod.rs index 03dd1da..0628d06 100644 --- a/ciborium/src/ser/mod.rs +++ b/ciborium/src/ser/mod.rs @@ -214,7 +214,16 @@ where variant: &'static str, value: &U, ) -> Result<(), Self::Error> { - if name != "@@TAG@@" || variant != "@@UNTAGGED@@" { + if name == "@@ST@@" && variant == "@@SIMPLETYPE@@" { + use serde::ser::Error as _; + + let v = crate::Value::serialized(value).map_err(Error::custom)?; + let v = v + .as_integer() + .ok_or_else(|| Error::custom("Internal error handling simple types"))?; + let v = u8::try_from(v).map_err(Error::custom)?; + return Ok(self.0.push(Header::Simple(v))?); + } else if name != "@@TAG@@" || variant != "@@UNTAGGED@@" { self.0.push(Header::Map(Some(1)))?; self.serialize_str(variant)?; } diff --git a/ciborium/src/simple_type.rs b/ciborium/src/simple_type.rs new file mode 100644 index 0000000..cbe0d42 --- /dev/null +++ b/ciborium/src/simple_type.rs @@ -0,0 +1,141 @@ +//! Contains helper types for dealing with CBOR simple types + +use serde::{de, de::Error as _, forward_to_deserialize_any, ser, Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename = "@@ST@@")] +enum Internal { + /// The integer can either be 23, or (32..=255) + #[serde(rename = "@@SIMPLETYPE@@")] + SimpleType(u8), +} + +/// A CBOR simple value +/// See https://datatracker.ietf.org/doc/html/rfc8949#section-3.3 +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct SimpleType(pub u8); + +impl<'de> Deserialize<'de> for SimpleType { + #[inline] + fn deserialize>(deserializer: D) -> Result { + match Internal::deserialize(deserializer)? { + Internal::SimpleType(t) => Ok(SimpleType(t)), + } + } +} + +impl Serialize for SimpleType { + #[inline] + fn serialize(&self, serializer: S) -> Result { + Internal::SimpleType(self.0).serialize(serializer) + } +} + +pub(crate) struct SimpleTypeAccess { + parent: Option, + state: usize, + typ: u8, +} + +impl SimpleTypeAccess { + pub fn new(parent: D, typ: u8) -> Self { + Self { + parent: Some(parent), + state: 0, + typ, + } + } +} + +impl<'de, D: de::Deserializer<'de>> de::Deserializer<'de> for &mut SimpleTypeAccess { + type Error = D::Error; + + #[inline] + fn deserialize_any>(self, visitor: V) -> Result { + self.state += 1; + match self.state { + 1 => visitor.visit_str("@@SIMPLETYPE@@"), + _ => visitor.visit_u8(self.typ), + } + } + + forward_to_deserialize_any! { + i8 i16 i32 i64 i128 + u8 u16 u32 u64 u128 + bool f32 f64 + char str string + bytes byte_buf + seq map + struct tuple tuple_struct + identifier ignored_any + option unit unit_struct newtype_struct enum + } +} + +impl<'de, D: de::Deserializer<'de>> de::EnumAccess<'de> for SimpleTypeAccess { + type Error = D::Error; + type Variant = Self; + + #[inline] + fn variant_seed>( + mut self, + seed: V, + ) -> Result<(V::Value, Self::Variant), Self::Error> { + let variant = seed.deserialize(&mut self)?; + Ok((variant, self)) + } +} + +impl<'de, D: de::Deserializer<'de>> de::VariantAccess<'de> for SimpleTypeAccess { + type Error = D::Error; + + #[inline] + fn unit_variant(self) -> Result<(), Self::Error> { + Err(Self::Error::custom("expected simple type")) + } + + #[inline] + fn newtype_variant_seed>( + mut self, + seed: U, + ) -> Result { + seed.deserialize(self.parent.take().unwrap()) + } + + #[inline] + fn tuple_variant>( + self, + _len: usize, + visitor: V, + ) -> Result { + visitor.visit_seq(self) + } + + #[inline] + fn struct_variant>( + self, + _fields: &'static [&'static str], + _visitor: V, + ) -> Result { + Err(Self::Error::custom("expected simple_type")) + } +} + +impl<'de, D: de::Deserializer<'de>> de::SeqAccess<'de> for SimpleTypeAccess { + type Error = D::Error; + + #[inline] + fn next_element_seed>( + &mut self, + seed: T, + ) -> Result, Self::Error> { + if self.state < 2 { + return Ok(Some(seed.deserialize(self)?)); + } + + Ok(match self.parent.take() { + Some(x) => Some(seed.deserialize(x)?), + None => None, + }) + } +} diff --git a/ciborium/src/value/de.rs b/ciborium/src/value/de.rs index f58a017..6f115ce 100644 --- a/ciborium/src/value/de.rs +++ b/ciborium/src/value/de.rs @@ -1,13 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 -use crate::tag::TagAccess; +use crate::{simple_type::SimpleTypeAccess, tag::TagAccess}; use super::{Error, Integer, Value}; use alloc::{boxed::Box, string::String, vec::Vec}; use core::iter::Peekable; -use ciborium_ll::tag; +use ciborium_ll::{simple, tag}; use serde::de::{self, Deserializer as _}; impl<'a> From for de::Unexpected<'a> { @@ -36,6 +36,7 @@ impl<'a> From<&'a Value> for de::Unexpected<'a> { Value::Map(..) => Self::Map, Value::Null => Self::Other("null"), Value::Tag(..) => Self::Other("tag"), + Value::Simple(..) => Self::Other("simple"), } } } @@ -218,6 +219,7 @@ impl<'a> Deserializer<&'a Value> { .map(|x| x ^ !0) .map_err(|_| err()) .and_then(|x| x.try_into().map_err(|_| err()))?, + Value::Simple(x) => i128::from(*x).try_into().map_err(|_| err())?, _ => return Err(de::Error::invalid_type(self.0.into(), &"(big)int")), }) } @@ -228,6 +230,7 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<&'a Value> { #[inline] fn deserialize_any>(self, visitor: V) -> Result { + use serde::ser::Error as _; match self.0 { Value::Bytes(x) => visitor.visit_bytes(x), Value::Text(x) => visitor.visit_str(x), @@ -235,6 +238,11 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<&'a Value> { Value::Map(x) => visitor.visit_map(Deserializer(x.iter().peekable())), Value::Bool(x) => visitor.visit_bool(*x), Value::Null => visitor.visit_none(), + Value::Simple(v @ simple::UNDEFINED) => { + visitor.visit_enum(SimpleTypeAccess::new(self, *v)) + } + Value::Simple(0..=31) => Err(Self::Error::custom("Unsupported simple type")), + Value::Simple(v) => visitor.visit_enum(SimpleTypeAccess::new(self, *v)), Value::Tag(t, v) => { let parent: Deserializer<&Value> = Deserializer(v); @@ -493,6 +501,18 @@ impl<'a, 'de> de::Deserializer<'de> for Deserializer<&'a Value> { let parent: Deserializer<&Value> = Deserializer(val); let access = TagAccess::new(parent, tag); return visitor.visit_enum(access); + } else if name == "@@ST@@" { + use serde::ser::Error as _; + return match self.0 { + Value::Simple(v @ simple::UNDEFINED) => { + visitor.visit_enum(SimpleTypeAccess::new(Deserializer(self.0), *v)) + } + Value::Simple(0..=31) => return Err(Error::custom("Unsupported simple type")), + Value::Simple(v) => { + visitor.visit_enum(SimpleTypeAccess::new(Deserializer(self.0), *v)) + } + _ => Err(Error::custom("Implementation error for simple type")), + }; } match self.0 { diff --git a/ciborium/src/value/mod.rs b/ciborium/src/value/mod.rs index 7233026..3bfa6a8 100644 --- a/ciborium/src/value/mod.rs +++ b/ciborium/src/value/mod.rs @@ -45,6 +45,9 @@ pub enum Value { /// A map Map(Vec<(Value, Value)>), + + /// A CBOR "Simple Value" other than true, false, or null + Simple(u8), } impl Value { diff --git a/ciborium/src/value/ser.rs b/ciborium/src/value/ser.rs index 99e6587..7fb32b9 100644 --- a/ciborium/src/value/ser.rs +++ b/ciborium/src/value/ser.rs @@ -5,15 +5,23 @@ use super::{Error, Value}; use alloc::{vec, vec::Vec}; use ::serde::ser::{self, SerializeMap as _, SerializeSeq as _, SerializeTupleVariant as _}; +use ciborium_ll::simple; impl ser::Serialize for Value { #[inline] fn serialize(&self, serializer: S) -> Result { + use serde::ser::Error as _; + match self { Value::Bytes(x) => serializer.serialize_bytes(x), Value::Bool(x) => serializer.serialize_bool(*x), Value::Text(x) => serializer.serialize_str(x), Value::Null => serializer.serialize_unit(), + Value::Simple(x @ simple::UNDEFINED) => { + serializer.serialize_newtype_struct("@@SIMPLETYPE@@", x) + } + Value::Simple(0..=31) => Err(S::Error::custom("Unsupported simple type")), + Value::Simple(x) => serializer.serialize_newtype_struct("@@SIMPLETYPE@@", x), Value::Tag(t, v) => { let mut acc = serializer.serialize_tuple_variant("@@TAG@@", 0, "@@TAGGED@@", 2)?; @@ -190,6 +198,16 @@ impl ser::Serializer for Serializer<()> { ) -> Result { Ok(match (name, variant) { ("@@TAG@@", "@@UNTAGGED@@") => Value::serialized(value)?, + ("@@ST@@", "@@SIMPLETYPE@@") => { + use serde::ser::Error as _; + + let v = Value::serialized(value)?; + let v = v + .as_integer() + .ok_or_else(|| Error::custom("Internal error handling simple types"))?; + let v = u8::try_from(v).map_err(Error::custom)?; + Value::Simple(v) + } _ => vec![(variant.into(), Value::serialized(value)?)].into(), }) } diff --git a/ciborium/tests/simple_type.rs b/ciborium/tests/simple_type.rs new file mode 100644 index 0000000..8c7ab70 --- /dev/null +++ b/ciborium/tests/simple_type.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 + +extern crate alloc; + +use ciborium::{de::from_reader, ser::into_writer, simple_type::SimpleType, value::Value}; +use rstest::rstest; +use serde::{de::DeserializeOwned, Serialize}; + +use core::fmt::Debug; + +#[rstest(item, bytes, value, encode, success, + case(SimpleType(0), "e0", Value::Simple(0), true, false), // Registered via Standard Actions + case(SimpleType(19), "f3", Value::Simple(19), true, false), // Registered via Standard Actions + case(SimpleType(23), "f7", Value::Simple(23), true, true), // CBOR simple value "undefined" + case(SimpleType(32), "f820", Value::Simple(32), true, true), + case(SimpleType(255), "f8ff", Value::Simple(255), true, true), + case(vec![SimpleType(255)], "81f8ff", Value::Array(vec![Value::Simple(255)]), true, true), +)] +fn test( + item: T, + bytes: &str, + value: Value, + encode: bool, + success: bool, +) { + let bytes = hex::decode(bytes).unwrap(); + + if encode { + // Encode into bytes + let mut encoded = Vec::new(); + into_writer(&item, &mut encoded).unwrap(); + assert_eq!(bytes, encoded); + + // Encode into value + assert_eq!(value, Value::serialized(&item).unwrap()); + } + + // Decode from bytes + match from_reader(&bytes[..]) { + Ok(x) if success => assert_eq!(item, x), + Ok(..) => panic!("unexpected success"), + Err(e) if success => panic!("{:?}", e), + Err(..) => (), + } + + // Decode from value + match value.deserialized() { + Ok(x) if success => assert_eq!(item, x), + Ok(..) => panic!("unexpected success"), + Err(e) if success => panic!("{:?}", e), + Err(..) => (), + } +}