Skip to content

Commit

Permalink
Pulse ALAP scheduling (#183)
Browse files Browse the repository at this point in the history
This PR adds pulse alap scheduling for pulse sequences of quantum gates
inside a circuit, based on the availability of involved ports

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Thomas Alexander <[email protected]>
  • Loading branch information
3 people authored Dec 28, 2023
1 parent 503eb93 commit c2637fe
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 286 deletions.
1 change: 1 addition & 0 deletions include/Dialect/Pulse/IR/PulseInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ llvm::Optional<uint64_t> getSetupLatency(mlir::Operation *op);
void setSetupLatency(mlir::Operation *op, uint64_t setupLatency);
llvm::Expected<uint64_t> getDuration(mlir::Operation *op,
mlir::Operation *callSequenceOp = nullptr);
llvm::Expected<mlir::ArrayAttr> getPorts(mlir::Operation *op);
void setDuration(mlir::Operation *op, uint64_t duration);

} // namespace mlir::pulse::interfaces_impl
Expand Down
15 changes: 15 additions & 0 deletions include/Dialect/Pulse/IR/PulseInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,28 @@ def PulseOpSchedulingInterface : OpInterface<"PulseOpSchedulingInterface"> {
return PulseOpSchedulingInterface::setDuration($_op, other);
}]
>,
InterfaceMethod<
/*desc=*/"Get the ports of a pulse operation",
/*retTy=*/"::llvm::Expected<mlir::ArrayAttr>",
/*methodName=*/"getPorts",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
// By default, return the pulse.argPorts attribute
return PulseOpSchedulingInterface::getPorts($_op);
}]
>,
];

