Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract Circuits Performance Improvement #291

Merged
merged 16 commits into from
Mar 13, 2024
13 changes: 8 additions & 5 deletions include/Dialect/QUIR/Transforms/ExtractCircuits.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include "Dialect/QUIR/IR/QUIROps.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

Expand All @@ -41,27 +42,29 @@ struct ExtractCircuitsPass
llvm::StringRef getName() const override;

private:
void processOps(mlir::Operation *currentOp, OpBuilder topLevelBuilder,
OpBuilder circuitBuilder);
void processRegion(mlir::Region &region, OpBuilder topLevelBuilder,
OpBuilder circuitBuilder);
void processBlock(mlir::Block &block, OpBuilder topLevelBuilder,
OpBuilder circuitBuilder);
OpBuilder startCircuit(mlir::Location location, OpBuilder topLevelBuilder);
void endCircuit(mlir::Operation *firstOp, mlir::Operation *lastOp,
OpBuilder topLevelBuilder, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
uint64_t circuitCount;
uint64_t circuitCount = 0;
llvm::StringMap<Operation *> circuitOpsMap;

mlir::quir::CircuitOp currentCircuitOp;
mlir::quir::CircuitOp currentCircuitOp = nullptr;
mlir::quir::CallCircuitOp newCallCircuitOp;

llvm::SmallVector<Type> inputTypes;
llvm::SmallVector<Value> inputValues;
llvm::SmallVector<Type> outputTypes;
llvm::SmallVector<Value> outputValues;
std::vector<int> phyiscalIds;
std::unordered_map<uint32_t, int> argToId;

std::unordered_map<uint32_t, BlockArgument> circuitArguments;
std::unordered_map<Operation *, uint32_t> circuitOperands;
llvm::SmallVector<OpResult> originalResults;

Expand Down
224 changes: 82 additions & 142 deletions lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "Dialect/QUIR/Utils/Utils.h"

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
Expand All @@ -47,7 +46,9 @@

#include <algorithm>
#include <cassert>
#include <optional>
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Block.h>
#include <mlir/IR/Region.h>
#include <string>
#include <sys/types.h>
#include <vector>
Expand All @@ -63,25 +64,11 @@ llvm::cl::opt<bool>
llvm::cl::init(false));

// NOLINTNEXTLINE(misc-use-anonymous-namespace)
static std::optional<Operation *> localNextQuantumOpOrNull(Operation *op) {
Operation *nextOp = op;
while (nextOp) {
if (isQuantumOp(nextOp) && nextOp != op)
return nextOp;
if (nextOp->hasTrait<::mlir::RegionBranchOpInterface::Trait>()) {
// control flow found, no next quantum op
return std::nullopt;
}
if (isa<qcs::ParallelControlFlowOp>(nextOp))
return std::nullopt;
if (isa<oq3::CBitInsertBitOp>(nextOp))
return std::nullopt;
if (isa<quir::SwitchOp>(nextOp))
return std::nullopt;
nextOp = nextOp->getNextNode();
}
return std::nullopt;
} // localNextQuantumOpOrNull
static bool terminatesCircuit(Operation &op) {
return (op.hasTrait<::mlir::RegionBranchOpInterface::Trait>() ||
isa<qcs::ParallelControlFlowOp>(op) ||
isa<oq3::CBitInsertBitOp>(op) || isa<quir::SwitchOp>(op));
} // terminatesCircuit

