diff --git a/include/Dialect/Pulse/Utils/Utils.h b/include/Dialect/Pulse/Utils/Utils.h index f93bddd59..2cdea100c 100644 --- a/include/Dialect/Pulse/Utils/Utils.h +++ b/include/Dialect/Pulse/Utils/Utils.h @@ -41,9 +41,8 @@ Waveform_CreateOp getWaveformOp(PlayOp pulsePlayOp, Waveform_CreateOp getWaveformOp(PlayOp pulsePlayOp, CallSequenceStack_t &callSequenceOpStack); - -double getPhaseValue(ShiftPhaseOp shiftPhaseOp, - CallSequenceStack_t &callSequenceOpStack); +mlir::Value getPhaseValue(ShiftPhaseOp shiftPhaseOp, + CallSequenceStack_t &callSequenceOpStack); /// this function goes over all the blocks of the input pulse sequence, and for /// each block, it sorts the pulse ops within the block according to their diff --git a/lib/Dialect/Pulse/Utils/Utils.cpp b/lib/Dialect/Pulse/Utils/Utils.cpp index eb7fb113a..02cfe20ef 100644 --- a/lib/Dialect/Pulse/Utils/Utils.cpp +++ b/lib/Dialect/Pulse/Utils/Utils.cpp @@ -64,25 +64,20 @@ Waveform_CreateOp getWaveformOp(PlayOp pulsePlayOp, return waveformOp; } -double getPhaseValue(ShiftPhaseOp shiftPhaseOp, - CallSequenceStack_t &callSequenceOpStack) { +mlir::Value getPhaseValue(ShiftPhaseOp shiftPhaseOp, + CallSequenceStack_t &callSequenceOpStack) { auto phaseOffsetIndex = 0; mlir::Value phaseOffset = shiftPhaseOp.getPhaseOffset(); for (auto it = callSequenceOpStack.rbegin(); it != callSequenceOpStack.rend(); ++it) { - if (phaseOffset.isa()) { - phaseOffsetIndex = phaseOffset.dyn_cast().getArgNumber(); - phaseOffset = it->getOperand(phaseOffsetIndex); - } else + if (auto blockArg = dyn_cast(phaseOffset)) + phaseOffset = it->getOperand(blockArg.getArgNumber()); + else break; } - auto phaseOffsetOp = - dyn_cast(phaseOffset.getDefiningOp()); - if (!phaseOffsetOp) - phaseOffsetOp->emitError() << "Phase offset is not a ConstantFloatOp."; - return phaseOffsetOp.value().convertToDouble(); + return phaseOffset; } void sortOpsByTimepoint(SequenceOp &sequenceOp) {