Skip to content

Commit

Permalink
Merge pull request #402 from robertknight/tensor-serialize
Browse files Browse the repository at this point in the history
Support (de-)serialization of tensors using serde
  • Loading branch information
robertknight authored Nov 9, 2024
2 parents cf1ee74 + 5ee9972 commit dc62fbd
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 27 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ miri:
# nightly Rust.
.PHONY: test
test:
cargo test --workspace --features mmap,random,text-decoder
cargo test --workspace --features mmap,random,text-decoder,serde

.PHONY: wasm
wasm:
Expand Down
2 changes: 1 addition & 1 deletion rten-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ rten = { path = "../", features = ["mmap", "random"] }
rten-generate = { path = "../rten-generate", features=["text-decoder"] }
rten-imageio = { path = "../rten-imageio" }
rten-imageproc = { path = "../rten-imageproc" }
rten-tensor = { path = "../rten-tensor" }
rten-tensor = { path = "../rten-tensor", features=["serde"] }
rten-text = { path = "../rten-text" }
smallvec = "1.13.2"

Expand Down
6 changes: 5 additions & 1 deletion rten-examples/data/dump_mel_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@


def ndarray_to_dict(array):
"""Return a JSON-serializable representation of an ndarray."""
"""
Return a JSON-serializable representation of an ndarray.
This representation is compatible with rten-tensor's serde deserialization.
"""
return {
"shape": array.shape,
"data": array.flatten().tolist(),
Expand Down
30 changes: 8 additions & 22 deletions rten-examples/src/whisper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use rten::{Dimension, FloatOperators, Model};
use rten_generate::filter::{token_id_filter, LogitsFilter};
use rten_generate::{Generator, GeneratorUtils};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Tensor};
use rten_tensor::{NdTensor, NdTensorView};
use rten_text::tokenizers::Tokenizer;
use serde::{Deserialize, Serialize};
use serde::Deserialize;