OpBuilder ExtractCircuitsPass::startCircuit(Location location,
OpBuilder topLevelBuilder) {
Expand All @@ -91,9 +78,9 @@ OpBuilder ExtractCircuitsPass::startCircuit(Location location,
outputTypes.clear();
outputValues.clear();
originalResults.clear();
circuitArguments.clear();
circuitOperands.clear();
phyiscalIds.clear();
argToId.clear();

std::string const circuitName = "circuit_";
std::string newName = circuitName + std::to_string(circuitCount++);
Expand Down Expand Up @@ -132,19 +119,14 @@ void ExtractCircuitsPass::addToCircuit(
if (search == circuitOperands.end()) {
argumentIndex = inputValues.size();
inputValues.push_back(operand);
inputTypes.push_back(operand.getType());
circuitOperands[defOp] = argumentIndex;

currentCircuitOp.insertArgument(argumentIndex, operand.getType(), {},
currentOp->getLoc());
currentCircuitOp.getBody().addArgument(operand.getType(),
currentOp->getLoc());
if (isa<quir::DeclareQubitOp>(defOp)) {
auto physicalId = defOp->getAttrOfType<IntegerAttr>("id");
phyiscalIds.push_back(physicalId.getInt());
currentCircuitOp.setArgAttrs(
argumentIndex,
ArrayRef({NamedAttribute(
StringAttr::get(&getContext(),
mlir::quir::getPhysicalIdAttrName()),
physicalId)}));
auto id = defOp->getAttrOfType<IntegerAttr>("id").getInt();
phyiscalIds.push_back(id);
argToId[argumentIndex] = id;
}
} else {
argumentIndex = search->second;
Expand Down Expand Up @@ -175,17 +157,25 @@ void ExtractCircuitsPass::endCircuit(
// change the input / output types for the quir.circuit
auto opType = currentCircuitOp.getFunctionType();
currentCircuitOp.setType(topLevelBuilder.getFunctionType(
/*inputs=*/opType.getInputs(),
/*inputs=*/ArrayRef<Type>(inputTypes),
/*results=*/ArrayRef<Type>(outputTypes)));

for (const auto &[key, value] : argToId)
currentCircuitOp.setArgAttrs(
key,
ArrayRef({NamedAttribute(
StringAttr::get(&getContext(), mlir::quir::getPhysicalIdAttrName()),
topLevelBuilder.getI32IntegerAttr(value))}));

std::sort(phyiscalIds.begin(), phyiscalIds.end());
currentCircuitOp->setAttr(
mlir::quir::getPhysicalIdsAttrName(),
topLevelBuilder.getI32ArrayAttr(ArrayRef<int>(phyiscalIds)));

// insert call_circuit
// NOLINTNEXTLINE(misc-const-correctness)
OpBuilder builder(firstOp);
OpBuilder builder(lastOp);
builder.setInsertionPointAfter(lastOp);
newCallCircuitOp = builder.create<mlir::quir::CallCircuitOp>(
currentCircuitOp->getLoc(), currentCircuitOp.getSymName(),
TypeRange(outputTypes), ValueRange(inputValues));
Expand All @@ -207,118 +197,72 @@ void ExtractCircuitsPass::endCircuit(
LLVM_DEBUG(op->dump());
op->erase();
}

currentCircuitOp = nullptr;
}

void ExtractCircuitsPass::processOps(Operation *currentOp,
OpBuilder topLevelBuilder,
OpBuilder circuitBuilder) {
void ExtractCircuitsPass::processRegion(mlir::Region &region,
OpBuilder topLevelBuilder,
OpBuilder circuitBuilder) {
for (mlir::Block &block : region.getBlocks())
processBlock(block, topLevelBuilder, circuitBuilder);
}

void ExtractCircuitsPass::processBlock(mlir::Block &block,
OpBuilder topLevelBuilder,
OpBuilder circuitBuilder) {
llvm::SmallVector<Operation *> eraseList;

Operation *firstQuantumOp = nullptr;

// Handle Shot Loop delay differently
if (isa<quir::DelayOp>(currentOp) &&
isa<qcs::ShotInitOp>(currentOp->getNextNode())) {
// skip past shot init
currentOp = currentOp->getNextNode()->getNextNode();
}

while (currentOp) {

// Walk through current block of operations and pull out quantum
// operations into quir.circuits:
//
// 1. Identify first quantum operation
// 2. Start new circuit and clone quantum operation into circuit
// 2.a. startCircuit will create a new unique quir.circuit
// 3. Walk forward node by node
// 4. If node is a quantum operation clone into circuit
// 5. If not quantum or if control flow - end circuit
// 5.a. endCircuit will finish circuit, adjust circuit input / output,
// create call_circuit and erase original operations
// 6. If control flow - recursively call processOps for each region of
// control flow

// do not assume first operation is quantum and find first quantum operation
if (!firstQuantumOp) {

if (isQuantumOp(currentOp)) {
firstQuantumOp = currentOp;
} else {
// walk forward for first quantum operation or control flow
auto firstOrNull = localNextQuantumOpOrNull(currentOp);
if (firstOrNull) {
currentOp = firstOrNull.value();
firstQuantumOp = currentOp;
}
}
if (firstQuantumOp)
Operation *lastQuantumOp = nullptr;

// Walk through current block of operations and pull out quantum
// operations into quir.circuits:
//
// 1. Identify first quantum operation
// 2. Start new circuit and clone quantum operation into circuit
// 2.a. startCircuit will create a new unique quir.circuit
// 3. Walk forward node by node
// 4. If node is a quantum operation clone into circuit
// 5. If not quantum or if control flow - end circuit
// 5.a. endCircuit will finish circuit, adjust circuit input / output,
// create call_circuit and erase original operations
// 6. If control flow - recursively call processRegion for each region of
// control flow
for (Operation &currentOp : llvm::make_early_inc_range(block)) {
// Handle Shot Loop delay differently
if (isa<quir::DelayOp>(currentOp) &&
isa<qcs::ShotInitOp>(currentOp.getNextNode())) {
// skip past shot init
continue;
}
if (isQuantumOp(&currentOp)) {
// Start building circuit if not already
lastQuantumOp = &currentOp;
if (!currentCircuitOp) {
firstQuantumOp = lastQuantumOp;
circuitBuilder =
startCircuit(firstQuantumOp->getLoc(), topLevelBuilder);
}

// if operation is a quantum operation clone into circuit
if (isQuantumOp(currentOp))
addToCircuit(currentOp, circuitBuilder, eraseList);

// walk forward for next operation
auto nextOpOrNull = localNextQuantumOpOrNull(currentOp);
if (nextOpOrNull) {
currentOp = nextOpOrNull.value();
}
addToCircuit(&currentOp, circuitBuilder, eraseList);
continue;
}
} if (terminatesCircuit(currentOp)) {
// next operation was not quantum so if there is a circuit builder in
// progress there is an in progress circuit to be ended.
if (currentCircuitOp) {
endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder,
circuitBuilder, eraseList);
}

// next operation was not quantum so if there is a firstQuantumOp there is
// an in progress circuit to be ended.
if (firstQuantumOp) {
Operation *lastOp = currentOp;
// nextOpOrNull was null so advance one node
currentOp = currentOp->getNextNode();
endCircuit(firstQuantumOp, lastOp, topLevelBuilder, circuitBuilder,
eraseList);
}
firstQuantumOp = nullptr;

if (!currentOp)
break;

// handle control flow -- and recursively call processOps for control flow
// regions

if (isa<scf::IfOp>(currentOp)) {
auto ifOp = static_cast<scf::IfOp>(currentOp);
if (!ifOp.getThenRegion().empty())
processOps(&ifOp.getThenRegion().front().front(), topLevelBuilder,
circuitBuilder);
if (!ifOp.getElseRegion().empty())
processOps(&ifOp.getElseRegion().front().front(), topLevelBuilder,
circuitBuilder);
} else if (isa<scf::ForOp>(currentOp)) {
auto forOp = static_cast<scf::ForOp>(currentOp);
processOps(&forOp.getBody()->front(), topLevelBuilder, circuitBuilder);
} else if (isa<scf::WhileOp>(currentOp)) {
auto whileOp = static_cast<scf::WhileOp>(currentOp);
if (!whileOp.getBefore().empty())
processOps(&whileOp.getBefore().front().front(), topLevelBuilder,
circuitBuilder);
if (!whileOp.getAfter().empty())
processOps(&whileOp.getAfter().front().front(), topLevelBuilder,
circuitBuilder);
} else if (isa<quir::SwitchOp>(currentOp)) {
// NOLINTNEXTLINE(llvm-qualified-auto)
auto switchOp = static_cast<quir::SwitchOp>(currentOp);
for (auto &region : switchOp.getCaseRegions())
processOps(&region.front().front(), topLevelBuilder, circuitBuilder);
} else if (isa<qcs::ParallelControlFlowOp>(currentOp)) {
// NOLINTNEXTLINE(llvm-qualified-auto)
auto parOp = static_cast<qcs::ParallelControlFlowOp>(currentOp);
processOps(&parOp.getBody()->front(), topLevelBuilder, circuitBuilder);
} else if (currentOp->hasTrait<::mlir::RegionBranchOpInterface::Trait>()) {
currentOp->dump();
assert(false && "Unhandled control flow");
// handle control flow by recursively calling processBlock for control
// flow regions
for (mlir::Region &region : currentOp.getRegions())
processRegion(region, topLevelBuilder, circuitBuilder);
}
currentOp = currentOp->getNextNode();
}
// End of block complete the circuit
if (currentCircuitOp) {
endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, circuitBuilder,
eraseList);
}
}

Expand All @@ -327,9 +271,6 @@ void ExtractCircuitsPass::runOnOperation() {
if (!enableCircuits)
return;

circuitCount = 0;
currentCircuitOp = nullptr;

Operation *moduleOp = getOperation();

llvm::StringMap<Operation *> circuitOpsMap;
Expand All @@ -343,8 +284,7 @@ void ExtractCircuitsPass::runOnOperation() {
assert(mainFunc && "could not find the main func");

auto const builder = OpBuilder(mainFunc);
auto *firstOp = &mainFunc.getBody().front().front();
processOps(firstOp, builder, builder);
processRegion(mainFunc.getRegion(), builder, builder);
} // runOnOperation

llvm::StringRef ExtractCircuitsPass::getArgument() const {
Expand Down
Loading