Skip to content

Commit

Permalink
refactor: Separate extension validation from the rest (#1011)
Browse files Browse the repository at this point in the history
A further refactor to validation. Note that the CI will fail without
#1010 because this change raises the hugr validity errors before the
extension errors. This seems like the win #943 is meant to give us 😁

Resolves #943
  • Loading branch information
croyzor authored May 8, 2024
1 parent 2ac1ebe commit 4ea7dfa
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
10 changes: 8 additions & 2 deletions hugr/src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ pub mod serialize;
pub mod validate;
pub mod views;

#[cfg(feature = "extension_inference")]
use std::collections::HashMap;
use std::collections::VecDeque;
use std::iter;

Expand Down Expand Up @@ -196,8 +198,12 @@ impl Hugr {
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
resolve_extension_ops(self, extension_registry)?;
self.infer_extensions()?;
self.validate(extension_registry)?;
self.validate_no_extensions(extension_registry)?;
#[cfg(feature = "extension_inference")]
{
self.infer_extensions()?;
self.validate_extensions(HashMap::new())?;
}
Ok(())
}

Expand Down
78 changes: 50 additions & 28 deletions hugr/src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ struct ValidationContext<'a, 'b> {
hugr: &'a Hugr,
/// Dominator tree for each CFG region, using the container node as index.
dominators: HashMap<Node, Dominators<Node>>,
/// Context for the extension validation.
#[allow(dead_code)]
extension_validator: ExtensionValidator,
/// Registry of available Extensions
extension_registry: &'b ExtensionRegistry,
}
Expand All @@ -48,7 +45,51 @@ impl Hugr {
/// TODO: Add a version of validation which allows for open extension
/// variables (see github issue #457)
pub fn validate(&self, extension_registry: &ExtensionRegistry) -> Result<(), ValidationError> {
self.validate_with_extension_closure(HashMap::new(), extension_registry)
#[cfg(feature = "extension_inference")]
self.validate_with_extension_closure(HashMap::new(), extension_registry)?;
#[cfg(not(feature = "extension_inference"))]
self.validate_no_extensions(extension_registry)?;
Ok(())
}

/// Check the validity of the HUGR, but don't check consistency of extension
/// requirements between connected nodes or between parents and children.
pub fn validate_no_extensions(
&self,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
let mut validator = ValidationContext::new(self, extension_registry);
validator.validate()
}

/// Validate extensions on the input and output edges of nodes. Check that
/// the target ends of edges require the extensions from the sources, and
/// check extension deltas from parent nodes are reflected in their children
pub fn validate_extensions(&self, closure: ExtensionSolution) -> Result<(), ValidationError> {
let validator = ExtensionValidator::new(self, closure);
for src_node in self.nodes() {
let node_type = self.get_nodetype(src_node);

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.get_io(src_node) {
validator.validate_io_extensions(src_node, input, output)?;
}
}

for src_port in self.node_outputs(src_node) {
for (tgt_node, tgt_port) in self.linked_inputs(src_node, src_port) {
validator.check_extensions_compatible(
&(src_node, src_port.into()),
&(tgt_node, tgt_port.into()),
)?;
}
}
}
Ok(())
}

/// Check the validity of a hugr, taking an argument of a closure for the
Expand All @@ -58,8 +99,10 @@ impl Hugr {
closure: ExtensionSolution,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
let mut validator = ValidationContext::new(self, closure, extension_registry);
validator.validate()
let mut validator = ValidationContext::new(self, extension_registry);
validator.validate()?;
self.validate_extensions(closure)?;
Ok(())
}
}

Expand All @@ -68,15 +111,10 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Allow unused "extension_closure" variable for when
// the "extension_inference" feature is disabled.
#[allow(unused_variables)]
pub fn new(
hugr: &'a Hugr,
extension_closure: ExtensionSolution,
extension_registry: &'b ExtensionRegistry,
) -> Self {
pub fn new(hugr: &'a Hugr, extension_registry: &'b ExtensionRegistry) -> Self {
Self {
hugr,
dominators: HashMap::new(),
extension_validator: ExtensionValidator::new(hugr, extension_closure),
extension_registry,
}
}
Expand Down Expand Up @@ -176,18 +214,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
// Secondly that the node has correct children
self.validate_children(node, node_type)?;

// FuncDefns have no resources since they're static nodes, but the
// functions they define can have any extension delta.
#[cfg(feature = "extension_inference")]
if node_type.tag() != OpTag::FuncDefn {
// If this is a container with I/O nodes, check that the extension they
// define match the extensions of the container.
if let Some([input, output]) = self.hugr.get_io(node) {
self.extension_validator
.validate_io_extensions(node, input, output)?;
}
}

Ok(())
}

Expand Down Expand Up @@ -247,10 +273,6 @@ impl<'a, 'b> ValidationContext<'a, 'b> {
let other_node: Node = self.hugr.graph.port_node(link).unwrap().into();
let other_offset = self.hugr.graph.port_offset(link).unwrap().into();

#[cfg(feature = "extension_inference")]
self.extension_validator
.check_extensions_compatible(&(node, port), &(other_node, other_offset))?;

let other_op = self.hugr.get_optype(other_node);
let Some(other_kind) = other_op.port_kind(other_offset) else {
panic!("The number of ports in {other_node} does not match the operation definition. This should have been caught by `validate_node`.");
Expand Down

0 comments on commit 4ea7dfa

Please sign in to comment.