Skip to content

Commit

Permalink
[FIRRTL] Add DPI call intrinsic and lowering pass (#7139)
Browse files Browse the repository at this point in the history
This PR adds DPICallIntrinsicOp and its lowering pass. DPICallIntrinsicOp is lowered into sim.func.dpi.call and sim.func.dpi ops. At FIRRTL level DPICallIntrinsicOp doesn't have symbols and instead LowerDPI pass accumulates call sites and creates symbols for dpi functions. LowerDPI pass directly lowers FIRRTL intrinsic into Sim dialect since FIRRTL doesn't have DPI/Function construct. LowerDPI pass could be simplified (or migrated into LowerToHW) once FIRRTL gets 1st class support for Function. Unrealized conversion cast is used to mix FIRRTL and HW type values before LowerToHW.
  • Loading branch information
uenoku authored Jun 13, 2024
1 parent 2876be2 commit 44becae
Show file tree
Hide file tree
Showing 14 changed files with 388 additions and 1 deletion.
28 changes: 28 additions & 0 deletions docs/Dialects/FIRRTL/FIRRTLIntrinsics.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,31 @@ ifdef USE_FORMAL_ONLY_CONSTRAINTS
`endif // USE_UNR_ONLY_CONSTRAINTS
endif // USE_FORMAL_ONLY_CONSTRAINTS
```

### circt_dpi_call

Call a DPI function. `clock` is optional and if `clock` is not provided,
the callee is invoked when input values are changed.
If provided, the dpi function is called at clock's posedge. The result values behave
like registers and the DPI function is used as a state transfer function of them.

`enable` operand is used to conditionally call the DPI since DPI call could be quite
more expensive than native constructs. When `enable` is low, results of unclocked
calls are undefined and evaluated into `X`. Users are expected to gate result values
by another `enable` to model a default value of results.

For clocked calls, a low enable means that its register state transfer function is
not called. Hence their values will not be modify in that clock.

| Parameter | Type | Description |
| ------------- | ------ | -------------------------------- |
| isClocked | int | Set 1 if the dpi call is clocked |
| functionName | string | Specify the function name |


| Port | Direction | Type | Description |
| ----------------- | --------- | -------- | ------------------------------- |
| clock (optional) | input | Clock | Optional clock operand |
| enable | input | UInt<1> | Enable signal |
| ... | input | Signals | Arguments to DPI function call |
| result (optional) | output | Signal | Optional result of the dpi call |
19 changes: 19 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,23 @@ def HasBeenResetIntrinsicOp : FIRRTLOp<"int.has_been_reset", [Pure]> {
}


def DPICallIntrinsicOp : FIRRTLOp<"int.dpi.call",
[AttrSizedOperandSegments]> {
let summary = "Import and call DPI function";
let description = [{
The `int.dpi.call` intrinsic calls an external function.
See Sim dialect DPI call op.
}];

let arguments = (ins StrAttr:$functionName,
Optional<NonConstClockType>:$clock,
Optional<NonConstUInt1Type>:$enable,
Variadic<PassiveType>:$inputs);
let results = (outs Optional<PassiveType>:$result);
let assemblyFormat = [{
$functionName `(` $inputs `)` (`clock` $clock^)? (`enable` $enable^)?
attr-dict `:` functional-type($inputs, results)
}];
}

#endif // CIRCT_DIALECT_FIRRTL_FIRRTLINTRINSICS_TD
2 changes: 2 additions & 0 deletions include/circt/Dialect/FIRRTL/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@ std::unique_ptr<mlir::Pass> createCreateCompanionAssume();

std::unique_ptr<mlir::Pass> createModuleSummaryPass();

std::unique_ptr<mlir::Pass> createLowerDPIPass();

/// Generate the code for registering passes.
#define GEN_PASS_REGISTRATION
#include "circt/Dialect/FIRRTL/Passes.h.inc"
Expand Down
6 changes: 6 additions & 0 deletions include/circt/Dialect/FIRRTL/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -908,4 +908,10 @@ def ModuleSummary :
let constructor = "circt::firrtl::createModuleSummaryPass()";
}

def LowerDPI : Pass<"firrtl-lower-dpi", "firrtl::CircuitOp"> {
let summary = "Lower DPI intrinsic into Sim DPI operations";
let constructor = "circt::firrtl::createLowerDPIPass()";
let dependentDialects = ["hw::HWDialect", "seq::SeqDialect", "sim::SimDialect"];
}

