Skip to content

Commit

Permalink
performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
reza-j committed Apr 2, 2024
1 parent b12baec commit 084e79b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 68 deletions.
3 changes: 2 additions & 1 deletion include/Conversion/QUIRToPulse/QUIRToPulse.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Pass/Pass.h"

#include <queue>
#include <unordered_map>

namespace mlir::pulse {

Expand Down Expand Up @@ -71,7 +72,7 @@ struct QUIRToPulsePass
uint convertedSequenceOpArgIndex;
std::map<uint, uint> circuitArgToConvertedSequenceArgMap;
SmallVector<Value> convertedPulseSequenceOpArgs;
std::vector<mlir::Attribute> convertedPulseCallSequenceOpOperandNames;
std::unordered_map<std::string, uint> operandNameToIndexMap;

// process the args of the circuit op, and add corresponding args to the
// converted pulse sequence op
Expand Down
75 changes: 20 additions & 55 deletions lib/Conversion/QUIRToPulse/QUIRToPulse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
convertedSequenceOpArgIndex = 0;
circuitArgToConvertedSequenceArgMap.clear();
convertedPulseSequenceOpArgs.clear();
convertedPulseCallSequenceOpOperandNames.clear();
operandNameToIndexMap.clear();

// convert quir circuit args if not already converted, and add the converted
// args to the the converted pulse sequence
Expand Down Expand Up @@ -198,9 +198,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
attr.getValue());
}

pulseCalCallSequenceOp->setAttr(
"pulse.operands",
pulseCalSequenceOp->getAttrOfType<ArrayAttr>("pulse.args"));
for (auto type : pulseCalCallSequenceOp.getResultTypes())
convertedPulseSequenceOpReturnTypes.push_back(type);
for (auto val : pulseCalCallSequenceOp.getRes())
Expand Down Expand Up @@ -252,12 +249,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
convertedPulseSequenceOp,
convertedPulseSequenceOpArgs);
convertedPulseCallSequenceOp->moveAfter(callCircuitOp);
convertedPulseSequenceOp->setAttr(
"pulse.args",
builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames));
convertedPulseCallSequenceOp->setAttr(
"pulse.operands",
builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames));

