From 2babb0310d840d361e82b5c10eddd2a068b47443 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Thu, 11 Jan 2024 16:35:41 -0800 Subject: [PATCH] support no return values in emit_metadata and verilog emitter --- docs/content/en/docs/pipelines.md | 4 +- include/Target/Verilog/VerilogEmitter.h | 5 +- lib/Target/Metadata/MetadataEmitter.cpp | 49 ++++----- lib/Target/Verilog/VerilogEmitter.cpp | 105 ++++++++++++-------- tests/verilog/emit_metadata.mlir | 63 +++++++++--- tests/verilog/emit_verilog_memref_load.mlir | 26 ++++- 6 files changed, 164 insertions(+), 88 deletions(-) diff --git a/docs/content/en/docs/pipelines.md b/docs/content/en/docs/pipelines.md index f756b9311..ba77f23a9 100644 --- a/docs/content/en/docs/pipelines.md +++ b/docs/content/en/docs/pipelines.md @@ -199,7 +199,7 @@ Example output: } } ], - "return_type": { + "return_types": [{ "memref": { "element_type": { "integer": { @@ -209,7 +209,7 @@ Example output: }, "shape": [1, 3, 2, 1] } - } + }] } ] } diff --git a/include/Target/Verilog/VerilogEmitter.h b/include/Target/Verilog/VerilogEmitter.h index 75461940a..a0479cc9d 100644 --- a/include/Target/Verilog/VerilogEmitter.h +++ b/include/Target/Verilog/VerilogEmitter.h @@ -77,7 +77,7 @@ class VerilogEmitter { // A helper to generalize the work of emitting a func.return and a // secret.yield - LogicalResult printReturnLikeOp(Value returnValue); + LogicalResult printReturnLikeOp(ValueRange returnValues); // Functions for printing individual ops LogicalResult printOperation(mlir::ModuleOp op, @@ -105,8 +105,6 @@ class VerilogEmitter { LogicalResult printOperation(mlir::affine::AffineLoadOp op); LogicalResult printOperation(mlir::affine::AffineStoreOp op); LogicalResult printOperation(mlir::func::CallOp op); - LogicalResult printOperation(mlir::func::ReturnOp op); - LogicalResult printOperation(mlir::heir::secret::YieldOp op); LogicalResult printOperation(mlir::math::CountLeadingZerosOp op); LogicalResult printOperation(mlir::memref::LoadOp op); @@ -116,6 +114,7 @@ class VerilogEmitter { // Emit a Verilog type of the form `wire [width-1:0]` LogicalResult emitType(Type type); + LogicalResult emitType(Type type, raw_ostream &os); // Emit a Verilog array shape specifier of the form `[width]` LogicalResult emitArrayShapeSuffix(Type type); diff --git a/lib/Target/Metadata/MetadataEmitter.cpp b/lib/Target/Metadata/MetadataEmitter.cpp index 34bb628f0..8e1fd9fe6 100644 --- a/lib/Target/Metadata/MetadataEmitter.cpp +++ b/lib/Target/Metadata/MetadataEmitter.cpp @@ -165,25 +165,6 @@ FailureOr MetadataEmitter::typeAsJson(MemRefType &ty) { } FailureOr MetadataEmitter::emitOperation(FuncOp funcOp) { - auto result_types = funcOp.getFunctionType().getResults(); - if (result_types.size() != 1) { - emitError(funcOp.getLoc(), - "Only functions with a single return type are supported"); - return failure(); - } - auto output_type = result_types[0]; - auto status = - llvm::TypeSwitch>(output_type) - .Case( - [&](auto ty) { return typeAsJson(ty); }) - .Default([&](Type &) { return failure(); }); - - if (failed(status)) { - funcOp.emitOpError( - llvm::formatv("Failed to handle output type {0}", output_type)); - return failure(); - } - llvm::json::Array arguments; for (auto arg : funcOp.getArguments()) { Type type = arg.getType(); @@ -203,13 +184,35 @@ FailureOr MetadataEmitter::emitOperation(FuncOp funcOp) { }); } - llvm::json::Object function{ + auto resultTypes = funcOp.getFunctionType().getResults(); + llvm::json::Array resultsJson; + if (resultTypes.size() > 1) { + emitError(funcOp.getLoc(), + "Only functions with <=1 return types are supported"); + return failure(); + } + + if (resultTypes.size() == 1) { + auto outputType = resultTypes[0]; + auto status = + llvm::TypeSwitch>(outputType) + .Case( + [&](auto ty) { return typeAsJson(ty); }) + .Default([&](Type &) { return failure(); }); + + if (failed(status)) { + funcOp.emitOpError( + llvm::formatv("Failed to handle output type {0}", outputType)); + return failure(); + } + resultsJson.push_back(std::move(status.value())); + } + + return llvm::json::Object{ {"name", funcOp.getName()}, - {"return_type", std::move(status.value())}, + {"return_types", std::move(resultsJson)}, {"params", std::move(arguments)}, }; - - return std::move(function); } } // namespace heir diff --git a/lib/Target/Verilog/VerilogEmitter.cpp b/lib/Target/Verilog/VerilogEmitter.cpp index 3129028eb..1ccd509e8 100644 --- a/lib/Target/Verilog/VerilogEmitter.cpp +++ b/lib/Target/Verilog/VerilogEmitter.cpp @@ -94,10 +94,23 @@ void printRawDataFromAttr(DenseElementsAttr attr, raw_ostream &os) { } } -llvm::SmallString<128> variableLoadStr(StringRef memref, StringRef index, +llvm::SmallString<128> flattenIndexExpression( + const llvm::ArrayRef indices, + const llvm::ArrayRef sizes, int width) { + llvm::SmallString<128> accum = llvm::formatv("{0}", indices[0]); + for (int i = 1; i < indices.size(); ++i) { + accum = llvm::formatv("{0} + {1} * ({2})", indices[i], sizes[i], accum); + } + return indices.size() == 1 ? llvm::formatv("{0} * {1}", width, accum) + : llvm::formatv("{0} * ({1})", width, accum); +} + +llvm::SmallString<128> variableLoadStr(StringRef memref, + llvm::ArrayRef indices, + llvm::ArrayRef sizes, unsigned int width) { - return llvm::formatv("{0}[{1} + {2} * {3} : {2} * {3}]", memref, width - 1, - width, index); + llvm::SmallString<128> index = flattenIndexExpression(indices, sizes, width); + return llvm::formatv("{0}[{1} + {2} : {2}]", memref, width - 1, index); } struct CtlzValueStruct { @@ -170,8 +183,10 @@ LogicalResult VerilogEmitter::translate( .Case( [&](auto op) { return printOperation(op, moduleName); }) // Func ops. - .Case( - [&](auto op) { return printOperation(op); }) + .Case([&](auto op) { return printOperation(op); }) + // Return-like ops + .Case( + [&](auto op) { return printReturnLikeOp(op.getOperands()); }) // Arithmetic ops. .Case([&](auto op) { if (auto iAttr = dyn_cast(op.getValue().getType())) { @@ -265,28 +280,36 @@ LogicalResult VerilogEmitter::printFunctionLikeOp( */ os_ << "module " << verilogModuleName << "(\n"; os_.indent(); + llvm::SmallVector argsToPrint; for (auto arg : arguments) { // e.g., `input wire [31:0] arg0,` - os_ << "input "; - if (failed(emitType(arg.getType()))) { + std::string result; + llvm::raw_string_ostream ss(result); + ss << "input "; + if (failed(emitType(arg.getType(), ss))) { op->emitError() << "failed to emit type" << arg.getType(); return failure(); } - os_ << " " << getOrCreateName(arg) << ",\n"; + ss << " " << getOrCreateName(arg); + argsToPrint.push_back(ss.str()); } + os_ << llvm::join(argsToPrint.begin(), argsToPrint.end(), ",\n"); // output arg declaration - if (resultTypes.size() != 1) { + if (resultTypes.size() > 1) { emitError(op->getLoc(), - "Only functions with a single return type are supported"); + "Only functions with a <= 1 return types are supported"); return failure(); } - os_ << "output "; - if (failed(emitType(resultTypes.front()))) { - op->emitError() << "failed to emit type" << resultTypes.front(); - return failure(); + if (resultTypes.size() == 1) { + os_ << ",\n"; + os_ << "output "; + if (failed(emitType(resultTypes.front()))) { + op->emitError() << "failed to emit type" << resultTypes.front(); + return failure(); + } + os_ << " " << kOutputName; } - os_ << " " << kOutputName; // End of module header os_.unindent(); @@ -403,23 +426,16 @@ LogicalResult VerilogEmitter::printOperation( blocks->begin(), blocks->end()); } -LogicalResult VerilogEmitter::printReturnLikeOp(Value returnValue) { +LogicalResult VerilogEmitter::printReturnLikeOp(ValueRange returnValues) { // Return is an assignment to the output wire // e.g., assign out = x1200; - os_ << "assign " << kOutputName << " = " << getName(returnValue) << ";\n"; + if (returnValues.empty()) { + return success(); + } + os_ << "assign " << kOutputName << " = " << getName(returnValues[0]) << ";\n"; return success(); } -LogicalResult VerilogEmitter::printOperation(func::ReturnOp op) { - // Only support one return value. - return printReturnLikeOp(op.getOperands()[0]); -} - -LogicalResult VerilogEmitter::printOperation(secret::YieldOp op) { - // Only support one return value. - return printReturnLikeOp(op.getOperands()[0]); -} - LogicalResult VerilogEmitter::printOperation(func::CallOp op) { // e.g., submodule submod_call(xInput0, xInput1, xOutput); std::string opName = (getOrCreateName(op.getResult(0)) + "_call").str(); @@ -620,12 +636,14 @@ LogicalResult VerilogEmitter::printOperation(affine::AffineLoadOp op) { os_ << memrefStr << "[" << flattenedBitIndex + width - 1 << " : " << flattenedBitIndex << "];\n"; } else { - if (op.getMemRefType().getRank() > 1) { - // TODO(b/284323495): Handle multi-dim variable access. - return failure(); - } emitAssignPrefix(op.getResult()); - os_ << variableLoadStr(memrefStr, getOrCreateName(op.getIndices()[0]), + + llvm::SmallVector indices; + for (auto index : op.getIndices()) { + indices.push_back(getOrCreateName(index)); + } + + os_ << variableLoadStr(memrefStr, indices, op.getMemRefType().getShape(), width) << ";\n"; } @@ -640,17 +658,16 @@ LogicalResult VerilogEmitter::printOperation(memref::LoadOp op) { return failure(); } - auto memrefStr = getOrCreateName(op.getMemref()); - auto indexStr = getOrCreateName(op.getIndices()[0]); - auto width = iType.getWidth(); + emitAssignPrefix(op.getResult()); - if (op.getMemRefType().getRank() > 1) { - // TODO(b/284323495): Handle multi-dim variable access. - return failure(); + llvm::SmallVector indices; + for (auto index : op.getIndices()) { + indices.push_back(getOrCreateName(index)); } - emitAssignPrefix(op.getResult()); - os_ << variableLoadStr(memrefStr, indexStr, width) << ";\n"; + os_ << variableLoadStr(getOrCreateName(op.getMemref()), indices, + op.getMemRefType().getShape(), iType.getWidth()) + << ";\n"; return success(); } @@ -740,15 +757,19 @@ LogicalResult VerilogEmitter::printOperation( } LogicalResult VerilogEmitter::emitType(Type type) { + return emitType(type, os_); +} + +LogicalResult VerilogEmitter::emitType(Type type, raw_ostream &os) { if (auto iType = dyn_cast(type)) { int32_t width = iType.getWidth(); - return (os_ << wireDeclaration(iType, width)), success(); + return (os << wireDeclaration(iType, width)), success(); } if (auto memRefType = dyn_cast(type)) { auto elementType = memRefType.getElementType(); if (auto iType = dyn_cast(elementType)) { int32_t flattenedWidth = memRefType.getNumElements() * iType.getWidth(); - return (os_ << wireDeclaration(iType, flattenedWidth)), success(); + return (os << wireDeclaration(iType, flattenedWidth)), success(); } } return failure(); diff --git a/tests/verilog/emit_metadata.mlir b/tests/verilog/emit_metadata.mlir index 1ef7d3c5f..d83995947 100644 --- a/tests/verilog/emit_metadata.mlir +++ b/tests/verilog/emit_metadata.mlir @@ -1,5 +1,4 @@ -// RUN: heir-translate --allow-unregistered-dialect --emit-metadata %s > %t -// RUN: FileCheck %s < %t +// RUN: heir-translate --allow-unregistered-dialect --emit-metadata %s | FileCheck %s module { func.func @main(%arg0: memref<1x80xi8>) -> memref<1x3x2x1xi8> { @@ -21,6 +20,13 @@ module { } return %alloc_0 : memref<1x3x2x1xi8> } + + func.func @main2(%arg0: memref<80xi8>) { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : i8 + memref.store %c8, %arg0[%c0] : memref<80xi8> + return + } } // CHECK: { @@ -46,22 +52,47 @@ module { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK-NEXT: ], -// CHECK-NEXT: "return_type": { -// CHECK-NEXT: "memref": { -// CHECK-NEXT: "element_type": { -// CHECK-NEXT: "integer": { -// CHECK-NEXT: "is_signed": false, -// CHECK-NEXT: "width": 8 +// CHECK-NEXT: "return_types": [ +// CHECK-NEXT: { +// CHECK-NEXT: "memref": { +// CHECK-NEXT: "element_type": { +// CHECK-NEXT: "integer": { +// CHECK-NEXT: "is_signed": false, +// CHECK-NEXT: "width": 8 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "shape": [ +// CHECK-NEXT: 1, +// CHECK-NEXT: 3, +// CHECK-NEXT: 2, +// CHECK-NEXT: 1 +// CHECK-NEXT: ] +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: }, + +// CHECK-NEXT: { +// CHECK-NEXT: "name": "main2", +// CHECK-NEXT: "params": [ +// CHECK-NEXT: { +// CHECK-NEXT: "index": 0, +// CHECK-NEXT: "type": { +// CHECK-NEXT: "memref": { +// CHECK-NEXT: "element_type": { +// CHECK-NEXT: "integer": { +// CHECK-NEXT: "is_signed": false, +// CHECK-NEXT: "width": 8 +// CHECK-NEXT: } +// CHECK-NEXT: }, +// CHECK-NEXT: "shape": [ +// CHECK-NEXT: 80 +// CHECK-NEXT: ] // CHECK-NEXT: } -// CHECK-NEXT: }, -// CHECK-NEXT: "shape": [ -// CHECK-NEXT: 1, -// CHECK-NEXT: 3, -// CHECK-NEXT: 2, -// CHECK-NEXT: 1 -// CHECK-NEXT: ] +// CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: } +// CHECK-NEXT: ], +// CHECK-NEXT: "return_types": [] // CHECK-NEXT: } // CHECK-NEXT: ] // CHECK-NEXT: } diff --git a/tests/verilog/emit_verilog_memref_load.mlir b/tests/verilog/emit_verilog_memref_load.mlir index 2ead90ba8..820ae0065 100644 --- a/tests/verilog/emit_verilog_memref_load.mlir +++ b/tests/verilog/emit_verilog_memref_load.mlir @@ -1,8 +1,7 @@ // Test emit-verilog supporting lowering global memref constants and lookups to // Verilog arrays and accessors. -// RUN: heir-translate %s --emit-verilog > %t -// RUN: FileCheck %s < %t +// RUN: heir-translate %s --emit-verilog | FileCheck %s module { memref.global "private" constant @__constant_513xi16 : memref<513xi16> = dense<"0x00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000001000100010001000100010001000200020002000200020003000300030004000400050005000600060007000800080009000A000B000C000E000F00110012001400160018001A001D002000230026002A002E00330038003D00430049005000580061006A0074007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F007F00"> @@ -25,3 +24,26 @@ module { // CHECK-NEXT: assign [[V3]] = [[V2]][15 + 16 * [[ARG]] : 16 * [[ARG]]]; // CHECK-NEXT: assign [[OUT]] = [[V3]]; // CHECK: endmodule + + +module { + memref.global "private" constant @__data : memref<2x2xi8> = dense<"0x4D415448"> + func.func @main(%arg0: i8, %arg1: i8) { + %1 = memref.get_global @__data : memref<2x2xi8> + %i0 = arith.index_cast %arg0 : i8 to index + %i1 = arith.index_cast %arg1 : i8 to index + %4 = memref.load %1[%i0, %i1] : memref<2x2xi8> + return + } +} + +// CHECK: module main +// CHECK-NEXT: input wire signed [7:0] [[ARG0:.*]], +// CHECK-NEXT: input wire signed [7:0] [[ARG1:.*]] +// CHECK-NEXT: ); +// CHECK-NEXT: wire signed [31:0] [[DATA:.*]]; +// CHECK-NEXT: wire signed [7:0] [[LOAD_DEST:.*]]; +// CHECK-NEXT: assign [[DATA]] = 32'h{{[A-Z0-9]+}}; +// CHECK-EMPTY: +// CHECK-NEXT: assign [[LOAD_DEST]] = [[DATA]][7 + 8 * ([[ARG1]] + 2 * ([[ARG0]])) : 8 * ([[ARG1]] + 2 * ([[ARG0]]))]; +// CHECK: endmodule