#endif // CIRCT_DIALECT_FIRRTL_PASSES_TD
2 changes: 1 addition & 1 deletion include/circt/Dialect/Sim/SimOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def FatalOp : SimOp<"fatal"> {

def DPIFuncOp : SimOp<"func.dpi",
[IsolatedFromAbove, Symbol, OpAsmOpInterface,
FunctionOpInterface, HasParent<"mlir::ModuleOp">]> {
FunctionOpInterface]> {
let summary = "A System Verilog function";
let description = [{
`sim.func.dpi` models an external function in a core dialect.
Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/FIRRTL/FIRRTLIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,45 @@ class CirctUnclockedAssumeConverter : public IntrinsicConverter {
}
};

class CirctDPICallConverter : public IntrinsicConverter {
static bool getIsClocked(GenericIntrinsic gi) {
return !gi.getParamValue<IntegerAttr>("isClocked").getValue().isZero();
}

public:
using IntrinsicConverter::IntrinsicConverter;

bool check(GenericIntrinsic gi) override {
if (gi.hasNParam(2) || gi.namedIntParam("isClocked") ||
gi.namedParam("functionName"))
return true;
auto isClocked = getIsClocked(gi);
// If clocked, the first operand must be a clock.
if (isClocked && gi.typedInput<ClockType>(0))
return true;
// Enable must be UInt<1>.
if (gi.sizedInput<UIntType>(isClocked, 1))
return true;

return false;
}

void convert(GenericIntrinsic gi, GenericIntrinsicOpAdaptor adaptor,
PatternRewriter &rewriter) override {
auto isClocked = getIsClocked(gi);
auto functionName = gi.getParamValue<StringAttr>("functionName");
// Clock and enable are optional.
Value clock = isClocked ? adaptor.getOperands()[0] : Value();
Value enable = adaptor.getOperands()[static_cast<size_t>(isClocked)];

auto inputs =
adaptor.getOperands().drop_front(static_cast<size_t>(isClocked) + 1);

rewriter.replaceOpWithNewOp<DPICallIntrinsicOp>(
gi.op, gi.op.getResultTypes(), functionName, clock, enable, inputs);
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -704,4 +743,5 @@ void FIRRTLIntrinsicLoweringDialectInterface::populateIntrinsicLowerings(
lowering.add<CirctCoverConverter>("circt.chisel_cover", "circt_chisel_cover");
lowering.add<CirctUnclockedAssumeConverter>("circt.unclocked_assume",
"circt_unclocked_assume");
lowering.add<CirctDPICallConverter>("circt.dpi_call", "circt_dpi_call");
}
3 changes: 3 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms
LowerAnnotations.cpp
LowerCHIRRTL.cpp
LowerClasses.cpp
LowerDPI.cpp
LowerIntmodules.cpp
LowerIntrinsics.cpp
LowerLayers.cpp
Expand Down Expand Up @@ -69,6 +70,8 @@ add_circt_dialect_library(CIRCTFIRRTLTransforms
CIRCTEmit
CIRCTHW
CIRCTOM
CIRCTSeq
CIRCTSim
CIRCTSV
CIRCTSupport
MLIRIR
Expand Down
158 changes: 158 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/LowerDPI.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
//===- LowerDPI.cpp - Lower to DPI to Sim dialects ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the LowerDPI pass.
//
//===----------------------------------------------------------------------===//

#include "PassDetails.h"
#include "circt/Dialect/FIRRTL/FIRRTLDialect.h"
#include "circt/Dialect/FIRRTL/FIRRTLTypes.h"
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h"
#include "circt/Dialect/FIRRTL/Namespace.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Threading.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/MapVector.h"

using namespace mlir;
using namespace llvm;
using namespace circt;
using namespace circt::firrtl;

struct LowerDPIPass : public LowerDPIBase<LowerDPIPass> {
void runOnOperation() override;
};

void LowerDPIPass::runOnOperation() {
auto circuitOp = getOperation();

CircuitNamespace nameSpace(circuitOp);
MapVector<StringAttr, SmallVector<DPICallIntrinsicOp>> funcNameToCallSites;
{
// A helper struct to collect DPI calls in the circuit.
struct DpiCallCollections {
FModuleOp module;
SmallVector<DPICallIntrinsicOp> dpiOps;
};

SmallVector<DpiCallCollections, 0> collections;
collections.reserve(64);

for (auto module : circuitOp.getOps<FModuleOp>())
collections.push_back(DpiCallCollections{module, {}});

parallelForEach(&getContext(), collections, [](auto &result) {
result.module.walk(
[&](DPICallIntrinsicOp dpi) { result.dpiOps.push_back(dpi); });
});

for (auto &collection : collections)
for (auto dpi : collection.dpiOps)
funcNameToCallSites[dpi.getFunctionNameAttr()].push_back(dpi);
}

for (auto [name, calls] : funcNameToCallSites) {
auto firstDPICallop = calls.front();
// Construct DPI func op.
auto inputTypes = firstDPICallop.getInputs().getTypes();
auto outputTypes = firstDPICallop.getResultTypes();
SmallVector<hw::ModulePort> ports;
ImplicitLocOpBuilder builder(firstDPICallop.getLoc(),
circuitOp.getOperation());
ports.reserve(inputTypes.size() + outputTypes.size());

// Add input arguments.
for (auto [idx, inType] : llvm::enumerate(inputTypes)) {
hw::ModulePort port;
port.dir = hw::ModulePort::Direction::Input;
port.name = builder.getStringAttr(Twine("in_") + Twine(idx));
port.type = lowerType(inType);
ports.push_back(port);
}

// Add output arguments.
for (auto [idx, outType] : llvm::enumerate(outputTypes)) {
hw::ModulePort port;
port.dir = hw::ModulePort::Direction::Output;
port.name = builder.getStringAttr(Twine("out_") + Twine(idx));
port.type = lowerType(outType);
ports.push_back(port);
}

auto modType = hw::ModuleType::get(&getContext(), ports);
auto funcSymbol =
nameSpace.newName(firstDPICallop.getFunctionNameAttr().getValue());
builder.setInsertionPointToStart(circuitOp.getBodyBlock());
auto sim = builder.create<sim::DPIFuncOp>(
funcSymbol, modType, ArrayAttr(), ArrayAttr(),
firstDPICallop.getFunctionNameAttr());
sim.setPrivate();

auto lowerCall = [&builder, funcSymbol](DPICallIntrinsicOp dpiOp) {
auto getLowered = [&](Value value) -> Value {
// Insert an unrealized conversion to cast FIRRTL type to HW type.
if (!value)
return value;
auto type = lowerType(value.getType());
return builder.create<mlir::UnrealizedConversionCastOp>(type, value)
->getResult(0);
};
builder.setInsertionPoint(dpiOp);
auto clock = getLowered(dpiOp.getClock());
auto enable = getLowered(dpiOp.getEnable());
SmallVector<Value, 4> inputs;
inputs.reserve(dpiOp.getInputs().size());
for (auto input : dpiOp.getInputs())
inputs.push_back(getLowered(input));

SmallVector<Type> outputTypes;
if (dpiOp.getResult())
outputTypes.push_back(lowerType(dpiOp.getResult().getType()));

auto call = builder.create<sim::DPICallOp>(outputTypes, funcSymbol, clock,
enable, inputs);
if (!call.getResults().empty()) {
// Insert unrealized conversion cast HW type to FIRRTL type.
auto result = builder
.create<mlir::UnrealizedConversionCastOp>(
dpiOp.getResult().getType(), call.getResult(0))
->getResult(0);
dpiOp.getResult().replaceAllUsesWith(result);
}
dpiOp.erase();
};

lowerCall(firstDPICallop);
for (auto dpiOp : llvm::ArrayRef(calls).drop_front()) {
// Check that all DPI declaration match.
// TODO: This should be implemented as a verifier once function is added
// to FIRRTL.
if (dpiOp.getInputs().getTypes() != inputTypes) {
auto diag = firstDPICallop.emitOpError()
<< "DPI function " << firstDPICallop.getFunctionNameAttr()
<< " input types don't match ";
diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here";
return signalPassFailure();
}
if (dpiOp.getResultTypes() != outputTypes) {
auto diag = firstDPICallop.emitOpError()
<< "DPI function " << firstDPICallop.getFunctionNameAttr()
<< " output types don't match";
diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here";
return signalPassFailure();
}
lowerCall(dpiOp);
}
}
}

std::unique_ptr<mlir::Pass> circt::firrtl::createLowerDPIPass() {
return std::make_unique<LowerDPIPass>();
}
8 changes: 8 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/PassDetails.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ namespace emit {
class EmitDialect;
} // namespace emit

namespace seq {
class SeqDialect;
} // namespace seq

namespace sim {
class SimDialect;
} // namespace sim

namespace sv {
class SVDialect;
} // namespace sv
Expand Down
1 change: 1 addition & 0 deletions lib/Firtool/Firtool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ LogicalResult firtool::populateLowFIRRTLToHW(mlir::PassManager &pm,
// RefType ports and ops.
pm.nest<firrtl::CircuitOp>().addPass(firrtl::createLowerXMRPass());

pm.nest<firrtl::CircuitOp>().addPass(firrtl::createLowerDPIPass());
pm.nest<firrtl::CircuitOp>().addPass(firrtl::createLowerClassesPass());
pm.nest<firrtl::CircuitOp>().addPass(om::createVerifyObjectFieldsPass());

Expand Down
24 changes: 24 additions & 0 deletions test/Dialect/FIRRTL/lower-dpi-error.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: circt-opt -firrtl-lower-dpi %s -verify-diagnostics --split-input-file

// CHECK-LABEL: firrtl.circuit "DPI" {
firrtl.circuit "DPI" {
firrtl.module @DPI(in %in_0: !firrtl.uint<8>, in %in_1: !firrtl.uint<16>) attributes {convention = #firrtl<convention scalarized>} {
// expected-error @below {{firrtl.int.dpi.call' op DPI function "foo" input types don't match}}
firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> ()
// expected-note @below {{mismatched caller is here}}
firrtl.int.dpi.call "foo"(%in_1) : (!firrtl.uint<16>) -> ()
}
}

// -----

// CHECK-LABEL: firrtl.circuit "DPI" {
firrtl.circuit "DPI" {
firrtl.module @DPI(in %in_0: !firrtl.uint<8>) attributes {convention = #firrtl<convention scalarized>} {
// expected-error @below {{firrtl.int.dpi.call' op DPI function "foo" output types don't match}}
%0 = firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> (!firrtl.uint<16>)
// expected-note @below {{mismatched caller is here}}
%1 = firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> (!firrtl.uint<8>)
}
}

34 changes: 34 additions & 0 deletions test/Dialect/FIRRTL/lower-dpi.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: circt-opt -firrtl-lower-dpi %s | FileCheck %s

// CHECK-LABEL: firrtl.circuit "DPI" {
firrtl.circuit "DPI" {
// CHECK-NEXT: sim.func.dpi private @unclocked_result(in %in_0 : i2, in %in_1 : i2, out out_0 : i2) attributes {verilogName = "unclocked_result"}
// CHECK-NEXT: sim.func.dpi private @clocked_void(in %in_0 : i2, in %in_1 : i2) attributes {verilogName = "clocked_void"}
// CHECK-NEXT: sim.func.dpi private @clocked_result(in %in_0 : i2, in %in_1 : i2, out out_0 : i2) attributes {verilogName = "clocked_result"}
// CHECK-LABEL: firrtl.module @DPI
firrtl.module @DPI(in %clock: !firrtl.clock, in %enable: !firrtl.uint<1>, in %in_0: !firrtl.uint<2>, in %in_1: !firrtl.uint<2>, out %out_0: !firrtl.uint<2>, out %out_1: !firrtl.uint<2>) attributes {convention = #firrtl<convention scalarized>} {
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %clock : !firrtl.clock to !seq.clock
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1
// CHECK-NEXT: %2 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2
// CHECK-NEXT: %3 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2
// CHECK-NEXT: %4 = sim.func.dpi.call @clocked_result(%2, %3) clock %0 enable %1 : (i2, i2) -> i2
// CHECK-NEXT: %5 = builtin.unrealized_conversion_cast %4 : i2 to !firrtl.uint<2>
// CHECK-NEXT: %6 = builtin.unrealized_conversion_cast %clock : !firrtl.clock to !seq.clock
// CHECK-NEXT: %7 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1
// CHECK-NEXT: %8 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2
// CHECK-NEXT: sim.func.dpi.call @clocked_void(%8, %9) clock %6 enable %7 : (i2, i2) -> ()
// CHECK-NEXT: %10 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1
// CHECK-NEXT: %11 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2
// CHECK-NEXT: %12 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2
// CHECK-NEXT: %13 = sim.func.dpi.call @unclocked_result(%11, %12) enable %10 : (i2, i2) -> i2
// CHECK-NEXT: %14 = builtin.unrealized_conversion_cast %13 : i2 to !firrtl.uint<2>
// CHECK-NEXT:firrtl.matchingconnect %out_0, %5 : !firrtl.uint<2>
// CHECK-NEXT:firrtl.matchingconnect %out_1, %14 : !firrtl.uint<2>
%0 = firrtl.int.dpi.call "clocked_result"(%in_0, %in_1) clock %clock enable %enable {name = "result1"} : (!firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
firrtl.int.dpi.call "clocked_void"(%in_0, %in_1) clock %clock enable %enable : (!firrtl.uint<2>, !firrtl.uint<2>) -> ()
%1 = firrtl.int.dpi.call "unclocked_result"(%in_0, %in_1) enable %enable {name = "result2"} : (!firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2>
firrtl.matchingconnect %out_0, %0 : !firrtl.uint<2>
firrtl.matchingconnect %out_1, %1 : !firrtl.uint<2>
}
}
Loading

0 comments on commit 44becae

Please sign in to comment.