Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaces the uses of quir call with the converted pulse call in QuirToPulse #273

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions include/Conversion/QUIRToPulse/QUIRToPulse.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ struct QUIRToPulsePass
mlir::Operation *mainFuncFirstOp;

// convert quir circuit to pulse sequence
void convertCircuitToSequence(mlir::quir::CallCircuitOp callCircuitOp,
mlir::func::FuncOp &mainFunc,
ModuleOp moduleOp);
mlir::pulse::CallSequenceOp
convertCircuitToSequence(mlir::quir::CallCircuitOp &callCircuitOp,
mlir::func::FuncOp &mainFunc, ModuleOp &moduleOp);
// helper datastructure for converting quir circuit to pulse sequence; these
// will be reset every time convertCircuitToSequence is called and will be
// used by several functions that are called within that function
Expand All @@ -75,51 +75,51 @@ struct QUIRToPulsePass

// process the args of the circuit op, and add corresponding args to the
// converted pulse sequence op
void processCircuitArgs(mlir::quir::CallCircuitOp callCircuitOp,
mlir::quir::CircuitOp circuitOp,
SequenceOp convertedPulseSequenceOp,
void processCircuitArgs(mlir::quir::CallCircuitOp &callCircuitOp,
mlir::quir::CircuitOp &circuitOp,
SequenceOp &convertedPulseSequenceOp,
mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder);

// process the args of the pulse cal sequence op corresponding to quirOp
void processPulseCalArgs(mlir::Operation *quirOp,
SequenceOp pulseCalSequenceOp,
SequenceOp &pulseCalSequenceOp,
SmallVector<Value> &pulseCalSeqArgs,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder);
void getQUIROpClassicalOperands(mlir::Operation *quirOp,
std::queue<Value> &angleOperands,
std::queue<Value> &durationOperands);
void processMixFrameOpArg(std::string const &mixFrameName,
std::string const &portName,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &quirOpPulseCalSeqArgs,
Value argumentValue, mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder);
void processPortOpArg(std::string const &portName,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &quirOpPulseCalSeqArgs,
Value argumentValue, mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder);
void processWfrOpArg(std::string const &wfrName,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &quirOpPulseCalSeqArgs,
Value argumentValue, mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder);
void processAngleArg(Value nextAngleOperand,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &quirOpPulseCalSeqArgs,
mlir::OpBuilder &builder);
void processDurationArg(Value frontDurOperand,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &quirOpPulseCalSeqArgs,
mlir::OpBuilder &builder);

// convert angle to F64
mlir::Value convertAngleToF64(Operation *angleOp, mlir::OpBuilder &builder);
// convert duration to I64
mlir::Value convertDurationToI64(mlir::quir::CallCircuitOp callCircuitOp,
mlir::Value convertDurationToI64(mlir::quir::CallCircuitOp &callCircuitOp,
Operation *durOp, uint &cnt,
mlir::OpBuilder &builder,
mlir::func::FuncOp &mainFunc);
Expand Down Expand Up @@ -149,18 +149,16 @@ struct QUIRToPulsePass
mlir::OpBuilder &builder);

void addCircuitToEraseList(mlir::Operation *op);
void addCallCircuitToEraseList(mlir::Operation *op);
void addCircuitOperandToEraseList(mlir::Operation *op);
std::vector<mlir::Operation *> quirCircuitEraseList;
std::vector<mlir::Operation *> quirCallCircuitEraseList;
std::vector<mlir::Operation *> quirCircuitOperandEraseList;

// parse the waveform containers and add them to pulseNameToWaveformMap
void parsePulseWaveformContainerOps(std::string &waveformContainerPath);
std::map<std::string, Waveform_CreateOp> pulseNameToWaveformMap;

llvm::StringMap<Operation *> symbolMap;
mlir::quir::CircuitOp getCircuitOp(mlir::quir::CallCircuitOp callCircuitOp);
mlir::quir::CircuitOp getCircuitOp(mlir::quir::CallCircuitOp &callCircuitOp);
mlir::pulse::SequenceOp getSequenceOp(std::string const &symbolName);
};
} // namespace mlir::pulse
Expand Down
65 changes: 34 additions & 31 deletions lib/Conversion/QUIRToPulse/QUIRToPulse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "Dialect/Pulse/IR/PulseOps.h"
#include "Dialect/Pulse/IR/PulseTypes.h"
#include "Dialect/QCS/IR/QCSOps.h"
#include "Dialect/QUIR/IR/QUIRDialect.h"
#include "Dialect/QUIR/IR/QUIREnums.h"
#include "Dialect/QUIR/IR/QUIROps.h"
#include "Dialect/QUIR/IR/QUIRTypes.h"
Expand Down Expand Up @@ -100,16 +101,15 @@ void QUIRToPulsePass::runOnOperation() {

// convert all QUIR circuits to Pulse sequences
moduleOp->walk([&](CallCircuitOp callCircOp) {
convertCircuitToSequence(callCircOp, mainFunc, moduleOp);
if (isa<CircuitOp>(callCircOp->getParentOp()))
return;
auto convertedPulseCallSequenceOp =
convertCircuitToSequence(callCircOp, mainFunc, moduleOp);
if (!callCircOp->use_empty())
callCircOp->replaceAllUsesWith(convertedPulseCallSequenceOp);
callCircOp->erase();
});

// first erase the quir call circuits
LLVM_DEBUG(llvm::dbgs() << "\nErasing quir call circuits:\n");
for (auto *op : quirCallCircuitEraseList) {
LLVM_DEBUG(op->dump());
op->erase();
}

// erase the quir circuits
LLVM_DEBUG(llvm::dbgs() << "\nErasing quir circuits:\n");
for (auto *op : quirCircuitEraseList) {
Expand All @@ -133,17 +133,17 @@ void QUIRToPulsePass::runOnOperation() {
});
}

void QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp callCircuitOp,
mlir::func::FuncOp &mainFunc,
ModuleOp moduleOp) {
mlir::pulse::CallSequenceOp
QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp &callCircuitOp,
mlir::func::FuncOp &mainFunc,
ModuleOp &moduleOp) {
mlir::OpBuilder builder(mainFunc);

auto circuitOp = getCircuitOp(callCircuitOp);
std::string const circName = circuitOp.getSymName().str();
LLVM_DEBUG(llvm::dbgs() << "\nConverting QUIR circuit " << circName << ":\n");
assert(callCircuitOp && "callCircuit op is null");
assert(circuitOp && "circuit op is null");
addCallCircuitToEraseList(callCircuitOp);
addCircuitToEraseList(circuitOp);

// build an empty pulse sequence
Expand Down Expand Up @@ -188,6 +188,15 @@ void QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp callCircuitOp,
auto pulseCalCallSequenceOp =
entryBuilder.create<mlir::pulse::CallSequenceOp>(
quirOp->getLoc(), pulseCalSequenceOp, pulseCalSequenceArgs);

// copy the quir attributes of measure op (if any)
if (isa<MeasureOp>(quirOp)) {
for (auto attr : quirOp->getAttrs())
if (llvm::isa<mlir::quir::QUIRDialect>(attr.getNameDialect()))
pulseCalCallSequenceOp->setAttr(attr.getName().str(),
attr.getValue());
}

pulseCalCallSequenceOp->setAttr(
"pulse.operands",
pulseCalSequenceOp->getAttrOfType<ArrayAttr>("pulse.args"));
Expand Down Expand Up @@ -248,11 +257,13 @@ void QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp callCircuitOp,
convertedPulseCallSequenceOp->setAttr(
"pulse.operands",
builder.getArrayAttr(convertedPulseCallSequenceOpOperandNames));

return convertedPulseCallSequenceOp;
}

