Skip to content

Commit

Permalink
Merge pull request google#377 from j2kun:support-no-return-types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599260952
  • Loading branch information
copybara-github committed Jan 17, 2024
2 parents d5ab397 + 2babb03 commit 43c1771
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 88 deletions.
4 changes: 2 additions & 2 deletions docs/content/en/docs/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ Example output:
}
}
],
"return_type": {
"return_types": [{
"memref": {
"element_type": {
"integer": {
Expand All @@ -209,7 +209,7 @@ Example output:
},
"shape": [1, 3, 2, 1]
}
}
}]
}
]
}
Expand Down
5 changes: 2 additions & 3 deletions include/Target/Verilog/VerilogEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class VerilogEmitter {

// A helper to generalize the work of emitting a func.return and a
// secret.yield
LogicalResult printReturnLikeOp(Value returnValue);
LogicalResult printReturnLikeOp(ValueRange returnValues);

// Functions for printing individual ops
LogicalResult printOperation(mlir::ModuleOp op,
Expand Down Expand Up @@ -105,8 +105,6 @@ class VerilogEmitter {
LogicalResult printOperation(mlir::affine::AffineLoadOp op);
LogicalResult printOperation(mlir::affine::AffineStoreOp op);
LogicalResult printOperation(mlir::func::CallOp op);
LogicalResult printOperation(mlir::func::ReturnOp op);
LogicalResult printOperation(mlir::heir::secret::YieldOp op);
LogicalResult printOperation(mlir::math::CountLeadingZerosOp op);
LogicalResult printOperation(mlir::memref::LoadOp op);

Expand All @@ -116,6 +114,7 @@ class VerilogEmitter {

// Emit a Verilog type of the form `wire [width-1:0]`
LogicalResult emitType(Type type);
LogicalResult emitType(Type type, raw_ostream &os);

// Emit a Verilog array shape specifier of the form `[width]`
LogicalResult emitArrayShapeSuffix(Type type);
Expand Down
49 changes: 26 additions & 23 deletions lib/Target/Metadata/MetadataEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,25 +165,6 @@ FailureOr<llvm::json::Object> MetadataEmitter::typeAsJson(MemRefType &ty) {
}

FailureOr<llvm::json::Object> MetadataEmitter::emitOperation(FuncOp funcOp) {
auto result_types = funcOp.getFunctionType().getResults();
if (result_types.size() != 1) {
emitError(funcOp.getLoc(),
"Only functions with a single return type are supported");
return failure();
}
auto output_type = result_types[0];
auto status =
llvm::TypeSwitch<Type &, FailureOr<llvm::json::Object>>(output_type)
.Case<IntegerType, MemRefType>(
[&](auto ty) { return typeAsJson(ty); })
.Default([&](Type &) { return failure(); });

if (failed(status)) {
funcOp.emitOpError(
llvm::formatv("Failed to handle output type {0}", output_type));
return failure();
}

llvm::json::Array arguments;
for (auto arg : funcOp.getArguments()) {
Type type = arg.getType();
Expand All @@ -203,13 +184,35 @@ FailureOr<llvm::json::Object> MetadataEmitter::emitOperation(FuncOp funcOp) {
});
}

llvm::json::Object function{
auto resultTypes = funcOp.getFunctionType().getResults();
llvm::json::Array resultsJson;
if (resultTypes.size() > 1) {
emitError(funcOp.getLoc(),
"Only functions with <=1 return types are supported");
return failure();
}

if (resultTypes.size() == 1) {
auto outputType = resultTypes[0];
auto status =
llvm::TypeSwitch<Type &, FailureOr<llvm::json::Object>>(outputType)
.Case<IntegerType, MemRefType>(
[&](auto ty) { return typeAsJson(ty); })
.Default([&](Type &) { return failure(); });

if (failed(status)) {
funcOp.emitOpError(
llvm::formatv("Failed to handle output type {0}", outputType));
return failure();
}
resultsJson.push_back(std::move(status.value()));
}

return llvm::json::Object{
{"name", funcOp.getName()},
{"return_type", std::move(status.value())},
{"return_types", std::move(resultsJson)},
{"params", std::move(arguments)},
};

return std::move(function);
}

} // namespace heir
Expand Down
105 changes: 63 additions & 42 deletions lib/Target/Verilog/VerilogEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,23 @@ void printRawDataFromAttr(DenseElementsAttr attr, raw_ostream &os) {
}
}

llvm::SmallString<128> variableLoadStr(StringRef memref, StringRef index,
llvm::SmallString<128> flattenIndexExpression(
const llvm::ArrayRef<StringRef> indices,
const llvm::ArrayRef<int64_t> sizes, int width) {
llvm::SmallString<128> accum = llvm::formatv("{0}", indices[0]);
for (int i = 1; i < indices.size(); ++i) {
accum = llvm::formatv("{0} + {1} * ({2})", indices[i], sizes[i], accum);
}
return indices.size() == 1 ? llvm::formatv("{0} * {1}", width, accum)
: llvm::formatv("{0} * ({1})", width, accum);
}

llvm::SmallString<128> variableLoadStr(StringRef memref,
llvm::ArrayRef<StringRef> indices,
llvm::ArrayRef<int64_t> sizes,
unsigned int width) {
return llvm::formatv("{0}[{1} + {2} * {3} : {2} * {3}]", memref, width - 1,
width, index);
llvm::SmallString<128> index = flattenIndexExpression(indices, sizes, width);
return llvm::formatv("{0}[{1} + {2} : {2}]", memref, width - 1, index);
}

struct CtlzValueStruct {
Expand Down Expand Up @@ -170,8 +183,10 @@ LogicalResult VerilogEmitter::translate(
.Case<ModuleOp, func::FuncOp, secret::GenericOp>(
[&](auto op) { return printOperation(op, moduleName); })
// Func ops.
.Case<func::ReturnOp, func::CallOp, secret::YieldOp>(
[&](auto op) { return printOperation(op); })
.Case<func::CallOp>([&](auto op) { return printOperation(op); })
// Return-like ops
.Case<func::ReturnOp, secret::YieldOp>(
[&](auto op) { return printReturnLikeOp(op.getOperands()); })
// Arithmetic ops.
.Case<arith::ConstantOp>([&](auto op) {
if (auto iAttr = dyn_cast<IndexType>(op.getValue().getType())) {
Expand Down Expand Up @@ -265,28 +280,36 @@ LogicalResult VerilogEmitter::printFunctionLikeOp(
*/
os_ << "module " << verilogModuleName << "(\n";
os_.indent();
llvm::SmallVector<std::string, 4> argsToPrint;
for (auto arg : arguments) {
// e.g., `input wire [31:0] arg0,`
os_ << "input ";
if (failed(emitType(arg.getType()))) {
std::string result;
llvm::raw_string_ostream ss(result);
ss << "input ";
if (failed(emitType(arg.getType(), ss))) {
op->emitError() << "failed to emit type" << arg.getType();
return failure();
}
os_ << " " << getOrCreateName(arg) << ",\n";
ss << " " << getOrCreateName(arg);
argsToPrint.push_back(ss.str());
}
os_ << llvm::join(argsToPrint.begin(), argsToPrint.end(), ",\n");

// output arg declaration
if (resultTypes.size() != 1) {
if (resultTypes.size() > 1) {
emitError(op->getLoc(),
"Only functions with a single return type are supported");
"Only functions with a <= 1 return types are supported");
return failure();
}
os_ << "output ";
if (failed(emitType(resultTypes.front()))) {
op->emitError() << "failed to emit type" << resultTypes.front();
return failure();
if (resultTypes.size() == 1) {
os_ << ",\n";
os_ << "output ";
if (failed(emitType(resultTypes.front()))) {
op->emitError() << "failed to emit type" << resultTypes.front();
return failure();
}
os_ << " " << kOutputName;
}
os_ << " " << kOutputName;

// End of module header
os_.unindent();
Expand Down Expand Up @@ -403,23 +426,16 @@ LogicalResult VerilogEmitter::printOperation(
blocks->begin(), blocks->end());
}

LogicalResult VerilogEmitter::printReturnLikeOp(Value returnValue) {
LogicalResult VerilogEmitter::printReturnLikeOp(ValueRange returnValues) {
// Return is an assignment to the output wire
// e.g., assign out = x1200;
os_ << "assign " << kOutputName << " = " << getName(returnValue) << ";\n";
if (returnValues.empty()) {
return success();
}
os_ << "assign " << kOutputName << " = " << getName(returnValues[0]) << ";\n";
return success();
}

LogicalResult VerilogEmitter::printOperation(func::ReturnOp op) {
// Only support one return value.
return printReturnLikeOp(op.getOperands()[0]);
}

LogicalResult VerilogEmitter::printOperation(secret::YieldOp op) {
// Only support one return value.
return printReturnLikeOp(op.getOperands()[0]);
}

LogicalResult VerilogEmitter::printOperation(func::CallOp op) {
// e.g., submodule submod_call(xInput0, xInput1, xOutput);
std::string opName = (getOrCreateName(op.getResult(0)) + "_call").str();
Expand Down Expand Up @@ -620,12 +636,14 @@ LogicalResult VerilogEmitter::printOperation(affine::AffineLoadOp op) {
os_ << memrefStr << "[" << flattenedBitIndex + width - 1 << " : "
<< flattenedBitIndex << "];\n";
} else {
if (op.getMemRefType().getRank() > 1) {
// TODO(b/284323495): Handle multi-dim variable access.
return failure();
}
emitAssignPrefix(op.getResult());
os_ << variableLoadStr(memrefStr, getOrCreateName(op.getIndices()[0]),

llvm::SmallVector<StringRef, 4> indices;
for (auto index : op.getIndices()) {
indices.push_back(getOrCreateName(index));
}

os_ << variableLoadStr(memrefStr, indices, op.getMemRefType().getShape(),
width)
<< ";\n";
}
Expand All @@ -640,17 +658,16 @@ LogicalResult VerilogEmitter::printOperation(memref::LoadOp op) {
return failure();
}

auto memrefStr = getOrCreateName(op.getMemref());
auto indexStr = getOrCreateName(op.getIndices()[0]);
auto width = iType.getWidth();
emitAssignPrefix(op.getResult());

if (op.getMemRefType().getRank() > 1) {
// TODO(b/284323495): Handle multi-dim variable access.
return failure();
llvm::SmallVector<StringRef, 4> indices;
for (auto index : op.getIndices()) {
indices.push_back(getOrCreateName(index));
}

emitAssignPrefix(op.getResult());
os_ << variableLoadStr(memrefStr, indexStr, width) << ";\n";
os_ << variableLoadStr(getOrCreateName(op.getMemref()), indices,
op.getMemRefType().getShape(), iType.getWidth())
<< ";\n";

return success();
}
Expand Down Expand Up @@ -740,15 +757,19 @@ LogicalResult VerilogEmitter::printOperation(
}

LogicalResult VerilogEmitter::emitType(Type type) {
return emitType(type, os_);
}

LogicalResult VerilogEmitter::emitType(Type type, raw_ostream &os) {
if (auto iType = dyn_cast<IntegerType>(type)) {
int32_t width = iType.getWidth();
return (os_ << wireDeclaration(iType, width)), success();
return (os << wireDeclaration(iType, width)), success();
}
if (auto memRefType = dyn_cast<MemRefType>(type)) {
auto elementType = memRefType.getElementType();
if (auto iType = dyn_cast<IntegerType>(elementType)) {
int32_t flattenedWidth = memRefType.getNumElements() * iType.getWidth();
return (os_ << wireDeclaration(iType, flattenedWidth)), success();
return (os << wireDeclaration(iType, flattenedWidth)), success();
}
}
return failure();
Expand Down
63 changes: 47 additions & 16 deletions tests/verilog/emit_metadata.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: heir-translate --allow-unregistered-dialect --emit-metadata %s > %t
// RUN: FileCheck %s < %t
// RUN: heir-translate --allow-unregistered-dialect --emit-metadata %s | FileCheck %s

module {
func.func @main(%arg0: memref<1x80xi8>) -> memref<1x3x2x1xi8> {
Expand All @@ -21,6 +20,13 @@ module {
}
return %alloc_0 : memref<1x3x2x1xi8>
}

func.func @main2(%arg0: memref<80xi8>) {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : i8
memref.store %c8, %arg0[%c0] : memref<80xi8>
return
}
}

// CHECK: {
Expand All @@ -46,22 +52,47 @@ module {
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: ],
// CHECK-NEXT: "return_type": {
// CHECK-NEXT: "memref": {
// CHECK-NEXT: "element_type": {
// CHECK-NEXT: "integer": {
// CHECK-NEXT: "is_signed": false,
// CHECK-NEXT: "width": 8
// CHECK-NEXT: "return_types": [
// CHECK-NEXT: {
// CHECK-NEXT: "memref": {
// CHECK-NEXT: "element_type": {
// CHECK-NEXT: "integer": {
// CHECK-NEXT: "is_signed": false,
// CHECK-NEXT: "width": 8
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "shape": [
// CHECK-NEXT: 1,
// CHECK-NEXT: 3,
// CHECK-NEXT: 2,
// CHECK-NEXT: 1
// CHECK-NEXT: ]
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: },

// CHECK-NEXT: {
// CHECK-NEXT: "name": "main2",
// CHECK-NEXT: "params": [
// CHECK-NEXT: {
// CHECK-NEXT: "index": 0,
// CHECK-NEXT: "type": {
// CHECK-NEXT: "memref": {
// CHECK-NEXT: "element_type": {
// CHECK-NEXT: "integer": {
// CHECK-NEXT: "is_signed": false,
// CHECK-NEXT: "width": 8
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "shape": [
// CHECK-NEXT: 80
// CHECK-NEXT: ]
// CHECK-NEXT: }
// CHECK-NEXT: },
// CHECK-NEXT: "shape": [
// CHECK-NEXT: 1,
// CHECK-NEXT: 3,
// CHECK-NEXT: 2,
// CHECK-NEXT: 1
// CHECK-NEXT: ]
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: ],
// CHECK-NEXT: "return_types": []
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
Loading

0 comments on commit 43c1771

Please sign in to comment.