return convertedPulseCallSequenceOp;
}
Expand All @@ -274,28 +265,22 @@ void QUIRToPulsePass::processCircuitArgs(
auto *angleOp = callCircuitOp.getOperand(cnt).getDefiningOp();
LLVM_DEBUG(llvm::dbgs() << "angle argument ");
LLVM_DEBUG(angleOp->dump());
convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex,
builder.getF64Type(), dictArg,
arg.getLoc());
convertedPulseSequenceOp.getBody().addArgument(builder.getF64Type(),
arg.getLoc());
circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex;
auto convertedAngleToF64 = convertAngleToF64(angleOp, builder);
convertedSequenceOpArgIndex += 1;
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr("angle"));
convertedPulseSequenceOpArgs.push_back(convertedAngleToF64);
} else if (argumentType.isa<mlir::quir::DurationType>()) {
auto *durationOp = callCircuitOp.getOperand(cnt).getDefiningOp();
LLVM_DEBUG(llvm::dbgs() << "duration argument ");
LLVM_DEBUG(durationOp->dump());
convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex,
builder.getI64Type(), dictArg,
arg.getLoc());
convertedPulseSequenceOp.getBody().addArgument(builder.getI64Type(),
arg.getLoc());
circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex;
auto convertedDurationToI64 = convertDurationToI64(
callCircuitOp, durationOp, cnt, builder, mainFunc);
convertedSequenceOpArgIndex += 1;
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr("duration"));
convertedPulseSequenceOpArgs.push_back(convertedDurationToI64);
} else if (argumentType.isa<mlir::quir::QubitType>()) {
auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp();
Expand Down Expand Up @@ -411,23 +396,16 @@ void QUIRToPulsePass::processMixFrameOpArg(
mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) {
auto mixedFrameOp =
addMixFrameOpToIR(mixFrameName, portName, mainFunc, builder);
auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(),
convertedPulseCallSequenceOpOperandNames.end(),
builder.getStringAttr(mixFrameName));
if (it == convertedPulseCallSequenceOpOperandNames.end()) {
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr(mixFrameName));
if (operandNameToIndexMap.find(mixFrameName) == operandNameToIndexMap.end()) {
operandNameToIndexMap[mixFrameName] = convertedSequenceOpArgIndex;
convertedPulseSequenceOpArgs.push_back(mixedFrameOp);
convertedPulseSequenceOp.insertArgument(
convertedSequenceOpArgIndex,
builder.getType<mlir::pulse::MixedFrameType>(), DictionaryAttr{},
argumentValue.getLoc());
convertedPulseSequenceOp.getBody().addArgument(
builder.getType<mlir::pulse::MixedFrameType>(), argumentValue.getLoc());
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]);
convertedSequenceOpArgIndex += 1;
} else {
uint const mixFrameOperandIndex =
std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it);
uint const mixFrameOperandIndex = operandNameToIndexMap[mixFrameName];
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[mixFrameOperandIndex]);
}
Expand All @@ -440,22 +418,16 @@ void QUIRToPulsePass::processPortOpArg(std::string const &portName,
mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder) {
auto portOp = addPortOpToIR(portName, mainFunc, builder);
auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(),
convertedPulseCallSequenceOpOperandNames.end(),
builder.getStringAttr(portName));
if (it == convertedPulseCallSequenceOpOperandNames.end()) {
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr(portName));
if (operandNameToIndexMap.find(portName) == operandNameToIndexMap.end()) {
operandNameToIndexMap[portName] = convertedSequenceOpArgIndex;
convertedPulseSequenceOpArgs.push_back(portOp);
convertedPulseSequenceOp.insertArgument(
convertedSequenceOpArgIndex, builder.getType<mlir::pulse::PortType>(),
DictionaryAttr{}, argumentValue.getLoc());
convertedPulseSequenceOp.getBody().addArgument(
builder.getType<mlir::pulse::PortType>(), argumentValue.getLoc());
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]);
convertedSequenceOpArgIndex += 1;
} else {
uint const portOperandIndex =
std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it);
uint const portOperandIndex = operandNameToIndexMap[portName];
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[portOperandIndex]);
}
Expand All @@ -468,23 +440,16 @@ void QUIRToPulsePass::processWfrOpArg(std::string const &wfrName,
mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder) {
auto wfrOp = addWfrOpToIR(wfrName, mainFunc, builder);
auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(),
convertedPulseCallSequenceOpOperandNames.end(),
builder.getStringAttr(wfrName));
if (it == convertedPulseCallSequenceOpOperandNames.end()) {
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr(wfrName));
if (operandNameToIndexMap.find(wfrName) == operandNameToIndexMap.end()) {
operandNameToIndexMap[wfrName] = convertedSequenceOpArgIndex;
convertedPulseSequenceOpArgs.push_back(wfrOp);
convertedPulseSequenceOp.insertArgument(
convertedSequenceOpArgIndex,
builder.getType<mlir::pulse::WaveformType>(), DictionaryAttr{},
argumentValue.getLoc());
convertedPulseSequenceOp.getBody().addArgument(
builder.getType<mlir::pulse::WaveformType>(), argumentValue.getLoc());
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]);
convertedSequenceOpArgIndex += 1;
} else {
uint const wfrOperandIndex =
std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it);
uint const wfrOperandIndex = operandNameToIndexMap[wfrName];
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[wfrOperandIndex]);
}
Expand Down
24 changes: 12 additions & 12 deletions test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,19 @@ module {
%false = arith.constant false
pulse.return %false : i1
}
// CHECK: pulse.sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame) -> (i1, i1, i1, i1) attributes {pulse.args = ["q3-drive-mixframe", "q5-drive-mixframe", "q3-readout-mixframe", "q3-capture-mixframe", "q5-readout-mixframe", "q5-capture-mixframe"]} {
// CHECK: %0 = pulse.call_sequence @x_3(%arg0) {{{.*}} : (!pulse.mixed_frame) -> i1
// CHECK: %1 = pulse.call_sequence @sx_5(%arg1) {{{.*}} : (!pulse.mixed_frame) -> i1
// CHECK: %2:2 = pulse.call_sequence @measure_3_5(%arg2, %arg3, %arg4, %arg5) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1)
// CHECK: pulse.sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame) -> (i1, i1, i1, i1) {
// CHECK: %0 = pulse.call_sequence @x_3(%arg0) : (!pulse.mixed_frame) -> i1
// CHECK: %1 = pulse.call_sequence @sx_5(%arg1) : (!pulse.mixed_frame) -> i1
// CHECK: %2:2 = pulse.call_sequence @measure_3_5(%arg2, %arg3, %arg4, %arg5) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1)
// CHECK: pulse.return %0, %1, %2#0, %2#1 : i1, i1, i1, i1