struct Args {
/// Path to Whisper encoder model.
Expand Down Expand Up @@ -145,25 +145,13 @@ fn stft(
output
}

#[derive(Deserialize, Serialize)]
struct TensorData {
shape: Vec<usize>,
data: Vec<f32>,
}

impl TensorData {
fn to_tensor(&self) -> Tensor {
Tensor::from_data(&self.shape, self.data.clone())
}
}

/// JSON-serialized mel filter bank.
///
/// See `data/dump_mel_filters.py`.
#[derive(Deserialize)]
struct MelFilters {
mel_80: TensorData,
mel_128: TensorData,
mel_80: NdTensor<f32, 2>,
mel_128: NdTensor<f32, 2>,
}

fn resource_path(path: &str) -> PathBuf {
Expand Down Expand Up @@ -203,13 +191,11 @@ fn log_mel_spectrogram(
// Get power spectrum of input.
let magnitudes: NdTensor<f32, 2> = audio_fft.map(|&x| x.norm_sqr());

let mel_filters: NdTensor<f32, 2> = match (n_mels, sample_rate, n_fft) {
(80, 16_000, 400) => mel_filter_map.mel_80.to_tensor(),
(128, 16_000, 400) => mel_filter_map.mel_128.to_tensor(),
let mel_filters: NdTensorView<f32, 2> = match (n_mels, sample_rate, n_fft) {
(80, 16_000, 400) => mel_filter_map.mel_80.view(),
(128, 16_000, 400) => mel_filter_map.mel_128.view(),
_ => return Err("unsupported mel filter parameters".into()),
}
.to_tensor()
.try_into()?;
};

// Convert from hz to mels.
let mels = mel_filters.matmul(magnitudes.as_dyn()).unwrap();
Expand Down
7 changes: 7 additions & 0 deletions rten-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@ repository = "https://github.com/robertknight/rten"
include = ["/src", "/README.md"]

[dependencies]
serde = { workspace = true, optional = true }
smallvec = { version = "1.10.0", features=["union", "const_generics", "const_new"] }

[dev-dependencies]
serde_json = { workspace = true }

[lib]
crate-type = ["lib"]

[lints.clippy]
# See comments about `needless_range_loop` in root Cargo.toml.
needless_range_loop = "allow"
manual_memcpy = "allow"

[features]
serde = ["dep:serde"]
2 changes: 1 addition & 1 deletion rten-tensor/src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ fn copy_range_into_slice_inner<T: Clone>(
let ranges: [IndexRange; 4] = ranges.try_into().unwrap();

// Check output length is correct.
let sliced_len = ranges.iter().map(|s| s.steps()).product();
let sliced_len: usize = ranges.iter().map(|s| s.steps()).product();
assert_eq!(dest.len(), sliced_len, "output too short");

let mut dest_offset = 0;
Expand Down
230 changes: 230 additions & 0 deletions rten-tensor/src/impl_serialize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
use std::fmt;

use serde::de::{Deserialize, Deserializer, Error, MapAccess, Visitor};
use serde::ser::{Serialize, SerializeStruct, Serializer};

use crate::iterators::Iter;
use crate::{AsView, Layout, MutLayout, Storage, TensorBase};

struct TensorData<'a, T> {
iter: Iter<'a, T>,
}

impl<'a, T> Serialize for TensorData<'a, T>
where
T: Serialize,
{
fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
where
Sr: Serializer,
{
serializer.collect_seq(self.iter.clone())
}
}

impl<S: Storage, L: MutLayout> Serialize for TensorBase<S, L>
where
S::Elem: Serialize,
{
fn serialize<Sr>(&self, serializer: Sr) -> Result<Sr::Ok, Sr::Error>
where
Sr: Serializer,
{
let mut tensor = serializer.serialize_struct("Tensor", 2)?;
tensor.serialize_field("shape", self.shape().as_ref())?;
tensor.serialize_field("data", &TensorData { iter: self.iter() })?;
tensor.end()
}
}

struct TensorVisitor<T, L> {
data_marker: std::marker::PhantomData<T>,
layout_marker: std::marker::PhantomData<L>,
}

impl<'de, T, L> Visitor<'de> for TensorVisitor<T, L>
where
T: Deserialize<'de>,
L: MutLayout,
for<'a> L::Index<'a>: TryFrom<&'a [usize]>,
{
type Value = TensorBase<Vec<T>, L>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a tensor with \"shape\" and \"data\" fields")
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut data: Option<Vec<T>> = None;
let mut shape: Option<Vec<usize>> = None;

while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"data" => {
if data.is_some() {
return Err(A::Error::duplicate_field("data"));
}
data = Some(map.next_value()?);
}
"shape" => {
if shape.is_some() {
return Err(A::Error::duplicate_field("shape"));
}
shape = Some(map.next_value()?);
}
_ => {
return Err(A::Error::unknown_field(&key, &["data", "shape"]));
}
}
}

let Some(shape) = shape else {
return Err(A::Error::missing_field("shape"));
};
let Some(data) = data else {
return Err(A::Error::missing_field("data"));
};

let Ok(shape_ref): Result<L::Index<'_>, _> = shape.as_slice().try_into() else {
return Err(A::Error::custom("incorrect shape length for tensor rank"));
};

TensorBase::try_from_data(shape_ref, data)
.map_err(|_| A::Error::custom("data length does not match shape product"))
}
}

impl<'de, T, L: MutLayout> Deserialize<'de> for TensorBase<Vec<T>, L>
where
T: Deserialize<'de>,
for<'a> L::Index<'a>: TryFrom<&'a [usize]>,
{
fn deserialize<D>(deserializer: D) -> Result<TensorBase<Vec<T>, L>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_struct(
"Tensor",
&["shape", "data"],
TensorVisitor::<T, L> {
data_marker: std::marker::PhantomData,
layout_marker: std::marker::PhantomData,
},
)
}
}

