diff --git a/include/Conversion/QUIRToPulse/QUIRToPulse.h b/include/Conversion/QUIRToPulse/QUIRToPulse.h index ae2eda352..7e2d498f9 100644 --- a/include/Conversion/QUIRToPulse/QUIRToPulse.h +++ b/include/Conversion/QUIRToPulse/QUIRToPulse.h @@ -29,6 +29,7 @@ #include "mlir/Pass/Pass.h" #include +#include namespace mlir::pulse { @@ -71,7 +72,7 @@ struct QUIRToPulsePass uint convertedSequenceOpArgIndex; std::map circuitArgToConvertedSequenceArgMap; SmallVector convertedPulseSequenceOpArgs; - std::vector convertedPulseCallSequenceOpOperandNames; + std::unordered_map operandNameToIndexMap; // process the args of the circuit op, and add corresponding args to the // converted pulse sequence op diff --git a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp index 97178294f..eadeb9c21 100644 --- a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp +++ b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp @@ -163,7 +163,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, convertedSequenceOpArgIndex = 0; circuitArgToConvertedSequenceArgMap.clear(); convertedPulseSequenceOpArgs.clear(); - convertedPulseCallSequenceOpOperandNames.clear(); + operandNameToIndexMap.clear(); // convert quir circuit args if not already converted, and add the converted // args to the the converted pulse sequence @@ -198,9 +198,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, attr.getValue()); } - pulseCalCallSequenceOp->setAttr( - "pulse.operands", - pulseCalSequenceOp->getAttrOfType("pulse.args")); for (auto type : pulseCalCallSequenceOp.getResultTypes()) convertedPulseSequenceOpReturnTypes.push_back(type); for (auto val : pulseCalCallSequenceOp.getRes()) @@ -252,12 +249,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp, convertedPulseSequenceOp, convertedPulseSequenceOpArgs); convertedPulseCallSequenceOp->moveAfter(callCircuitOp); - convertedPulseSequenceOp->setAttr( - "pulse.args", - builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames)); - convertedPulseCallSequenceOp->setAttr( - "pulse.operands", - builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames)); return convertedPulseCallSequenceOp; } @@ -274,28 +265,22 @@ void QUIRToPulsePass::processCircuitArgs( auto *angleOp = callCircuitOp.getOperand(cnt).getDefiningOp(); LLVM_DEBUG(llvm::dbgs() << "angle argument "); LLVM_DEBUG(angleOp->dump()); - convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex, - builder.getF64Type(), dictArg, - arg.getLoc()); + convertedPulseSequenceOp.getBody().addArgument(builder.getF64Type(), + arg.getLoc()); circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex; auto convertedAngleToF64 = convertAngleToF64(angleOp, builder); convertedSequenceOpArgIndex += 1; - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr("angle")); convertedPulseSequenceOpArgs.push_back(convertedAngleToF64); } else if (argumentType.isa()) { auto *durationOp = callCircuitOp.getOperand(cnt).getDefiningOp(); LLVM_DEBUG(llvm::dbgs() << "duration argument "); LLVM_DEBUG(durationOp->dump()); - convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex, - builder.getI64Type(), dictArg, - arg.getLoc()); + convertedPulseSequenceOp.getBody().addArgument(builder.getI64Type(), + arg.getLoc()); circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex; auto convertedDurationToI64 = convertDurationToI64( callCircuitOp, durationOp, cnt, builder, mainFunc); convertedSequenceOpArgIndex += 1; - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr("duration")); convertedPulseSequenceOpArgs.push_back(convertedDurationToI64); } else if (argumentType.isa()) { auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp(); @@ -411,23 +396,16 @@ void QUIRToPulsePass::processMixFrameOpArg( mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) { auto mixedFrameOp = addMixFrameOpToIR(mixFrameName, portName, mainFunc, builder); - auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(), - convertedPulseCallSequenceOpOperandNames.end(), - builder.getStringAttr(mixFrameName)); - if (it == convertedPulseCallSequenceOpOperandNames.end()) { - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr(mixFrameName)); + if (operandNameToIndexMap.find(mixFrameName) == operandNameToIndexMap.end()) { + operandNameToIndexMap[mixFrameName] = convertedSequenceOpArgIndex; convertedPulseSequenceOpArgs.push_back(mixedFrameOp); - convertedPulseSequenceOp.insertArgument( - convertedSequenceOpArgIndex, - builder.getType(), DictionaryAttr{}, - argumentValue.getLoc()); + convertedPulseSequenceOp.getBody().addArgument( + builder.getType(), argumentValue.getLoc()); pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]); convertedSequenceOpArgIndex += 1; } else { - uint const mixFrameOperandIndex = - std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it); + uint const mixFrameOperandIndex = operandNameToIndexMap[mixFrameName]; pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[mixFrameOperandIndex]); } @@ -440,22 +418,16 @@ void QUIRToPulsePass::processPortOpArg(std::string const &portName, mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) { auto portOp = addPortOpToIR(portName, mainFunc, builder); - auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(), - convertedPulseCallSequenceOpOperandNames.end(), - builder.getStringAttr(portName)); - if (it == convertedPulseCallSequenceOpOperandNames.end()) { - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr(portName)); + if (operandNameToIndexMap.find(portName) == operandNameToIndexMap.end()) { + operandNameToIndexMap[portName] = convertedSequenceOpArgIndex; convertedPulseSequenceOpArgs.push_back(portOp); - convertedPulseSequenceOp.insertArgument( - convertedSequenceOpArgIndex, builder.getType(), - DictionaryAttr{}, argumentValue.getLoc()); + convertedPulseSequenceOp.getBody().addArgument( + builder.getType(), argumentValue.getLoc()); pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]); convertedSequenceOpArgIndex += 1; } else { - uint const portOperandIndex = - std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it); + uint const portOperandIndex = operandNameToIndexMap[portName]; pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[portOperandIndex]); } @@ -468,23 +440,16 @@ void QUIRToPulsePass::processWfrOpArg(std::string const &wfrName, mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) { auto wfrOp = addWfrOpToIR(wfrName, mainFunc, builder); - auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(), - convertedPulseCallSequenceOpOperandNames.end(), - builder.getStringAttr(wfrName)); - if (it == convertedPulseCallSequenceOpOperandNames.end()) { - convertedPulseCallSequenceOpOperandNames.push_back( - builder.getStringAttr(wfrName)); + if (operandNameToIndexMap.find(wfrName) == operandNameToIndexMap.end()) { + operandNameToIndexMap[wfrName] = convertedSequenceOpArgIndex; convertedPulseSequenceOpArgs.push_back(wfrOp); - convertedPulseSequenceOp.insertArgument( - convertedSequenceOpArgIndex, - builder.getType(), DictionaryAttr{}, - argumentValue.getLoc()); + convertedPulseSequenceOp.getBody().addArgument( + builder.getType(), argumentValue.getLoc()); pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]); convertedSequenceOpArgIndex += 1; } else { - uint const wfrOperandIndex = - std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it); + uint const wfrOperandIndex = operandNameToIndexMap[wfrName]; pulseCalSequenceArgs.push_back( convertedPulseSequenceOp.getArguments()[wfrOperandIndex]); } diff --git a/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir b/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir index 687e4b973..2aa74f235 100644 --- a/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir +++ b/test/Conversion/QUIRToPulse/convert-quir-to-pulse.mlir @@ -71,19 +71,19 @@ module { %false = arith.constant false pulse.return %false : i1 } - // CHECK: pulse.sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame) -> (i1, i1, i1, i1) attributes {pulse.args = ["q3-drive-mixframe", "q5-drive-mixframe", "q3-readout-mixframe", "q3-capture-mixframe", "q5-readout-mixframe", "q5-capture-mixframe"]} { - // CHECK: %0 = pulse.call_sequence @x_3(%arg0) {{{.*}} : (!pulse.mixed_frame) -> i1 - // CHECK: %1 = pulse.call_sequence @sx_5(%arg1) {{{.*}} : (!pulse.mixed_frame) -> i1 - // CHECK: %2:2 = pulse.call_sequence @measure_3_5(%arg2, %arg3, %arg4, %arg5) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) + // CHECK: pulse.sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame) -> (i1, i1, i1, i1) { + // CHECK: %0 = pulse.call_sequence @x_3(%arg0) : (!pulse.mixed_frame) -> i1 + // CHECK: %1 = pulse.call_sequence @sx_5(%arg1) : (!pulse.mixed_frame) -> i1 + // CHECK: %2:2 = pulse.call_sequence @measure_3_5(%arg2, %arg3, %arg4, %arg5) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) // CHECK: pulse.return %0, %1, %2#0, %2#1 : i1, i1, i1, i1 - // CHECK: pulse.sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame, %arg6: !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) attributes {pulse.args = ["q5-drive-mixframe", "q3-5-cx-mixframe", "q5-3-cx-mixframe", "q3-readout-mixframe", "q3-capture-mixframe", "q5-readout-mixframe", "q5-capture-mixframe"]} { + // CHECK: pulse.sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%arg0: !pulse.mixed_frame, %arg1: !pulse.mixed_frame, %arg2: !pulse.mixed_frame, %arg3: !pulse.mixed_frame, %arg4: !pulse.mixed_frame, %arg5: !pulse.mixed_frame, %arg6: !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) { // CHECK: %cst = arith.constant 1.5707963267948966 : f64 - // CHECK: %0 = pulse.call_sequence @rz_5(%cst, %arg0) {{{.*}} : (f64, !pulse.mixed_frame) -> i1 - // CHECK: %1 = pulse.call_sequence @sx_5(%arg0) {{{.*}} : (!pulse.mixed_frame) -> i1 - // CHECK: %2 = pulse.call_sequence @rz_5(%cst, %arg0) {{{.*}} : (f64, !pulse.mixed_frame) -> i1 - // CHECK: %3 = pulse.call_sequence @cx_5_3(%arg1, %arg2) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame) -> i1 - // CHECK: %4:2 = pulse.call_sequence @measure_3_5(%arg3, %arg4, %arg5, %arg6) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) + // CHECK: %0 = pulse.call_sequence @rz_5(%cst, %arg0) : (f64, !pulse.mixed_frame) -> i1 + // CHECK: %1 = pulse.call_sequence @sx_5(%arg0) : (!pulse.mixed_frame) -> i1 + // CHECK: %2 = pulse.call_sequence @rz_5(%cst, %arg0) : (f64, !pulse.mixed_frame) -> i1 + // CHECK: %3 = pulse.call_sequence @cx_5_3(%arg1, %arg2) : (!pulse.mixed_frame, !pulse.mixed_frame) -> i1 + // CHECK: %4:2 = pulse.call_sequence @measure_3_5(%arg3, %arg4, %arg5, %arg6) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1) // CHECK: pulse.return %0, %1, %2, %3, %4#0, %4#1 : i1, i1, i1, i1, i1, i1 func.func @main() -> i32 attributes {quir.classicalOnly = false} { @@ -121,12 +121,12 @@ module { // CHECK-NOT: %5 = quir.declare_qubit {id = 5 : i32} : !quir.qubit<1> %7:2 = quir.call_circuit @circuit_0_q5_q3_circuit_1_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) // CHECK-NOT: %7:2 = quir.call_circuit @circuit_0_q5_q3_circuit_1_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) - // CHECK: %14:4 = pulse.call_sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%1, %3, %5, %7, %9, %11) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1) + // CHECK: %14:4 = pulse.call_sequence @circuit_0_q5_q3_circuit_1_q5_sequence(%1, %3, %5, %7, %9, %11) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1) quir.barrier %3, %5 : (!quir.qubit<1>, !quir.qubit<1>) -> () // CHECK-NOT: %quir.barrier %3, %5 : (!quir.qubit<1>, !quir.qubit<1>) -> () %8:2 = quir.call_circuit @circuit_2_q5_q3_circuit_3_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) // CHECK-NOT: %8:2 = quir.call_circuit @circuit_2_q5_q3_circuit_3_q5(%5, %3) : (!quir.qubit<1>, !quir.qubit<1>) -> (i1, i1) - // CHECK: %15:6 = pulse.call_sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%3, %12, %13, %5, %7, %9, %11) {{{.*}} : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) + // CHECK: %15:6 = pulse.call_sequence @circuit_2_q5_q3_circuit_3_q5_sequence(%3, %12, %13, %5, %7, %9, %11) : (!pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame, !pulse.mixed_frame) -> (i1, i1, i1, i1, i1, i1) } {qcs.shot_loop} return %c0_i32 : i32 }