Skip to content

Commit

Permalink
fix: serialization round-trips (#948)
Browse files Browse the repository at this point in the history
We add a check during validation, conditioned on `cfg(test)`, that the
hugr in question successfully round-trips.

---------

Co-authored-by: Agustín Borgna <[email protected]>
  • Loading branch information
doug-q and aborgna-q authored Apr 19, 2024
1 parent 1398bd4 commit 9d2de07
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 16 deletions.
31 changes: 22 additions & 9 deletions hugr/src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ pub enum HUGRSerializationError {
#[error("Failed to build edge when deserializing: {0:?}.")]
LinkError(#[from] LinkError),
/// Edges without port offsets cannot be present in operations without non-dataflow ports.
#[error("Cannot connect an edge without port offset to node {node:?} with operation type {op_type:?}.")]
#[error("Cannot connect an {dir:?} edge without port offset to node {node:?} with operation type {op_type:?}.")]
MissingPortOffset {
/// The node that has the port without offset.
node: Node,
/// The direction of the port without an offset
dir: Direction,
/// The operation type of the node.
op_type: OpType,
},
Expand Down Expand Up @@ -232,6 +234,7 @@ impl TryFrom<SerHugrV1> for Hugr {
.other_port(dir)
.ok_or(HUGRSerializationError::MissingPortOffset {
node,
dir,
op_type: op_type.clone(),
})?
.index()
Expand Down Expand Up @@ -329,10 +332,20 @@ pub mod test {
}

/// Serialize and deserialize a HUGR, and check that the result is the same as the original.
/// Checks the serialized json against the in-tree schema.
///
/// Returns the deserialized HUGR.
pub fn check_hugr_roundtrip(hugr: &Hugr) -> Hugr {
let new_hugr: Hugr = ser_roundtrip_validate(hugr, Some(&SCHEMA));
pub fn check_hugr_schema_roundtrip(hugr: &Hugr) -> Hugr {
check_hugr_roundtrip(hugr, true)
}

/// Serialize and deserialize a HUGR, and check that the result is the same as the original.
///
/// If `check_schema` is true, checks the serialized json against the in-tree schema.
///
/// Returns the deserialized HUGR.
pub fn check_hugr_roundtrip(hugr: &Hugr, check_schema: bool) -> Hugr {
let new_hugr: Hugr = ser_roundtrip_validate(hugr, check_schema.then_some(&SCHEMA));

// Original HUGR, with canonicalized node indices
//
Expand Down Expand Up @@ -418,7 +431,7 @@ pub mod test {
metadata: Default::default(),
};

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
}

#[test]
Expand Down Expand Up @@ -452,7 +465,7 @@ pub mod test {
module_builder.finish_prelude_hugr().unwrap()
};

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
}

#[test]
Expand All @@ -468,7 +481,7 @@ pub mod test {
}
let hugr = dfg.finish_hugr_with_outputs(params, &EMPTY_REG)?;

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
Ok(())
}

Expand All @@ -491,7 +504,7 @@ pub mod test {

let hugr = dfg.finish_hugr_with_outputs([wire], &PRELUDE_REGISTRY)?;

check_hugr_roundtrip(&hugr);
check_hugr_schema_roundtrip(&hugr);
Ok(())
}

Expand All @@ -502,7 +515,7 @@ pub mod test {
let op = bldr.add_dataflow_op(Noop { ty: fn_ty }, bldr.input_wires())?;
let h = bldr.finish_prelude_hugr_with_outputs(op.outputs())?;

check_hugr_roundtrip(&h);
check_hugr_schema_roundtrip(&h);
Ok(())
}

Expand All @@ -520,7 +533,7 @@ pub mod test {
hugr.remove_node(old_in);
hugr.update_validate(&PRELUDE_REGISTRY)?;

let new_hugr: Hugr = check_hugr_roundtrip(&hugr);
let new_hugr: Hugr = check_hugr_schema_roundtrip(&hugr);
new_hugr.validate(&EMPTY_REG).unwrap_err();
new_hugr.validate(&PRELUDE_REGISTRY)?;
Ok(())
Expand Down
9 changes: 9 additions & 0 deletions hugr/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Hierarchy and children. No type variables declared outside the root.
self.validate_subtree(self.hugr.root(), &[])?;

// In tests we take the opportunity to verify that the hugr
// serialization round-trips.
//
// TODO: We should also verify that the serialized hugr matches the
// in-tree schema. For now, our serialized hugr does not match the
// schema. When this is fixed we should pass true below.
#[cfg(test)]
crate::hugr::serialize::test::check_hugr_roundtrip(self.hugr, false);

Ok(())
}

Expand Down
12 changes: 6 additions & 6 deletions hugr/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ impl OpType {
///
/// Returns None if there is no such port, or if the operation defines multiple non-dataflow ports.
pub fn other_port(&self, dir: Direction) -> Option<Port> {
let df_count = self.value_port_count(dir);
let non_df_count = self.non_df_port_count(dir);
if self.other_port_kind(dir).is_some() && non_df_count == 1 {
// if there is a static input it comes before the non_df_ports
let static_input =
(dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize;

Some(Port::new(dir, self.value_port_count(dir) + static_input))
// if there is a static input it comes before the non_df_ports
let static_input =
(dir == Direction::Incoming && OpTag::StaticInput.is_superset(self.tag())) as usize;
if self.other_port_kind(dir).is_some() && non_df_count >= 1 {
Some(Port::new(dir, df_count + static_input))
} else {
None
}
Expand Down
6 changes: 5 additions & 1 deletion hugr/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ mod test {

use super::*;

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
/// A custom constant value used in testing
pub(crate) struct CustomTestValue(pub CustomType);

Expand All @@ -322,6 +322,10 @@ mod test {
fn get_type(&self) -> Type {
self.0.clone().into()
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}
}

/// A [`CustomSerialized`] encoding a [`FLOAT64_TYPE`] float constant used in testing.
Expand Down

0 comments on commit 9d2de07

Please sign in to comment.