Skip to content

Commit

Permalink
Add Symbol Cache Analysis (#301)
Browse files Browse the repository at this point in the history
Adds a SymbolCacheAnalysis to standardizing the caching of
symbols for Circuits, Sequences and other function like operations which
follow a call_<name> by callee, <name> @callee pattern.

Co-authored-by: mbhealy <[email protected]>
  • Loading branch information
bcdonovan and mbhealy authored Apr 2, 2024
1 parent e5d0cbf commit c4542ba
Show file tree
Hide file tree
Showing 17 changed files with 322 additions and 174 deletions.
5 changes: 2 additions & 3 deletions include/Conversion/QUIRToPulse/QUIRToPulse.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "Dialect/OQ3/IR/OQ3Ops.h"
#include "Dialect/Pulse/IR/PulseOps.h"
#include "Dialect/QCS/IR/QCSOps.h"
#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -157,9 +158,7 @@ struct QUIRToPulsePass
void parsePulseWaveformContainerOps(std::string &waveformContainerPath);
std::map<std::string, Waveform_CreateOp> pulseNameToWaveformMap;

llvm::StringMap<Operation *> symbolMap;
mlir::quir::CircuitOp getCircuitOp(mlir::quir::CallCircuitOp &callCircuitOp);
mlir::pulse::SequenceOp getSequenceOp(std::string const &symbolName);
qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};
};
} // namespace mlir::pulse

Expand Down
4 changes: 3 additions & 1 deletion include/Dialect/Pulse/Transforms/SchedulePort.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

#include "Dialect/Pulse/IR/PulseOps.h"
#include "Utils/DebugIndent.h"
#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/Pass/Pass.h"

#include <map>
Expand Down Expand Up @@ -60,7 +62,7 @@ class SchedulePortPass

void addTimepoints(mlir::OpBuilder &builder,
mixedFrameMap_t &mixedFrameSequences, int64_t &maxTime);
llvm::StringMap<mlir::pulse::SequenceOp> sequenceOps;
qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};
};
} // namespace mlir::pulse

Expand Down
5 changes: 2 additions & 3 deletions include/Dialect/Pulse/Transforms/Scheduling.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#define SCHEDULING_PULSE_SEQUENCES_H

#include "Dialect/Pulse/IR/PulseOps.h"
#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -86,9 +87,7 @@ struct QuantumCircuitPulseSchedulingPass
std::unordered_set<unsigned int> &mixFramesBlockArgNums,
int64_t updatedAvailableTime);
bool sequenceOpIncludeCapture(mlir::pulse::SequenceOp quantumGateSequenceOp);
llvm::StringMap<Operation *> symbolMap;
mlir::pulse::SequenceOp
getSequenceOp(mlir::pulse::CallSequenceOp callSequenceOp);
qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};
};
} // namespace mlir::pulse

Expand Down
5 changes: 4 additions & 1 deletion include/Dialect/QUIR/Transforms/BreakReset.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#ifndef QUIR_BREAK_RESET_H
#define QUIR_BREAK_RESET_H

#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"

Expand Down Expand Up @@ -73,7 +75,8 @@ struct BreakResetPass

private:
// keep track of all circuits
llvm::StringMap<Operation *> circuitsSymbolMap;
qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};

