Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for externally linked gates #182

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions qbraid_qir/qasm3/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ def qasm3_to_qir(
elif not isinstance(program, str):
raise TypeError("Input quantum program must be of type openqasm3.ast.Program or str.")

qasm3_module = pyqasm.load(program)
qasm3_module.unroll()
external_gates: list[str] = kwargs.get("external_gates", [])

qasm3_module = pyqasm.load(program)
qasm3_module.unroll(external_gates=external_gates)
if name is None:
name = generate_module_id()
llvm_module = qir_module(Context(), name)
Expand Down
65 changes: 63 additions & 2 deletions qbraid_qir/qasm3/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ class QasmQIRVisitor:
Args:
initialize_runtime (bool): If True, quantum runtime will be initialized. Defaults to True.
record_output (bool): If True, output of the circuit will be recorded. Defaults to True.
external_gates (list[str]): List of custom gates that should not be unrolled.
Instead, these gates are marked for external linkage, as
qir-functions with the name "__quantum__qis__<GateName>__body"
"""

def __init__(
self,
initialize_runtime: bool = True,
record_output: bool = True,
external_gates: list[str] | None = None,
):
self._llvm_module: pyqir.Module
self._builder: pyqir.Builder
Expand All @@ -57,6 +61,12 @@ def __init__(
self._initialize_runtime: bool = initialize_runtime
self._record_output: bool = record_output

if external_gates is None:
external_gates = []
self._external_gates_map: dict[str, pyqir.Function | None] = {
external_gate: None for external_gate in external_gates
}

def visit_qasm3_module(self, module: QasmQIRModule) -> None:
"""
Visit a Qasm3 module.
Expand Down Expand Up @@ -319,6 +329,55 @@ def _visit_basic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
else:
qir_func(self._builder, *qubit_subset)

def _visit_external_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
"""Visit an external gate operation element.

Args:
operation (qasm3_ast.QuantumGate): The gate operation to visit.


Returns:
None

Raises:
Qasm3ConversionError: If the number of qubits is invalid.

"""
logger.debug("Visiting external gate operation '%s'", str(operation))
op_name: str = operation.name.name
op_qubits = self._get_op_bits(operation)
op_qubit_count = len(op_qubits)

if len(operation.modifiers) > 0:
raise_qasm3_error(
"Modifiers on externally linked gates are not supported in pyqir",
err_type=NotImplementedError,
)

context = self._llvm_module.context
qir_function = self._external_gates_map[op_name]
if qir_function is None:
# First time seeing this external gate -> define new function
qir_function_arguments = [pyqir.Type.double(context)] * len(operation.arguments)
qir_function_arguments += [pyqir.qubit_type(context)] * op_qubit_count

qir_function = pyqir.Function(
pyqir.FunctionType(pyqir.Type.void(context), qir_function_arguments),
pyqir.Linkage.EXTERNAL,
f"__quantum__qis__{op_name}__body",
self._llvm_module,
)
self._external_gates_map[op_name] = qir_function

op_parameters = None
if len(operation.arguments) > 0: # parametric gate
op_parameters = self._get_op_parameters(operation)

if op_parameters is not None:
self._builder.call(qir_function, [*op_parameters, *op_qubits])
else:
self._builder.call(qir_function, op_qubits)

def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> None:
"""Visit a gate operation element.

Expand All @@ -328,8 +387,10 @@ def _visit_generic_gate_operation(self, operation: qasm3_ast.QuantumGate) -> Non
Returns:
None
"""
# TODO: maybe needs to be extended for custom gates
self._visit_basic_gate_operation(operation)
if operation.name.name in self._external_gates_map:
self._visit_external_gate_operation(operation)
else:
self._visit_basic_gate_operation(operation)

TheGupta2012 marked this conversation as resolved.
Show resolved Hide resolved
def _get_branch_params(self, condition: Any) -> tuple[str, int, bool]:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/cirq_qir/test_cirq_to_qir.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_triple_qubit_gates(circuit_name, request):
check_attributes(generated_qir, 3, 3)
func = get_entry_point_body(generated_qir)
assert func[0] == initialize_call_string()
assert func[1] == generic_op_call_string(qir_op, [0, 1, 2])
assert func[1] == generic_op_call_string(qir_op, [], [0, 1, 2])
assert func[5] == return_string()
assert len(func) == 6

Expand Down
29 changes: 29 additions & 0 deletions tests/qasm3_qir/converter/test_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from tests.qir_utils import (
check_attributes,
check_custom_qasm_gate_op,
check_custom_qasm_gate_op_with_external_gates,
check_generic_gate_op,
check_single_qubit_gate_op,
check_single_qubit_rotation_op,
check_three_qubit_gate_op,
Expand Down Expand Up @@ -144,6 +146,20 @@ def test_qasm_u3_gates():
check_single_qubit_rotation_op(generated_qir, 1, [0], [0.5, 0.5, 0.5], "u3")


def test_qasm_u3_gates_external():
qasm3_string = """
OPENQASM 3;
include "stdgates.inc";

qubit[2] q1;
u3(0.5, 0.5, 0.5) q1[0];
"""
result = qasm3_to_qir(qasm3_string, external_gates=["u3"])
generated_qir = str(result).splitlines()
check_attributes(generated_qir, 2, 0)
check_generic_gate_op(generated_qir, 1, [0], ["5.000000e-01"] * 3, "u3")


def test_qasm_u2_gates():
qasm3_string = """
OPENQASM 3;
Expand Down Expand Up @@ -171,6 +187,19 @@ def test_custom_ops(test_name, request):
check_custom_qasm_gate_op(generated_qir, gate_type)


@pytest.mark.parametrize("test_name", custom_op_tests)
def test_custom_ops_with_external_gates(test_name, request):
qasm3_string = request.getfixturevalue(test_name)
gate_type = test_name.removeprefix("Fixture_")
result = qasm3_to_qir(qasm3_string, external_gates=["custom", "custom1"])

generated_qir = str(result).splitlines()
check_attributes(generated_qir, 2, 0)

# Check for custom gate definition
check_custom_qasm_gate_op_with_external_gates(generated_qir, gate_type)


def test_pow_gate_modifier():
qasm3_string = """
OPENQASM 3;
Expand Down
52 changes: 48 additions & 4 deletions tests/qir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,11 @@ def reset_call_string(qb: int) -> str:
return f"call void @__quantum__qis__reset__body({_qubit_string(qb)})"


def generic_op_call_string(name: str, qbs: list[int]) -> str:
args = ", ".join(_qubit_string(qb) for qb in qbs)
return f"call void @__quantum__qis__{name}__body({args})"
def generic_op_call_string(name: str, angles: list[str], qubits: list[int]) -> str:
angles = ["double " + angle for angle in angles]
qubits = [_qubit_string(q) for q in qubits]
parameters = ", ".join(angles + qubits)
return f"call void @__quantum__qis__{name}__body({parameters})"


def return_string() -> str:
Expand Down Expand Up @@ -235,6 +237,31 @@ def check_single_qubit_gate_op(
), f"Incorrect single qubit gate count: {expected_ops} expected, {op_count} actual"


def check_generic_gate_op(
qir: list[str], expected_ops: int, qubit_list: list[int], param_list: list[str], gate_name: str
):
entry_body = get_entry_point_body(qir)
op_count = 0

for line in entry_body:
gate_call_id = (
f"qis__{gate_name}" if "dg" not in gate_name else f"qis__{gate_name.removesuffix('dg')}"
)
if line.strip().startswith("call") and gate_call_id in line:
expected_line = generic_op_call_string(gate_name, param_list, qubit_list)
assert line.strip() == expected_line, (
"Incorrect single qubit gate call in qir"
+ f"Expected {expected_line}, found {line.strip()}"
)
op_count += 1

if op_count == expected_ops:
break

if op_count != expected_ops:
assert False, f"Incorrect gate count: {expected_ops} expected, {op_count} actual"


def check_two_qubit_gate_op(
qir: list[str], expected_ops: int, qubit_lists: list[int], gate_name: str
):
Expand Down Expand Up @@ -346,7 +373,7 @@ def check_three_qubit_gate_op(
for line in entry_body:
if line.strip().startswith("call") and f"qis__{gate_name}" in line:
assert line.strip() == generic_op_call_string(
gate_name, qubit_lists[q_id]
gate_name, [], qubit_lists[q_id]
), f"Incorrect three qubit gate call in qir - {line}"
op_count += 1
q_id += 1
Expand Down Expand Up @@ -427,6 +454,23 @@ def check_custom_qasm_gate_op(qir: list[str], test_type: str):
assert False, f"Unknown test type {test_type} for custom ops"


def check_custom_qasm_gate_op_with_external_gates(qir: list[str], test_type: str):
if test_type == "simple":
check_generic_gate_op(qir, 1, [0, 1], ["1.100000e+00"], "custom")
elif test_type == "nested":
check_generic_gate_op(
qir, 1, [0, 1], ["4.800000e+00", "1.000000e-01", "3.000000e-01"], "custom"
)
elif test_type == "complex":
# Only custom1 is external, custom2 and custom3 should be unrolled
check_generic_gate_op(qir, 1, [0], [], "custom1")
check_generic_gate_op(qir, 1, [0], ["1.000000e-01"], "ry")
check_generic_gate_op(qir, 1, [0], ["2.000000e-01"], "rz")
check_generic_gate_op(qir, 1, [0, 1], [], "cnot")
else:
assert False, f"Unknown test type {test_type} for custom ops"


def check_expressions(
qir: list[str], expected_ops: int, gates: list[str], expression_values, qubits: list[int]
):
Expand Down
Loading