diff --git a/include/Dialect/Pulse/IR/PulseInterfaces.h b/include/Dialect/Pulse/IR/PulseInterfaces.h index 1d4cf984d..9d007ad65 100644 --- a/include/Dialect/Pulse/IR/PulseInterfaces.h +++ b/include/Dialect/Pulse/IR/PulseInterfaces.h @@ -42,6 +42,7 @@ llvm::Optional getSetupLatency(mlir::Operation *op); void setSetupLatency(mlir::Operation *op, uint64_t setupLatency); llvm::Expected getDuration(mlir::Operation *op, mlir::Operation *callSequenceOp = nullptr); +llvm::Expected getPorts(mlir::Operation *op); void setDuration(mlir::Operation *op, uint64_t duration); } // namespace mlir::pulse::interfaces_impl diff --git a/include/Dialect/Pulse/IR/PulseInterfaces.td b/include/Dialect/Pulse/IR/PulseInterfaces.td index 943c08334..50c321054 100644 --- a/include/Dialect/Pulse/IR/PulseInterfaces.td +++ b/include/Dialect/Pulse/IR/PulseInterfaces.td @@ -100,6 +100,17 @@ def PulseOpSchedulingInterface : OpInterface<"PulseOpSchedulingInterface"> { return PulseOpSchedulingInterface::setDuration($_op, other); }] >, + InterfaceMethod< + /*desc=*/"Get the ports of a pulse operation", + /*retTy=*/"::llvm::Expected", + /*methodName=*/"getPorts", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + // By default, return the pulse.argPorts attribute + return PulseOpSchedulingInterface::getPorts($_op); + }] + >, ]; let extraSharedClassDeclaration = [{ @@ -107,6 +118,10 @@ def PulseOpSchedulingInterface : OpInterface<"PulseOpSchedulingInterface"> { return interfaces_impl::getTimepoint(op); } + static llvm::Expected getPorts(mlir::Operation *op) { + return interfaces_impl::getPorts(op); + } + static void setTimepoint(mlir::Operation *op, int64_t timepoint) { return interfaces_impl::setTimepoint(op, timepoint); } diff --git a/include/Dialect/Pulse/IR/PulseOps.td b/include/Dialect/Pulse/IR/PulseOps.td index ff919d368..8cd020c3e 100644 --- a/include/Dialect/Pulse/IR/PulseOps.td +++ b/include/Dialect/Pulse/IR/PulseOps.td @@ -712,6 +712,7 @@ def Pulse_CallSequenceOp : Pulse_Op<"call_sequence", [CallOpInterface, MemRefsNo def Pulse_SequenceOp : Pulse_Op<"sequence", [ AutomaticAllocationScope, CallableOpInterface, + DeclareOpInterfaceMethods, FunctionOpInterface, IsolatedFromAbove, Symbol, SequenceAllowed ]> { let summary = "An operation with a name containing a single `SSACFG` region corresponding to a pulse sequence execution"; diff --git a/include/Dialect/Pulse/Transforms/Passes.h b/include/Dialect/Pulse/Transforms/Passes.h index bc5502209..e8e1245ac 100644 --- a/include/Dialect/Pulse/Transforms/Passes.h +++ b/include/Dialect/Pulse/Transforms/Passes.h @@ -27,6 +27,7 @@ #include "Dialect/Pulse/Transforms/MergeDelays.h" #include "Dialect/Pulse/Transforms/RemoveUnusedArguments.h" #include "Dialect/Pulse/Transforms/SchedulePort.h" +#include "Dialect/Pulse/Transforms/Scheduling.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" diff --git a/include/Dialect/Pulse/Transforms/Scheduling.h b/include/Dialect/Pulse/Transforms/Scheduling.h index 7b2c98ef5..c1b37643f 100644 --- a/include/Dialect/Pulse/Transforms/Scheduling.h +++ b/include/Dialect/Pulse/Transforms/Scheduling.h @@ -1,4 +1,4 @@ -//===- Scheduling.h - Add absolute timing to defcal calls. ------*- C++ -*-===// +//===- scheduling.h --- quantum circuits pulse scheduling -------*- C++ -*-===// // // (C) Copyright IBM 2023. // @@ -14,72 +14,60 @@ // //===----------------------------------------------------------------------===// /// -/// This file declares the pass for adding absolute timing to defcal calls. +/// This file implements the pass for scheduling the quantum circuits at pulse +/// level, based on the availability of involved ports /// //===----------------------------------------------------------------------===// -//#ifndef PULSE_SCHEDULING_H -//#define PULSE_SCHEDULING_H - -//#include -//#include - -//#include "Dialect/Pulse/IR/PulseOps.h" -//#include "mlir/Pass/Pass.h" - -// namespace mlir::pulse { - -//// This pass applies absolute timing to each relevant Pulse IR instruction. -//// Timing is calculated on a per frame basis. -///*** Steps: -// * 1. Identify each defcal gate call. -// * 2. Find associated defcal body. -// * 3. Compute and store duration of each waveform and initialze time on each -// *frame. 4. For each play/delay instruction, increment the frame timing. Add -// *time attribute to instruction. For each barrier instruction, resolve to -// delays *on push forward basis (frames will be delayed to maximum time amongst -// all *frames). -// ***/ -// struct SchedulingPass : public PassWrapper> { - -// std::unordered_set -// scheduledDefCals; // hashes of defcal's that have already been scheduled -// std::unordered_map -// pulseDurations; // mapping of waveform hashes to durations -// std::unordered_map -// frameTimes; // mapping of frame hashes to time on that frame - -// // Hash an operation based on the result. -// auto getResultHash(Operation *op) -> uint; - -// // Check if the pulse hash is cached in pulse durations. -// // If it is cached, the hash will be found in pulseDurations. -// auto pulseCached(llvm::hash_code hash) -> bool; - -// // Get hash and time of a frame as a std::pair -// auto getFrameHashAndTime(mlir::Value &frame) -> std::pair; - -// // Get the maximum time among a set of frames -// auto getMaxTime(mlir::OperandRange &frames) -> uint; - -// // Process each operation in the defcal -// template -// void processOp(WaveformOp &wfrOp); - -// void processOp(Frame_CreateOp &frameOp); - -// void processOp(DelayOp &delayOp); -// void processOp(BarrierOp &barrierOp); -// void processOp(PlayOp &playOp); -// void processOp(CaptureOp &captureOp); - -// // Schedule the defcal -// void schedule(Operation *defCalOp); - -// // Entry point for the pass -// void runOnOperation() override; - -//}; // end struct SchedulingPass -//} // namespace mlir::pulse - -//#endif // PULSE_SCHEDULING_H +#ifndef SCHEDULING_PULSE_SEQUENCES_H +#define SCHEDULING_PULSE_SEQUENCES_H + +#include "Dialect/Pulse/IR/PulseOps.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::pulse { + +struct quantumCircuitPulseSchedulingPass + : public PassWrapper> { +public: + enum SchedulingMethod { ALAP, ASAP }; + SchedulingMethod SCHEDULING_METHOD = ALAP; + + // this pass can optionally receive an string specifying the scheduling + // method; default method is alap scheduling + quantumCircuitPulseSchedulingPass() = default; + quantumCircuitPulseSchedulingPass( + const quantumCircuitPulseSchedulingPass &pass) + : PassWrapper(pass) {} + quantumCircuitPulseSchedulingPass(SchedulingMethod inSchedulingMethod) { + SCHEDULING_METHOD = inSchedulingMethod; + } + + void runOnOperation() override; + + llvm::StringRef getArgument() const override; + llvm::StringRef getDescription() const override; + + // optionally, one can override the scheduling method with this option + Option schedulingMethod{ + *this, "scheduling-method", + llvm::cl::desc("an string to specify scheduling method"), + llvm::cl::value_desc("filename"), llvm::cl::init("")}; + +private: + // map to keep track of next availability of ports + std::map portNameToNextAvailabilityMap; + + void scheduleAlap(mlir::pulse::CallSequenceOp quantumCircuitCallSequenceOp); + int getNextAvailableTimeOfPorts(mlir::ArrayAttr ports); + void updatePortAvailabilityMap(mlir::ArrayAttr ports, + int updatedAvailableTime); + static mlir::pulse::SequenceOp + getSequenceOp(mlir::pulse::CallSequenceOp callSequenceOp); +}; +} // namespace mlir::pulse + +#endif // SCHEDULING_PULSE_SEQUENCES_H diff --git a/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp b/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp index 63912ee61..8630442a8 100644 --- a/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp +++ b/lib/Conversion/QUIRToPulse/LoadPulseCals.cpp @@ -253,7 +253,6 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::BarrierOp barrierOp, void LoadPulseCalsPass::loadPulseCals(mlir::quir::DelayOp delayOp, CallCircuitOp callCircuitOp, FuncOp funcOp) { - OpBuilder builder(funcOp.body()); std::vector qubitOperands; @@ -264,7 +263,7 @@ void LoadPulseCalsPass::loadPulseCals(mlir::quir::DelayOp delayOp, delayOp->setAttr("pulse.calName", builder.getStringAttr(gateMangledName)); if (pulseCalsNameToSequenceMap.find(gateMangledName) != pulseCalsNameToSequenceMap.end()) { - // found a pulse calibration for the barrier gate + // found a pulse calibration for the delay gate addPulseCalToModule(funcOp, pulseCalsNameToSequenceMap[gateMangledName]); return; } diff --git a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp index f6df21441..d36079aaa 100644 --- a/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp +++ b/lib/Conversion/QUIRToPulse/QUIRToPulse.cpp @@ -143,6 +143,29 @@ void QUIRToPulsePass::convertCircuitToSequence(CallCircuitOp callCircuitOp, convertedPulseSequenceOpReturnTypes.push_back(type); for (auto val : pulseCalCallSequenceOp.res()) convertedPulseSequenceOpReturnValues.push_back(val); + + // add starting timepoint for delayOp + if (auto delayOp = dyn_cast(quirOp)) { + uint64_t durValue = 0; + if (delayOp.time().isa()) { + uint argNum = delayOp.time().dyn_cast().getArgNumber(); + auto durOpConstantOp = callCircuitOp.getOperand(argNum) + .getDefiningOp(); + auto durOp = quir::getDuration(durOpConstantOp).get(); + durValue = static_cast(durOp.getDuration().convertToDouble()); + assert(durOp.getType().dyn_cast().getUnits() == + TimeUnits::dt && + "this pass only accepts durations with dt unit"); + } else { + auto durOp = quir::getDuration(delayOp).get(); + durValue = static_cast(durOp.getDuration().convertToDouble()); + assert(durOp.getType().dyn_cast().getUnits() == + TimeUnits::dt && + "this pass only accepts durations with dt unit"); + } + PulseOpSchedulingInterface::setDuration(pulseCalCallSequenceOp, + durValue); + } } else assert(((isa(quirOp) or isa(quirOp) or isa(quirOp))) && diff --git a/lib/Dialect/Pulse/IR/PulseInterfaces.cpp b/lib/Dialect/Pulse/IR/PulseInterfaces.cpp index e24ff399a..4fc6b8bbf 100644 --- a/lib/Dialect/Pulse/IR/PulseInterfaces.cpp +++ b/lib/Dialect/Pulse/IR/PulseInterfaces.cpp @@ -69,6 +69,14 @@ interfaces_impl::getDuration(Operation *op, Operation *callSequenceOp) { "Operation does not have a pulse.duration attribute."); } +llvm::Expected interfaces_impl::getPorts(mlir::Operation *op) { + if (op->hasAttrOfType("pulse.argPorts")) + return op->getAttrOfType("pulse.argPorts"); + return llvm::createStringError( + llvm::inconvertibleErrorCode(), + "Operation does not have a pulse.argPorts attribute."); +} + void interfaces_impl::setDuration(Operation *op, uint64_t duration) { mlir::OpBuilder builder(op); op->setAttr("pulse.duration", builder.getI64IntegerAttr(duration)); diff --git a/lib/Dialect/Pulse/IR/PulseOps.cpp b/lib/Dialect/Pulse/IR/PulseOps.cpp index 36e7f6c25..c3f9c62ef 100644 --- a/lib/Dialect/Pulse/IR/PulseOps.cpp +++ b/lib/Dialect/Pulse/IR/PulseOps.cpp @@ -184,6 +184,23 @@ CallSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // //===----------------------------------------------------------------------===// +llvm::Expected +SequenceOp::getDuration(mlir::Operation *callSequenceOp = nullptr) { + // first, check if the sequence has duration attribute. If not, also check if + // the call sequence has duration attribute; e.g., for sequences that receives + // delay arguments, duration of the sequence can vary depending on the + // argument, so we look at the duration of call sequence as well + if ((*this)->hasAttr("pulse.duration")) + return static_cast( + (*this)->getAttrOfType("pulse.duration").getInt()); + if (callSequenceOp->hasAttr("pulse.duration")) + return static_cast( + callSequenceOp->getAttrOfType("pulse.duration").getInt()); + return llvm::createStringError( + llvm::inconvertibleErrorCode(), + "Operation does not have a pulse.duration attribute."); +} + static ParseResult parseSequenceOp(OpAsmParser &parser, OperationState &result) { auto buildSequenceType = diff --git a/lib/Dialect/Pulse/Transforms/Passes.cpp b/lib/Dialect/Pulse/Transforms/Passes.cpp index a465a44fe..7ae5e76a7 100644 --- a/lib/Dialect/Pulse/Transforms/Passes.cpp +++ b/lib/Dialect/Pulse/Transforms/Passes.cpp @@ -43,6 +43,7 @@ void registerPulsePasses() { PassRegistration(); PassRegistration(); PassRegistration(); + PassRegistration(); PassRegistration(); } diff --git a/lib/Dialect/Pulse/Transforms/Scheduling.cpp b/lib/Dialect/Pulse/Transforms/Scheduling.cpp index 635500722..2f1eb006b 100644 --- a/lib/Dialect/Pulse/Transforms/Scheduling.cpp +++ b/lib/Dialect/Pulse/Transforms/Scheduling.cpp @@ -1,4 +1,4 @@ -//===- Scheduling.cpp - Determine absolute timing in defcal's. ---*- C++-*-===// +//===- Scheduling.cpp --- quantum circuits pulse scheduling ----*- C++ -*-===// // // (C) Copyright IBM 2023. // @@ -14,222 +14,182 @@ // //===----------------------------------------------------------------------===// /// -/// This file implements the pass for filling in absolute timing attributes -/// within defcal calls. +/// This file implements the pass for scheduling the quantum circuits at pulse +/// level, based on the availability of involved ports /// //===----------------------------------------------------------------------===// -//#include "Dialect/Pulse/IR/PulseEnums.h" -//#include "Dialect/Pulse/Transforms/Scheduling.h" - -// using namespace mlir; -// using namespace mlir::pulse; - -// auto SchedulingPass::getResultHash(Operation *op) -> uint { -// auto res = op->getResult(0); -// return mlir::hash_value(res); -//} - -// auto SchedulingPass::pulseCached(llvm::hash_code hash) -> bool { -// if (pulseDurations.find(hash) != pulseDurations.end()) -// return true; -// return false; -//} - -// auto SchedulingPass::getFrameHashAndTime(mlir::Value &frame) -// -> std::pair { -// auto *frameOp = frame.getDefiningOp(); -// auto frameHash = getResultHash(frameOp); -// auto time = frameTimes[frameHash]; -// return std::make_pair(frameHash, time); -//} - -// auto SchedulingPass::getMaxTime(mlir::OperandRange &frames) -> uint { -// uint maxTime = 0; -// for (auto frame : frames) { -// // get frame time -// auto pair = getFrameHashAndTime(frame); -// auto time = pair.second; -// if (time > maxTime) -// maxTime = time; -// } -// return maxTime; -//} - -///// templated waveform processing method -// template -// void SchedulingPass::processOp(WaveformOp &wfrOp) { -// // compute and cache waveform duration -// auto wfrHash = getResultHash(wfrOp); -// pulseDurations[wfrHash] = wfrOp.getDuration(); -//} - -// void SchedulingPass::processOp(Frame_CreateOp &frameOp) { -// // initialize frame timing -// auto frameHash = getResultHash(frameOp); -// frameTimes[frameHash] = 0; -//} - -// void SchedulingPass::processOp(DelayOp &delayOp) { -// OpBuilder delayBuilder(delayOp); // delay operation builder -// auto dur = delayOp.dur(); // operand -// auto frames = delayOp.frames(); -// auto intDur = delayOp.getDuration(); // integer duration of delay - -// // split delays onto each frame and add timing attribute -// for (auto frame : frames) { -// // get frame hash and time -// auto [frameHash, time] = getFrameHashAndTime(frame); -// // create delay on frame w/ time tagged -// auto frameDelayOp = -// delayBuilder.create(delayOp->getLoc(), dur, frame); -// frameDelayOp->setAttr( -// llvm::StringRef("t"), -// delayBuilder.getIntegerAttr(delayBuilder.getI32Type(), time)); -// // update frame time -// frameTimes[frameHash] += intDur; -// } -// // delete original op -// delayOp.erase(); -//} - -// void SchedulingPass::processOp(BarrierOp &barrierOp) { -// OpBuilder barrierBuilder(barrierOp); // barrier operation builder -// auto frames = barrierOp.frames(); -// auto maxTime = getMaxTime(frames); // maximum time across frames - -// // more than one frame: compute max time among frames and -// // add delays on all other frames to sync with this time -// if (frames.size() > 1) { -// for (auto frame : frames) { -// // get frame hash and time -// auto [frameHash, time] = getFrameHashAndTime(frame); -// // create delays on non-max time frames -// if (time < maxTime) { -// auto len = maxTime - time; -// // create length -// auto lenOp = barrierBuilder.create( -// barrierOp->getLoc(), len, barrierBuilder.getI32Type()); -// // create delay -// auto delOp = -// barrierBuilder.create(barrierOp->getLoc(), lenOp, frame); -// delOp->setAttr( -// llvm::StringRef("t"), -// barrierBuilder.getIntegerAttr(barrierBuilder.getI32Type(), time)); -// // update frame time -// frameTimes[frameHash] += len; -// } -// } -// } -// // delete original op -// barrierOp.erase(); -//} - -// void SchedulingPass::processOp(PlayOp &playOp) { -// OpBuilder playBuilder(playOp); // play operation builder -// // SSA form: waveform and frame should already be declared before -// // play -// auto *wfrOp = playOp.wfr().getDefiningOp(); -// auto *frameOp = playOp.frame().getDefiningOp(); -// auto wfrHash = getResultHash(wfrOp); -// auto frameHash = getResultHash(frameOp); - -// // Add frame time as attribute -// auto time = frameTimes[frameHash]; -// playOp->setAttr(llvm::StringRef("t"), -// playBuilder.getIntegerAttr(playBuilder.getI32Type(), time)); - -// // Update frame time -// auto wfrDur = pulseDurations[wfrHash]; -// frameTimes[frameHash] += wfrDur; -//} - -// void SchedulingPass::processOp(CaptureOp &captureOp) { -// OpBuilder captureBuilder(captureOp); // capture operation builder -// // SSA form: frame should already be declared before capture -// auto *frameOp = captureOp.frame().getDefiningOp(); -// auto frameHash = getResultHash(frameOp); - -// // Add frame time as attribute -// auto time = frameTimes[frameHash]; -// captureOp->setAttr( -// llvm::StringRef("t"), -// captureBuilder.getIntegerAttr(captureBuilder.getI32Type(), time)); - -// // Update frame time -// auto capDur = captureOp.getDuration(); -// frameTimes[frameHash] += capDur; -//} - -// void SchedulingPass::schedule(Operation *defCalOp) { -// // add timing attributes for Pulse IR operations -// defCalOp->walk([&](Operation *dcOp) { -// if (auto sampWfrOp = dyn_cast(dcOp)) { -// processOp(sampWfrOp); -// } else if (auto gaussOp = dyn_cast(dcOp)) { -// processOp(gaussOp); -// } else if (auto gaussSqOp = dyn_cast(dcOp)) { -// processOp(gaussSqOp); -// } else if (auto dragOp = dyn_cast(dcOp)) { -// processOp(dragOp); -// } else if (auto constWfrOp = dyn_cast(dcOp)) { -// processOp(constWfrOp); -// } else if (auto frameOp = dyn_cast(dcOp)) { -// processOp(frameOp); -// } else if (auto delayOp = dyn_cast(dcOp)) { -// processOp(delayOp); -// } else if (auto barrierOp = dyn_cast(dcOp)) { -// processOp(barrierOp); -// } else if (auto playOp = dyn_cast(dcOp)) { -// processOp(playOp); -// } else if (auto captureOp = dyn_cast(dcOp)) { -// processOp(captureOp); -// } -// // else: Pulse IR op does not impact scheduling -// }); // defCalOp->walk -//} - -///*** ASSUMPTIONS -// * Times are integers. -// * Barriers and delays are on frames, not qubits. -// * Stretches have been resolved. -// * No control flow. -// * Gate calls are resolved to defcal calls. -// ***/ -///*** TODO PASSES : must occur before scheduling. -// * TODO: Write timing pass to update all quir lengths to std::constant -// integers -// *(in dt units). -// * TODO: Write pass to lower all barriers and delays onto frames from qubits. -// * TODO: Resolve stretch pass (should this be at scheduling time or earlier?) -// * TODO: Pass to handle control flow timing -> add mlir branching. -// * TODO: Pass to resolve gate calls to defcal calls (this should be a QUIR -// *pass). -// ***/ -// void SchedulingPass::runOnOperation() { -// // This pass is only called on the top-level module Op -// Operation *moduleOperation = getOperation(); -// moduleOperation->walk([&](Operation *op) { -// // find defcal call -// if (auto callOp = dyn_cast(op)) { -// // find defcal body -// auto calleeStr = callOp.getCallee(); -// // TODO: Lookup should include full function signature, not just the -// // string -// auto *defCalOp = SymbolTable::lookupSymbolIn(moduleOperation, -// calleeStr); if (!defCalOp) { -// callOp->emitError() -// << "Could not find defcal body for " << calleeStr << "."; -// return; -// } - -// auto defCalHash = mlir::OperationEquivalence::computeHash(defCalOp); -// // schedule this defcal if it has not already been scheduled -// if (scheduledDefCals.find(defCalHash) == scheduledDefCals.end()) { -// schedule(defCalOp); -// // add defcal to those already scheduled -// scheduledDefCals.insert(defCalHash); -// } -// } -// }); // moduleOperation->walk -//} // runOnOperation +#include "Dialect/Pulse/Transforms/Scheduling.h" +#include "Dialect/QUIR/Utils/Utils.h" + +#define DEBUG_TYPE "SchedulingDebug" + +using namespace mlir; +using namespace mlir::pulse; + +void quantumCircuitPulseSchedulingPass::runOnOperation() { + // check for command line override of the scheduling method + if (schedulingMethod.hasValue()) { + if (schedulingMethod.getValue() == "alap") + SCHEDULING_METHOD = ALAP; + else if (schedulingMethod.getValue() == "asap") + SCHEDULING_METHOD = ASAP; + else + llvm_unreachable("scheduling method not supported currently"); + } + + ModuleOp moduleOp = getOperation(); + + // schedule all the quantum circuits which are root call sequence ops + moduleOp->walk([&](mlir::pulse::CallSequenceOp callSequenceOp) { + // return if the call sequence op is not a root op + if (isa(callSequenceOp->getParentOp())) + return; + switch (SCHEDULING_METHOD) { + case ALAP: + scheduleAlap(callSequenceOp); + break; + default: + llvm_unreachable("scheduling method not supported currently"); + } + }); +} + +void quantumCircuitPulseSchedulingPass::scheduleAlap( + mlir::pulse::CallSequenceOp quantumCircuitCallSequenceOp) { + + auto quantumCircuitSequenceOp = getSequenceOp(quantumCircuitCallSequenceOp); + std::string sequenceName = quantumCircuitSequenceOp.sym_name().str(); + LLVM_DEBUG(llvm::dbgs() << "\nscheduling " << sequenceName << "\n"); + + int totalDurationOfQuantumCircuitNegative = 0; + portNameToNextAvailabilityMap.clear(); + + // get the MLIR block of the quantum circuit + auto quantumCircuitSequenceOpBlock = quantumCircuitSequenceOp.body().begin(); + // go over the MLIR operation of the block in reverse order, and find + // CallSequenceOps, each of which corresponds to a quantum gate. for each + // CallSequenceOps, we add a timepoint based on the availability of involved + // ports; timepoints are <=0 because we're walking in reverse order. Note this + // pass assumes that the operations inside these CallSequenceOps are already + // scheduled + for (auto opIt = quantumCircuitSequenceOpBlock->rbegin(), + opEnd = quantumCircuitSequenceOpBlock->rend(); + opIt != opEnd; ++opIt) { + auto &op = *opIt; + if (auto quantumGateCallSequenceOp = + dyn_cast(op)) { + // find quantum gate SequenceOp + auto quantumGateSequenceOp = getSequenceOp(quantumGateCallSequenceOp); + std::string quantumGateSequenceName = + quantumGateSequenceOp.sym_name().str(); + LLVM_DEBUG(llvm::dbgs() << "\tprocessing inner sequence " + << quantumGateSequenceName << "\n"); + + // find ports of the quantum gate SequenceOp + auto portsOrError = + PulseOpSchedulingInterface::getPorts(quantumGateSequenceOp); + if (auto err = portsOrError.takeError()) { + quantumGateSequenceOp.emitError() << toString(std::move(err)); + signalPassFailure(); + } + auto ports = portsOrError.get(); + + // find duration of the quantum gate callSequenceOp + llvm::Expected durOrError = + quantumGateSequenceOp.getDuration(quantumGateCallSequenceOp); + if (auto err = durOrError.takeError()) { + quantumGateSequenceOp.emitError() << toString(std::move(err)); + signalPassFailure(); + } + uint64_t quantumGateCallSequenceOpDuration = durOrError.get(); + LLVM_DEBUG(llvm::dbgs() << "\t\tduration " + << quantumGateCallSequenceOpDuration << "\n"); + + // find next available time for all the ports + int nextAvailableTimeOfAllPorts = getNextAvailableTimeOfPorts(ports); + LLVM_DEBUG(llvm::dbgs() << "\t\tnext availability is at " + << nextAvailableTimeOfAllPorts << "\n"); + + // find the updated available time, i.e., when the current quantum gate + // will be scheduled + int updatedAvailableTime = + nextAvailableTimeOfAllPorts - quantumGateCallSequenceOpDuration; + LLVM_DEBUG(llvm::dbgs() << "\t\tcurrent gate scheduled at " + << updatedAvailableTime << "\n"); + // update the port availability map + updatePortAvailabilityMap(ports, updatedAvailableTime); + + // keep track of total duration of the quantum circuit + if (updatedAvailableTime < totalDurationOfQuantumCircuitNegative) + totalDurationOfQuantumCircuitNegative = updatedAvailableTime; + + // set the timepoint of quantum gate + PulseOpSchedulingInterface::setTimepoint(quantumGateCallSequenceOp, + updatedAvailableTime); + } + } + + // multiply by -1 so that quantum circuit duration becomes positive + int totalDurationOfQuantumCircuit = -totalDurationOfQuantumCircuitNegative; + LLVM_DEBUG(llvm::dbgs() << "\ttotal duration of quantum circuit " + << totalDurationOfQuantumCircuit << "\n"); + + // setting duration of the quantum circuit + PulseOpSchedulingInterface::setDuration(quantumCircuitSequenceOp, + totalDurationOfQuantumCircuit); + // setting timepoint of the quantum circuit; at this point, we can add + // totalDurationOfQuantumCircuit to above <=0 timepoints, so that they become + // >=0, however, that would require walking the IR again. Instead, we add a + // postive timepoint to the parent op, i.e., quantum circuit sequence op, and + // later passes would need to add this value as an offset to determine the + // effective timepoints + PulseOpSchedulingInterface::setTimepoint(quantumCircuitSequenceOp, + totalDurationOfQuantumCircuit); +} + +int quantumCircuitPulseSchedulingPass::getNextAvailableTimeOfPorts( + mlir::ArrayAttr ports) { + int nextAvailableTimeOfAllPorts = 0; + for (auto attr : ports) { + std::string portName = attr.dyn_cast().getValue().str(); + if (portName.empty()) + continue; + if (portNameToNextAvailabilityMap.find(portName) != + portNameToNextAvailabilityMap.end()) { + if (portNameToNextAvailabilityMap[portName] < nextAvailableTimeOfAllPorts) + nextAvailableTimeOfAllPorts = portNameToNextAvailabilityMap[portName]; + } + } + return nextAvailableTimeOfAllPorts; +} + +void quantumCircuitPulseSchedulingPass::updatePortAvailabilityMap( + mlir::ArrayAttr ports, int updatedAvailableTime) { + for (auto attr : ports) { + std::string portName = attr.dyn_cast().getValue().str(); + if (portName.empty()) + continue; + portNameToNextAvailabilityMap[portName] = updatedAvailableTime; + } +} + +mlir::pulse::SequenceOp quantumCircuitPulseSchedulingPass::getSequenceOp( + mlir::pulse::CallSequenceOp callSequenceOp) { + auto seqAttr = callSequenceOp->getAttrOfType("callee"); + assert(seqAttr && "Requires a 'callee' symbol reference attribute"); + + auto sequenceOp = + SymbolTable::lookupNearestSymbolFrom( + callSequenceOp, seqAttr); + assert(sequenceOp && "matching sequence not found"); + return sequenceOp; +} + +llvm::StringRef quantumCircuitPulseSchedulingPass::getArgument() const { + return "quantum-circuit-pulse-scheduling"; +} + +llvm::StringRef quantumCircuitPulseSchedulingPass::getDescription() const { + return "Scheduling a quantum circuit at pulse level."; +}