void insertMeasureInCircuit(mlir::func::FuncOp &mainFunc,
mlir::quir::MeasureOp measureOp);
void insertCallGateInCircuit(mlir::func::FuncOp &mainFunc,
Expand Down
3 changes: 2 additions & 1 deletion include/Dialect/QUIR/Transforms/ExtractCircuits.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#define QUIR_EXTRACT_CIRCUITS_H

#include "Dialect/QUIR/IR/QUIROps.h"
#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -53,7 +54,7 @@ struct ExtractCircuitsPass
void addToCircuit(mlir::Operation *currentOp, OpBuilder circuitBuilder,
llvm::SmallVector<Operation *> &eraseList);
uint64_t circuitCount = 0;
llvm::StringMap<Operation *> circuitOpsMap;
qssc::utils::SymbolCacheAnalysis *symbolCache{nullptr};

mlir::quir::CircuitOp currentCircuitOp = nullptr;
mlir::quir::CallCircuitOp newCallCircuitOp;
Expand Down
5 changes: 2 additions & 3 deletions include/Dialect/QUIR/Transforms/MergeCircuits.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#define QUIR_MERGE_CIRCUITS_H

#include "Dialect/QUIR/IR/QUIROps.h"
#include "Utils/SymbolCacheAnalysis.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -34,12 +35,10 @@ struct MergeCircuitsPass
: public PassWrapper<MergeCircuitsPass, OperationPass<>> {
void runOnOperation() override;

static CircuitOp getCircuitOp(CallCircuitOp callCircuitOp,
llvm::StringMap<Operation *> *symbolMap);
static LogicalResult mergeCallCircuits(
MLIRContext *context, PatternRewriter &rewriter,
CallCircuitOp callCircuitOp, CallCircuitOp nextCallCircuitOp,
llvm::StringMap<Operation *> *symbolMap,
qssc::utils::SymbolCacheAnalysis *symbolCache,
std::optional<llvm::SmallVector<Operation *>> barriers = std::nullopt);

llvm::StringRef getArgument() const override;
Expand Down
6 changes: 4 additions & 2 deletions include/Dialect/QUIR/Transforms/SubroutineCloning.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#ifndef QUIR_SUBROUTINE_CLONING_H
#define QUIR_SUBROUTINE_CLONING_H

#include "Utils/SymbolCacheAnalysis.h"

#include <deque>
#include <unordered_set>

Expand All @@ -36,7 +38,7 @@ class Operation;

namespace mlir::quir {

using SymbolOpMap = llvm::StringMap<Operation *>;
using SymbolCache = qssc::utils::SymbolCacheAnalysis;

struct SubroutineCloningPass
: public PassWrapper<SubroutineCloningPass, OperationPass<>> {
Expand All @@ -45,7 +47,7 @@ struct SubroutineCloningPass
template <class CallLikeOp>
auto getMangledName(Operation *op) -> std::string;
template <class CallLikeOp, class FuncLikeOp>
void processCallOp(Operation *op, SymbolOpMap &symbolOpMap);
void processCallOp(Operation *op, SymbolCache &symbolOpMap);

void runOnOperation() override;

Expand Down
203 changes: 203 additions & 0 deletions include/Utils/SymbolCacheAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
//===- SymbolCacheAnalysis.h - Cache symbols --------------------*- C++ -*-===//
//
// (C) Copyright IBM 2024.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.
//
//===----------------------------------------------------------------------===//
///
/// This file implements an analysis for caching symbols that match a
/// call -> callee pattern. This currently includes circuit / call_circuit
/// and sequence / call_sequence.
///
///
//===----------------------------------------------------------------------===//

#ifndef CACHE_SYMBOLS_ANALYSIS_H
#define CACHE_SYMBOLS_ANALYSIS_H

#include "HAL/SystemConfiguration.h"

#include "mlir/IR/Operation.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"

#include <string>
#include <typeinfo>
#include <unordered_map>
#include <unordered_set>

namespace qssc::utils {

// This analysis maintains a mapping of symbol name to operation in
// symbolOpsMap. It will also maintain a cache of CallOp to CalleeOp
// when operations are looked up through the getOp method. The CallOp
// to CalleeOp cache is intended to reduce string comparison where
// possible
//
// Example usage:
// auto & cache = getAnalysis<qssc::utils::SymbolCacheAnalysis>()
// .addToCache<CircuitOp>();
//
// multiple symbol types may be cached using:
// auto & cache = getAnalysis<qssc::utils::SymbolCacheAnalysis>()
// .addToCache<CircuitOp>()
// .addToCache<SequenceOp>();
//
// This analysis is intended to be used with MLIR's getAnalysis
// framework. It has been designed to reused the chached value
// and will not be invalidated automatically with each pass.
// If a pass manipulates the symbols that are cached with this
// analysis then it should use the addCallee method to update the
// map or call invalidate after appying updates.
// Note this analysis should always be used by reference or
// via a pointer to ensure that updates are applied to the maps
// stored by the MLIR analysis framework.
//
// Passes may force the maps to be re-loaded by calling invalidate
// before calling addToCache:
//
// auto & cache = getAnalysis<qssc::utils::SymbolCacheAnalysis>()
// .invalidate()
// .addToCache<CircuitOp>();

class SymbolCacheAnalysis {
public:
SymbolCacheAnalysis(mlir::Operation *op) {
if (topOp && topOp != op)
invalidate();
topOp = op;
}
SymbolCacheAnalysis(mlir::Operation *op,
qssc::hal::SystemConfiguration *config) {
if (topOp && topOp != op)
invalidate();
topOp = op;
}

template <class CalleeOp>
SymbolCacheAnalysis &addToCache() {
return addToCache<CalleeOp>(topOp);
}

template <class CalleeOp>
SymbolCacheAnalysis &addToCache(mlir::Operation *op) {
std::string typeName = typeid(CalleeOp).name();

if (!invalid && (cachedTypes.find(typeName) != cachedTypes.end())) {
// already cached skipping
return *this;
}

op->walk([&](CalleeOp op) {
symbolOpsMap[op.getSymName()] = op.getOperation();
});
cachedTypes.insert(typeName);
invalid = false;
return *this;
}

template <class CallOp>
SymbolCacheAnalysis &cacheCallMap() {
return cacheCallMap<CallOp>(topOp);
}

template <class CallOp>
SymbolCacheAnalysis &cacheCallMap(mlir::Operation *op) {
std::string typeName = typeid(CallOp).name();
if ((cachedTypes.find(typeName) != cachedTypes.end()))
return *this;

op->walk([&](CallOp callOp) {
auto search = symbolOpsMap.find(callOp.getCallee());
if (search != symbolOpsMap.end())
callMap[callOp.getOperation()] = search->second;
});
return *this;
}

template <class CalleeOp>
CalleeOp getOpByName(llvm::StringRef callee) {
auto search = symbolOpsMap.find(callee);
assert(search != symbolOpsMap.end() && "matching callee not found");
auto calleeOp = llvm::dyn_cast<CalleeOp>(search->second);
assert(calleeOp && "callee is not of the expected type");
return calleeOp;
}

template <class CalleeOp, class CallOp>
CalleeOp getOpByCall(CallOp callOp) {
auto search = callMap.find(callOp.getOperation());
if (search == callMap.end()) {
auto calleeOp = getOpByName<CalleeOp>(callOp.getCallee());
callMap[callOp.getOperation()] = calleeOp.getOperation();
return calleeOp;
}
auto calleeOp = llvm::dyn_cast<CalleeOp>(search->second);
assert(calleeOp && "callee is not of the expected type");
return calleeOp;
}

template <class CalleeOp, class CallOp>
CalleeOp getOp(CallOp callOp) {
return getOpByCall<CalleeOp, CallOp>(callOp);
}

template <class CalleeOp>
void addCallee(CalleeOp calleeOp) {
addCallee(calleeOp.getSymName(), calleeOp.getOperation());
}

void addCallee(llvm::StringRef name, mlir::Operation *op) {
// if this is an update to existing symbol clear callMap cache
if (symbolOpsMap.contains(name))
callMap.clear();
symbolOpsMap[name] = op;
}

template <class CallOp, class CalleeOp>
void cacheCall(CallOp callOp, CalleeOp calleeOp) {
callMap[callOp.getOperation()] = calleeOp.getOperation();
}

bool contains(llvm::StringRef name) { return symbolOpsMap.contains(name); }

template <class CalleeOp>
void erase(CalleeOp calleeOp) {
symbolOpsMap.erase(calleeOp.getSymName());
// TODO: determine if it is worth just clearing the callers of calleeOp
callMap.clear();
}

SymbolCacheAnalysis &invalidate() {
symbolOpsMap.clear();
callMap.clear();
cachedTypes.clear();
invalid = true;
return *this;
}

bool isInvalidated(const mlir::AnalysisManager::PreservedAnalyses &pa) {
return invalid;
}

// for debugging purposes
void listSymbols() {
for (auto &[key, value] : symbolOpsMap)
llvm::errs() << key << "\n";
}

private:
llvm::StringMap<mlir::Operation *> symbolOpsMap;
std::unordered_map<mlir::Operation *, mlir::Operation *> callMap;
std::unordered_set<std::string> cachedTypes;
mlir::Operation *topOp{nullptr};
bool invalid{true};
};
} // namespace qssc::utils

#endif // CACHE_SYMBOLS_ANALYSIS_H
Loading

0 comments on commit c4542ba

Please sign in to comment.