Skip to content

Commit

Permalink
Merge pull request #1 from taalexander/taa-update-ce-op-walking
Browse files Browse the repository at this point in the history
Update operation walking to be more efficient.
  • Loading branch information
bcdonovan authored Mar 12, 2024
2 parents 603cea8 + 6473e9c commit 2bfa779
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 130 deletions.
10 changes: 6 additions & 4 deletions include/Dialect/QUIR/Transforms/ExtractCircuits.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,20 @@ struct ExtractCircuitsPass
llvm::StringRef getName() const override;

private:
void processOps(mlir::Operation *currentOp, OpBuilder topLevelBuilder,
OpBuilder circuitBuilder);
void processRegion(mlir::Region &region, OpBuilder topLevelBuilder,
OpBuilder circuitBuilder);
void processBlock(mlir::Block &block, OpBuilder topLevelBuilder,
OpBuilder circuitBuilder);
OpBuilder startCircuit(mlir::Location location, OpBuilder topLevelBuilder);
void endCircuit(mlir::Operation *firstOp, mlir::Operation *lastOp,
OpBuilder topLevelBuilder, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
uint64_t circuitCount;
uint64_t circuitCount = 0;
llvm::StringMap<Operation *> circuitOpsMap;

mlir::quir::CircuitOp currentCircuitOp;
mlir::quir::CircuitOp currentCircuitOp = nullptr;
mlir::quir::CallCircuitOp newCallCircuitOp;

llvm::SmallVector<Type> inputTypes;
Expand Down
187 changes: 61 additions & 126 deletions lib/Dialect/QUIR/Transforms/ExtractCircuits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,11 @@ llvm::cl::opt<bool>
llvm::cl::init(false));

// NOLINTNEXTLINE(misc-use-anonymous-namespace)
static std::optional<Operation *> localNextQuantumOpOrNull(Operation *op) {
Operation *nextOp = op;
while (nextOp) {
if (isQuantumOp(nextOp) && nextOp != op)
return nextOp;
if (nextOp->hasTrait<::mlir::RegionBranchOpInterface::Trait>()) {
// control flow found, no next quantum op
return std::nullopt;
}
if (isa<qcs::ParallelControlFlowOp>(nextOp))
return std::nullopt;
if (isa<oq3::CBitInsertBitOp>(nextOp))
return std::nullopt;
if (isa<quir::SwitchOp>(nextOp))
return std::nullopt;
nextOp = nextOp->getNextNode();
}
return std::nullopt;
} // localNextQuantumOpOrNull
static bool terminatesCircuit(Operation &op) {
return (op.hasTrait<::mlir::RegionBranchOpInterface::Trait>() ||
isa<qcs::ParallelControlFlowOp>(op) ||
isa<oq3::CBitInsertBitOp>(op) || isa<quir::SwitchOp>(op));
} // terminatesCircuit

OpBuilder ExtractCircuitsPass::startCircuit(Location location,
OpBuilder topLevelBuilder) {
Expand Down Expand Up @@ -210,118 +196,71 @@ void ExtractCircuitsPass::endCircuit(
LLVM_DEBUG(op->dump());
op->erase();
}

currentCircuitOp = nullptr;
}

