diff --git a/qbraid_qir/qasm3/convert.py b/qbraid_qir/qasm3/convert.py index c8f0ce1..1dd03f5 100644 --- a/qbraid_qir/qasm3/convert.py +++ b/qbraid_qir/qasm3/convert.py @@ -52,9 +52,10 @@ def qasm3_to_qir( elif not isinstance(program, str): raise TypeError("Input quantum program must be of type openqasm3.ast.Program or str.") - qasm3_module = pyqasm.load(program) - qasm3_module.unroll() + external_gates: list[str] = kwargs.get("external_gates", []) + qasm3_module = pyqasm.load(program) + qasm3_module.unroll(external_gates=external_gates) if name is None: name = generate_module_id() llvm_module = qir_module(Context(), name) diff --git a/qbraid_qir/qasm3/visitor.py b/qbraid_qir/qasm3/visitor.py index 8a06291..0f35020 100644 --- a/qbraid_qir/qasm3/visitor.py +++ b/qbraid_qir/qasm3/visitor.py @@ -38,12 +38,16 @@ class QasmQIRVisitor: Args: initialize_runtime (bool): If True, quantum runtime will be initialized. Defaults to True. record_output (bool): If True, output of the circuit will be recorded. Defaults to True. + external_gates (list[str]): List of custom gates that should not be unrolled. + Instead, these gates are marked for external linkage, as + qir-functions with the name "__quantum__qis____body" """ def __init__( self, initialize_runtime: bool = True, record_output: bool = True, + external_gates: list[str] | None = None, ): self._llvm_module: pyqir.Module self._builder: pyqir.Builder @@ -57,6 +61,12 @@ def __init__( self._initialize_runtime: bool = initialize_runtime self._record_output: bool = record_output + if external_gates is None: + external_gates = [] + self._external_gates_map: dict[str, pyqir.Function | None] = { + external_gate: None for external_gate in external_gates + } + def visit_qasm3_module(self, module: QasmQIRModule) -> None: """ Visit a Qasm3 module. @@ -319,6 +329,55 @@ def _visit_basic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None: else: qir_func(self._builder, *qubit_subset) + def _visit_external_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None: + """Visit an external gate operation element. + + Args: + operation (qasm3_ast.QuantumGate): The gate operation to visit. + + + Returns: + None + + Raises: + Qasm3ConversionError: If the number of qubits is invalid. + + """ + logger.debug("Visiting external gate operation '%s'", str(operation)) + op_name: str = operation.name.name + op_qubits = self._get_op_bits(operation) + op_qubit_count = len(op_qubits) + + if len(operation.modifiers) > 0: + raise_qasm3_error( + "Modifiers on externally linked gates are not supported in pyqir", + err_type=NotImplementedError, + ) + + context = self._llvm_module.context + qir_function = self._external_gates_map[op_name] + if qir_function is None: + # First time seeing this external gate -> define new function + qir_function_arguments = [pyqir.Type.double(context)] * len(operation.arguments) + qir_function_arguments += [pyqir.qubit_type(context)] * op_qubit_count + + qir_function = pyqir.Function( + pyqir.FunctionType(pyqir.Type.void(context), qir_function_arguments), + pyqir.Linkage.EXTERNAL, + f"__quantum__qis__{op_name}__body", + self._llvm_module, + ) + self._external_gates_map[op_name] = qir_function + + op_parameters = None + if len(operation.arguments) > 0: # parametric gate + op_parameters = self._get_op_parameters(operation) + + if op_parameters is not None: + self._builder.call(qir_function, [*op_parameters, *op_qubits]) + else: + self._builder.call(qir_function, op_qubits) + def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None: """Visit a gate operation element. @@ -328,8 +387,10 @@ def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> Non Returns: None """ - # TODO: maybe needs to be extended for custom gates - self._visit_basic_gate_operation(operation) + if operation.name.name in self._external_gates_map: + self._visit_external_gate_operation(operation) + else: + self._visit_basic_gate_operation(operation) def _get_branch_params(self, condition: Any) -> tuple[str, int, bool]: """ diff --git a/tests/cirq_qir/test_cirq_to_qir.py b/tests/cirq_qir/test_cirq_to_qir.py index 4756c9a..f9973a1 100644 --- a/tests/cirq_qir/test_cirq_to_qir.py +++ b/tests/cirq_qir/test_cirq_to_qir.py @@ -143,7 +143,7 @@ def test_triple_qubit_gates(circuit_name, request): check_attributes(generated_qir, 3, 3) func = get_entry_point_body(generated_qir) assert func[0] == initialize_call_string() - assert func[1] == generic_op_call_string(qir_op, [0, 1, 2]) + assert func[1] == generic_op_call_string(qir_op, [], [0, 1, 2]) assert func[5] == return_string() assert len(func) == 6 diff --git a/tests/qasm3_qir/converter/test_gates.py b/tests/qasm3_qir/converter/test_gates.py index 518458c..1abe394 100644 --- a/tests/qasm3_qir/converter/test_gates.py +++ b/tests/qasm3_qir/converter/test_gates.py @@ -25,6 +25,8 @@ from tests.qir_utils import ( check_attributes, check_custom_qasm_gate_op, + check_custom_qasm_gate_op_with_external_gates, + check_generic_gate_op, check_single_qubit_gate_op, check_single_qubit_rotation_op, check_three_qubit_gate_op, @@ -144,6 +146,20 @@ def test_qasm_u3_gates(): check_single_qubit_rotation_op(generated_qir, 1, [0], [0.5, 0.5, 0.5], "u3") +def test_qasm_u3_gates_external(): + qasm3_string = """ + OPENQASM 3; + include "stdgates.inc"; + + qubit[2] q1; + u3(0.5, 0.5, 0.5) q1[0]; + """ + result = qasm3_to_qir(qasm3_string, external_gates=["u3"]) + generated_qir = str(result).splitlines() + check_attributes(generated_qir, 2, 0) + check_generic_gate_op(generated_qir, 1, [0], ["5.000000e-01"] * 3, "u3") + + def test_qasm_u2_gates(): qasm3_string = """ OPENQASM 3; @@ -171,6 +187,19 @@ def test_custom_ops(test_name, request): check_custom_qasm_gate_op(generated_qir, gate_type) +@pytest.mark.parametrize("test_name", custom_op_tests) +def test_custom_ops_with_external_gates(test_name, request): + qasm3_string = request.getfixturevalue(test_name) + gate_type = test_name.removeprefix("Fixture_") + result = qasm3_to_qir(qasm3_string, external_gates=["custom", "custom1"]) + + generated_qir = str(result).splitlines() + check_attributes(generated_qir, 2, 0) + + # Check for custom gate definition + check_custom_qasm_gate_op_with_external_gates(generated_qir, gate_type) + + def test_pow_gate_modifier(): qasm3_string = """ OPENQASM 3; diff --git a/tests/qir_utils.py b/tests/qir_utils.py index 4738ee0..474ebf2 100644 --- a/tests/qir_utils.py +++ b/tests/qir_utils.py @@ -103,9 +103,11 @@ def reset_call_string(qb: int) -> str: return f"call void @__quantum__qis__reset__body({_qubit_string(qb)})" -def generic_op_call_string(name: str, qbs: list[int]) -> str: - args = ", ".join(_qubit_string(qb) for qb in qbs) - return f"call void @__quantum__qis__{name}__body({args})" +def generic_op_call_string(name: str, angles: list[str], qubits: list[int]) -> str: + angles = ["double " + angle for angle in angles] + qubits = [_qubit_string(q) for q in qubits] + parameters = ", ".join(angles + qubits) + return f"call void @__quantum__qis__{name}__body({parameters})" def return_string() -> str: @@ -235,6 +237,31 @@ def check_single_qubit_gate_op( ), f"Incorrect single qubit gate count: {expected_ops} expected, {op_count} actual" +def check_generic_gate_op( + qir: list[str], expected_ops: int, qubit_list: list[int], param_list: list[str], gate_name: str +): + entry_body = get_entry_point_body(qir) + op_count = 0 + + for line in entry_body: + gate_call_id = ( + f"qis__{gate_name}" if "dg" not in gate_name else f"qis__{gate_name.removesuffix('dg')}" + ) + if line.strip().startswith("call") and gate_call_id in line: + expected_line = generic_op_call_string(gate_name, param_list, qubit_list) + assert line.strip() == expected_line, ( + "Incorrect single qubit gate call in qir" + + f"Expected {expected_line}, found {line.strip()}" + ) + op_count += 1 + + if op_count == expected_ops: + break + + if op_count != expected_ops: + assert False, f"Incorrect gate count: {expected_ops} expected, {op_count} actual" + + def check_two_qubit_gate_op( qir: list[str], expected_ops: int, qubit_lists: list[int], gate_name: str ): @@ -346,7 +373,7 @@ def check_three_qubit_gate_op( for line in entry_body: if line.strip().startswith("call") and f"qis__{gate_name}" in line: assert line.strip() == generic_op_call_string( - gate_name, qubit_lists[q_id] + gate_name, [], qubit_lists[q_id] ), f"Incorrect three qubit gate call in qir - {line}" op_count += 1 q_id += 1 @@ -427,6 +454,23 @@ def check_custom_qasm_gate_op(qir: list[str], test_type: str): assert False, f"Unknown test type {test_type} for custom ops" +def check_custom_qasm_gate_op_with_external_gates(qir: list[str], test_type: str): + if test_type == "simple": + check_generic_gate_op(qir, 1, [0, 1], ["1.100000e+00"], "custom") + elif test_type == "nested": + check_generic_gate_op( + qir, 1, [0, 1], ["4.800000e+00", "1.000000e-01", "3.000000e-01"], "custom" + ) + elif test_type == "complex": + # Only custom1 is external, custom2 and custom3 should be unrolled + check_generic_gate_op(qir, 1, [0], [], "custom1") + check_generic_gate_op(qir, 1, [0], ["1.000000e-01"], "ry") + check_generic_gate_op(qir, 1, [0], ["2.000000e-01"], "rz") + check_generic_gate_op(qir, 1, [0, 1], [], "cnot") + else: + assert False, f"Unknown test type {test_type} for custom ops" + + def check_expressions( qir: list[str], expected_ops: int, gates: list[str], expression_values, qubits: list[int] ):