Skip to content

Commit

Permalink
feat: support visit_seq for Deserialize
Browse files Browse the repository at this point in the history
Signed-off-by: Keming <[email protected]>
  • Loading branch information
kemingy authored and sarah-quinones committed Sep 29, 2024
1 parent 0d41f79 commit 8afdba4
Showing 1 changed file with 108 additions and 93 deletions.
201 changes: 108 additions & 93 deletions src/serde/mat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,113 +88,128 @@ where
}
const FIELDS: &'static [&'static str] = &["nrows", "ncols", "data"];
struct MatVisitor<E: Entity>(PhantomData<E>);
impl<'a, E: Entity + Deserialize<'a>> Visitor<'a> for MatVisitor<E> {
type Value = Mat<E>;
enum MatrixOrVec<E: Entity> {
Matrix(Mat<E>),
Vec(Vec<E>),
}
impl<E: Entity> MatrixOrVec<E> {
fn into_mat(self, nrows: usize, ncols: usize) -> Mat<E> {
match self {
MatrixOrVec::Matrix(m) => m,
MatrixOrVec::Vec(v) => Mat::from_fn(nrows, ncols, |i, j| v[i * ncols + j]),
}
}
}
struct MatrixOrVecDeserializer<'a, E: Entity + Deserialize<'a>> {
marker: PhantomData<&'a E>,
nrows: Option<usize>,
ncols: Option<usize>,
}
impl<'a, E: Entity + Deserialize<'a>> MatrixOrVecDeserializer<'a, E> {
fn new(nrows: Option<usize>, ncols: Option<usize>) -> Self {
Self {
marker: PhantomData,
nrows,
ncols,
}
}
}
impl<'a, E: Entity> DeserializeSeed<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_seq(self)
}
}
impl<'a, E: Entity> Visitor<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;

fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
formatter.write_str("a faer matrix")
formatter.write_str("a sequence")
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'a>,
A: SeqAccess<'a>,
{
enum MatrixOrVec<E: Entity> {
Matrix(Mat<E>),
Vec(Vec<E>),
}
impl<E: Entity> MatrixOrVec<E> {
fn into_mat(self, nrows: usize, ncols: usize) -> Mat<E> {
match self {
MatrixOrVec::Matrix(m) => m,
MatrixOrVec::Vec(v) => {
Mat::from_fn(nrows, ncols, |i, j| v[i * ncols + j])
}
match (self.ncols, self.nrows) {
(Some(ncols), Some(nrows)) => {
let mut data = Mat::<E>::with_capacity(nrows, ncols);
unsafe {
data.set_dims(nrows, ncols);
}
let expected_length = nrows * ncols;
for i in 0..expected_length {
let el = seq.next_element::<E>()?.ok_or_else(|| {
serde::de::Error::invalid_length(
i,
&format!("{} elements", expected_length).as_str(),
)
})?;
data.write(i / ncols, i % ncols, el);
}
let mut additional = 0usize;
while let Some(_) = seq.next_element::<E>()? {
additional += 1;
}
if additional > 0 {
return Err(serde::de::Error::invalid_length(
additional + expected_length,
&format!("{} elements", expected_length).as_str(),
));
}
Ok(MatrixOrVec::Matrix(data))
}
}
struct MatrixOrVecDeserializer<'a, E: Entity + Deserialize<'a>> {
marker: PhantomData<&'a E>,
nrows: Option<usize>,
ncols: Option<usize>,
}
impl<'a, E: Entity + Deserialize<'a>> MatrixOrVecDeserializer<'a, E> {
fn new(nrows: Option<usize>, ncols: Option<usize>) -> Self {
Self {
marker: PhantomData,
nrows,
ncols,
_ => {
let mut data = Vec::new();
while let Some(el) = seq.next_element::<E>()? {
data.push(el);
}
Ok(MatrixOrVec::Vec(data))
}
}
impl<'a, E: Entity> DeserializeSeed<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;
}
}
impl<'a, E: Entity + Deserialize<'a>> Visitor<'a> for MatVisitor<E> {
type Value = Mat<E>;

fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'a>,
{
deserializer.deserialize_seq(self)
}
}
impl<'a, E: Entity> Visitor<'a> for MatrixOrVecDeserializer<'a, E>
where
E: Deserialize<'a>,
{
type Value = MatrixOrVec<E>;
fn expecting(&self, formatter: &mut alloc::fmt::Formatter) -> alloc::fmt::Result {
formatter.write_str("a faer matrix")
}

fn expecting(
&self,
formatter: &mut alloc::fmt::Formatter,
) -> alloc::fmt::Result {
formatter.write_str("a sequence")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'a>,
{
let nrows = seq
.next_element::<usize>()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &"nrows"))?;
let ncols = seq
.next_element::<usize>()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &"ncols"))?;
let data = seq.next_element_seed(MatrixOrVecDeserializer::<E>::new(
Some(nrows),
Some(ncols),
))?;
let mat = data
.ok_or_else(|| serde::de::Error::missing_field("data"))?
.into_mat(nrows, ncols);
Ok(mat)
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'a>,
{
match (self.ncols, self.nrows) {
(Some(ncols), Some(nrows)) => {
let mut data = Mat::<E>::with_capacity(nrows, ncols);
unsafe {
data.set_dims(nrows, ncols);
}
let expected_length = nrows * ncols;
for i in 0..expected_length {
let el = seq.next_element::<E>()?.ok_or_else(|| {
serde::de::Error::invalid_length(
i,
&format!("{} elements", expected_length).as_str(),
)
})?;
data.write(i / ncols, i % ncols, el);
}
let mut additional = 0usize;
while let Some(_) = seq.next_element::<E>()? {
additional += 1;
}
if additional > 0 {
return Err(serde::de::Error::invalid_length(
additional + expected_length,
&format!("{} elements", expected_length).as_str(),
));
}
Ok(MatrixOrVec::Matrix(data))
}
_ => {
let mut data = Vec::new();
while let Some(el) = seq.next_element::<E>()? {
data.push(el);
}
Ok(MatrixOrVec::Vec(data))
}
}
}
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'a>,
{
let mut nrows = None;
let mut ncols = None;
let mut data: Option<MatrixOrVec<E>> = None;
Expand Down

0 comments on commit 8afdba4

Please sign in to comment.