#[cfg(test)]
mod tests {
use crate::{NdTensor, Tensor};

#[test]
fn test_deserialize_serialize_dynamic_rank() {
struct Case<'a> {
json: &'a str,
expected: Result<Tensor<f32>, String>,
}

let cases = [
Case {
json: "[]",
expected: Err(format!(
"expected a tensor with \"shape\" and \"data\" fields"
)),
},
Case {
json: r#"{"data":[]}"#,
expected: Err(format!("missing field `shape`")),
},
Case {
json: r#"{"data":[], "data": []}"#,
expected: Err(format!("duplicate field `data`")),
},
Case {
json: r#"{"shape":[]}"#,
expected: Err(format!("missing field `data`")),
},
Case {
json: r#"{"shape":[], "shape": []}"#,
expected: Err(format!("duplicate field `shape`")),
},
Case {
json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [2, 2]}"#,
expected: Ok(Tensor::from([[1.0, 0.5], [2.0, 1.5]])),
},
Case {
json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [2, 3]}"#,
expected: Err(format!("data length does not match shape product")),
},
];

for Case { json, expected } in cases {
let actual: Result<Tensor<f32>, String> =
serde_json::from_str(&json).map_err(|e| e.to_string());
match (actual, expected) {
(Ok(actual), Ok(expected)) => {
assert_eq!(actual, expected);

// Verify that serializing the result produces the original
// JSON.
let actual_json = serde_json::to_value(actual).unwrap();
let expected_json: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(actual_json, expected_json);
}
(Err(actual_err), Err(expected_err)) => assert!(
actual_err.contains(&expected_err),
"expected \"{}\" to contain \"{}\"",
actual_err,
expected_err
),
(actual, expected) => assert_eq!(actual, expected),
}
}
}

#[test]
fn test_deserialize_serialize_static_rank() {
struct Case<'a> {
json: &'a str,
expected: Result<NdTensor<f32, 2>, String>,
}

let cases = [
Case {
json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [2, 2]}"#,
expected: Ok(NdTensor::from([[1.0, 0.5], [2.0, 1.5]])),
},
Case {
json: r#"{"data": [1.0, 0.5, 2.0, 1.5], "shape": [1, 2, 2]}"#,
expected: Err(format!("incorrect shape length for tensor rank")),
},
];

for Case { json, expected } in cases {
let actual: Result<NdTensor<f32, 2>, String> =
serde_json::from_str(&json).map_err(|e| e.to_string());

match (actual, expected) {
(Ok(actual), Ok(expected)) => {
assert_eq!(actual, expected);

// Verify that serializing the result produces the original
// JSON.
let actual_json = serde_json::to_value(actual).unwrap();
let expected_json: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(actual_json, expected_json);
}
(Err(actual_err), Err(expected_err)) => assert!(
actual_err.contains(&expected_err),
"expected \"{}\" to contain \"{}\"",
actual_err,
expected_err
),
(actual, expected) => assert_eq!(actual, expected),
}
}
}
}
16 changes: 16 additions & 0 deletions rten-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@
//! let transposed_elems: Vec<_> = tensor.transposed().iter().copied().collect();
//! assert_eq!(transposed_elems, [1, 3, 2, 4]);
//! ```
//!
//! # Serialization
//!
//! Tensors can be serialized and deserialized using [serde](https://serde.rs)
//! if the `serde` feature is enabled. The serialized representation of a
//! tensor includes its shape and elements in row-major (C) order. The JSON
//! serialization of a matrix (`NdTensor<f32, 2>`) looks like this for example:
//!
//! ```json
//! {
//! "shape": [2, 2],
//! "data": [0.5, 1.0, 1.5, 2.0]
//! }
//! ```

mod copy;
pub mod errors;
Expand All @@ -50,6 +64,8 @@ mod storage;
pub mod type_num;

mod impl_debug;
#[cfg(feature = "serde")]
mod impl_serialize;
mod tensor;

/// Trait for sources of random data for tensors, for use with [`Tensor::rand`].
Expand Down
2 changes: 1 addition & 1 deletion rten-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2593,7 +2593,7 @@ mod tests {
fn test_from_nested_array() {
// Scalar
let x = NdTensor::from(5);
assert_eq!(x.shape(), []);
assert!(x.shape().is_empty());
assert_eq!(x.data(), Some([5].as_slice()));

// 1D
Expand Down

0 comments on commit dc62fbd

Please sign in to comment.