void QUIRToPulsePass::processCircuitArgs(
mlir::quir::CallCircuitOp callCircuitOp, mlir::quir::CircuitOp circuitOp,
SequenceOp convertedPulseSequenceOp, mlir::func::FuncOp &mainFunc,
mlir::quir::CallCircuitOp &callCircuitOp, mlir::quir::CircuitOp &circuitOp,
SequenceOp &convertedPulseSequenceOp, mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder) {
for (uint cnt = 0; cnt < circuitOp.getNumArguments(); cnt++) {
auto arg = circuitOp.getArgument(cnt);
Expand Down Expand Up @@ -296,9 +307,9 @@ void QUIRToPulsePass::processCircuitArgs(
}

void QUIRToPulsePass::processPulseCalArgs(
mlir::Operation *quirOp, SequenceOp pulseCalSequenceOp,
mlir::Operation *quirOp, SequenceOp &pulseCalSequenceOp,
SmallVector<Value> &pulseCalSequenceArgs,
SequenceOp convertedPulseSequenceOp, mlir::func::FuncOp &mainFunc,
SequenceOp &convertedPulseSequenceOp, mlir::func::FuncOp &mainFunc,
mlir::OpBuilder &builder) {

// get the classical operands of the quir op
Expand Down Expand Up @@ -394,7 +405,7 @@ void QUIRToPulsePass::getQUIROpClassicalOperands(

void QUIRToPulsePass::processMixFrameOpArg(
std::string const &mixFrameName, std::string const &portName,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &pulseCalSequenceArgs, Value argumentValue,
mlir::func::FuncOp &mainFunc, mlir::OpBuilder &builder) {
auto mixedFrameOp =
Expand Down Expand Up @@ -422,7 +433,7 @@ void QUIRToPulsePass::processMixFrameOpArg(
}

void QUIRToPulsePass::processPortOpArg(std::string const &portName,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &pulseCalSequenceArgs,
Value argumentValue,
mlir::func::FuncOp &mainFunc,
Expand Down Expand Up @@ -450,7 +461,7 @@ void QUIRToPulsePass::processPortOpArg(std::string const &portName,
}

void QUIRToPulsePass::processWfrOpArg(std::string const &wfrName,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &pulseCalSequenceArgs,
Value argumentValue,
mlir::func::FuncOp &mainFunc,
Expand Down Expand Up @@ -479,7 +490,7 @@ void QUIRToPulsePass::processWfrOpArg(std::string const &wfrName,
}

void QUIRToPulsePass::processAngleArg(Value nextAngleOperand,
SequenceOp convertedPulseSequenceOp,
SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &pulseCalSequenceArgs,
mlir::OpBuilder &entryBuilder) {
if (nextAngleOperand.isa<BlockArgument>()) {
Expand Down Expand Up @@ -507,7 +518,7 @@ void QUIRToPulsePass::processAngleArg(Value nextAngleOperand,
}

void QUIRToPulsePass::processDurationArg(
Value nextDurationOperand, SequenceOp convertedPulseSequenceOp,
Value nextDurationOperand, SequenceOp &convertedPulseSequenceOp,
SmallVector<Value> &pulseCalSequenceArgs, mlir::OpBuilder &entryBuilder) {
if (nextDurationOperand.isa<BlockArgument>()) {
uint const circNum =
Expand Down Expand Up @@ -578,7 +589,7 @@ mlir::Value QUIRToPulsePass::convertAngleToF64(Operation *angleOp,
}

mlir::Value QUIRToPulsePass::convertDurationToI64(
mlir::quir::CallCircuitOp callCircuitOp, Operation *durationOp, uint &cnt,
mlir::quir::CallCircuitOp &callCircuitOp, Operation *durationOp, uint &cnt,
mlir::OpBuilder &builder, mlir::func::FuncOp &mainFunc) {
assert(durationOp && "duration op is null");
std::string const durLocHash =
Expand Down Expand Up @@ -651,14 +662,6 @@ void QUIRToPulsePass::addCircuitToEraseList(mlir::Operation *op) {
quirCircuitEraseList.push_back(op);
}

void QUIRToPulsePass::addCallCircuitToEraseList(mlir::Operation *op) {
assert(op && "caller requested adding a null op to erase list");
if (std::find(quirCallCircuitEraseList.begin(),
quirCallCircuitEraseList.end(),
op) == quirCallCircuitEraseList.end())
quirCallCircuitEraseList.push_back(op);
}

void QUIRToPulsePass::addCircuitOperandToEraseList(mlir::Operation *op) {
assert(op && "caller requested adding a null op to erase list");
if (std::find(quirCircuitOperandEraseList.begin(),
Expand Down Expand Up @@ -691,7 +694,7 @@ void QUIRToPulsePass::parsePulseWaveformContainerOps(
}

mlir::quir::CircuitOp
QUIRToPulsePass::getCircuitOp(CallCircuitOp callCircuitOp) {
QUIRToPulsePass::getCircuitOp(CallCircuitOp &callCircuitOp) {
auto search = symbolMap.find(callCircuitOp.getCallee());
assert(search != symbolMap.end() && "matching circuit not found");
auto circuitOp = dyn_cast<CircuitOp>(search->second);
Expand Down
Loading