void ExtractCircuitsPass::processOps(Operation *currentOp,
OpBuilder topLevelBuilder,
OpBuilder circuitBuilder) {
void ExtractCircuitsPass::processRegion(mlir::Region &region,
OpBuilder topLevelBuilder,
OpBuilder circuitBuilder) {
for (mlir::Block &block : region.getBlocks())
processBlock(block, topLevelBuilder, circuitBuilder);
}

void ExtractCircuitsPass::processBlock(mlir::Block &block,
OpBuilder topLevelBuilder,
OpBuilder circuitBuilder) {
llvm::SmallVector<Operation *> eraseList;

Operation *firstQuantumOp = nullptr;

// Handle Shot Loop delay differently
if (isa<quir::DelayOp>(currentOp) &&
isa<qcs::ShotInitOp>(currentOp->getNextNode())) {
// skip past shot init
currentOp = currentOp->getNextNode()->getNextNode();
}

while (currentOp) {

// Walk through current block of operations and pull out quantum
// operations into quir.circuits:
//
// 1. Identify first quantum operation
// 2. Start new circuit and clone quantum operation into circuit
// 2.a. startCircuit will create a new unique quir.circuit
// 3. Walk forward node by node
// 4. If node is a quantum operation clone into circuit
// 5. If not quantum or if control flow - end circuit
// 5.a. endCircuit will finish circuit, adjust circuit input / output,
// create call_circuit and erase original operations
// 6. If control flow - recursively call processOps for each region of
// control flow

// do not assume first operation is quantum and find first quantum operation
if (!firstQuantumOp) {

if (isQuantumOp(currentOp)) {
firstQuantumOp = currentOp;
} else {
// walk forward for first quantum operation or control flow
auto firstOrNull = localNextQuantumOpOrNull(currentOp);
if (firstOrNull) {
currentOp = firstOrNull.value();
firstQuantumOp = currentOp;
}
}
if (firstQuantumOp)
Operation *lastQuantumOp = nullptr;

// Walk through current block of operations and pull out quantum
// operations into quir.circuits:
//
// 1. Identify first quantum operation
// 2. Start new circuit and clone quantum operation into circuit
// 2.a. startCircuit will create a new unique quir.circuit
// 3. Walk forward node by node
// 4. If node is a quantum operation clone into circuit
// 5. If not quantum or if control flow - end circuit
// 5.a. endCircuit will finish circuit, adjust circuit input / output,
// create call_circuit and erase original operations
// 6. If control flow - recursively call processRegion for each region of
// control flow
for (Operation &currentOp : llvm::make_early_inc_range(block)) {
// Handle Shot Loop delay differently
if (isa<quir::DelayOp>(currentOp) &&
isa<qcs::ShotInitOp>(currentOp.getNextNode())) {
// skip past shot init
continue;
} else if (isQuantumOp(&currentOp)) {
// Start building circuit if not already
lastQuantumOp = &currentOp;
if (!currentCircuitOp) {
firstQuantumOp = lastQuantumOp;
circuitBuilder =
startCircuit(firstQuantumOp->getLoc(), topLevelBuilder);
}

// if operation is a quantum operation clone into circuit
if (isQuantumOp(currentOp))
addToCircuit(currentOp, circuitBuilder, eraseList);

// walk forward for next operation
auto nextOpOrNull = localNextQuantumOpOrNull(currentOp);
if (nextOpOrNull) {
currentOp = nextOpOrNull.value();
}
addToCircuit(&currentOp, circuitBuilder, eraseList);
continue;
}
} else if (terminatesCircuit(currentOp)) {
// next operation was not quantum so if there is a circuit builder in
// progress there is an in progress circuit to be ended.
if (currentCircuitOp) {
endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder,
circuitBuilder, eraseList);
}

// next operation was not quantum so if there is a firstQuantumOp there is
// an in progress circuit to be ended.
if (firstQuantumOp) {
Operation *lastOp = currentOp;
// nextOpOrNull was null so advance one node
currentOp = currentOp->getNextNode();
endCircuit(firstQuantumOp, lastOp, topLevelBuilder, circuitBuilder,
eraseList);
}
firstQuantumOp = nullptr;

if (!currentOp)
break;

// handle control flow -- and recursively call processOps for control flow
// regions

if (isa<scf::IfOp>(currentOp)) {
auto ifOp = static_cast<scf::IfOp>(currentOp);
if (!ifOp.getThenRegion().empty())
processOps(&ifOp.getThenRegion().front().front(), topLevelBuilder,
circuitBuilder);
if (!ifOp.getElseRegion().empty())
processOps(&ifOp.getElseRegion().front().front(), topLevelBuilder,
circuitBuilder);
} else if (isa<scf::ForOp>(currentOp)) {
auto forOp = static_cast<scf::ForOp>(currentOp);
processOps(&forOp.getBody()->front(), topLevelBuilder, circuitBuilder);
} else if (isa<scf::WhileOp>(currentOp)) {
auto whileOp = static_cast<scf::WhileOp>(currentOp);
if (!whileOp.getBefore().empty())
processOps(&whileOp.getBefore().front().front(), topLevelBuilder,
circuitBuilder);
if (!whileOp.getAfter().empty())
processOps(&whileOp.getAfter().front().front(), topLevelBuilder,
circuitBuilder);
} else if (isa<quir::SwitchOp>(currentOp)) {
// NOLINTNEXTLINE(llvm-qualified-auto)
auto switchOp = static_cast<quir::SwitchOp>(currentOp);
for (auto &region : switchOp.getCaseRegions())
processOps(&region.front().front(), topLevelBuilder, circuitBuilder);
} else if (isa<qcs::ParallelControlFlowOp>(currentOp)) {
// NOLINTNEXTLINE(llvm-qualified-auto)
auto parOp = static_cast<qcs::ParallelControlFlowOp>(currentOp);
processOps(&parOp.getBody()->front(), topLevelBuilder, circuitBuilder);
} else if (currentOp->hasTrait<::mlir::RegionBranchOpInterface::Trait>()) {
currentOp->dump();
assert(false && "Unhandled control flow");
// handle control flow by recursively calling processBlock for control
// flow regions
for (mlir::Region &region : currentOp.getRegions())
processRegion(region, topLevelBuilder, circuitBuilder);
}
currentOp = currentOp->getNextNode();
}
// End of block complete the circuit
if (currentCircuitOp) {
endCircuit(firstQuantumOp, lastQuantumOp, topLevelBuilder, circuitBuilder,
eraseList);
}
}

Expand All @@ -330,9 +269,6 @@ void ExtractCircuitsPass::runOnOperation() {
if (!enableCircuits)
return;

circuitCount = 0;
currentCircuitOp = nullptr;

Operation *moduleOp = getOperation();

llvm::StringMap<Operation *> circuitOpsMap;
Expand All @@ -346,8 +282,7 @@ void ExtractCircuitsPass::runOnOperation() {
assert(mainFunc && "could not find the main func");

auto const builder = OpBuilder(mainFunc);
auto *firstOp = &mainFunc.getBody().front().front();
processOps(firstOp, builder, builder);
processRegion(mainFunc.getRegion(), builder, builder);
} // runOnOperation

llvm::StringRef ExtractCircuitsPass::getArgument() const {
Expand Down

0 comments on commit 2bfa779

Please sign in to comment.