Skip to content

Commit

Permalink
feat: implement Encode and Decode for primitive collection types
Browse files Browse the repository at this point in the history
  • Loading branch information
xJonathanLEI committed Oct 21, 2024
1 parent d7a5feb commit 94f6643
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 4 deletions.
18 changes: 15 additions & 3 deletions examples/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,34 @@ use starknet::{
struct CairoType {
a: Felt,
b: Option<u32>,
c: bool,
c: Vec<bool>,
d: [u8; 2],
}

fn main() {
let instance = CairoType {
a: felt!("123456789"),
b: Some(100),
c: false,
c: vec![false, true],
d: [3, 4],
};

let mut serialized = vec![];
instance.encode(&mut serialized).unwrap();

assert_eq!(
serialized,
[felt!("123456789"), felt!("0"), felt!("100"), felt!("0")]
[
felt!("123456789"),
felt!("0"),
felt!("100"),
felt!("2"),
felt!("0"),
felt!("1"),
felt!("2"),
felt!("3"),
felt!("4"),
]
);

let restored = CairoType::decode(&serialized).unwrap();
Expand Down
183 changes: 182 additions & 1 deletion starknet-core/src/codec.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use alloc::{boxed::Box, fmt::Formatter, format, string::*, vec::*};
use core::fmt::Display;
use core::{fmt::Display, mem::MaybeUninit};

use num_traits::ToPrimitive;

Expand Down Expand Up @@ -139,6 +139,51 @@ where
}
}

impl<T> Encode for Vec<T>
where
T: Encode,
{
fn encode<W: FeltWriter>(&self, writer: &mut W) -> Result<(), Error> {
writer.write(Felt::from(self.len()));

for item in self {
item.encode(writer)?;
}

Ok(())
}
}

impl<T, const N: usize> Encode for [T; N]
where
T: Encode,
{
fn encode<W: FeltWriter>(&self, writer: &mut W) -> Result<(), Error> {
writer.write(Felt::from(N));

for item in self {
item.encode(writer)?;
}

Ok(())
}
}

impl<T> Encode for [T]
where
T: Encode,
{
fn encode<W: FeltWriter>(&self, writer: &mut W) -> Result<(), Error> {
writer.write(Felt::from(self.len()));

for item in self {
item.encode(writer)?;
}

Ok(())
}
}

impl<'a> Decode<'a> for Felt {
fn decode_iter<T>(iter: &mut T) -> Result<Self, Error>
where
Expand Down Expand Up @@ -263,6 +308,56 @@ where
}
}

impl<'a, T> Decode<'a> for Vec<T>
where
T: Decode<'a>,
{
fn decode_iter<I>(iter: &mut I) -> Result<Self, Error>
where
I: Iterator<Item = &'a Felt>,
{
let length = iter.next().ok_or_else(Error::input_exhausted)?;
let length = length
.to_usize()
.ok_or_else(|| Error::value_out_of_range(length, "usize"))?;

let mut result = Self::with_capacity(length);

for _ in 0..length {
result.push(T::decode_iter(iter)?);
}

Ok(result)
}
}

impl<'a, T, const N: usize> Decode<'a> for [T; N]
where
T: Decode<'a> + Sized,
{
fn decode_iter<I>(iter: &mut I) -> Result<Self, Error>
where
I: Iterator<Item = &'a Felt>,
{
let length = iter.next().ok_or_else(Error::input_exhausted)?;
let length = length
.to_usize()
.ok_or_else(|| Error::value_out_of_range(length, "usize"))?;

if length != N {
return Err(Error::length_mismatch(N, length));
}

let mut result: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };

for elem in &mut result[..] {
*elem = MaybeUninit::new(T::decode_iter(iter)?);
}

Ok(unsafe { core::mem::transmute_copy::<_, [T; N]>(&result) })
}
}

impl Error {
/// Creates an [`Error`] which indicates that the input stream has ended prematurely.
pub fn input_exhausted() -> Self {
Expand All @@ -273,6 +368,14 @@ impl Error {
}
}

/// Creates an [`Error`] which indicates that the length (likely prefix) is different from the
/// expected value.
pub fn length_mismatch(expected: usize, actual: usize) -> Self {
Self {
repr: format!("expecting length `{}` but got `{}`", expected, actual).into_boxed_str(),
}
}

/// Creates an [`Error`] which indicates that the input value is out of range.
pub fn value_out_of_range<V>(value: V, type_name: &str) -> Self
where
Expand Down Expand Up @@ -426,6 +529,54 @@ mod tests {
assert_eq!(serialized, vec![Felt::from_str("1").unwrap()]);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_encode_vec() {
let mut serialized = Vec::<Felt>::new();
vec![Some(10u32), None].encode(&mut serialized).unwrap();
assert_eq!(
serialized,
vec![
Felt::from_str("2").unwrap(),
Felt::from_str("0").unwrap(),
Felt::from_str("10").unwrap(),
Felt::from_str("1").unwrap()
]
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_encode_array() {
let mut serialized = Vec::<Felt>::new();
<[Option<u32>; 2]>::encode(&[Some(10u32), None], &mut serialized).unwrap();
assert_eq!(
serialized,
vec![
Felt::from_str("2").unwrap(),
Felt::from_str("0").unwrap(),
Felt::from_str("10").unwrap(),
Felt::from_str("1").unwrap()
]
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_encode_slice() {
let mut serialized = Vec::<Felt>::new();
<[Option<u32>]>::encode(&[Some(10u32), None], &mut serialized).unwrap();
assert_eq!(
serialized,
vec![
Felt::from_str("2").unwrap(),
Felt::from_str("0").unwrap(),
Felt::from_str("10").unwrap(),
Felt::from_str("1").unwrap()
]
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_derive_encode_struct_named() {
Expand Down Expand Up @@ -639,6 +790,36 @@ mod tests {
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_decode_vec() {
assert_eq!(
vec![Some(10u32), None],
Vec::<Option::<u32>>::decode(&[
Felt::from_str("2").unwrap(),
Felt::from_str("0").unwrap(),
Felt::from_str("10").unwrap(),
Felt::from_str("1").unwrap()
])
.unwrap()
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_decode_array() {
assert_eq!(
[Some(10u32), None],
<[Option<u32>; 2]>::decode(&[
Felt::from_str("2").unwrap(),
Felt::from_str("0").unwrap(),
Felt::from_str("10").unwrap(),
Felt::from_str("1").unwrap()
])
.unwrap()
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
fn test_derive_decode_struct_named() {
Expand Down

0 comments on commit 94f6643

Please sign in to comment.