Skip to content

Commit

Permalink
performance improvements for quir to pulse (#314)
Browse files Browse the repository at this point in the history
improves the performance of quir to pulse
  • Loading branch information
reza-j authored Apr 3, 2024
1 parent 0d18a61 commit cf0b387
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 125 deletions.
21 changes: 9 additions & 12 deletions include/Conversion/QUIRToPulse/QUIRToPulse.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "mlir/Pass/Pass.h"

#include <queue>
#include <unordered_map>

namespace mlir::pulse {

Expand Down Expand Up @@ -70,9 +71,9 @@ struct QUIRToPulsePass
// will be reset every time convertCircuitToSequence is called and will be
// used by several functions that are called within that function
uint convertedSequenceOpArgIndex;
std::map<uint, uint> circuitArgToConvertedSequenceArgMap;
std::unordered_map<uint, uint> circuitArgToConvertedSequenceArgMap;
SmallVector<Value> convertedPulseSequenceOpArgs;
std::vector<mlir::Attribute> convertedPulseCallSequenceOpOperandNames;
std::unordered_map<std::string, uint> operandNameToIndexMap;

// process the args of the circuit op, and add corresponding args to the
// converted pulse sequence op
Expand Down Expand Up @@ -126,14 +127,15 @@ struct QUIRToPulsePass
mlir::func::FuncOp &mainFunc);
// map of the hashed location of quir angle/duration ops to their converted
// pulse ops
std::map<std::string, mlir::Value> classicalQUIROpLocToConvertedPulseOpMap;
std::unordered_map<std::string, mlir::Value>
classicalQUIROpLocToConvertedPulseOpMap;

// port name to Port_CreateOp map
std::map<std::string, mlir::pulse::Port_CreateOp> openedPorts;
std::unordered_map<std::string, mlir::pulse::Port_CreateOp> openedPorts;
// mixframe name to MixFrameOp map
std::map<std::string, mlir::pulse::MixFrameOp> openedMixFrames;
std::unordered_map<std::string, mlir::pulse::MixFrameOp> openedMixFrames;
// waveform name to Waveform_CreateOp map
std::map<std::string, mlir::pulse::Waveform_CreateOp> openedWfrs;
std::unordered_map<std::string, mlir::pulse::Waveform_CreateOp> openedWfrs;
// add a port to IR if it's not already added and return the Port_CreateOp
mlir::pulse::Port_CreateOp addPortOpToIR(std::string const &portName,
mlir::func::FuncOp &mainFunc,
Expand All @@ -149,14 +151,9 @@ struct QUIRToPulsePass
mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder);

void addCircuitToEraseList(mlir::Operation *op);
void addCircuitOperandToEraseList(mlir::Operation *op);
std::vector<mlir::Operation *> quirCircuitEraseList;
std::vector<mlir::Operation *> quirCircuitOperandEraseList;

// parse the waveform containers and add them to pulseNameToWaveformMap
void parsePulseWaveformContainerOps(std::string &waveformContainerPath);
std::map<std::string, Waveform_CreateOp> pulseNameToWaveformMap;
std::unordered_map<std::string, Waveform_CreateOp> pulseNameToWaveformMap;

qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};
};
Expand Down
130 changes: 31 additions & 99 deletions lib/Conversion/QUIRToPulse/QUIRToPulse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iterator>
#include <memory>
#include <queue>
#include <string>
Expand Down Expand Up @@ -109,26 +107,17 @@ void QUIRToPulsePass::runOnOperation() {
callCircOp->erase();
});

// erase the quir circuits
LLVM_DEBUG(llvm::dbgs() << "\nErasing quir circuits:\n");
for (auto *op : quirCircuitEraseList) {
LLVM_DEBUG(op->dump());
op->erase();
}

// erase quir barriers before erasing the operands
moduleOp->walk([&](mlir::quir::BarrierOp barrierOp) { barrierOp->erase(); });

// erase the quir circuit operands
LLVM_DEBUG(llvm::dbgs() << "\nErasing quir circuit operands:\n");
for (auto *op : quirCircuitOperandEraseList) {
LLVM_DEBUG(op->dump());
op->erase();
}
// erase circuit ops
moduleOp->walk([&](CircuitOp circOp) { circOp->erase(); });

// erase the rest of quir.declare_qubits (unused in the input program)
moduleOp->walk([&](mlir::quir::DeclareQubitOp declareQubitOp) {
declareQubitOp->erase();
// erase qubit ops and constant angle ops
moduleOp->walk([&](Operation *op) {
if (isa<mlir::quir::DeclareQubitOp>(op))
op->erase();
else if (auto castOp = dyn_cast<mlir::quir::ConstantOp>(op)) {
if (castOp.getType().isa<::mlir::quir::AngleType>())
op->erase();
}
});
}

Expand All @@ -144,7 +133,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
LLVM_DEBUG(llvm::dbgs() << "\nConverting QUIR circuit " << circName << ":\n");
assert(callCircuitOp && "callCircuit op is null");
assert(circuitOp && "circuit op is null");
addCircuitToEraseList(circuitOp);

// build an empty pulse sequence
SmallVector<Value> arguments;
Expand All @@ -162,7 +150,7 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
convertedSequenceOpArgIndex = 0;
circuitArgToConvertedSequenceArgMap.clear();
convertedPulseSequenceOpArgs.clear();
convertedPulseCallSequenceOpOperandNames.clear();
operandNameToIndexMap.clear();

// convert quir circuit args if not already converted, and add the converted
// args to the the converted pulse sequence
Expand Down Expand Up @@ -200,9 +188,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
attr.getValue());
}

