Skip to content

Commit

Permalink
[NFC] Cache common lookups in ModuleType (#6892)
Browse files Browse the repository at this point in the history
Use custom storage for ModuleType to cache input/output <-> index mappings.  Speeds up many things in small ways.
  • Loading branch information
darthscsi authored Apr 2, 2024
1 parent 562f4d7 commit f77c002
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 57 deletions.
44 changes: 44 additions & 0 deletions include/circt/Dialect/HW/HWTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,50 @@ struct ModulePort {
Direction dir;
};

static bool operator==(const ModulePort &a, const ModulePort &b) {
return a.dir == b.dir && a.name == b.name && a.type == b.type;
}
static llvm::hash_code hash_value(const ModulePort &port) {
return llvm::hash_combine(port.dir, port.name, port.type);
}

namespace detail {
struct ModuleTypeStorage : public TypeStorage {
ModuleTypeStorage(ArrayRef<ModulePort> inPorts);

using KeyTy = ArrayRef<ModulePort>;

/// Define the comparison function for the key type.
bool operator==(const KeyTy &key) const {
return std::equal(key.begin(), key.end(), ports.begin(), ports.end());
}

/// Define a hash function for the key type.
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_combine_range(key.begin(), key.end());
}

/// Define a construction method for creating a new instance of this storage.
static ModuleTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<ModuleTypeStorage>()) ModuleTypeStorage(key);
}

/// Construct an instance of the key from this storage class.
KeyTy getAsKey() const { return ports; }

ArrayRef<ModulePort> getPorts() const { return ports; }

/// The parametric data held by the storage class.
SmallVector<ModulePort> ports;
// Cache of common lookups
SmallVector<size_t> inputToAbs;
SmallVector<size_t> outputToAbs;
SmallVector<size_t> absToInput;
SmallVector<size_t> absToOutput;
};
} // namespace detail

class HWSymbolCache;
class ParamDeclAttr;
class TypedeclOp;
Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/HW/HWTypesImpl.td
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def ModuleTypeImpl : HWType<"Module"> {
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let mnemonic = "modty";
let genStorageClass = 0;

let extraClassDeclaration = [{
// Many of these are transitional and will be removed when modules and instances
Expand Down
1 change: 1 addition & 0 deletions lib/CAPI/Dialect/FIRRTL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ FIRRTLValueFlow firrtlValueFoldFlow(MlirValue value, FIRRTLValueFlow flow) {
case Flow::Duplex:
return FIRRTL_VALUE_FLOW_DUPLEX;
}
llvm_unreachable("invalid flow");
}

bool firrtlImportAnnotationsFromJSONRaw(
Expand Down
2 changes: 0 additions & 2 deletions lib/Dialect/HW/HWOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,8 +1075,6 @@ static LogicalResult verifyModuleCommon(HWModuleLike module) {
assert(isa<HWModuleLike>(module) &&
"verifier hook should only be called on modules");

auto moduleType = module.getHWModuleType();

SmallPtrSet<Attribute, 4> paramNames;

// Check parameter default values are sensible.
Expand Down
91 changes: 36 additions & 55 deletions lib/Dialect/HW/HWTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -823,60 +823,30 @@ LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
}

size_t ModuleType::getPortIdForInputId(size_t idx) {
for (auto [i, p] : llvm::enumerate(getPorts())) {
if (p.dir != ModulePort::Direction::Output) {
if (!idx)
return i;
--idx;
}
}
assert(0 && "Out of bounds input port id");
return ~0UL;
assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
return getImpl()->inputToAbs[idx];
}

size_t ModuleType::getPortIdForOutputId(size_t idx) {
for (auto [i, p] : llvm::enumerate(getPorts())) {
if (p.dir == ModulePort::Direction::Output) {
if (!idx)
return i;
--idx;
}
}
assert(0 && "Out of bounds output port id");
return ~0UL;
assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
return getImpl()->outputToAbs[idx];
}

size_t ModuleType::getInputIdForPortId(size_t idx) {
auto ports = getPorts();
assert(ports[idx].dir != ModulePort::Direction::Output);
size_t retval = 0;
for (size_t i = 0; i < idx; ++i)
if (ports[i].dir != ModulePort::Direction::Output)
++retval;
return retval;
auto nIdx = getImpl()->absToInput[idx];
assert(nIdx != ~0ULL);
return nIdx;
}

size_t ModuleType::getOutputIdForPortId(size_t idx) {
auto ports = getPorts();
assert(ports[idx].dir == ModulePort::Direction::Output);
size_t retval = 0;
for (size_t i = 0; i < idx; ++i)
if (ports[i].dir == ModulePort::Direction::Output)
++retval;
return retval;
auto nIdx = getImpl()->absToOutput[idx];
assert(nIdx != ~0ULL);
return nIdx;
}

size_t ModuleType::getNumInputs() {
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
return p.dir != ModulePort::Direction::Output;
});
}
size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }

size_t ModuleType::getNumOutputs() {
return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
return p.dir == ModulePort::Direction::Output;
});
}
size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }

size_t ModuleType::getNumPorts() { return getPorts().size(); }

Expand Down Expand Up @@ -984,6 +954,10 @@ FunctionType ModuleType::getFuncType() {
return FunctionType::get(getContext(), inputs, outputs);
}

ArrayRef<ModulePort> ModuleType::getPorts() const {
return getImpl()->getPorts();
}

FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
LocationAttr loc,
bool emitErrors) {
Expand Down Expand Up @@ -1021,7 +995,7 @@ static ModulePort::Direction strToDir(StringRef str) {
}

/// Parse a list of field names and types within <>. E.g.:
/// <foo: i7, bar: i8>
/// <input foo: i7, output bar: i8>
static ParseResult parsePorts(AsmParser &p,
SmallVectorImpl<ModulePort> &ports) {
return p.parseCommaSeparatedList(
Expand Down Expand Up @@ -1060,18 +1034,6 @@ void ModuleType::print(AsmPrinter &odsPrinter) const {
printPorts(odsPrinter, getPorts());
}

namespace circt {
namespace hw {

static bool operator==(const ModulePort &a, const ModulePort &b) {
return a.dir == b.dir && a.name == b.name && a.type == b.type;
}
static llvm::hash_code hash_value(const ModulePort &port) {
return llvm::hash_combine(port.dir, port.name, port.type);
}
} // namespace hw
} // namespace circt

ModuleType circt::hw::detail::fnToMod(Operation *op,
ArrayRef<Attribute> inputNames,
ArrayRef<Attribute> outputNames) {
Expand Down Expand Up @@ -1109,6 +1071,25 @@ ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
return ModuleType::get(fnty.getContext(), ports);
}

detail::ModuleTypeStorage::ModuleTypeStorage(ArrayRef<ModulePort> inPorts)
: ports(inPorts) {
size_t nextInput = 0;
size_t nextOutput = 0;
for (auto [idx, p] : llvm::enumerate(ports)) {
if (p.dir == ModulePort::Direction::Output) {
outputToAbs.push_back(idx);
absToOutput.push_back(nextOutput);
absToInput.push_back(~0ULL);
++nextOutput;
} else {
inputToAbs.push_back(idx);
absToInput.push_back(nextInput);
absToOutput.push_back(~0ULL);
++nextInput;
}
}
}

////////////////////////////////////////////////////////////////////////////////
// BoilerPlate
////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit f77c002

Please sign in to comment.