-
Notifications
You must be signed in to change notification settings - Fork 289
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FIRRTL] Add DPI call intrinsic and lowering pass (#7139)
This PR adds DPICallIntrinsicOp and its lowering pass. DPICallIntrinsicOp is lowered into sim.func.dpi.call and sim.func.dpi ops. At FIRRTL level DPICallIntrinsicOp doesn't have symbols and instead LowerDPI pass accumulates call sites and creates symbols for dpi functions. LowerDPI pass directly lowers FIRRTL intrinsic into Sim dialect since FIRRTL doesn't have DPI/Function construct. LowerDPI pass could be simplified (or migrated into LowerToHW) once FIRRTL gets 1st class support for Function. Unrealized conversion cast is used to mix FIRRTL and HW type values before LowerToHW.
- Loading branch information
Showing
14 changed files
with
388 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
//===- LowerDPI.cpp - Lower to DPI to Sim dialects ------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file defines the LowerDPI pass. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "PassDetails.h" | ||
#include "circt/Dialect/FIRRTL/FIRRTLDialect.h" | ||
#include "circt/Dialect/FIRRTL/FIRRTLTypes.h" | ||
#include "circt/Dialect/FIRRTL/FIRRTLUtils.h" | ||
#include "circt/Dialect/FIRRTL/Namespace.h" | ||
#include "circt/Dialect/Sim/SimOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Threading.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "llvm/ADT/MapVector.h" | ||
|
||
using namespace mlir; | ||
using namespace llvm; | ||
using namespace circt; | ||
using namespace circt::firrtl; | ||
|
||
struct LowerDPIPass : public LowerDPIBase<LowerDPIPass> { | ||
void runOnOperation() override; | ||
}; | ||
|
||
void LowerDPIPass::runOnOperation() { | ||
auto circuitOp = getOperation(); | ||
|
||
CircuitNamespace nameSpace(circuitOp); | ||
MapVector<StringAttr, SmallVector<DPICallIntrinsicOp>> funcNameToCallSites; | ||
{ | ||
// A helper struct to collect DPI calls in the circuit. | ||
struct DpiCallCollections { | ||
FModuleOp module; | ||
SmallVector<DPICallIntrinsicOp> dpiOps; | ||
}; | ||
|
||
SmallVector<DpiCallCollections, 0> collections; | ||
collections.reserve(64); | ||
|
||
for (auto module : circuitOp.getOps<FModuleOp>()) | ||
collections.push_back(DpiCallCollections{module, {}}); | ||
|
||
parallelForEach(&getContext(), collections, [](auto &result) { | ||
result.module.walk( | ||
[&](DPICallIntrinsicOp dpi) { result.dpiOps.push_back(dpi); }); | ||
}); | ||
|
||
for (auto &collection : collections) | ||
for (auto dpi : collection.dpiOps) | ||
funcNameToCallSites[dpi.getFunctionNameAttr()].push_back(dpi); | ||
} | ||
|
||
for (auto [name, calls] : funcNameToCallSites) { | ||
auto firstDPICallop = calls.front(); | ||
// Construct DPI func op. | ||
auto inputTypes = firstDPICallop.getInputs().getTypes(); | ||
auto outputTypes = firstDPICallop.getResultTypes(); | ||
SmallVector<hw::ModulePort> ports; | ||
ImplicitLocOpBuilder builder(firstDPICallop.getLoc(), | ||
circuitOp.getOperation()); | ||
ports.reserve(inputTypes.size() + outputTypes.size()); | ||
|
||
// Add input arguments. | ||
for (auto [idx, inType] : llvm::enumerate(inputTypes)) { | ||
hw::ModulePort port; | ||
port.dir = hw::ModulePort::Direction::Input; | ||
port.name = builder.getStringAttr(Twine("in_") + Twine(idx)); | ||
port.type = lowerType(inType); | ||
ports.push_back(port); | ||
} | ||
|
||
// Add output arguments. | ||
for (auto [idx, outType] : llvm::enumerate(outputTypes)) { | ||
hw::ModulePort port; | ||
port.dir = hw::ModulePort::Direction::Output; | ||
port.name = builder.getStringAttr(Twine("out_") + Twine(idx)); | ||
port.type = lowerType(outType); | ||
ports.push_back(port); | ||
} | ||
|
||
auto modType = hw::ModuleType::get(&getContext(), ports); | ||
auto funcSymbol = | ||
nameSpace.newName(firstDPICallop.getFunctionNameAttr().getValue()); | ||
builder.setInsertionPointToStart(circuitOp.getBodyBlock()); | ||
auto sim = builder.create<sim::DPIFuncOp>( | ||
funcSymbol, modType, ArrayAttr(), ArrayAttr(), | ||
firstDPICallop.getFunctionNameAttr()); | ||
sim.setPrivate(); | ||
|
||
auto lowerCall = [&builder, funcSymbol](DPICallIntrinsicOp dpiOp) { | ||
auto getLowered = [&](Value value) -> Value { | ||
// Insert an unrealized conversion to cast FIRRTL type to HW type. | ||
if (!value) | ||
return value; | ||
auto type = lowerType(value.getType()); | ||
return builder.create<mlir::UnrealizedConversionCastOp>(type, value) | ||
->getResult(0); | ||
}; | ||
builder.setInsertionPoint(dpiOp); | ||
auto clock = getLowered(dpiOp.getClock()); | ||
auto enable = getLowered(dpiOp.getEnable()); | ||
SmallVector<Value, 4> inputs; | ||
inputs.reserve(dpiOp.getInputs().size()); | ||
for (auto input : dpiOp.getInputs()) | ||
inputs.push_back(getLowered(input)); | ||
|
||
SmallVector<Type> outputTypes; | ||
if (dpiOp.getResult()) | ||
outputTypes.push_back(lowerType(dpiOp.getResult().getType())); | ||
|
||
auto call = builder.create<sim::DPICallOp>(outputTypes, funcSymbol, clock, | ||
enable, inputs); | ||
if (!call.getResults().empty()) { | ||
// Insert unrealized conversion cast HW type to FIRRTL type. | ||
auto result = builder | ||
.create<mlir::UnrealizedConversionCastOp>( | ||
dpiOp.getResult().getType(), call.getResult(0)) | ||
->getResult(0); | ||
dpiOp.getResult().replaceAllUsesWith(result); | ||
} | ||
dpiOp.erase(); | ||
}; | ||
|
||
lowerCall(firstDPICallop); | ||
for (auto dpiOp : llvm::ArrayRef(calls).drop_front()) { | ||
// Check that all DPI declaration match. | ||
// TODO: This should be implemented as a verifier once function is added | ||
// to FIRRTL. | ||
if (dpiOp.getInputs().getTypes() != inputTypes) { | ||
auto diag = firstDPICallop.emitOpError() | ||
<< "DPI function " << firstDPICallop.getFunctionNameAttr() | ||
<< " input types don't match "; | ||
diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here"; | ||
return signalPassFailure(); | ||
} | ||
if (dpiOp.getResultTypes() != outputTypes) { | ||
auto diag = firstDPICallop.emitOpError() | ||
<< "DPI function " << firstDPICallop.getFunctionNameAttr() | ||
<< " output types don't match"; | ||
diag.attachNote(dpiOp.getLoc()) << " mismatched caller is here"; | ||
return signalPassFailure(); | ||
} | ||
lowerCall(dpiOp); | ||
} | ||
} | ||
} | ||
|
||
std::unique_ptr<mlir::Pass> circt::firrtl::createLowerDPIPass() { | ||
return std::make_unique<LowerDPIPass>(); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// RUN: circt-opt -firrtl-lower-dpi %s -verify-diagnostics --split-input-file | ||
|
||
// CHECK-LABEL: firrtl.circuit "DPI" { | ||
firrtl.circuit "DPI" { | ||
firrtl.module @DPI(in %in_0: !firrtl.uint<8>, in %in_1: !firrtl.uint<16>) attributes {convention = #firrtl<convention scalarized>} { | ||
// expected-error @below {{firrtl.int.dpi.call' op DPI function "foo" input types don't match}} | ||
firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> () | ||
// expected-note @below {{mismatched caller is here}} | ||
firrtl.int.dpi.call "foo"(%in_1) : (!firrtl.uint<16>) -> () | ||
} | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: firrtl.circuit "DPI" { | ||
firrtl.circuit "DPI" { | ||
firrtl.module @DPI(in %in_0: !firrtl.uint<8>) attributes {convention = #firrtl<convention scalarized>} { | ||
// expected-error @below {{firrtl.int.dpi.call' op DPI function "foo" output types don't match}} | ||
%0 = firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> (!firrtl.uint<16>) | ||
// expected-note @below {{mismatched caller is here}} | ||
%1 = firrtl.int.dpi.call "foo"(%in_0) : (!firrtl.uint<8>) -> (!firrtl.uint<8>) | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// RUN: circt-opt -firrtl-lower-dpi %s | FileCheck %s | ||
|
||
// CHECK-LABEL: firrtl.circuit "DPI" { | ||
firrtl.circuit "DPI" { | ||
// CHECK-NEXT: sim.func.dpi private @unclocked_result(in %in_0 : i2, in %in_1 : i2, out out_0 : i2) attributes {verilogName = "unclocked_result"} | ||
// CHECK-NEXT: sim.func.dpi private @clocked_void(in %in_0 : i2, in %in_1 : i2) attributes {verilogName = "clocked_void"} | ||
// CHECK-NEXT: sim.func.dpi private @clocked_result(in %in_0 : i2, in %in_1 : i2, out out_0 : i2) attributes {verilogName = "clocked_result"} | ||
// CHECK-LABEL: firrtl.module @DPI | ||
firrtl.module @DPI(in %clock: !firrtl.clock, in %enable: !firrtl.uint<1>, in %in_0: !firrtl.uint<2>, in %in_1: !firrtl.uint<2>, out %out_0: !firrtl.uint<2>, out %out_1: !firrtl.uint<2>) attributes {convention = #firrtl<convention scalarized>} { | ||
// CHECK-NEXT: %0 = builtin.unrealized_conversion_cast %clock : !firrtl.clock to !seq.clock | ||
// CHECK-NEXT: %1 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1 | ||
// CHECK-NEXT: %2 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2 | ||
// CHECK-NEXT: %3 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2 | ||
// CHECK-NEXT: %4 = sim.func.dpi.call @clocked_result(%2, %3) clock %0 enable %1 : (i2, i2) -> i2 | ||
// CHECK-NEXT: %5 = builtin.unrealized_conversion_cast %4 : i2 to !firrtl.uint<2> | ||
// CHECK-NEXT: %6 = builtin.unrealized_conversion_cast %clock : !firrtl.clock to !seq.clock | ||
// CHECK-NEXT: %7 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1 | ||
// CHECK-NEXT: %8 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2 | ||
// CHECK-NEXT: %9 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2 | ||
// CHECK-NEXT: sim.func.dpi.call @clocked_void(%8, %9) clock %6 enable %7 : (i2, i2) -> () | ||
// CHECK-NEXT: %10 = builtin.unrealized_conversion_cast %enable : !firrtl.uint<1> to i1 | ||
// CHECK-NEXT: %11 = builtin.unrealized_conversion_cast %in_0 : !firrtl.uint<2> to i2 | ||
// CHECK-NEXT: %12 = builtin.unrealized_conversion_cast %in_1 : !firrtl.uint<2> to i2 | ||
// CHECK-NEXT: %13 = sim.func.dpi.call @unclocked_result(%11, %12) enable %10 : (i2, i2) -> i2 | ||
// CHECK-NEXT: %14 = builtin.unrealized_conversion_cast %13 : i2 to !firrtl.uint<2> | ||
// CHECK-NEXT:firrtl.matchingconnect %out_0, %5 : !firrtl.uint<2> | ||
// CHECK-NEXT:firrtl.matchingconnect %out_1, %14 : !firrtl.uint<2> | ||
%0 = firrtl.int.dpi.call "clocked_result"(%in_0, %in_1) clock %clock enable %enable {name = "result1"} : (!firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2> | ||
firrtl.int.dpi.call "clocked_void"(%in_0, %in_1) clock %clock enable %enable : (!firrtl.uint<2>, !firrtl.uint<2>) -> () | ||
%1 = firrtl.int.dpi.call "unclocked_result"(%in_0, %in_1) enable %enable {name = "result2"} : (!firrtl.uint<2>, !firrtl.uint<2>) -> !firrtl.uint<2> | ||
firrtl.matchingconnect %out_0, %0 : !firrtl.uint<2> | ||
firrtl.matchingconnect %out_1, %1 : !firrtl.uint<2> | ||
} | ||
} |
Oops, something went wrong.