pulseCalCallSequenceOp->setAttr(
"pulse.operands",
pulseCalSequenceOp->getAttrOfType<ArrayAttr>("pulse.args"));
for (auto type : pulseCalCallSequenceOp.getResultTypes())
convertedPulseSequenceOpReturnTypes.push_back(type);
for (auto val : pulseCalCallSequenceOp.getRes())
Expand Down Expand Up @@ -254,12 +239,6 @@ QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
convertedPulseSequenceOp,
convertedPulseSequenceOpArgs);
convertedPulseCallSequenceOp->moveAfter(callCircuitOp);
convertedPulseSequenceOp->setAttr(
"pulse.args",
builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames));
convertedPulseCallSequenceOp->setAttr(
"pulse.operands",
builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames));

return convertedPulseCallSequenceOp;
}
Expand All @@ -276,35 +255,26 @@ void QUIRToPulsePass::processCircuitArgs(
auto *angleOp = callCircuitOp.getOperand(cnt).getDefiningOp();
LLVM_DEBUG(llvm::dbgs() << "angle argument ");
LLVM_DEBUG(angleOp->dump());
convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex,
builder.getF64Type(), dictArg,
arg.getLoc());
convertedPulseSequenceOp.getBody().addArgument(builder.getF64Type(),
arg.getLoc());
circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex;
auto convertedAngleToF64 = convertAngleToF64(angleOp, builder);
convertedSequenceOpArgIndex += 1;
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr("angle"));
convertedPulseSequenceOpArgs.push_back(convertedAngleToF64);
} else if (argumentType.isa<mlir::quir::DurationType>()) {
auto *durationOp = callCircuitOp.getOperand(cnt).getDefiningOp();
LLVM_DEBUG(llvm::dbgs() << "duration argument ");
LLVM_DEBUG(durationOp->dump());
convertedPulseSequenceOp.insertArgument(convertedSequenceOpArgIndex,
builder.getI64Type(), dictArg,
arg.getLoc());
convertedPulseSequenceOp.getBody().addArgument(builder.getI64Type(),
arg.getLoc());
circuitArgToConvertedSequenceArgMap[cnt] = convertedSequenceOpArgIndex;
auto convertedDurationToI64 = convertDurationToI64(
callCircuitOp, durationOp, cnt, builder, mainFunc);
convertedSequenceOpArgIndex += 1;
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr("duration"));
convertedPulseSequenceOpArgs.push_back(convertedDurationToI64);
} else if (argumentType.isa<mlir::quir::QubitType>()) {
auto *qubitOp = callCircuitOp.getOperand(cnt).getDefiningOp();
addCircuitOperandToEraseList(qubitOp);
}

