From f39fff7652eadaa7e678d4e47219f5dcf055eba3 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Mon, 28 Oct 2024 07:56:37 +0000 Subject: [PATCH 1/2] Support (de)serializing tensors via Serde Add a `serde` feature to the rten-tensor crate. When enabled, the `TensorBase` struct implements serde's `Serialize` and `Deserialize` traits. --- Cargo.lock | 2 + Makefile | 2 +- rten-tensor/Cargo.toml | 7 + rten-tensor/src/copy.rs | 2 +- rten-tensor/src/impl_serialize.rs | 230 ++++++++++++++++++++++++++++++ rten-tensor/src/lib.rs | 16 +++ rten-tensor/src/tensor.rs | 2 +- 7 files changed, 258 insertions(+), 3 deletions(-) create mode 100644 rten-tensor/src/impl_serialize.rs diff --git a/Cargo.lock b/Cargo.lock index 6a3d6193..476ccf02 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -466,6 +466,8 @@ version = "0.14.0" name = "rten-tensor" version = "0.14.0" dependencies = [ + "serde", + "serde_json", "smallvec", ] diff --git a/Makefile b/Makefile index 14ae3373..5d64c901 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ miri: # nightly Rust. .PHONY: test test: - cargo test --workspace --features mmap,random,text-decoder + cargo test --workspace --features mmap,random,text-decoder,serde .PHONY: wasm wasm: diff --git a/rten-tensor/Cargo.toml b/rten-tensor/Cargo.toml index a7c94490..1a3232d5 100644 --- a/rten-tensor/Cargo.toml +++ b/rten-tensor/Cargo.toml @@ -10,8 +10,12 @@ repository = "https://github.com/robertknight/rten" include = ["/src", "/README.md"] [dependencies] +serde = { workspace = true, optional = true } smallvec = { version = "1.10.0", features=["union", "const_generics", "const_new"] } +[dev-dependencies] +serde_json = { workspace = true } + [lib] crate-type = ["lib"] @@ -19,3 +23,6 @@ crate-type = ["lib"] # See comments about `needless_range_loop` in root Cargo.toml. needless_range_loop = "allow" manual_memcpy = "allow" + +[features] +serde = ["dep:serde"] diff --git a/rten-tensor/src/copy.rs b/rten-tensor/src/copy.rs index cb3ab2e0..54e1a82b 100644 --- a/rten-tensor/src/copy.rs +++ b/rten-tensor/src/copy.rs @@ -415,7 +415,7 @@ fn copy_range_into_slice_inner( let ranges: [IndexRange; 4] = ranges.try_into().unwrap(); // Check output length is correct. - let sliced_len = ranges.iter().map(|s| s.steps()).product(); + let sliced_len: usize = ranges.iter().map(|s| s.steps()).product(); assert_eq!(dest.len(), sliced_len, "output too short"); let mut dest_offset = 0; diff --git a/rten-tensor/src/impl_serialize.rs b/rten-tensor/src/impl_serialize.rs new file mode 100644 index 00000000..ca4c6660 --- /dev/null +++ b/rten-tensor/src/impl_serialize.rs @@ -0,0 +1,230 @@ +use std::fmt; + +use serde::de::{Deserialize, Deserializer, Error, MapAccess, Visitor}; +use serde::ser::{Serialize, SerializeStruct, Serializer}; + +use crate::iterators::Iter; +use crate::{AsView, Layout, MutLayout, Storage, TensorBase}; + +struct TensorData<'a, T> { + iter: Iter<'a, T>, +} + +impl<'a, T> Serialize for TensorData<'a, T> +where + T: Serialize, +{ + fn serialize(&self, serializer: Sr) -> Result + where + Sr: Serializer, + { + serializer.collect_seq(self.iter.clone()) + } +} + +impl Serialize for TensorBase +where + S::Elem: Serialize, +{ + fn serialize(&self, serializer: Sr) -> Result + where + Sr: Serializer, + { + let mut tensor = serializer.serialize_struct("Tensor", 2)?; + tensor.serialize_field("shape", self.shape().as_ref())?; + tensor.serialize_field("data", &TensorData { iter: self.iter() })?; + tensor.end() + } +} + +struct TensorVisitor { + data_marker: std::marker::PhantomData, + layout_marker: std::marker::PhantomData, +} + +impl<'de, T, L> Visitor<'de> for TensorVisitor +where + T: Deserialize<'de>, + L: MutLayout, + for<'a> L::Index<'a>: TryFrom<&'a [usize]>, +{ + type Value = TensorBase, L>; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a tensor with \"shape\" and \"data\" fields") + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let mut data: Option> = None; + let mut shape: Option> = None; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "data" => { + if data.is_some() { + return Err(A::Error::duplicate_field("data")); + } + data = Some(map.next_value()?); + } + "shape" => { + if shape.is_some() { + return Err(A::Error::duplicate_field("shape")); + } + shape = Some(map.next_value()?); + } + _ => { + return Err(A::Error::unknown_field(&key, &["data", "shape"])); + } + } + } + + let Some(shape) = shape else { + return Err(A::Error::missing_field("shape")); + }; + let Some(data) = data else { + return Err(A::Error::missing_field("data")); + }; + + let Ok(shape_ref): Result, _> = shape.as_slice().try_into() else { + return Err(A::Error::custom("incorrect shape length for tensor rank")); + }; + + TensorBase::try_from_data(shape_ref, data) + .map_err(|_| A::Error::custom("data length does not match shape product")) + } +} + +impl<'de, T, L: MutLayout> Deserialize<'de> for TensorBase, L> +where + T: Deserialize<'de>, + for<'a> L::Index<'a>: TryFrom<&'a [usize]>, +{ + fn deserialize(deserializer: D) -> Result, L>, D::Error> + where + D: Deserializer<'de>, + { + deserializer.deserialize_struct( + "Tensor", + &["shape", "data"], + TensorVisitor:: { + data_marker: std::marker::PhantomData, + layout_marker: std::marker::PhantomData, + }, + ) + } +} + +#[cfg(test)] +mod tests { + use crate::{NdTensor, Tensor}; + + #[test] + fn test_deserialize_serialize_dynamic_rank() { + struct Case<'a> { + json: &'a str, + expected: Result, String>, + } + + let cases = [ + Case { + json: "[]", + expected: Err(format!( + "expected a tensor with \"shape\" and \"data\" fields" + )), + }, + Case { + json: r#"{"data":[]}"#, + expected: Err(format!("missing field `shape`")), + }, + Case { + json: r#"{"data":[], "data": []}"#, + expected: Err(format!("duplicate field `data`")), + }, + Case { + json: r#"{"shape":[]}"#, + expected: Err(format!("missing field `data`")), + }, + Case { + json: r#"{"shape":[], "shape": []}"#, + expected: Err(format!("duplicate field `shape`")), + }, + Case { + json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [2, 2]}"#, + expected: Ok(Tensor::from([[1.0, 0.5], [2.0, 1.5]])), + }, + Case { + json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [2, 3]}"#, + expected: Err(format!("data length does not match shape product")), + }, + ]; + + for Case { json, expected } in cases { + let actual: Result, String> = + serde_json::from_str(&json).map_err(|e| e.to_string()); + match (actual, expected) { + (Ok(actual), Ok(expected)) => { + assert_eq!(actual, expected); + + // Verify that serializing the result produces the original + // JSON. + let actual_json = serde_json::to_value(actual).unwrap(); + let expected_json: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(actual_json, expected_json); + } + (Err(actual_err), Err(expected_err)) => assert!( + actual_err.contains(&expected_err), + "expected \"{}\" to contain \"{}\"", + actual_err, + expected_err + ), + (actual, expected) => assert_eq!(actual, expected), + } + } + } + + #[test] + fn test_deserialize_serialize_static_rank() { + struct Case<'a> { + json: &'a str, + expected: Result, String>, + } + + let cases = [ + Case { + json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [2, 2]}"#, + expected: Ok(NdTensor::from([[1.0, 0.5], [2.0, 1.5]])), + }, + Case { + json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [1, 2, 2]}"#, + expected: Err(format!("incorrect shape length for tensor rank")), + }, + ]; + + for Case { json, expected } in cases { + let actual: Result, String> = + serde_json::from_str(&json).map_err(|e| e.to_string()); + + match (actual, expected) { + (Ok(actual), Ok(expected)) => { + assert_eq!(actual, expected); + + // Verify that serializing the result produces the original + // JSON. + let actual_json = serde_json::to_value(actual).unwrap(); + let expected_json: serde_json::Value = serde_json::from_str(&json).unwrap(); + assert_eq!(actual_json, expected_json); + } + (Err(actual_err), Err(expected_err)) => assert!( + actual_err.contains(&expected_err), + "expected \"{}\" to contain \"{}\"", + actual_err, + expected_err + ), + (actual, expected) => assert_eq!(actual, expected), + } + } + } +} diff --git a/rten-tensor/src/lib.rs b/rten-tensor/src/lib.rs index 2039a733..771ef439 100644 --- a/rten-tensor/src/lib.rs +++ b/rten-tensor/src/lib.rs @@ -37,6 +37,20 @@ //! let transposed_elems: Vec<_> = tensor.transposed().iter().copied().collect(); //! assert_eq!(transposed_elems, [1, 3, 2, 4]); //! ``` +//! +//! # Serialization +//! +//! Tensors can be serialized and deserialized using [serde](https://serde.rs) +//! if the `serde` feature is enabled. The serialized representation of a +//! tensor includes its shape and elements in row-major (C) order. The JSON +//! serialization of a matrix (`NdTensor`) looks like this for example: +//! +//! ```json +//! { +//! "shape": [2, 2], +//! "data": [0.5, 1.0, 1.5, 2.0] +//! } +//! ``` mod copy; pub mod errors; @@ -50,6 +64,8 @@ mod storage; pub mod type_num; mod impl_debug; +#[cfg(feature = "serde")] +mod impl_serialize; mod tensor; /// Trait for sources of random data for tensors, for use with [`Tensor::rand`]. diff --git a/rten-tensor/src/tensor.rs b/rten-tensor/src/tensor.rs index cc65c133..e7d10dcb 100644 --- a/rten-tensor/src/tensor.rs +++ b/rten-tensor/src/tensor.rs @@ -2593,7 +2593,7 @@ mod tests { fn test_from_nested_array() { // Scalar let x = NdTensor::from(5); - assert_eq!(x.shape(), []); + assert!(x.shape().is_empty()); assert_eq!(x.data(), Some([5].as_slice())); // 1D From 5ee9972afa38d4023fb3bb2e8e55cbc59c63f4f3 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 9 Nov 2024 13:17:45 +0000 Subject: [PATCH 2/2] Use serde deserialization of tensors in Whisper example --- rten-examples/Cargo.toml | 2 +- rten-examples/data/dump_mel_filters.py | 6 +++++- rten-examples/src/whisper.rs | 30 +++++++------------------- 3 files changed, 14 insertions(+), 24 deletions(-) diff --git a/rten-examples/Cargo.toml b/rten-examples/Cargo.toml index b49bceea..6390e232 100644 --- a/rten-examples/Cargo.toml +++ b/rten-examples/Cargo.toml @@ -21,7 +21,7 @@ rten = { path = "../", features = ["mmap", "random"] } rten-generate = { path = "../rten-generate", features=["text-decoder"] } rten-imageio = { path = "../rten-imageio" } rten-imageproc = { path = "../rten-imageproc" } -rten-tensor = { path = "../rten-tensor" } +rten-tensor = { path = "../rten-tensor", features=["serde"] } rten-text = { path = "../rten-text" } smallvec = "1.13.2" diff --git a/rten-examples/data/dump_mel_filters.py b/rten-examples/data/dump_mel_filters.py index 6b47a8dd..0b5b0a13 100644 --- a/rten-examples/data/dump_mel_filters.py +++ b/rten-examples/data/dump_mel_filters.py @@ -5,7 +5,11 @@ def ndarray_to_dict(array): - """Return a JSON-serializable representation of an ndarray.""" + """ + Return a JSON-serializable representation of an ndarray. + + This representation is compatible with rten-tensor's serde deserialization. + """ return { "shape": array.shape, "data": array.flatten().tolist(), diff --git a/rten-examples/src/whisper.rs b/rten-examples/src/whisper.rs index 3eda0167..bae84281 100644 --- a/rten-examples/src/whisper.rs +++ b/rten-examples/src/whisper.rs @@ -8,9 +8,9 @@ use rten::{Dimension, FloatOperators, Model}; use rten_generate::filter::{token_id_filter, LogitsFilter}; use rten_generate::{Generator, GeneratorUtils}; use rten_tensor::prelude::*; -use rten_tensor::{NdTensor, NdTensorView, Tensor}; +use rten_tensor::{NdTensor, NdTensorView}; use rten_text::tokenizers::Tokenizer; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; struct Args { /// Path to Whisper encoder model. @@ -145,25 +145,13 @@ fn stft( output } -#[derive(Deserialize, Serialize)] -struct TensorData { - shape: Vec, - data: Vec, -} - -impl TensorData { - fn to_tensor(&self) -> Tensor { - Tensor::from_data(&self.shape, self.data.clone()) - } -} - /// JSON-serialized mel filter bank. /// /// See `data/dump_mel_filters.py`. #[derive(Deserialize)] struct MelFilters { - mel_80: TensorData, - mel_128: TensorData, + mel_80: NdTensor, + mel_128: NdTensor, } fn resource_path(path: &str) -> PathBuf { @@ -203,13 +191,11 @@ fn log_mel_spectrogram( // Get power spectrum of input. let magnitudes: NdTensor = audio_fft.map(|&x| x.norm_sqr()); - let mel_filters: NdTensor = match (n_mels, sample_rate, n_fft) { - (80, 16_000, 400) => mel_filter_map.mel_80.to_tensor(), - (128, 16_000, 400) => mel_filter_map.mel_128.to_tensor(), + let mel_filters: NdTensorView = match (n_mels, sample_rate, n_fft) { + (80, 16_000, 400) => mel_filter_map.mel_80.view(), + (128, 16_000, 400) => mel_filter_map.mel_128.view(), _ => return Err("unsupported mel filter parameters".into()), - } - .to_tensor() - .try_into()?; + }; // Convert from hz to mels. let mels = mel_filters.matmul(magnitudes.as_dyn()).unwrap();