Skip to content

Commit

Permalink
fix: add op's extension to signature check in resolve_opaque_op (#1317
Browse files Browse the repository at this point in the history
)

We also remove an unsafe `unwrap` in `resolve_opaque_op` while we are
here.

---------

Co-authored-by: Alec Edgington <[email protected]>
Co-authored-by: Alan Lawrence <[email protected]>
  • Loading branch information
3 people authored Jul 17, 2024
1 parent b832274 commit 01da7ba
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 55 deletions.
9 changes: 7 additions & 2 deletions hugr-core/src/builder/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ mod test {
use super::*;
use cool_asserts::assert_matches;

use crate::extension::{ExtensionId, ExtensionSet};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
use crate::utils::test_quantum_extension::{
self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64,
Expand Down Expand Up @@ -295,16 +296,20 @@ mod test {

#[test]
fn with_nonlinear_and_outputs() {
let missing_ext: ExtensionId = "MissingExt".try_into().unwrap();
let my_custom_op = CustomOp::new_opaque(OpaqueOp::new(
"MissingRsrc".try_into().unwrap(),
missing_ext.clone(),
"MyOp",
"unknown op".to_string(),
vec![],
Signature::new(vec![QB, NAT], vec![QB]),
));
let build_res = build_main(
Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID)
.with_extension_delta(ExtensionSet::from_iter([
test_quantum_extension::EXTENSION_ID,
missing_ext,
]))
.into(),
|mut f_build| {
let [q0, q1, angle]: [Wire; 3] = f_build.input_wires_arr();
Expand Down
68 changes: 43 additions & 25 deletions hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -640,38 +640,56 @@ mod test {
}

#[test]
fn test_invalid() -> Result<(), Box<dyn std::error::Error>> {
fn test_invalid() {
let unknown_ext: ExtensionId = "unknown_ext".try_into().unwrap();
let utou = Signature::new_endo(vec![USIZE_T]);
let mk_op = |s| {
CustomOp::new_opaque(OpaqueOp::new(
ExtensionId::new("unknown_ext").unwrap(),
unknown_ext.clone(),
s,
String::new(),
vec![],
utou.clone(),
))
};
let mut h = DFGBuilder::new(Signature::new(
type_row![USIZE_T, BOOL_T],
type_row![USIZE_T],
))?;
let mut h = DFGBuilder::new(
Signature::new(type_row![USIZE_T, BOOL_T], type_row![USIZE_T])
.with_extension_delta(unknown_ext.clone()),
)
.unwrap();
let [i, b] = h.input_wires_arr();
let mut cond = h.conditional_builder(
(vec![type_row![]; 2], b),
[(USIZE_T, i)],
type_row![USIZE_T],
)?;
let mut case1 = cond.case_builder(0)?;
let foo = case1.add_dataflow_op(mk_op("foo"), case1.input_wires())?;
let case1 = case1.finish_with_outputs(foo.outputs())?.node();
let mut case2 = cond.case_builder(1)?;
let bar = case2.add_dataflow_op(mk_op("bar"), case2.input_wires())?;
let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs())?;
let baz = baz_dfg.add_dataflow_op(mk_op("baz"), baz_dfg.input_wires())?;
let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs())?;
let case2 = case2.finish_with_outputs(baz_dfg.outputs())?.node();
let cond = cond.finish_sub_container()?;
let h = h.finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY)?;
let mut cond = h
.conditional_builder_exts(
(vec![type_row![]; 2], b),
[(USIZE_T, i)],
type_row![USIZE_T],
unknown_ext.clone(),
)
.unwrap();
let mut case1 = cond.case_builder(0).unwrap();
let foo = case1
.add_dataflow_op(mk_op("foo"), case1.input_wires())
.unwrap();
let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node();
let mut case2 = cond.case_builder(1).unwrap();
let bar = case2
.add_dataflow_op(mk_op("bar"), case2.input_wires())
.unwrap();
let mut baz_dfg = case2
.dfg_builder(
utou.clone().with_extension_delta(unknown_ext.clone()),
bar.outputs(),
)
.unwrap();
let baz = baz_dfg
.add_dataflow_op(mk_op("baz"), baz_dfg.input_wires())
.unwrap();
let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap();
let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node();
let cond = cond.finish_sub_container().unwrap();
let h = h
.finish_hugr_with_outputs(cond.outputs(), &PRELUDE_REGISTRY)
.unwrap();

let mut r_hugr = Hugr::new(h.get_optype(cond.node()).clone());
let r1 = r_hugr.add_node_with_parent(
Expand All @@ -698,7 +716,7 @@ mod test {
rep.verify(&h).unwrap();
{
let mut target = h.clone();
let node_map = rep.clone().apply(&mut target)?;
let node_map = rep.clone().apply(&mut target).unwrap();
let new_case2 = *node_map.get(&r2).unwrap();
assert_eq!(target.get_parent(baz.node()), Some(new_case2));
}
Expand All @@ -713,7 +731,8 @@ mod test {
// Root node type needs to be that of common parent of the removed nodes:
let mut rep2 = rep.clone();
rep2.replacement
.replace_op(rep2.replacement.root(), h.root_type().clone())?;
.replace_op(rep2.replacement.root(), h.root_type().clone())
.unwrap();
assert_eq!(
check_same_errors(rep2),
ReplaceError::WrongRootNodeTag {
Expand Down Expand Up @@ -812,6 +831,5 @@ mod test {
}),
ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge)
);
Ok(())
}
}
34 changes: 26 additions & 8 deletions hugr-core/src/ops/custom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl DataflowOpTrait for CustomOp {
/// The signature of the operation.
fn signature(&self) -> Signature {
match self {
Self::Opaque(op) => op.signature.clone(),
Self::Opaque(op) => op.signature(),
Self::Extension(ext_op) => ext_op.signature(),
}
}
Expand Down Expand Up @@ -276,7 +276,15 @@ impl DataflowOpTrait for ExtensionOp {
}
}

/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`]
/// An opaquely-serialized op that refers to an as-yet-unresolved [`OpDef`].
///
/// All [CustomOp]s are serialised as `OpaqueOp`s.
///
/// The signature of a [CustomOp] always includes that op's extension. We do not
/// require that the `signature` field of [OpaqueOp] contains `extension`,
/// instead we are careful to add it whenever we look at the `signature` of an
/// `OpaqueOp`. This is a small efficiency in serialisation and allows us to
/// be more liberal in deserialisation.
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(Arbitrary))]
pub struct OpaqueOp {
Expand All @@ -286,6 +294,9 @@ pub struct OpaqueOp {
#[cfg_attr(test, proptest(strategy = "any_nonempty_string()"))]
description: String, // cache in advance so description() can return &str
args: Vec<TypeArg>,
// note that the `signature` field might not include `extension`. Thus this must
// remain private, and should be accessed through
// `DataflowOpTrait::signature`.
signature: Signature,
}

Expand Down Expand Up @@ -343,7 +354,9 @@ impl DataflowOpTrait for OpaqueOp {
}

fn signature(&self) -> Signature {
self.signature.clone()
self.signature
.clone()
.with_extension_delta(self.extension().clone())
}
}

Expand Down Expand Up @@ -392,9 +405,8 @@ pub fn resolve_opaque_op(
r.name().clone(),
));
};
let ext_op =
ExtensionOp::new(def.clone(), opaque.args.clone(), extension_registry).unwrap();
if opaque.signature != ext_op.signature {
let ext_op = ExtensionOp::new(def.clone(), opaque.args.clone(), extension_registry)?;
if opaque.signature() != ext_op.signature() {
return Err(CustomOpError::SignatureMismatch {
extension: opaque.extension.clone(),
op: def.name().clone(),
Expand Down Expand Up @@ -425,10 +437,14 @@ pub enum CustomOpError {
stored: Signature,
computed: Signature,
},
/// An error in computing the signature of the ExtensionOp
#[error(transparent)]
SignatureError(#[from] SignatureError),
}

#[cfg(test)]
mod test {

use crate::{
extension::prelude::{BOOL_T, QB_T, USIZE_T},
std_extensions::arithmetic::{
Expand All @@ -453,13 +469,15 @@ mod test {
assert_eq!(op.name(), "res.op");
assert_eq!(DataflowOpTrait::description(&op), "desc");
assert_eq!(op.args(), &[TypeArg::Type { ty: USIZE_T }]);
assert_eq!(op.signature(), sig);
assert_eq!(
op.signature(),
sig.with_extension_delta(op.extension().clone())
);
assert!(op.is_opaque());
assert!(!op.is_extension_op());
}

#[test]
#[should_panic] // https://github.com/CQCL/hugr/issues/1315
fn resolve_opaque_op() {
let registry = &INT_OPS_REGISTRY;
let i0 = &INT_TYPES[0];
Expand Down
4 changes: 1 addition & 3 deletions hugr-py/src/hugr/std/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ def to_custom(self) -> Custom:
return Custom(
"idivmod_u",
tys.FunctionType(
input=[int_t(self.arg1)] * 2,
output=[int_t(self.arg2)] * 2,
extension_reqs=[OPS_EXTENSION],
input=[int_t(self.arg1)] * 2, output=[int_t(self.arg2)] * 2
),
extension=OPS_EXTENSION,
args=[tys.BoundedNatArg(n=self.arg1), tys.BoundedNatArg(n=self.arg2)],
Expand Down
3 changes: 1 addition & 2 deletions hugr-py/src/hugr/std/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ class _NotDef(AsCustomOp):
"""Not operation."""

def to_custom(self) -> Custom:
sig = tys.FunctionType.endo([tys.Bool], [EXTENSION_ID])
return Custom("Not", sig, extension=EXTENSION_ID)
return Custom("Not", tys.FunctionType.endo([tys.Bool]), extension=EXTENSION_ID)

def __call__(self, a: ComWire) -> Command:
return DataflowOp.__call__(self, a)
Expand Down
20 changes: 5 additions & 15 deletions hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from hugr.hugr import Hugr
from hugr.ops import AsCustomOp, Command, Custom, DataflowOp
from hugr.serialization.serial_hugr import SerialHugr
from hugr.std.float import FLOAT_EXT_ID, FLOAT_T
from hugr.std.float import FLOAT_T

if TYPE_CHECKING:
from hugr.ops import ComWire
Expand Down Expand Up @@ -48,7 +48,7 @@ def __call__(self, q: ComWire) -> Command:
def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo([tys.Qubit], extension_reqs=[QUANTUM_EXTENSION_ID]),
tys.FunctionType.endo([tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -70,9 +70,7 @@ class _Enum(Enum):
def to_custom(self) -> Custom:
return Custom(
self._enum.value,
tys.FunctionType.endo(
[tys.Qubit] * 2, extension_reqs=[QUANTUM_EXTENSION_ID]
),
tys.FunctionType.endo([tys.Qubit] * 2),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -92,11 +90,7 @@ class MeasureDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Measure",
tys.FunctionType(
[tys.Qubit],
[tys.Qubit, tys.Bool],
extension_reqs=[QUANTUM_EXTENSION_ID],
),
tys.FunctionType([tys.Qubit], [tys.Qubit, tys.Bool]),
extension=QUANTUM_EXTENSION_ID,
)

Expand All @@ -112,11 +106,7 @@ class RzDef(AsCustomOp):
def to_custom(self) -> Custom:
return Custom(
"Rz",
tys.FunctionType(
[tys.Qubit, FLOAT_T],
[tys.Qubit],
extension_reqs=[QUANTUM_EXTENSION_ID, FLOAT_EXT_ID],
),
tys.FunctionType([tys.Qubit, FLOAT_T], [tys.Qubit]),
extension=QUANTUM_EXTENSION_ID,
)

Expand Down

0 comments on commit 01da7ba

Please sign in to comment.