else
} else
llvm_unreachable("unkown circuit argument.");
}
}
Expand Down Expand Up @@ -413,23 +383,16 @@ void QUIRToPulsePass::processMixFrameOpArg(
mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) {
auto mixedFrameOp =
addMixFrameOpToIR(mixFrameName, portName, mainFunc, builder);
auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(),
convertedPulseCallSequenceOpOperandNames.end(),
builder.getStringAttr(mixFrameName));
if (it == convertedPulseCallSequenceOpOperandNames.end()) {
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr(mixFrameName));
if (operandNameToIndexMap.find(mixFrameName) == operandNameToIndexMap.end()) {
operandNameToIndexMap[mixFrameName] = convertedSequenceOpArgIndex;
convertedPulseSequenceOpArgs.push_back(mixedFrameOp);
convertedPulseSequenceOp.insertArgument(
convertedSequenceOpArgIndex,
builder.getType<mlir::pulse::MixedFrameType>(), DictionaryAttr{},
argumentValue.getLoc());
convertedPulseSequenceOp.getBody().addArgument(
builder.getType<mlir::pulse::MixedFrameType>(), argumentValue.getLoc());
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]);
convertedSequenceOpArgIndex += 1;
} else {
uint const mixFrameOperandIndex =
std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it);
uint const mixFrameOperandIndex = operandNameToIndexMap[mixFrameName];
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[mixFrameOperandIndex]);
}
Expand All @@ -442,22 +405,16 @@ void QUIRToPulsePass::processPortOpArg(std::string const &portName,
mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder) {
auto portOp = addPortOpToIR(portName, mainFunc, builder);
auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(),
convertedPulseCallSequenceOpOperandNames.end(),
builder.getStringAttr(portName));
if (it == convertedPulseCallSequenceOpOperandNames.end()) {
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr(portName));
if (operandNameToIndexMap.find(portName) == operandNameToIndexMap.end()) {
operandNameToIndexMap[portName] = convertedSequenceOpArgIndex;
convertedPulseSequenceOpArgs.push_back(portOp);
convertedPulseSequenceOp.insertArgument(
convertedSequenceOpArgIndex, builder.getType<mlir::pulse::PortType>(),
DictionaryAttr{}, argumentValue.getLoc());
convertedPulseSequenceOp.getBody().addArgument(
builder.getType<mlir::pulse::PortType>(), argumentValue.getLoc());
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]);
convertedSequenceOpArgIndex += 1;
} else {
uint const portOperandIndex =
std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it);
uint const portOperandIndex = operandNameToIndexMap[portName];
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[portOperandIndex]);
}
Expand All @@ -470,23 +427,16 @@ void QUIRToPulsePass::processWfrOpArg(std::string const &wfrName,
mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder) {
auto wfrOp = addWfrOpToIR(wfrName, mainFunc, builder);
auto it = std::find(convertedPulseCallSequenceOpOperandNames.begin(),
convertedPulseCallSequenceOpOperandNames.end(),
builder.getStringAttr(wfrName));
if (it == convertedPulseCallSequenceOpOperandNames.end()) {
convertedPulseCallSequenceOpOperandNames.push_back(
builder.getStringAttr(wfrName));
if (operandNameToIndexMap.find(wfrName) == operandNameToIndexMap.end()) {
operandNameToIndexMap[wfrName] = convertedSequenceOpArgIndex;
convertedPulseSequenceOpArgs.push_back(wfrOp);
convertedPulseSequenceOp.insertArgument(
convertedSequenceOpArgIndex,
builder.getType<mlir::pulse::WaveformType>(), DictionaryAttr{},
argumentValue.getLoc());
convertedPulseSequenceOp.getBody().addArgument(
builder.getType<mlir::pulse::WaveformType>(), argumentValue.getLoc());
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[convertedSequenceOpArgIndex]);
convertedSequenceOpArgIndex += 1;
} else {
uint const wfrOperandIndex =
std::distance(convertedPulseCallSequenceOpOperandNames.begin(), it);
uint const wfrOperandIndex = operandNameToIndexMap[wfrName];
pulseCalSequenceArgs.push_back(
convertedPulseSequenceOp.getArguments()[wfrOperandIndex]);
}
Expand Down Expand Up @@ -561,7 +511,6 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
if (classicalQUIROpLocToConvertedPulseOpMap.find(angleLocHash) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
if (auto castOp = dyn_cast<quir::ConstantOp>(angleOp)) {
addCircuitOperandToEraseList(angleOp);
double const angleVal =
castOp.getAngleValueFromConstant().convertToDouble();
auto f64Angle = builder.create<mlir::arith::ConstantOp>(
Expand All @@ -575,7 +524,6 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
angleCastedOp->moveAfter(castOp);
classicalQUIROpLocToConvertedPulseOpMap[angleLocHash] = angleCastedOp;
} else if (auto castOp = dyn_cast<oq3::CastOp>(angleOp)) {
addCircuitOperandToEraseList(angleOp);
auto castOpArg = castOp.getArg();
if (auto paramCastOp =
dyn_cast<qcs::ParameterLoadOp>(castOpArg.getDefiningOp())) {
Expand All @@ -600,7 +548,6 @@ mlir::Value QUIRToPulsePass::convertDurationToI64(
if (classicalQUIROpLocToConvertedPulseOpMap.find(durLocHash) ==
classicalQUIROpLocToConvertedPulseOpMap.end()) {
if (auto castOp = dyn_cast<quir::ConstantOp>(durationOp)) {
addCircuitOperandToEraseList(durationOp);
auto durVal =
quir::getDuration(castOp).get().getDuration().convertToDouble();
assert(castOp.getType().dyn_cast<DurationType>().getUnits() ==
Expand Down Expand Up @@ -658,21 +605,6 @@ QUIRToPulsePass::addWfrOpToIR(std::string const &wfrName,
return openedWfrs[wfrName];
}

void QUIRToPulsePass::addCircuitToEraseList(mlir::Operation *op) {
assert(op && "caller requested adding a null op to erase list");
if (std::find(quirCircuitEraseList.begin(), quirCircuitEraseList.end(), op) ==
quirCircuitEraseList.end())
quirCircuitEraseList.push_back(op);
}

void QUIRToPulsePass::addCircuitOperandToEraseList(mlir::Operation *op) {
assert(op && "caller requested adding a null op to erase list");
if (std::find(quirCircuitOperandEraseList.begin(),
quirCircuitOperandEraseList.end(),
op) == quirCircuitOperandEraseList.end())
quirCircuitOperandEraseList.push_back(op);
}

void QUIRToPulsePass::parsePulseWaveformContainerOps(
std::string &waveformContainerPath) {
std::string errorMessage;
Expand Down
Loading

0 comments on commit cf0b387

Please sign in to comment.