let extraSharedClassDeclaration = [{
static llvm::Optional<int64_t> getTimepoint(mlir::Operation *op) {
return interfaces_impl::getTimepoint(op);
}

static llvm::Expected<mlir::ArrayAttr> getPorts(mlir::Operation *op) {
return interfaces_impl::getPorts(op);
}

static void setTimepoint(mlir::Operation *op, int64_t timepoint) {
return interfaces_impl::setTimepoint(op, timepoint);
}
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/Pulse/IR/PulseOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,7 @@ def Pulse_CallSequenceOp : Pulse_Op<"call_sequence", [CallOpInterface, MemRefsNo

def Pulse_SequenceOp : Pulse_Op<"sequence", [
AutomaticAllocationScope, CallableOpInterface,
DeclareOpInterfaceMethods<PulseOpSchedulingInterface, ["getDuration"]>,
FunctionOpInterface, IsolatedFromAbove, Symbol, SequenceAllowed
]> {
let summary = "An operation with a name containing a single `SSACFG` region corresponding to a pulse sequence execution";
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/Pulse/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "Dialect/Pulse/Transforms/MergeDelays.h"
#include "Dialect/Pulse/Transforms/RemoveUnusedArguments.h"
#include "Dialect/Pulse/Transforms/SchedulePort.h"
#include "Dialect/Pulse/Transforms/Scheduling.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
Expand Down
122 changes: 55 additions & 67 deletions include/Dialect/Pulse/Transforms/Scheduling.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- Scheduling.h - Add absolute timing to defcal calls. ------*- C++ -*-===//
//===- scheduling.h --- quantum circuits pulse scheduling -------*- C++ -*-===//
//
// (C) Copyright IBM 2023.
//
Expand All @@ -14,72 +14,60 @@
//
//===----------------------------------------------------------------------===//
///
/// This file declares the pass for adding absolute timing to defcal calls.
/// This file implements the pass for scheduling the quantum circuits at pulse
/// level, based on the availability of involved ports
///
//===----------------------------------------------------------------------===//

//#ifndef PULSE_SCHEDULING_H
//#define PULSE_SCHEDULING_H

//#include <unordered_map>
//#include <unordered_set>

//#include "Dialect/Pulse/IR/PulseOps.h"
//#include "mlir/Pass/Pass.h"

// namespace mlir::pulse {

//// This pass applies absolute timing to each relevant Pulse IR instruction.
//// Timing is calculated on a per frame basis.
///*** Steps:
// * 1. Identify each defcal gate call.
// * 2. Find associated defcal body.
// * 3. Compute and store duration of each waveform and initialze time on each
// *frame. 4. For each play/delay instruction, increment the frame timing. Add
// *time attribute to instruction. For each barrier instruction, resolve to
// delays *on push forward basis (frames will be delayed to maximum time amongst
// all *frames).
// ***/
// struct SchedulingPass : public PassWrapper<SchedulingPass, OperationPass<>> {

// std::unordered_set<uint>
// scheduledDefCals; // hashes of defcal's that have already been scheduled
// std::unordered_map<uint, uint>
// pulseDurations; // mapping of waveform hashes to durations
// std::unordered_map<uint, uint>
// frameTimes; // mapping of frame hashes to time on that frame

// // Hash an operation based on the result.
// auto getResultHash(Operation *op) -> uint;

// // Check if the pulse hash is cached in pulse durations.
// // If it is cached, the hash will be found in pulseDurations.
// auto pulseCached(llvm::hash_code hash) -> bool;

// // Get hash and time of a frame as a std::pair
// auto getFrameHashAndTime(mlir::Value &frame) -> std::pair<uint, uint>;

// // Get the maximum time among a set of frames
// auto getMaxTime(mlir::OperandRange &frames) -> uint;

// // Process each operation in the defcal
// template <class WaveformOp>
// void processOp(WaveformOp &wfrOp);

// void processOp(Frame_CreateOp &frameOp);

// void processOp(DelayOp &delayOp);
// void processOp(BarrierOp &barrierOp);
// void processOp(PlayOp &playOp);
// void processOp(CaptureOp &captureOp);

// // Schedule the defcal
// void schedule(Operation *defCalOp);

// // Entry point for the pass
// void runOnOperation() override;

//}; // end struct SchedulingPass
//} // namespace mlir::pulse

//#endif // PULSE_SCHEDULING_H
#ifndef SCHEDULING_PULSE_SEQUENCES_H
#define SCHEDULING_PULSE_SEQUENCES_H

#include "Dialect/Pulse/IR/PulseOps.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"

namespace mlir::pulse {

struct quantumCircuitPulseSchedulingPass
: public PassWrapper<quantumCircuitPulseSchedulingPass,
OperationPass<ModuleOp>> {
public:
enum SchedulingMethod { ALAP, ASAP };
SchedulingMethod SCHEDULING_METHOD = ALAP;

// this pass can optionally receive an string specifying the scheduling
// method; default method is alap scheduling
quantumCircuitPulseSchedulingPass() = default;
quantumCircuitPulseSchedulingPass(
const quantumCircuitPulseSchedulingPass &pass)
: PassWrapper(pass) {}
quantumCircuitPulseSchedulingPass(SchedulingMethod inSchedulingMethod) {
SCHEDULING_METHOD = inSchedulingMethod;
}

void runOnOperation() override;

llvm::StringRef getArgument() const override;
llvm::StringRef getDescription() const override;

// optionally, one can override the scheduling method with this option
Option<std::string> schedulingMethod{
*this, "scheduling-method",
llvm::cl::desc("an string to specify scheduling method"),
llvm::cl::value_desc("filename"), llvm::cl::init("")};

private:
// map to keep track of next availability of ports
std::map<std::string, int> portNameToNextAvailabilityMap;

void scheduleAlap(mlir::pulse::CallSequenceOp quantumCircuitCallSequenceOp);
int getNextAvailableTimeOfPorts(mlir::ArrayAttr ports);
void updatePortAvailabilityMap(mlir::ArrayAttr ports,
int updatedAvailableTime);
static mlir::pulse::SequenceOp
getSequenceOp(mlir::pulse::CallSequenceOp callSequenceOp);
};
} // namespace mlir::pulse

#endif // SCHEDULING_PULSE_SEQUENCES_H
3 changes: 1 addition & 2 deletions lib/Conversion/QUIRToPulse/LoadPulseCals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::BarrierOp barrierOp,
void LoadPulseCalsPass::loadPulseCals(mlir::quir::DelayOp delayOp,
CallCircuitOp callCircuitOp,
FuncOp funcOp) {

OpBuilder builder(funcOp.body());

std::vector<Value> qubitOperands;
Expand All @@ -264,7 +263,7 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::DelayOp delayOp,
delayOp->setAttr("pulse.calName", builder.getStringAttr(gateMangledName));
if (pulseCalsNameToSequenceMap.find(gateMangledName) !=
pulseCalsNameToSequenceMap.end()) {
// found a pulse calibration for the barrier gate
// found a pulse calibration for the delay gate
addPulseCalToModule(funcOp, pulseCalsNameToSequenceMap[gateMangledName]);
return;
}
Expand Down
23 changes: 23 additions & 0 deletions lib/Conversion/QUIRToPulse/QUIRToPulse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,29 @@ void QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp callCircuitOp,
convertedPulseSequenceOpReturnTypes.push_back(type);
for (auto val : pulseCalCallSequenceOp.res())
convertedPulseSequenceOpReturnValues.push_back(val);

// add starting timepoint for delayOp
if (auto delayOp = dyn_cast<mlir::quir::DelayOp>(quirOp)) {
uint64_t durValue = 0;
if (delayOp.time().isa<BlockArgument>()) {
uint argNum = delayOp.time().dyn_cast<BlockArgument>().getArgNumber();
auto durOpConstantOp = callCircuitOp.getOperand(argNum)
.getDefiningOp<mlir::quir::ConstantOp>();
auto durOp = quir::getDuration(durOpConstantOp).get();
durValue = static_cast<uint>(durOp.getDuration().convertToDouble());
assert(durOp.getType().dyn_cast<DurationType>().getUnits() ==
TimeUnits::dt &&
"this pass only accepts durations with dt unit");
} else {
auto durOp = quir::getDuration(delayOp).get();
durValue = static_cast<uint>(durOp.getDuration().convertToDouble());
assert(durOp.getType().dyn_cast<DurationType>().getUnits() ==
TimeUnits::dt &&
"this pass only accepts durations with dt unit");
}
PulseOpSchedulingInterface::setDuration(pulseCalCallSequenceOp,
durValue);
}
} else
assert(((isa<quir::ConstantOp>(quirOp) or isa<quir::ReturnOp>(quirOp) or
isa<quir::CircuitOp>(quirOp))) &&
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Pulse/IR/PulseInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ interfaces_impl::getDuration(Operation *op, Operation *callSequenceOp) {
"Operation does not have a pulse.duration attribute.");
}

llvm::Expected<mlir::ArrayAttr> interfaces_impl::getPorts(mlir::Operation *op) {
if (op->hasAttrOfType<ArrayAttr>("pulse.argPorts"))
return op->getAttrOfType<ArrayAttr>("pulse.argPorts");
return llvm::createStringError(
llvm::inconvertibleErrorCode(),
"Operation does not have a pulse.argPorts attribute.");
}

void interfaces_impl::setDuration(Operation *op, uint64_t duration) {
mlir::OpBuilder builder(op);
op->setAttr("pulse.duration", builder.getI64IntegerAttr(duration));
Expand Down
17 changes: 17 additions & 0 deletions lib/Dialect/Pulse/IR/PulseOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,23 @@ CallSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//
//===----------------------------------------------------------------------===//

llvm::Expected<uint64_t>
SequenceOp::getDuration(mlir::Operation *callSequenceOp = nullptr) {
// first, check if the sequence has duration attribute. If not, also check if
// the call sequence has duration attribute; e.g., for sequences that receives
// delay arguments, duration of the sequence can vary depending on the
// argument, so we look at the duration of call sequence as well
if ((*this)->hasAttr("pulse.duration"))
return static_cast<uint64_t>(
(*this)->getAttrOfType<IntegerAttr>("pulse.duration").getInt());
if (callSequenceOp->hasAttr("pulse.duration"))
return static_cast<uint64_t>(
callSequenceOp->getAttrOfType<IntegerAttr>("pulse.duration").getInt());
return llvm::createStringError(
llvm::inconvertibleErrorCode(),
"Operation does not have a pulse.duration attribute.");
}

static ParseResult parseSequenceOp(OpAsmParser &parser,
OperationState &result) {
auto buildSequenceType =
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Pulse/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void registerPulsePasses() {
PassRegistration<MergeDelayPass>();
PassRegistration<RemoveUnusedArgumentsPass>();
PassRegistration<SchedulePortPass>();
PassRegistration<quantumCircuitPulseSchedulingPass>();
PassRegistration<ClassicalOnlyDetectionPass>();
}

Expand Down
Loading

0 comments on commit c2637fe

Please sign in to comment.