diff --git a/pytket/pytket/circuit/decompose_classical.py b/pytket/pytket/circuit/decompose_classical.py index 5419a1470a..11a7f1a71c 100644 --- a/pytket/pytket/circuit/decompose_classical.py +++ b/pytket/pytket/circuit/decompose_classical.py @@ -337,30 +337,34 @@ def decompose_expr(self, expr: ClExpr, out_var: Variable | None) -> Variable: self.add_var(out_var) match op: case ClOp.BitAnd: - self.circ.add_c_and(*terms, out_var, **self.kwargs) + self.circ.add_c_and(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitNot: - self.circ.add_c_not(*terms, out_var, **self.kwargs) + self.circ.add_c_not(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitOne: + assert isinstance(out_var, Bit) self.circ.add_c_setbits([True], [out_var], **self.kwargs) case ClOp.BitOr: - self.circ.add_c_or(*terms, out_var, **self.kwargs) + self.circ.add_c_or(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitXor: - self.circ.add_c_xor(*terms, out_var, **self.kwargs) + self.circ.add_c_xor(*terms, out_var, **self.kwargs) # type: ignore case ClOp.BitZero: + assert isinstance(out_var, Bit) self.circ.add_c_setbits([False], [out_var], **self.kwargs) case ClOp.RegAnd: - self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) + self.circ.add_c_and_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegNot: - self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) + self.circ.add_c_not_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegOne: + assert isinstance(out_var, BitRegister) self.circ.add_c_setbits( [True] * out_var.size, out_var.to_list(), **self.kwargs ) case ClOp.RegOr: - self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) + self.circ.add_c_or_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegXor: - self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) + self.circ.add_c_xor_to_registers(*terms, out_var, **self.kwargs) # type: ignore case ClOp.RegZero: + assert isinstance(out_var, BitRegister) self.circ.add_c_setbits( [False] * out_var.size, out_var.to_list(), **self.kwargs ) @@ -483,14 +487,15 @@ def _decompose_expressions(circ: Circuit) -> tuple[Circuit, bool]: reg_posn = wexpr.reg_posn output_posn = wexpr.output_posn assert len(output_posn) > 0 - output0: Bit = args[output_posn[0]] + output0 = args[output_posn[0]] + assert isinstance(output0, Bit) out_var: Variable = ( BitRegister(output0.reg_name, len(output_posn)) if has_reg_output(expr.op) else output0 ) decomposer = ClExprDecomposer( - newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs + newcirc, bit_posn, reg_posn, args, bit_heap, reg_heap, kwargs # type: ignore ) comp_var = decomposer.decompose_expr(expr, out_var) if comp_var != out_var: diff --git a/pytket/tests/qasm_test.py b/pytket/tests/qasm_test.py index b0a8864a0f..a3824c03fe 100644 --- a/pytket/tests/qasm_test.py +++ b/pytket/tests/qasm_test.py @@ -1238,7 +1238,8 @@ def test_multibitop() -> None: test_hqs_conditional_params() test_barrier() test_barrier_2() - test_decomposable_extended() + test_decomposable_extended(True) + test_decomposable_extended(False) test_alternate_encoding() test_header_stops_gate_definition() test_tk2_definition()