// CHECK: pulse.sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame, %arg6: !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) attributes {pulse.args = ["q5-drive-mixframe", "q3-5-cx-mixframe", "q5-3-cx-mixframe", "q3-readout-mixframe", "q3-capture-mixframe", "q5-readout-mixframe", "q5-capture-mixframe"]} {
// CHECK: pulse.sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame, %arg6: !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) {
// CHECK: %cst = arith.constant 1.5707963267948966 : f64
// CHECK: %0 = pulse.call_sequence @rz_5(%cst, %arg0) {{{.*}} : (f64, !pulse.mixed_frame) -> i1
// CHECK: %1 = pulse.call_sequence @sx_5(%arg0) {{{.*}} : (!pulse.mixed_frame) -> i1
// CHECK: %2 = pulse.call_sequence @rz_5(%cst, %arg0) {{{.*}} : (f64, !pulse.mixed_frame) -> i1
// CHECK: %3 = pulse.call_sequence @cx_5_3(%arg1, %arg2) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame) -> i1
// CHECK: %4:2 = pulse.call_sequence @measure_3_5(%arg3, %arg4, %arg5, %arg6) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1)
// CHECK: %0 = pulse.call_sequence @rz_5(%cst, %arg0) : (f64, !pulse.mixed_frame) -> i1
// CHECK: %1 = pulse.call_sequence @sx_5(%arg0) : (!pulse.mixed_frame) -> i1
// CHECK: %2 = pulse.call_sequence @rz_5(%cst, %arg0) : (f64, !pulse.mixed_frame) -> i1
// CHECK: %3 = pulse.call_sequence @cx_5_3(%arg1, %arg2) : (!pulse.mixed_frame, !pulse.mixed_frame) -> i1
// CHECK: %4:2 = pulse.call_sequence @measure_3_5(%arg3, %arg4, %arg5, %arg6) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1)
// CHECK: pulse.return %0, %1, %2, %3, %4#0, %4#1 : i1, i1, i1, i1, i1, i1

func.func @main() -> i32 attributes {quir.classicalOnly = false} {
Expand Down Expand Up @@ -121,12 +121,12 @@ module {
// CHECK-NOT: %5 = quir.declare_qubit {id = 5 : i32} : !quir.qubit<1>
%7:2 = quir.call_circuit @circuit_0_q5_q3_circuit_1_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1)
// CHECK-NOT: %7:2 = quir.call_circuit @circuit_0_q5_q3_circuit_1_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1)
// CHECK: %14:4 = pulse.call_sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%1, %3, %5, %7, %9, %11) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1)
// CHECK: %14:4 = pulse.call_sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%1, %3, %5, %7, %9, %11) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1)
quir.barrier %3, %5 : (!quir.qubit<1>, !quir.qubit<1>) -> ()
// CHECK-NOT: %quir.barrier %3, %5 : (!quir.qubit<1>, !quir.qubit<1>) -> ()
%8:2 = quir.call_circuit @circuit_2_q5_q3_circuit_3_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1)
// CHECK-NOT: %8:2 = quir.call_circuit @circuit_2_q5_q3_circuit_3_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1)
// CHECK: %15:6 = pulse.call_sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%3, %12, %13, %5, %7, %9, %11) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1)
// CHECK: %15:6 = pulse.call_sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%3, %12, %13, %5, %7, %9, %11) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1)
} {qcs.shot_loop}
return %c0_i32 : i32
}
Expand Down

0 comments on commit 084e79b

Please sign in to comment.