Skip to content

Commit

Permalink
[FIRRTL][FIRParser] Defer inner symbols to post-processing, fix race. (
Browse files Browse the repository at this point in the history
…#7584)

Adding inner symbols, specifically to module ports, mutates the
attribute dictionary and that races / breaks concurrent access to
properties (e.g., port names and types) commonly inspected during
parsing other module bodies (e.g., InstanceOp's).

Avoid this by instead gathering a list of operations and their
intended target as "fixups" to be processed after the original
(possibly parallel) parsing has completed and it is safe to do so.

This could be parallellized (per-module fixups)
but is not expected to have sufficient work to justify the overhead.
  • Loading branch information
dtzSiFive authored Sep 4, 2024
1 parent 9d500f5 commit 75a192e
Showing 1 changed file with 96 additions and 20 deletions.
116 changes: 96 additions & 20 deletions lib/Dialect/FIRRTL/Import/FIRParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ struct SharedParserConstants {
hiIdentifier(StringAttr::get(context, "hi")),
amountIdentifier(StringAttr::get(context, "amount")),
fieldIndexIdentifier(StringAttr::get(context, "fieldIndex")),
indexIdentifier(StringAttr::get(context, "index")) {}
indexIdentifier(StringAttr::get(context, "index")),
placeholderInnerRef(
hw::InnerRefAttr::get(StringAttr::get(context, "module"),
StringAttr::get(context, "placeholder"))) {}

/// The context we're parsing into.
MLIRContext *const context;
Expand All @@ -91,6 +94,9 @@ struct SharedParserConstants {
const StringAttr loIdentifier, hiIdentifier, amountIdentifier;
const StringAttr fieldIndexIdentifier, indexIdentifier;

/// Cached placeholder inner-ref used until fixed up.
const hw::InnerRefAttr placeholderInnerRef;

private:
SharedParserConstants(const SharedParserConstants &) = delete;
void operator=(const SharedParserConstants &) = delete;
Expand Down Expand Up @@ -1640,20 +1646,68 @@ struct LazyLocationListener : public OpBuilder::Listener {
};
} // end anonymous namespace

namespace {
/// This class tracks inner-ref users and their intended targets,
/// (presently there must be just one) for post-processing at a point
/// where adding the symbols is safe without risk of races.
struct InnerSymFixups {
/// Add a fixup to be processed later.
void add(hw::InnerRefUserOpInterface user, hw::InnerSymTarget target) {
fixups.push_back({user, target});
}

/// Resolve all stored fixups, if any. Not expected to fail,
/// as checking should primarily occur during original parsing.
LogicalResult resolve(hw::InnerSymbolNamespaceCollection &isnc);

private:
struct Fixup {
hw::InnerRefUserOpInterface innerRefUser;
hw::InnerSymTarget target;
};
SmallVector<Fixup, 0> fixups;
};
} // end anonymous namespace

LogicalResult
InnerSymFixups::resolve(hw::InnerSymbolNamespaceCollection &isnc) {
for (auto &f : fixups) {
auto ref = getInnerRefTo(
f.target, [&isnc](FModuleLike module) -> hw::InnerSymbolNamespace & {
return isnc.get(module);
});
assert(ref && "unable to resolve inner symbol target");

// Per-op fixup logic. Only RWProbeOp's presently.
auto result =
TypeSwitch<Operation *, LogicalResult>(f.innerRefUser.getOperation())
.Case<RWProbeOp>([ref](RWProbeOp op) {
op.setTargetAttr(ref);
return success();
})
.Default([](auto *op) {
return op->emitError("unknown inner-ref user requiring fixup");
});
if (failed(result))
return failure();
}
return success();
}

namespace {
/// This class implements logic and state for parsing statements, suites, and
/// similar module body constructs.
struct FIRStmtParser : public FIRParser {
explicit FIRStmtParser(Block &blockToInsertInto,
FIRModuleContext &moduleContext,
hw::InnerSymbolNamespace &modNameSpace,
InnerSymFixups &innerSymFixups,
const SymbolTable &circuitSymTbl, FIRVersion version,
SymbolRefAttr layerSym = {})
: FIRParser(moduleContext.getConstants(), moduleContext.getLexer(),
version),
builder(UnknownLoc::get(getContext()), getContext()),
locationProcessor(this->builder), moduleContext(moduleContext),
modNameSpace(modNameSpace), layerSym(layerSym),
innerSymFixups(innerSymFixups), layerSym(layerSym),
circuitSymTbl(circuitSymTbl) {
builder.setInsertionPointToEnd(&blockToInsertInto);
}
Expand Down Expand Up @@ -1774,7 +1828,8 @@ struct FIRStmtParser : public FIRParser {
// Extra information maintained across a module.
FIRModuleContext &moduleContext;

hw::InnerSymbolNamespace &modNameSpace;
/// Inner symbol users to fixup after parsing.
InnerSymFixups &innerSymFixups;

// An optional symbol that contains the current layer block that we are in.
// This is used to construct a nested symbol for a layer block operation.
Expand Down Expand Up @@ -2745,7 +2800,7 @@ ParseResult FIRStmtParser::parseSubBlock(Block &blockToInsertInto,
// We parse the substatements into their own parser, so they get inserted
// into the specified 'when' region.
auto subParser = std::make_unique<FIRStmtParser>(
blockToInsertInto, moduleContext, modNameSpace, circuitSymTbl, version,
blockToInsertInto, moduleContext, innerSymFixups, circuitSymTbl, version,
layerSym);

// Figure out whether the body is a single statement or a nested one.
Expand Down Expand Up @@ -3071,7 +3126,7 @@ ParseResult FIRStmtParser::parseWhen(unsigned whenIndent) {
if (getToken().is(FIRToken::kw_when)) {
// We create a sub parser for the else block.
auto subParser = std::make_unique<FIRStmtParser>(
whenStmt.getElseBlock(), moduleContext, modNameSpace, circuitSymTbl,
whenStmt.getElseBlock(), moduleContext, innerSymFixups, circuitSymTbl,
version, layerSym);

return subParser->parseSimpleStmt(whenIndent);
Expand Down Expand Up @@ -3211,9 +3266,9 @@ ParseResult FIRStmtParser::parseMatch(unsigned matchIndent) {
return failure();

// Parse a block of statements that are indented more than the case.
auto subParser =
std::make_unique<FIRStmtParser>(*caseBlock, moduleContext, modNameSpace,
circuitSymTbl, version, layerSym);
auto subParser = std::make_unique<FIRStmtParser>(
*caseBlock, moduleContext, innerSymFixups, circuitSymTbl, version,
layerSym);
if (subParser->parseSimpleStmtBlock(*caseIndent))
return failure();
}
Expand Down Expand Up @@ -3661,11 +3716,11 @@ ParseResult FIRStmtParser::parseRWProbe(Value &result) {
return emitError(startTok.getLoc(), "cannot force target of type ")
<< targetType;

// Get InnerRef for target field.
auto sym = getInnerRefTo(
getTargetFor(staticRef),
[&](auto _) -> hw::InnerSymbolNamespace & { return modNameSpace; });
result = builder.create<RWProbeOp>(forceableType, sym);
// Create the operation with a placeholder reference and add to fixup list.
auto op = builder.create<RWProbeOp>(forceableType,
getConstants().placeholderInnerRef);
innerSymFixups.add(op, getTargetFor(staticRef));
result = op;
return success();
}

Expand Down Expand Up @@ -4669,9 +4724,15 @@ struct FIRCircuitParser : public FIRParser {
};

ParseResult parseModuleBody(const SymbolTable &circuitSymTbl,
DeferredModuleToParse &deferredModule);
DeferredModuleToParse &deferredModule,
InnerSymFixups &fixups);

SmallVector<DeferredModuleToParse, 0> deferredModules;

SmallVector<InnerSymFixups, 0> moduleFixups;

hw::InnerSymbolNamespaceCollection innerSymbolNamespaces;

ModuleOp mlirModule;
};

Expand Down Expand Up @@ -5438,7 +5499,8 @@ ParseResult FIRCircuitParser::parseLayer(CircuitOp circuit) {
// Parse the body of this module.
ParseResult
FIRCircuitParser::parseModuleBody(const SymbolTable &circuitSymTbl,
DeferredModuleToParse &deferredModule) {
DeferredModuleToParse &deferredModule,
InnerSymFixups &fixups) {
FModuleLike moduleOp = deferredModule.moduleOp;
auto &body = moduleOp->getRegion(0).front();
auto &portLocs = deferredModule.portLocs;
Expand All @@ -5465,9 +5527,7 @@ FIRCircuitParser::parseModuleBody(const SymbolTable &circuitSymTbl,
return failure();
}

hw::InnerSymbolNamespace modNameSpace(moduleOp);
FIRStmtParser stmtParser(body, moduleContext, modNameSpace, circuitSymTbl,
version);
FIRStmtParser stmtParser(body, moduleContext, fixups, circuitSymTbl, version);

// Parse the moduleBlock.
auto result = stmtParser.parseSimpleStmtBlock(deferredModule.indent);
Expand Down Expand Up @@ -5665,16 +5725,32 @@ ParseResult FIRCircuitParser::parseCircuit(

SymbolTable circuitSymTbl(circuit);

moduleFixups.resize(deferredModules.size());

// Stub out inner symbol namespace for each module,
// none should be added so do this now to avoid walking later
// to discover that this is the case.
for (auto &d : deferredModules)
innerSymbolNamespaces.get(d.moduleOp.getOperation());

// Next, parse all the module bodies.
auto anyFailed = mlir::failableParallelForEachN(
getContext(), 0, deferredModules.size(), [&](size_t index) {
if (parseModuleBody(circuitSymTbl, deferredModules[index]))
if (parseModuleBody(circuitSymTbl, deferredModules[index],
moduleFixups[index]))
return failure();
return success();
});
if (failed(anyFailed))
return failure();

// Walk operations created that have inner symbol references
// that need replacing now that it's safe to create inner symbols everywhere.
for (auto &fixups : moduleFixups) {
if (failed(fixups.resolve(innerSymbolNamespaces)))
return failure();
}

// Helper to transform a layer name specification of the form `A::B::C` into
// a SymbolRefAttr.
auto parseLayerName = [&](StringRef name) {
Expand Down

0 comments on commit 75a192e

Please sign in to comment.