From 0f0c5fa4d0ea181797147876e489ce24738163ab Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Thu, 14 Sep 2023 22:31:49 +0000 Subject: [PATCH] feat: add support for circt's comb and hw dialect Signed-off-by: Asra Ali make some more fixes Signed-off-by: Asra Ali remove spurious includes Signed-off-by: Asra Ali --- .pre-commit-config.yaml | 4 + tests/comb.mlir | 8 + third_party/circt/BUILD | 4 + .../circt/include/circt/Dialect/Comb/BUILD | 149 + .../include/circt/Dialect/Comb/CMakeLists.txt | 14 + .../circt/include/circt/Dialect/Comb/Comb.td | 44 + .../include/circt/Dialect/Comb/CombDialect.h | 25 + .../include/circt/Dialect/Comb/CombOps.h | 63 + .../include/circt/Dialect/Comb/CombPasses.h | 33 + .../include/circt/Dialect/Comb/CombVisitors.h | 113 + .../circt/Dialect/Comb/Combinational.td | 318 ++ .../include/circt/Dialect/Comb/Passes.td | 25 + .../circt/include/circt/Dialect/HW/BUILD | 245 ++ .../include/circt/Dialect/HW/CMakeLists.txt | 44 + .../circt/Dialect/HW/ConversionPatterns.h | 36 + .../circt/Dialect/HW/CustomDirectiveImpl.h | 70 + .../circt/include/circt/Dialect/HW/HW.td | 30 + .../include/circt/Dialect/HW/HWAggregates.td | 306 ++ .../include/circt/Dialect/HW/HWAttributes.h | 45 + .../include/circt/Dialect/HW/HWAttributes.td | 309 ++ .../circt/Dialect/HW/HWAttributesNaming.td | 70 + .../include/circt/Dialect/HW/HWDialect.h | 26 + .../include/circt/Dialect/HW/HWDialect.td | 49 + .../circt/Dialect/HW/HWInstanceGraph.h | 56 + .../include/circt/Dialect/HW/HWMiscOps.td | 192 + .../include/circt/Dialect/HW/HWModuleGraph.h | 186 + .../include/circt/Dialect/HW/HWOpInterfaces.h | 297 ++ .../circt/Dialect/HW/HWOpInterfaces.td | 528 +++ .../circt/include/circt/Dialect/HW/HWOps.h | 140 + .../circt/include/circt/Dialect/HW/HWPasses.h | 38 + .../include/circt/Dialect/HW/HWReductions.h | 29 + .../include/circt/Dialect/HW/HWStructure.td | 720 ++++ .../include/circt/Dialect/HW/HWSymCache.h | 118 + .../include/circt/Dialect/HW/HWTypeDecls.td | 63 + .../circt/Dialect/HW/HWTypeInterfaces.h | 44 + .../circt/Dialect/HW/HWTypeInterfaces.td | 81 + .../circt/include/circt/Dialect/HW/HWTypes.h | 165 + .../circt/include/circt/Dialect/HW/HWTypes.td | 179 + .../include/circt/Dialect/HW/HWTypesImpl.td | 280 ++ .../include/circt/Dialect/HW/HWVisitors.h | 140 + .../circt/Dialect/HW/InnerSymbolNamespace.h | 51 + .../circt/Dialect/HW/InnerSymbolTable.h | 264 ++ .../circt/Dialect/HW/InstanceImplementation.h | 104 + .../circt/Dialect/HW/ModuleImplementation.h | 56 + .../circt/include/circt/Dialect/HW/Passes.td | 61 + .../include/circt/Dialect/HW/PortConverter.h | 182 + .../circt/include/circt/Support/APInt.h | 30 + third_party/circt/include/circt/Support/BUILD | 52 + .../include/circt/Support/BackedgeBuilder.h | 100 + .../include/circt/Support/BuilderUtils.h | 45 + .../include/circt/Support/CMakeLists.txt | 1 + .../circt/Support/ConversionPatterns.h | 36 + .../circt/Support/CustomDirectiveImpl.h | 91 + .../circt/include/circt/Support/FieldRef.h | 116 + .../circt/include/circt/Support/FoldUtils.h | 38 + .../include/circt/Support/InstanceGraph.h | 453 +++ .../circt/Support/InstanceGraphInterface.h | 23 + .../circt/Support/InstanceGraphInterface.td | 74 + .../circt/include/circt/Support/JSON.h | 30 + .../circt/include/circt/Support/LLVM.h | 293 ++ .../include/circt/Support/LoweringOptions.h | 169 + .../circt/Support/LoweringOptionsParser.h | 57 + .../circt/include/circt/Support/Namespace.h | 134 + .../include/circt/Support/ParsingUtils.h | 57 + .../circt/include/circt/Support/Passes.h | 64 + .../circt/include/circt/Support/Path.h | 30 + .../include/circt/Support/PrettyPrinter.h | 325 ++ .../circt/Support/PrettyPrinterHelpers.h | 374 ++ .../circt/include/circt/Support/SymCache.h | 132 + .../circt/include/circt/Support/ValueMapper.h | 63 + .../circt/include/circt/Support/Version.h | 16 + third_party/circt/lib/Dialect/Comb/BUILD | 23 + .../circt/lib/Dialect/Comb/CMakeLists.txt | 26 + .../circt/lib/Dialect/Comb/CombAnalysis.cpp | 87 + .../circt/lib/Dialect/Comb/CombDialect.cpp | 62 + .../circt/lib/Dialect/Comb/CombFolds.cpp | 3026 +++++++++++++++ .../circt/lib/Dialect/Comb/CombOps.cpp | 302 ++ .../Dialect/Comb/Transforms/CMakeLists.txt | 15 + .../lib/Dialect/Comb/Transforms/LowerComb.cpp | 88 + .../lib/Dialect/Comb/Transforms/PassDetails.h | 27 + third_party/circt/lib/Dialect/HW/BUILD | 35 + .../circt/lib/Dialect/HW/CMakeLists.txt | 52 + .../lib/Dialect/HW/ConversionPatterns.cpp | 103 + .../lib/Dialect/HW/CustomDirectiveImpl.cpp | 136 + .../circt/lib/Dialect/HW/HWAttributes.cpp | 1032 +++++ .../circt/lib/Dialect/HW/HWDialect.cpp | 116 + .../circt/lib/Dialect/HW/HWInstanceGraph.cpp | 33 + .../lib/Dialect/HW/HWModuleOpInterface.cpp | 88 + .../circt/lib/Dialect/HW/HWOpInterfaces.cpp | 99 + third_party/circt/lib/Dialect/HW/HWOps.cpp | 3376 +++++++++++++++++ .../circt/lib/Dialect/HW/HWReductions.cpp | 157 + .../circt/lib/Dialect/HW/HWTypeInterfaces.cpp | 74 + third_party/circt/lib/Dialect/HW/HWTypes.cpp | 974 +++++ .../circt/lib/Dialect/HW/InnerSymbolTable.cpp | 251 ++ .../lib/Dialect/HW/InstanceImplementation.cpp | 347 ++ .../lib/Dialect/HW/ModuleImplementation.cpp | 328 ++ .../circt/lib/Dialect/HW/PortConverter.cpp | 233 ++ .../lib/Dialect/HW/Transforms/CMakeLists.txt | 20 + .../lib/Dialect/HW/Transforms/FlattenIO.cpp | 429 +++ .../HW/Transforms/HWPrintInstanceGraph.cpp | 36 + .../Dialect/HW/Transforms/HWSpecialize.cpp | 422 +++ .../lib/Dialect/HW/Transforms/PassDetails.h | 31 + .../HW/Transforms/PrintHWModuleGraph.cpp | 42 + .../HW/Transforms/VerifyInnerRefNamespace.cpp | 44 + third_party/circt/lib/Support/APInt.cpp | 27 + third_party/circt/lib/Support/BUILD | 37 + .../circt/lib/Support/BackedgeBuilder.cpp | 71 + third_party/circt/lib/Support/CMakeLists.txt | 52 + .../circt/lib/Support/ConversionPatterns.cpp | 90 + .../circt/lib/Support/CustomDirectiveImpl.cpp | 130 + third_party/circt/lib/Support/FieldRef.cpp | 23 + .../circt/lib/Support/InstanceGraph.cpp | 314 ++ third_party/circt/lib/Support/JSON.cpp | 123 + .../circt/lib/Support/LoweringOptions.cpp | 187 + .../circt/lib/Support/ParsingUtils.cpp | 50 + third_party/circt/lib/Support/Passes.cpp | 21 + third_party/circt/lib/Support/Path.cpp | 32 + .../circt/lib/Support/PrettyPrinter.cpp | 305 ++ .../lib/Support/PrettyPrinterHelpers.cpp | 50 + third_party/circt/lib/Support/SymCache.cpp | 29 + third_party/circt/lib/Support/ValueMapper.cpp | 62 + third_party/circt/lib/Support/Version.cpp.in | 6 + tools/BUILD | 7 +- tools/heir-opt.cpp | 3 + 124 files changed, 22539 insertions(+), 1 deletion(-) create mode 100644 tests/comb.mlir create mode 100644 third_party/circt/BUILD create mode 100644 third_party/circt/include/circt/Dialect/Comb/BUILD create mode 100644 third_party/circt/include/circt/Dialect/Comb/CMakeLists.txt create mode 100644 third_party/circt/include/circt/Dialect/Comb/Comb.td create mode 100644 third_party/circt/include/circt/Dialect/Comb/CombDialect.h create mode 100644 third_party/circt/include/circt/Dialect/Comb/CombOps.h create mode 100644 third_party/circt/include/circt/Dialect/Comb/CombPasses.h create mode 100644 third_party/circt/include/circt/Dialect/Comb/CombVisitors.h create mode 100644 third_party/circt/include/circt/Dialect/Comb/Combinational.td create mode 100644 third_party/circt/include/circt/Dialect/Comb/Passes.td create mode 100644 third_party/circt/include/circt/Dialect/HW/BUILD create mode 100644 third_party/circt/include/circt/Dialect/HW/CMakeLists.txt create mode 100644 third_party/circt/include/circt/Dialect/HW/ConversionPatterns.h create mode 100644 third_party/circt/include/circt/Dialect/HW/CustomDirectiveImpl.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HW.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWAggregates.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWAttributes.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWAttributes.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWAttributesNaming.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWDialect.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWDialect.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWInstanceGraph.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWMiscOps.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWModuleGraph.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWOps.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWPasses.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWReductions.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWStructure.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWSymCache.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWTypeDecls.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWTypes.h create mode 100644 third_party/circt/include/circt/Dialect/HW/HWTypes.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWTypesImpl.td create mode 100644 third_party/circt/include/circt/Dialect/HW/HWVisitors.h create mode 100644 third_party/circt/include/circt/Dialect/HW/InnerSymbolNamespace.h create mode 100644 third_party/circt/include/circt/Dialect/HW/InnerSymbolTable.h create mode 100644 third_party/circt/include/circt/Dialect/HW/InstanceImplementation.h create mode 100644 third_party/circt/include/circt/Dialect/HW/ModuleImplementation.h create mode 100644 third_party/circt/include/circt/Dialect/HW/Passes.td create mode 100644 third_party/circt/include/circt/Dialect/HW/PortConverter.h create mode 100644 third_party/circt/include/circt/Support/APInt.h create mode 100644 third_party/circt/include/circt/Support/BUILD create mode 100644 third_party/circt/include/circt/Support/BackedgeBuilder.h create mode 100644 third_party/circt/include/circt/Support/BuilderUtils.h create mode 100644 third_party/circt/include/circt/Support/CMakeLists.txt create mode 100644 third_party/circt/include/circt/Support/ConversionPatterns.h create mode 100644 third_party/circt/include/circt/Support/CustomDirectiveImpl.h create mode 100644 third_party/circt/include/circt/Support/FieldRef.h create mode 100644 third_party/circt/include/circt/Support/FoldUtils.h create mode 100644 third_party/circt/include/circt/Support/InstanceGraph.h create mode 100644 third_party/circt/include/circt/Support/InstanceGraphInterface.h create mode 100644 third_party/circt/include/circt/Support/InstanceGraphInterface.td create mode 100644 third_party/circt/include/circt/Support/JSON.h create mode 100644 third_party/circt/include/circt/Support/LLVM.h create mode 100644 third_party/circt/include/circt/Support/LoweringOptions.h create mode 100644 third_party/circt/include/circt/Support/LoweringOptionsParser.h create mode 100644 third_party/circt/include/circt/Support/Namespace.h create mode 100644 third_party/circt/include/circt/Support/ParsingUtils.h create mode 100644 third_party/circt/include/circt/Support/Passes.h create mode 100644 third_party/circt/include/circt/Support/Path.h create mode 100644 third_party/circt/include/circt/Support/PrettyPrinter.h create mode 100644 third_party/circt/include/circt/Support/PrettyPrinterHelpers.h create mode 100644 third_party/circt/include/circt/Support/SymCache.h create mode 100644 third_party/circt/include/circt/Support/ValueMapper.h create mode 100644 third_party/circt/include/circt/Support/Version.h create mode 100644 third_party/circt/lib/Dialect/Comb/BUILD create mode 100644 third_party/circt/lib/Dialect/Comb/CMakeLists.txt create mode 100644 third_party/circt/lib/Dialect/Comb/CombAnalysis.cpp create mode 100644 third_party/circt/lib/Dialect/Comb/CombDialect.cpp create mode 100644 third_party/circt/lib/Dialect/Comb/CombFolds.cpp create mode 100644 third_party/circt/lib/Dialect/Comb/CombOps.cpp create mode 100644 third_party/circt/lib/Dialect/Comb/Transforms/CMakeLists.txt create mode 100644 third_party/circt/lib/Dialect/Comb/Transforms/LowerComb.cpp create mode 100644 third_party/circt/lib/Dialect/Comb/Transforms/PassDetails.h create mode 100644 third_party/circt/lib/Dialect/HW/BUILD create mode 100644 third_party/circt/lib/Dialect/HW/CMakeLists.txt create mode 100644 third_party/circt/lib/Dialect/HW/ConversionPatterns.cpp create mode 100644 third_party/circt/lib/Dialect/HW/CustomDirectiveImpl.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWAttributes.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWDialect.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWInstanceGraph.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWModuleOpInterface.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWOpInterfaces.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWOps.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWReductions.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWTypeInterfaces.cpp create mode 100644 third_party/circt/lib/Dialect/HW/HWTypes.cpp create mode 100644 third_party/circt/lib/Dialect/HW/InnerSymbolTable.cpp create mode 100644 third_party/circt/lib/Dialect/HW/InstanceImplementation.cpp create mode 100644 third_party/circt/lib/Dialect/HW/ModuleImplementation.cpp create mode 100644 third_party/circt/lib/Dialect/HW/PortConverter.cpp create mode 100644 third_party/circt/lib/Dialect/HW/Transforms/CMakeLists.txt create mode 100644 third_party/circt/lib/Dialect/HW/Transforms/FlattenIO.cpp create mode 100644 third_party/circt/lib/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp create mode 100644 third_party/circt/lib/Dialect/HW/Transforms/HWSpecialize.cpp create mode 100644 third_party/circt/lib/Dialect/HW/Transforms/PassDetails.h create mode 100644 third_party/circt/lib/Dialect/HW/Transforms/PrintHWModuleGraph.cpp create mode 100644 third_party/circt/lib/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp create mode 100644 third_party/circt/lib/Support/APInt.cpp create mode 100644 third_party/circt/lib/Support/BUILD create mode 100644 third_party/circt/lib/Support/BackedgeBuilder.cpp create mode 100644 third_party/circt/lib/Support/CMakeLists.txt create mode 100644 third_party/circt/lib/Support/ConversionPatterns.cpp create mode 100644 third_party/circt/lib/Support/CustomDirectiveImpl.cpp create mode 100644 third_party/circt/lib/Support/FieldRef.cpp create mode 100644 third_party/circt/lib/Support/InstanceGraph.cpp create mode 100644 third_party/circt/lib/Support/JSON.cpp create mode 100644 third_party/circt/lib/Support/LoweringOptions.cpp create mode 100644 third_party/circt/lib/Support/ParsingUtils.cpp create mode 100644 third_party/circt/lib/Support/Passes.cpp create mode 100644 third_party/circt/lib/Support/Path.cpp create mode 100644 third_party/circt/lib/Support/PrettyPrinter.cpp create mode 100644 third_party/circt/lib/Support/PrettyPrinterHelpers.cpp create mode 100644 third_party/circt/lib/Support/SymCache.cpp create mode 100644 third_party/circt/lib/Support/ValueMapper.cpp create mode 100644 third_party/circt/lib/Support/Version.cpp.in diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45361e6a3..6bf5a9b86 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,10 @@ repos: rev: "v2.2.5" hooks: - id: codespell + exclude: > + (?x)^( + third_party/.* + )$ # Changes tabs to spaces - repo: https://github.com/Lucas-C/pre-commit-hooks diff --git a/tests/comb.mlir b/tests/comb.mlir new file mode 100644 index 000000000..d8b9684db --- /dev/null +++ b/tests/comb.mlir @@ -0,0 +1,8 @@ +// RUN: heir-opt %s -verify-diagnostics + +module { + func.func @comb(%a: i1, %b: i1) -> () { + %0 = comb.truth_table %a, %b -> [true, false, true, false] + return + } +} diff --git a/third_party/circt/BUILD b/third_party/circt/BUILD new file mode 100644 index 000000000..1fd12b91a --- /dev/null +++ b/third_party/circt/BUILD @@ -0,0 +1,4 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) diff --git a/third_party/circt/include/circt/Dialect/Comb/BUILD b/third_party/circt/include/circt/Dialect/Comb/BUILD new file mode 100644 index 000000000..7751a6a7c --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/BUILD @@ -0,0 +1,149 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + glob(["*.h"]), +) + +cc_library( + name = "headers", + hdrs = [ + "CombDialect.h", + "CombOps.h", + "CombPasses.h", + "CombVisitors.h", + ], + deps = [ + "@heir//third_party/circt/include/circt/Dialect/Comb:dialect_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/Comb:enum_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/Comb:ops_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/Comb:type_inc_gen", + "@heir//third_party/circt/lib/Support", + ], +) + +td_library( + name = "td_files", + srcs = [ + "Comb.td", + "Combinational.td", + ], + includes = ["/third_party/circt/include"], + deps = [ + "@heir//third_party/circt/include/circt/Dialect/HW:td_files", + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "dialect_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + "-dialect=comb", + ], + "CombDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=comb", + ], + "CombDialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":td_files", + ], +) + +gentbl_cc_library( + name = "ops_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-op-decls", + ], + "Comb.h.inc", + ), + ( + [ + "-gen-op-defs", + ], + "Comb.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) + +gentbl_cc_library( + name = "type_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + ], + "CombTypes.h.inc", + ), + ( + [ + "-gen-typedef-defs", + ], + "CombTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":td_files", + ], +) + +gentbl_cc_library( + name = "enum_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-enum-decls", + ], + "CombEnums.h.inc", + ), + ( + [ + "-gen-enum-defs", + ], + "CombEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) diff --git a/third_party/circt/include/circt/Dialect/Comb/CMakeLists.txt b/third_party/circt/include/circt/Dialect/Comb/CMakeLists.txt new file mode 100644 index 000000000..a7bb1a625 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/CMakeLists.txt @@ -0,0 +1,14 @@ +add_circt_dialect(Comb comb) +add_circt_dialect_doc(Comb comb) + +set(LLVM_TARGET_DEFINITIONS Comb.td) +mlir_tablegen(CombEnums.h.inc -gen-enum-decls) +mlir_tablegen(CombEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRCombEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(CIRCTCombTransformsIncGen) + +# Generate Pass documentation. +add_circt_doc(Passes CombPasses -gen-pass-doc) diff --git a/third_party/circt/include/circt/Dialect/Comb/Comb.td b/third_party/circt/include/circt/Dialect/Comb/Comb.td new file mode 100644 index 000000000..b69ce57d5 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/Comb.td @@ -0,0 +1,44 @@ +//===- Comb.td - Comb dialect definition --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top level file for the Comb dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef COMB_TD +#define COMB_TD + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" + +def CombDialect : Dialect { + let name = "comb"; + + let summary = "Types and operations for comb dialect"; + let description = [{ + This dialect defines the `comb` dialect, which is intended to be a generic + representation of combinational logic outside of a particular use-case. + }]; + let hasConstantMaterializer = 1; + let cppNamespace = "::circt::comb"; + + // This will be the default after next LLVM bump. + let usePropertiesForAttributes = 1; + +} + +// Base class for the operation in this dialect. +class CombOp traits = []> : + Op; + +include "circt/Dialect/HW/HWTypes.td" +include "circt/Dialect/Comb/Combinational.td" + +#endif // COMB_TD diff --git a/third_party/circt/include/circt/Dialect/Comb/CombDialect.h b/third_party/circt/include/circt/Dialect/Comb/CombDialect.h new file mode 100644 index 000000000..13fe75b7a --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/CombDialect.h @@ -0,0 +1,25 @@ +//===- CombDialect.h - Comb dialect declaration -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Combinational MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_COMB_COMBDIALECT_H +#define CIRCT_DIALECT_COMB_COMBDIALECT_H + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" + +// Pull in the Dialect definition. +#include "circt/Dialect/Comb/CombDialect.h.inc" + +// Pull in all enum type definitions and utility function declarations. +#include "circt/Dialect/Comb/CombEnums.h.inc" + +#endif // CIRCT_DIALECT_COMB_COMBDIALECT_H diff --git a/third_party/circt/include/circt/Dialect/Comb/CombOps.h b/third_party/circt/include/circt/Dialect/Comb/CombOps.h new file mode 100644 index 000000000..d4199c690 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/CombOps.h @@ -0,0 +1,63 @@ +//===- CombOps.h - Declare Comb dialect operations --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the operation classes for the Comb dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_COMB_COMBOPS_H +#define CIRCT_DIALECT_COMB_COMBOPS_H + +#include "circt/Dialect/Comb/CombDialect.h" +#include "circt/Dialect/HW/HWTypes.h" +#include "circt/Support/LLVM.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace llvm { +struct KnownBits; +} + +namespace mlir { +class PatternRewriter; +} + +#define GET_OP_CLASSES +#include "circt/Dialect/Comb/Comb.h.inc" + +namespace circt { +namespace comb { + +using llvm::KnownBits; + +/// Compute "known bits" information about the specified value - the set of bits +/// that are guaranteed to always be zero, and the set of bits that are +/// guaranteed to always be one (these must be exclusive!). A bit that exists +/// in neither set is unknown. +KnownBits computeKnownBits(Value value); + +/// Create a sign extension operation from a value of integer type to an equal +/// or larger integer type. +Value createOrFoldSExt(Location loc, Value value, Type destTy, + OpBuilder &builder); +Value createOrFoldSExt(Value value, Type destTy, ImplicitLocOpBuilder &builder); + +/// Create a ``Not'' gate on a value. +Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, + bool twoState = false); +Value createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, + bool twoState = false); + +} // namespace comb +} // namespace circt + +#endif // CIRCT_DIALECT_COMB_COMBOPS_H diff --git a/third_party/circt/include/circt/Dialect/Comb/CombPasses.h b/third_party/circt/include/circt/Dialect/Comb/CombPasses.h new file mode 100644 index 000000000..31585ffc4 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/CombPasses.h @@ -0,0 +1,33 @@ +//===- Passes.h - Comb pass entry points ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_COMB_COMBPASSES_H +#define CIRCT_DIALECT_COMB_COMBPASSES_H + +#include +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace circt { +namespace comb { + +/// Generate the code for registering passes. +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "circt/Dialect/Comb/Passes.h.inc" + +} // namespace comb +} // namespace circt + +#endif // CIRCT_DIALECT_COMB_COMBPASSES_H diff --git a/third_party/circt/include/circt/Dialect/Comb/CombVisitors.h b/third_party/circt/include/circt/Dialect/Comb/CombVisitors.h new file mode 100644 index 000000000..3cef0ee04 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/CombVisitors.h @@ -0,0 +1,113 @@ +//===- CombVisitors.h - Comb Dialect Visitors -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines visitors that make it easier to work with the Comb IR. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_COMB_COMBVISITORS_H +#define CIRCT_DIALECT_COMB_COMBVISITORS_H + +#include "circt/Dialect/Comb/CombOps.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace circt { +namespace comb { + +/// This helps visit Combinational nodes. +template +class CombinationalVisitor { + public: + ResultType dispatchCombinationalVisitor(Operation *op, ExtraArgs... args) { + auto *thisCast = static_cast(this); + return TypeSwitch(op) + .template Case< + // Arithmetic and Logical Binary Operations. + AddOp, SubOp, MulOp, DivUOp, DivSOp, ModUOp, ModSOp, ShlOp, ShrUOp, + ShrSOp, + // Bitwise operations + AndOp, OrOp, XorOp, + // Comparison operations + ICmpOp, + // Reduction Operators + ParityOp, + // Other operations. + ConcatOp, ReplicateOp, ExtractOp, MuxOp>( + [&](auto expr) -> ResultType { + return thisCast->visitComb(expr, args...); + }) + .Default([&](auto expr) -> ResultType { + return thisCast->visitInvalidComb(op, args...); + }); + } + + /// This callback is invoked on any non-expression operations. + ResultType visitInvalidComb(Operation *op, ExtraArgs... args) { + op->emitOpError("unknown combinational node"); + abort(); + } + + /// This callback is invoked on any combinational operations that are not + /// handled by the concrete visitor. + ResultType visitUnhandledComb(Operation *op, ExtraArgs... args) { + return ResultType(); + } + + /// This fallback is invoked on any binary node that isn't explicitly handled. + /// The default implementation delegates to the 'unhandled' fallback. + ResultType visitBinaryComb(Operation *op, ExtraArgs... args) { + return static_cast(this)->visitUnhandledComb(op, args...); + } + + ResultType visitUnaryComb(Operation *op, ExtraArgs... args) { + return static_cast(this)->visitUnhandledComb(op, args...); + } + + ResultType visitVariadicComb(Operation *op, ExtraArgs... args) { + return static_cast(this)->visitUnhandledComb(op, args...); + } + +#define HANDLE(OPTYPE, OPKIND) \ + ResultType visitComb(OPTYPE op, ExtraArgs... args) { \ + return static_cast(this)->visit##OPKIND##Comb(op, \ + args...); \ + } + + // Arithmetic and Logical Binary Operations. + HANDLE(AddOp, Binary); + HANDLE(SubOp, Binary); + HANDLE(MulOp, Binary); + HANDLE(DivUOp, Binary); + HANDLE(DivSOp, Binary); + HANDLE(ModUOp, Binary); + HANDLE(ModSOp, Binary); + HANDLE(ShlOp, Binary); + HANDLE(ShrUOp, Binary); + HANDLE(ShrSOp, Binary); + + HANDLE(AndOp, Variadic); + HANDLE(OrOp, Variadic); + HANDLE(XorOp, Variadic); + + HANDLE(ParityOp, Unary); + + HANDLE(ICmpOp, Binary); + + // Other operations. + HANDLE(ConcatOp, Unhandled); + HANDLE(ReplicateOp, Unhandled); + HANDLE(ExtractOp, Unhandled); + HANDLE(MuxOp, Unhandled); +#undef HANDLE +}; + +} // namespace comb +} // namespace circt + +#endif // CIRCT_DIALECT_COMB_COMBVISITORS_H diff --git a/third_party/circt/include/circt/Dialect/Comb/Combinational.td b/third_party/circt/include/circt/Dialect/Comb/Combinational.td new file mode 100644 index 000000000..8e53c843a --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/Combinational.td @@ -0,0 +1,318 @@ +//===- Combinational.td - combinational logic ops ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the MLIR ops for combinational logic. +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Arithmetic and Logical Operations +//===----------------------------------------------------------------------===// + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/EnumAttr.td" + +// Base class for binary operators. +class BinOp traits = []> : + CombOp { + let arguments = (ins HWIntegerType:$lhs, HWIntegerType:$rhs, UnitAttr:$twoState); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = + "$lhs `,` $rhs (`bin` $twoState^)? attr-dict `:` functional-type($args, $results)"; +} + +// Binary operator with uniform input/result types. +class UTBinOp traits = []> : + BinOp { + let assemblyFormat = "(`bin` $twoState^)? $lhs `,` $rhs attr-dict `:` qualified(type($result))"; +} + +// Base class for variadic operators. +class VariadicOp traits = []> : + CombOp { + let arguments = (ins Variadic:$inputs, UnitAttr:$twoState); + let results = (outs HWIntegerType:$result); +} + +class UTVariadicOp traits = []> : + VariadicOp { + + let hasCanonicalizeMethod = true; + let hasFolder = true; + let hasVerifier = 1; + + let assemblyFormat = "(`bin` $twoState^)? $inputs attr-dict `:` qualified(type($result))"; + + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs, CArg<"bool", "false">:$twoState), [{ + return build($_builder, $_state, lhs.getType(), + ValueRange{lhs, rhs}, twoState); + }]> + ]; +} + +// Arithmetic and Logical Operations. +def AddOp : UTVariadicOp<"add", [Commutative]>; +def MulOp : UTVariadicOp<"mul", [Commutative]>; +let hasFolder = true in { + def DivUOp : UTBinOp<"divu">; + def DivSOp : UTBinOp<"divs">; + def ModUOp : UTBinOp<"modu">; + def ModSOp : UTBinOp<"mods">; + let hasCanonicalizeMethod = true in { + def ShlOp : UTBinOp<"shl">; + def ShrUOp : UTBinOp<"shru">; + def ShrSOp : UTBinOp<"shrs">; + def SubOp : UTBinOp<"sub">; + } +} + +def AndOp : UTVariadicOp<"and", [Commutative]>; +def OrOp : UTVariadicOp<"or", [Commutative]>; +def XorOp : UTVariadicOp<"xor", [Commutative]> { + let extraClassDeclaration = [{ + /// Return true if this is a two operand xor with an all ones constant as + /// its RHS operand. + bool isBinaryNot(); + }]; +} + +//===----------------------------------------------------------------------===// +// Comparisons +//===----------------------------------------------------------------------===// + +def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; +def ICmpPredicateNE : I64EnumAttrCase<"ne", 1>; +def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>; +def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>; +def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>; +def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>; +def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>; +def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>; +def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>; +def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>; +// SV case equality +def ICmpPredicateCEQ : I64EnumAttrCase<"ceq", 10>; +def ICmpPredicateCNE : I64EnumAttrCase<"cne", 11>; +// SV wild card equality +def ICmpPredicateWEQ : I64EnumAttrCase<"weq", 12>; +def ICmpPredicateWNE : I64EnumAttrCase<"wne", 13>; +let cppNamespace = "circt::comb" in +def ICmpPredicate : I64EnumAttr< + "ICmpPredicate", + "hw.icmp comparison predicate", + [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE, + ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE, + ICmpPredicateUGT, ICmpPredicateUGE, ICmpPredicateCEQ, ICmpPredicateCNE, + ICmpPredicateWEQ, ICmpPredicateWNE]>; + +def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { + let summary = "Compare two integer values"; + let description = [{ + This operation compares two integers using a predicate. If the predicate is + true, returns 1, otherwise returns 0. This operation always returns a one + bit wide result. + + ``` + %r = comb.icmp eq %a, %b : i4 + ``` + }]; + + let arguments = (ins ICmpPredicate:$predicate, + HWIntegerType:$lhs, HWIntegerType:$rhs, UnitAttr:$twoState); + let results = (outs I1:$result); + + let assemblyFormat = "(`bin` $twoState^)? $predicate $lhs `,` $rhs attr-dict `:` qualified(type($lhs))"; + + let hasFolder = true; + let hasCanonicalizeMethod = true; + + let extraClassDeclaration = [{ + /// Returns the flipped predicate, reversing the LHS and RHS operands. The + /// lhs and rhs operands should be flipped to match the new predicate. + static ICmpPredicate getFlippedPredicate(ICmpPredicate predicate); + + /// Returns true if the predicate is signed. + static bool isPredicateSigned(ICmpPredicate predicate); + + /// Returns the predicate for a logically negated comparison, e.g. mapping + /// EQ => NE and SLE => SGT. + static ICmpPredicate getNegatedPredicate(ICmpPredicate predicate); + + /// Return true if this is an equality test with -1, which is a "reduction + /// and" operation in Verilog. + bool isEqualAllOnes(); + + /// Return true if this is a not equal test with 0, which is a "reduction + /// or" operation in Verilog. + bool isNotEqualZero(); + }]; +} + +//===----------------------------------------------------------------------===// +// Unary Operations +//===----------------------------------------------------------------------===// + +// Base class for unary reduction operations that produce an i1. +class UnaryI1ReductionOp traits = []> : + CombOp { + let arguments = (ins HWIntegerType:$input, UnitAttr:$twoState); + let results = (outs I1:$result); + let hasFolder = 1; + + let assemblyFormat = "(`bin` $twoState^)? $input attr-dict `:` qualified(type($input))"; +} + +def ParityOp : UnaryI1ReductionOp<"parity">; + +//===----------------------------------------------------------------------===// +// Integer width modifying operations. +//===----------------------------------------------------------------------===// + +// Extract a range of bits from the specified input. +def ExtractOp : CombOp<"extract", [Pure]> { + let summary = "Extract a range of bits into a smaller value, lowBit " + "specifies the lowest bit included."; + + let arguments = (ins HWIntegerType:$input, I32Attr:$lowBit); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = + "$input `from` $lowBit attr-dict `:` functional-type($input, $result)"; + + let hasFolder = true; + let hasVerifier = 1; + let hasCanonicalizeMethod = true; + + let builders = [ + OpBuilder<(ins "Value":$lhs, "int32_t":$lowBit, "int32_t":$bitWidth), [{ + auto resultType = $_builder.getIntegerType(bitWidth); + return build($_builder, $_state, resultType, lhs, lowBit); + }]> + ]; +} + +//===----------------------------------------------------------------------===// +// Other Operations +//===----------------------------------------------------------------------===// +def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { + let summary = "Concatenate a variadic list of operands together."; + let description = [{ + See the comb rationale document for details on operand ordering. + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs HWIntegerType:$result); + + let hasFolder = true; + let hasCanonicalizeMethod = true; + let hasVerifier = 1; + + let assemblyFormat = "$inputs attr-dict `:` qualified(type($inputs))"; + + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ + return build($_builder, $_state, ValueRange{lhs, rhs}); + }]>, + OpBuilder<(ins "Value":$hd, "ValueRange":$tl)>, + ]; + + let extraClassDeclaration = [{ + /// Infer the return types of this operation. + static LogicalResult inferReturnTypes(MLIRContext *context, + std::optional loc, + ValueRange operands, + DictionaryAttr attrs, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + SmallVectorImpl &results); + }]; +} + +def ReplicateOp : CombOp<"replicate", [Pure]> { + let summary = "Concatenate the operand a constant number of times"; + + let arguments = (ins HWIntegerType:$input); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = + "$input attr-dict `:` functional-type($input, $result)"; + + let hasFolder = true; + let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "Value":$operand, "int32_t":$multiple), [{ + auto bitWidth = operand.getType().cast().getWidth(); + auto resultType = $_builder.getIntegerType(bitWidth*multiple); + return build($_builder, $_state, resultType, operand); + }]> + ]; + + let extraClassDeclaration = [{ + /// Returns the number of times the operand is replicated. + size_t getMultiple() { + auto opWidth = getInput().getType().cast().getWidth(); + return getType().cast().getWidth()/opWidth; + } + }]; +} + +// Select one of two values based on a condition. +def MuxOp : CombOp<"mux", + [Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>]> { + let summary = "Return one or the other operand depending on a selector bit"; + let description = [{ + ``` + %0 = mux %pred, %tvalue, %fvalue : i4 + ``` + }]; + + let arguments = (ins I1:$cond, AnyType:$trueValue, + AnyType:$falseValue, UnitAttr:$twoState); + let results = (outs AnyType:$result); + + let assemblyFormat = + "(`bin` $twoState^)? $cond `,` $trueValue `,` $falseValue attr-dict `:` qualified(type($result))"; + + let hasFolder = true; + let hasCanonicalizer = true; +} + +def TruthTableOp : CombOp<"truth_table", [Pure]> { + let summary = "Return a true/false based on a lookup table"; + let description = [{ + ``` + %a = ... : i1 + %b = ... : i1 + %0 = comb.truth_table %a, %b -> [false, true, true, false] + ``` + + This operation assumes a fully elaborated table -- 2^n entries. Inputs are + sorted MSB -> LSB from left to right and the offset into `lookupTable` is + computed from them. The table is sorted from 0 -> (2^n - 1) from left to + right. + + No difference from array_get into an array of constants except for xprop + behavior. If one of the inputs is unknown, but said input doesn't make a + difference in the output (based on the lookup table) the result should not + be 'x' -- it should be the well-known result. + }]; + + let arguments = (ins Variadic:$inputs, BoolArrayAttr:$lookupTable); + let results = (outs I1:$result); + + let assemblyFormat = [{ + $inputs `->` $lookupTable attr-dict + }]; + + let hasVerifier = 1; +} diff --git a/third_party/circt/include/circt/Dialect/Comb/Passes.td b/third_party/circt/include/circt/Dialect/Comb/Passes.td new file mode 100644 index 000000000..5ecaa997d --- /dev/null +++ b/third_party/circt/include/circt/Dialect/Comb/Passes.td @@ -0,0 +1,25 @@ +//===-- Passes.td - Comb pass definition file --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_COMB_PASSES_TD +#define CIRCT_DIALECT_COMB_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def LowerComb : Pass<"lower-comb", "mlir::ModuleOp"> { + let summary = "Lowers the some of the comb ops"; + let description = [{ + Some operations in the comb dialect (e.g. `comb.truth_table`) are not + directly supported by ExportVerilog. They need to be lowered into ops which + are supported. There are many ways to lower these ops so we do this in a + separate pass. This also allows the lowered form to participate in + optimizations like the comb canonicalizers. + }]; +} + +#endif // CIRCT_DIALECT_COMB_PASSES_TD diff --git a/third_party/circt/include/circt/Dialect/HW/BUILD b/third_party/circt/include/circt/Dialect/HW/BUILD new file mode 100644 index 000000000..73f8411a5 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/BUILD @@ -0,0 +1,245 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + glob(["*.h"]), +) + +cc_library( + name = "headers", + hdrs = [ + "ConversionPatterns.h", + "CustomDirectiveImpl.h", + "HWAttributes.h", + "HWDialect.h", + "HWInstanceGraph.h", + "HWOpInterfaces.h", + "HWOps.h", + "HWSymCache.h", + "HWTypeInterfaces.h", + "HWTypes.h", + "HWVisitors.h", + "InnerSymbolTable.h", + "InstanceImplementation.h", + "ModuleImplementation.h", + "PortConverter.h", + "@heir//third_party/circt/include/circt/Dialect/Comb:CombDialect.h", + "@heir//third_party/circt/include/circt/Dialect/Comb:CombOps.h", + ], + strip_include_prefix = "/third_party/circt/include", + deps = [ + "@heir//third_party/circt/include/circt/Dialect/HW:attributes_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:dialect_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:enum_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:op_interfaces_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:ops_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:type_interfaces_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:types_inc_gen", + "@heir//third_party/circt/lib/Support", + ], +) + +td_library( + name = "td_files", + srcs = glob([ + "*.td", + ]), + includes = ["/third_party/circt/include"], + deps = [ + "@heir//third_party/circt/include/circt/Support:td_files", + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "dialect_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + ], + "HWDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + ], + "HWDialect.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HWDialect.td", + deps = [ + ":td_files", + "@heir//third_party/circt/include/circt/Support:interfaces_inc_gen", + ], +) + +gentbl_cc_library( + name = "types_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + ], + "HWTypes.h.inc", + ), + ( + [ + "-gen-typedef-defs", + ], + "HWTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HW.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) + +gentbl_cc_library( + name = "ops_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-op-decls", + ], + "HW.h.inc", + ), + ( + [ + "-gen-op-defs", + ], + "HW.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HW.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ":types_inc_gen", + ], +) + +gentbl_cc_library( + name = "attributes_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-attrdef-decls", + ], + "HWAttributes.h.inc", + ), + ( + [ + "-gen-attrdef-defs", + ], + "HWAttributes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HW.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "op_interfaces_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-op-interface-decls", + ], + "HWOpInterfaces.h.inc", + ), + ( + [ + "-gen-op-interface-defs", + ], + "HWOpInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HWOpInterfaces.td", + deps = [ + ":td_files", + ], +) + +gentbl_cc_library( + name = "type_interfaces_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-type-interface-decls", + ], + "HWTypeInterfaces.h.inc", + ), + ( + [ + "-gen-type-interface-defs", + ], + "HWTypeInterfaces.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HWTypeInterfaces.td", + deps = [ + ":attributes_inc_gen", + ":dialect_inc_gen", + ":td_files", + ], +) + +gentbl_cc_library( + name = "enum_inc_gen", + includes = ["/third_party/circt/include"], + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-enum-decls", + ], + "HWEnums.h.inc", + ), + ( + [ + "-gen-enum-defs", + ], + "HWEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "HW.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) diff --git a/third_party/circt/include/circt/Dialect/HW/CMakeLists.txt b/third_party/circt/include/circt/Dialect/HW/CMakeLists.txt new file mode 100644 index 000000000..8572e8a76 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/CMakeLists.txt @@ -0,0 +1,44 @@ +add_circt_dialect(HW hw) + +set(LLVM_TARGET_DEFINITIONS HW.td) + +mlir_tablegen(HWAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(HWAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRHWAttrIncGen) +add_dependencies(circt-headers MLIRHWAttrIncGen) + +mlir_tablegen(HWEnums.h.inc -gen-enum-decls) +mlir_tablegen(HWEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRHWEnumsIncGen) +add_dependencies(circt-headers MLIRHWEnumsIncGen) + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(CIRCTHWTransformsIncGen) + +set(LLVM_TARGET_DEFINITIONS HWOpInterfaces.td) +mlir_tablegen(HWOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(HWOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(CIRCTHWOpInterfacesIncGen) +add_dependencies(circt-headers CIRCTHWOpInterfacesIncGen) + +set(LLVM_TARGET_DEFINITIONS HWTypeInterfaces.td) +mlir_tablegen(HWTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(HWTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(CIRCTHWTypeInterfacesIncGen) +add_dependencies(circt-headers CIRCTHWTypeInterfacesIncGen) + +# Generate Dialect documentation. +add_circt_doc(HWAggregates Dialects/HWAggregateOps -gen-op-doc) +add_circt_doc(HWAttributes Dialects/HWAttributes -gen-attrdef-doc) +add_circt_doc(HWAttributesNaming Dialects/HWAttributesNaming -gen-attrdef-doc) +add_circt_doc(HWMiscOps Dialects/HWMiscOps -gen-op-doc) +add_circt_doc(HWOpInterfaces Dialects/HWOpInterfaces -gen-op-interface-docs) +add_circt_doc(HWStructure Dialects/HWStructureOps -gen-op-doc) +add_circt_doc(HWTypeDecls Dialects/HWTypeDeclsOps -gen-op-doc) +add_circt_doc(HWTypeInterfaces Dialects/HWTypeInterfaces -gen-type-interface-docs) +add_circt_doc(HWTypes Dialects/HWTypes -gen-typedef-doc) +add_circt_doc(HWTypesImpl Dialects/HWTypesImpl -gen-typedef-doc) + +# Generate Pass documentation. +add_circt_doc(Passes HWPasses -gen-pass-doc) diff --git a/third_party/circt/include/circt/Dialect/HW/ConversionPatterns.h b/third_party/circt/include/circt/Dialect/HW/ConversionPatterns.h new file mode 100644 index 000000000..28daf72d8 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/ConversionPatterns.h @@ -0,0 +1,36 @@ +//===- ConversionPatterns.h - Common Conversion patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_CONVERSIONPATTERNS_H +#define CIRCT_SUPPORT_CONVERSIONPATTERNS_H + +#include "circt/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace circt { + +/// Generic pattern which replaces an operation by one of the same operation +/// name, but with converted attributes, operands, and result types to eliminate +/// illegal types. Uses generic builders based on OperationState to make sure +/// that this pattern can apply to _any_ operation. +/// +/// Useful when a conversion can be entirely defined by a TypeConverter. +struct TypeConversionPattern : public mlir::ConversionPattern { + public: + TypeConversionPattern(TypeConverter &converter, MLIRContext *context) + : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {} + using ConversionPattern::ConversionPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_CONVERSIONPATTERNS_H diff --git a/third_party/circt/include/circt/Dialect/HW/CustomDirectiveImpl.h b/third_party/circt/include/circt/Dialect/HW/CustomDirectiveImpl.h new file mode 100644 index 000000000..a0cf5ac13 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/CustomDirectiveImpl.h @@ -0,0 +1,70 @@ +//===- CustomDirectiveImpl.h - Table-gen custom directive impl --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides common custom directives for table-gen assembly formats. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_CUSTOMDIRECTIVEIMPL_H +#define CIRCT_DIALECT_HW_CUSTOMDIRECTIVEIMPL_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" + +namespace circt { + +//===----------------------------------------------------------------------===// +// InputPortList Custom Directive +//===----------------------------------------------------------------------===// + +/// Parse a list of instance input ports. +/// input-list ::= `(` ( input-element (`,` input-element )* )? `)` +/// input-element ::= identifier `:` value `:` type +ParseResult parseInputPortList( + OpAsmParser &parser, + SmallVectorImpl &inputs, + SmallVectorImpl &inputTypes, ArrayAttr &inputNames); + +/// Print a list of instance input ports. +void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, + TypeRange inputTypes, ArrayAttr inputNames); + +//===----------------------------------------------------------------------===// +// OutputPortList Custom Directive +//===----------------------------------------------------------------------===// + +/// Parse a list of instance output ports. +/// output-list ::= `(` ( output-element (`,` output-element )* )? `)` +/// output-element ::= identifier `:` type +ParseResult parseOutputPortList(OpAsmParser &parser, + SmallVectorImpl &resultTypes, + ArrayAttr &resultNames); + +/// Print a list of instance output ports. +void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, + ArrayAttr resultNames); + +//===----------------------------------------------------------------------===// +// OptionalParameterList Custom Directive +//===----------------------------------------------------------------------===// + +/// Parse an parameter list if present. +/// module-parameter-list ::= `<` parameter-decl (`,` parameter-decl)* `>` +/// parameter-decl ::= identifier `:` type +/// parameter-decl ::= identifier `:` type `=` attribute +ParseResult parseOptionalParameterList(OpAsmParser &parser, + ArrayAttr ¶meters); + +/// Print a parameter list for a module or instance. +void printOptionalParameterList(OpAsmPrinter &p, Operation *op, + ArrayAttr parameters); + +} // namespace circt + +#endif // CIRCT_DIALECT_HW_CUSTOMDIRECTIVEIMPL_H diff --git a/third_party/circt/include/circt/Dialect/HW/HW.td b/third_party/circt/include/circt/Dialect/HW/HW.td new file mode 100644 index 000000000..cbee1e53f --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HW.td @@ -0,0 +1,30 @@ +//===- HW.td - HW dialect definition -----------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top level file for the HW dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HW_TD +#define CIRCT_DIALECT_HW_HW_TD + +include "circt/Dialect/HW/HWDialect.td" +include "circt/Dialect/HW/HWAttributes.td" +include "circt/Dialect/HW/HWAttributesNaming.td" + +include "circt/Dialect/HW/HWTypesImpl.td" +include "circt/Dialect/HW/HWTypes.td" + +include "circt/Dialect/HW/HWOpInterfaces.td" +include "circt/Dialect/HW/HWTypeInterfaces.td" +include "circt/Dialect/HW/HWMiscOps.td" +include "circt/Dialect/HW/HWAggregates.td" +include "circt/Dialect/HW/HWStructure.td" +include "circt/Dialect/HW/HWTypeDecls.td" + +#endif // CIRCT_DIALECT_HW_HW_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWAggregates.td b/third_party/circt/include/circt/Dialect/HW/HWAggregates.td new file mode 100644 index 000000000..8294bb5ec --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWAggregates.td @@ -0,0 +1,306 @@ +//===- HWAggregates.td - HW ops for structs/arrays/etc -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the MLIR ops for working with aggregate values like structs +// and arrays. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWAGGREGATES_TD +#define CIRCT_DIALECT_HW_HWAGGREGATES_TD + +include "circt/Dialect/HW/HWDialect.td" +include "circt/Dialect/HW/HWTypes.td" +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/OpAsmInterface.td" + +//===----------------------------------------------------------------------===// +// Constants +//===----------------------------------------------------------------------===// + +def AggregateConstantOp : HWOp<"aggregate_constant", [Pure, ConstantLike]> { + let summary = "Produce a constant aggregate value"; + let description = [{ + This operation produces a constant value of an aggregate type. Clock and + reset values are supported. For nested aggregates, embedded arrays are + used. + + Examples: + ```mlir + %result = hw.aggregate.constant [1 : i1, 2 : i2, 3 : i2] : !hw.struct + %result = hw.aggregate.constant [1 : i1, [2 : i2, 3 : i2]] : !hw.struct> + ``` + }]; + + let arguments = (ins ArrayAttr:$fields); + let results = (outs HWAggregateType:$result); + let assemblyFormat = "$fields attr-dict `:` type($result)"; + let hasVerifier = 1; + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// Packed Array Processing Operations +//===----------------------------------------------------------------------===// + +def ArrayCreateOp : HWOp<"array_create", [Pure, SameTypeOperands]> { + let summary = "Create an array from values"; + let description = [{ + Creates an array from a variable set of values. One or more values must be + listed. + + ``` + // %a, %b, %c are all i4 + %array = hw.array_create %a, %b, %c : i4 + ``` + + See the HW-SV rationale document for details on operand ordering. + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs ArrayType:$result); + + let hasVerifier = 1; + let hasFolder = 1; + let hasCanonicalizeMethod = 1; + + let hasCustomAssemblyFormat = 1; + let builders = [ + // ValueRange needs to contain at least one element. + OpBuilder<(ins "ValueRange":$input)> + ]; + + let extraClassDeclaration = [{ + /// If the all elements of the array are identical, returns that element + /// value. Otherwise returns a null value. + Value getUniformElement(); + /// Returns true if all array elements are identical. + bool isUniform() { return !!getUniformElement(); } + }]; +} + +def ArrayConcatOp : HWOp<"array_concat", [Pure]> { + let summary = "Concatenate some arrays"; + let description = [{ + Creates an array by concatenating a variable set of arrays. One or more + values must be listed. + + ``` + // %a, %b, %c are hw arrays of i4 with sizes 2, 5, and 4 respectively. + %array = hw.array_concat %a, %b, %c : (2, 5, 4 x i4) + // %array is !hw.array<11 x i4> + ``` + + See the HW-SV rationale document for details on operand ordering. + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs ArrayType:$result); + + let assemblyFormat = [{ + $inputs attr-dict `:` custom(type($inputs), qualified(type($result))) + }]; + + let builders = [ + // ValueRange needs to contain at least one element. + OpBuilder<(ins "ValueRange":$inputs)> + ]; + let hasFolder = 1; + let hasCanonicalizeMethod = 1; +} + +def ArraySliceOp : HWOp<"array_slice", [Pure]> { + let summary = "Get a range of values from an array"; + let description = [{ + Extracts a sub-range from an array. The range is from `lowIndex` to + `lowIndex` + the number of elements in the return type, non-inclusive on + the high end. For instance, + + ``` + // Slices 16 elements starting at '%offset'. + %subArray = hw.slice %largerArray at %offset : + (!hw.array<1024xi8>) -> !hw.array<16xi8> + ``` + + Would translate to the following SystemVerilog: + + ``` + logic [7:0][15:0] subArray = largerArray[offset +: 16]; + ``` + + Width of 'idx' is defined to be the precise number of bits required to + index the 'input' array. More precisely: for an input array of size M, + the width of 'idx' is ceil(log2(M)). Lower and upper bound indexes which + are larger than the size of the 'input' array results in undefined + behavior. + }]; + + let arguments = (ins ArrayType:$input, HWIntegerType:$lowIndex); + let results = (outs ArrayType:$dst); + + let hasVerifier = 1; + let hasFolder = 1; + let hasCanonicalizeMethod = 1; + + let assemblyFormat = [{ + $input`[`$lowIndex`]` attr-dict `:` + `(` custom(type($input), qualified(type($lowIndex))) `)` `->` qualified(type($dst)) + }]; +} + +class IndexBitWidthConstraint + : PredOpTrait<"Index width should be exactly clog2 (size of array), or either 0 or 1 if the array is a singleton.", + CPred<"isValidIndexBitWidth($" # index # ", $" # input # ")">>; + +class ArrayElementTypeConstraint + : TypesMatchWith<"Result must be arrays element type", + input, result, + "type_cast($_self).getElementType()">; + +// hw.array_get does not work with unpacked arrays. +def ArrayGetOp : HWOp<"array_get", + [Pure, IndexBitWidthConstraint<"index", "input">, + ArrayElementTypeConstraint<"result", "input">]> { + let summary = "Get the value in an array at the specified index"; + let arguments = (ins ArrayType:$input, HWIntegerType:$index); + let results = (outs HWNonInOutType:$result); + + let assemblyFormat = [{ + $input`[`$index`]` attr-dict `:` qualified(type($input)) `,` qualified(type($index)) + }]; + + let hasFolder = 1; + let hasCanonicalizeMethod = 1; +} + +//===----------------------------------------------------------------------===// +// Structure Processing Operations +//===----------------------------------------------------------------------===// + +def StructCreateOp : HWOp<"struct_create", [Pure]> { + let summary = "Create a struct from constituent parts."; + let arguments = (ins Variadic:$input); + let results = (outs StructType:$result); + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasVerifier = 1; +} + +// Extract the value of a field of a structure. +def StructExtractOp : HWOp<"struct_extract", + [Pure, + DeclareOpInterfaceMethods + ]> { + let summary = "Extract a named field from a struct."; + let description = [{ + ``` + %result = hw.struct_extract %input["field"] : !hw.struct + ``` + }]; + + let arguments = (ins StructType:$input, StrAttr:$field); + let results = (outs HWNonInOutType:$result); + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins "Value":$input, "StructType::FieldInfo":$field)>, + OpBuilder<(ins "Value":$input, "StringAttr":$field)>, + OpBuilder<(ins "Value":$input, "StringRef":$field), [{ + build(odsBuilder, odsState, input, odsBuilder.getStringAttr(field)); + }]> + ]; + + let hasFolder = 1; + let hasCanonicalizeMethod = 1; +} + +// Create a structure by replacing a field with a value in an existing one. +def StructInjectOp : HWOp<"struct_inject", [Pure, + AllTypesMatch<["input", "result"]>]> { + let summary = "Inject a value into a named field of a struct."; + let description = [{ + ``` + %result = hw.struct_inject %input["field"], %newValue + : !hw.struct + ``` + }]; + + let arguments = (ins StructType:$input, StrAttr:$field, + HWNonInOutType:$newValue); + let results = (outs StructType:$result); + let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let hasCanonicalizeMethod = 1; +} + +def StructExplodeOp : HWOp<"struct_explode", [Pure, + DeclareOpInterfaceMethods + ]> { + let summary = "Expand a struct into its constituent parts."; + let description = [{ + ``` + %result:2 = hw.struct_explode %input : !hw.struct + ``` + }]; + let arguments = (ins StructType:$input); + let results = (outs Variadic:$result); + let hasCustomAssemblyFormat = 1; + + let builders = [ + OpBuilder<(ins "Value":$input)> + ]; + + let hasFolder = 1; + let hasCanonicalizeMethod = 1; + } + +//===----------------------------------------------------------------------===// +// Union operations +//===----------------------------------------------------------------------===// + +def UnionCreateOp : HWOp<"union_create", [Pure]> { + let summary = "Create a union with the specified value."; + let description = [{ + Create a union with the value 'input', which can then be accessed via the + specified field. + + ``` + %x = hw.constant 0 : i3 + %z = hw.union_create "bar", %x : !hw.union + ``` + }]; + + let arguments = (ins StrAttr:$field, HWNonInOutType:$input); + let results = (outs UnionType:$result); + let hasCustomAssemblyFormat = 1; +} + +def UnionExtractOp : HWOp<"union_extract", [Pure, + DeclareOpInterfaceMethods, + ]> { + let summary = "Get a union member."; + let description = [{ + Get the value of a union, interpreting it as the type of the specified + member field. Extracting a value belonging to a different field than the + union was initially created will result in undefined behavior. + + ``` + %u = ... + %v = hw.union_extract %u["foo"] : !hw.union + // %v is of type 'i3' + ``` + }]; + + let arguments = (ins UnionType:$input, StrAttr:$field); + let results = (outs HWNonInOutType:$result); + let hasCustomAssemblyFormat = 1; +} + +#endif // CIRCT_DIALECT_HW_HWAGGREGATES_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWAttributes.h b/third_party/circt/include/circt/Dialect/HW/HWAttributes.h new file mode 100644 index 000000000..3c6e32bcb --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWAttributes.h @@ -0,0 +1,45 @@ +//===- HWAttributes.h - Declare HW dialect attributes ------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_ATTRIBUTES_H +#define CIRCT_DIALECT_HW_ATTRIBUTES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace circt { +namespace hw { +class PEOAttr; +class EnumType; +enum class PEO : uint32_t; + +// Forward declaration. +class GlobalRefOp; + +/// Returns a resolved version of 'type' wherein any parameter reference +/// has been evaluated based on the set of provided 'parameters'. +mlir::FailureOr evaluateParametricType(mlir::Location loc, + mlir::ArrayAttr parameters, + mlir::Type type); + +/// Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set +/// of provided parameter values. +mlir::FailureOr evaluateParametricAttr( + mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr); + +/// Returns true if any part of t is parametric. +bool isParametricType(mlir::Type t); + +} // namespace hw +} // namespace circt + +#define GET_ATTRDEF_CLASSES +#include "circt/Dialect/HW/HWAttributes.h.inc" + +#endif // CIRCT_DIALECT_HW_ATTRIBUTES_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWAttributes.td b/third_party/circt/include/circt/Dialect/HW/HWAttributes.td new file mode 100644 index 000000000..bf94e433b --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWAttributes.td @@ -0,0 +1,309 @@ +//===- HWAttributes.td - Attributes for HW dialect ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines HW dialect specific attributes. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWATTRIBUTES_TD +#define CIRCT_DIALECT_HW_HWATTRIBUTES_TD + +include "circt/Dialect/HW/HWDialect.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinAttributeInterfaces.td" + +/// An attribute to indicate the output file an operation should be emitted to. +def OutputFileAttr : AttrDef { + let summary = "Output file attribute"; + let description = [{ + This attribute represents an output file for something which will be + printed. The `filename` string is the file to be output to. If `filename` + ends in a `/` it is considered an output directory. + + When ExportVerilog runs, one of the files produced is a list of all other + files which are produced. The flag `excludeFromFileList` controls if this + file should be included in this list. If any `OutputFileAttr` referring to + the same file sets this to `true`, it will be included in the file list. + This option defaults to `false`. + + For each file emitted by the verilog emitter, certain prelude output will + be included before the main content. The flag `includeReplicatedOps` can + be used to disable the addition of the prelude text. All `OutputFileAttr`s + referring to the same file must use a consistent setting for this value. + This option defaults to `true`. + + Examples: + ```mlir + #hw.ouput_file<"/home/tester/t.sv"> + #hw.ouput_file<"t.sv", excludeFromFileList, includeReplicatedOps> + ``` + }]; + let mnemonic = "output_file"; + let parameters = (ins "::mlir::StringAttr":$filename, + "::mlir::BoolAttr":$excludeFromFilelist, + "::mlir::BoolAttr":$includeReplicatedOps); + let builders = [ + AttrBuilderWithInferredContext<(ins + "::mlir::StringAttr":$filename, + "::mlir::BoolAttr":$excludeFromFileList, + "::mlir::BoolAttr":$includeReplicatedOps), [{ + return get(filename.getContext(), filename, excludeFromFileList, + includeReplicatedOps); + }]>, + ]; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Get an OutputFileAttr from a string filename, canonicalizing the + /// filename. + static OutputFileAttr getFromFilename(::mlir::MLIRContext *context, + const ::mlir::Twine &filename, + bool excludeFromFileList = false, + bool includeReplicatedOps = false); + + /// Get an OutputFileAttr from a string filename, resolving it relative to + /// `directory`. If `filename` is an absolute path, the given `directory` + /// will not be used. + static OutputFileAttr getFromDirectoryAndFilename( + ::mlir::MLIRContext *context, + const ::mlir::Twine &directory, + const ::mlir::Twine &filename, + bool excludeFromFileList = false, + bool includeReplicatedOps = false); + + /// Get an OutputFileAttr from a string directory name. The name will have + /// a trailing `/` added if it is not there, ensuring that this will be + /// an output directory. + static OutputFileAttr getAsDirectory(::mlir::MLIRContext *context, + const ::mlir::Twine &directory, + bool excludeFromFileList = false, + bool includeReplicatedOps = false); + + /// Returns true if this a directory. + bool isDirectory(); + }]; +} + +// An attribute to indicate which filelist an operation's file should be +// included in. +def FileListAttr : AttrDef { + let summary = "Ouput filelist attribute"; + let description = [{ + This attribute represents an output filelist for something which will be + printed. The `filename` string is the file which the filename of the + operation to be output to. + + When ExportVerilog runs, some of the files produced are lists of other files + which are produced. Each filelist exported contains entities' output file + with `FileListAttr` marked. + + + Examples: + ```mlir + #hw.ouput_filelist<"/home/tester/t.F"> + #hw.ouput_filelist<"t.f"> + ``` + }]; + let mnemonic = "output_filelist"; + let parameters = (ins "::mlir::StringAttr":$filename); + let builders = [ + AttrBuilderWithInferredContext<(ins + "::mlir::StringAttr":$filename), [{ + return get(filename.getContext(), filename); + }]>, + ]; + + let assemblyFormat = "`<` $filename `>`"; + + let extraClassDeclaration = [{ + /// Get an OutputFileAttr from a string filename, canonicalizing the + /// filename. + static FileListAttr getFromFilename(::mlir::MLIRContext *context, + const ::mlir::Twine &filename); + }]; + +} + +def ParamDeclAttr : AttrDef { + let summary = "Module or instance parameter definition"; + let description = [{ + An attribute describing a module parameter, or instance parameter + specification. + }]; + + /// The value of the attribute - in a module, this is the default + /// value (and may be missing). In an instance, this is a required field that + /// specifies the value being passed. The verilog emitter omits printing the + /// parameter for an instance when the applied value and the default value are + /// the same. + let parameters = (ins "::mlir::StringAttr":$name, + AttributeSelfTypeParameter<"">:$type, + "::mlir::Attribute":$value); + let mnemonic = "param.decl"; + + let hasCustomAssemblyFormat = 1; + + let builders = [ + AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$name, + "::mlir::Type":$type), + "auto *context = type.getContext();\n" + "return $_get(context, name, type, Attribute());">, + AttrBuilderWithInferredContext<(ins "::mlir::StringRef":$name, + "::mlir::Type":$type), + "return get(StringAttr::get(type.getContext(), name), type);">, + + AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$name, + "::mlir::TypedAttr":$value), + "auto *context = value.getContext();\n" + "return $_get(context, name, value.getType(), value);">, + AttrBuilderWithInferredContext<(ins "::mlir::StringRef":$name, + "::mlir::TypedAttr":$value), + "return get(StringAttr::get(value.getContext(), name), value);"> + ]; + + let extraClassDeclaration = [{ + static ParamDeclAttr getWithName(ParamDeclAttr param, + ::mlir::StringAttr name) { + return get(param.getContext(), name, param.getType(), param.getValue()); + } + }]; +} + +/// An array of ParamDeclAttr's that may or may not have a 'value' specified, +/// to be used on hw.module or hw.instance. The hw.instance verifier further +/// ensures that all the values are specified. +def ParamDeclArrayAttr + : TypedArrayAttrBase; + +/// This attribute models a reference to a named parameter within a module body. +/// The type of the ParamDeclRefAttr must always be the same as the type of the +/// parameter being referenced. +def ParamDeclRefAttr : AttrDef { + let summary = "Is a reference to a parameter value."; + let parameters = (ins "::mlir::StringAttr":$name, + AttributeSelfTypeParameter<"">:$type); + let mnemonic = "param.decl.ref"; + + let builders = [ + AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$name, + "::mlir::Type":$type), [{ + return get(name.getContext(), name, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; +} + +def ParamVerbatimAttr : AttrDef { + let summary = + "Represents text to emit directly to SystemVerilog for a parameter"; + let parameters = (ins "::mlir::StringAttr":$value, + AttributeSelfTypeParameter<"">:$type); + let mnemonic = "param.verbatim"; + let hasCustomAssemblyFormat = 1; + let builders = [ + AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$value), [{ + return get(value, NoneType::get(value.getContext())); + }]>, + AttrBuilderWithInferredContext< + (ins "::mlir::StringAttr":$value, "::mlir::Type":$type), [{ + return get(value.getContext(), value, type); + }]>, + ]; +} + +/// Parameter Expression Opcodes. +let cppNamespace = "circt::hw" in { + +/// Fully Associative Expression Opcodes. +def PEO_Add : I32EnumAttrCase<"Add", 0, "add">; +def PEO_Mul : I32EnumAttrCase<"Mul", 1, "mul">; +def PEO_And : I32EnumAttrCase<"And", 2, "and">; +def PEO_Or : I32EnumAttrCase<"Or", 3, "or">; +def PEO_Xor : I32EnumAttrCase<"Xor", 4, "xor">; + +// Binary Expression Opcodes. +def PEO_Shl : I32EnumAttrCase<"Shl" , 5, "shl">; +def PEO_ShrU : I32EnumAttrCase<"ShrU", 6, "shru">; +def PEO_ShrS : I32EnumAttrCase<"ShrS", 7, "shrs">; +def PEO_DivU : I32EnumAttrCase<"DivU", 8, "divu">; +def PEO_DivS : I32EnumAttrCase<"DivS", 9, "divs">; +def PEO_ModU : I32EnumAttrCase<"ModU",10, "modu">; +def PEO_ModS : I32EnumAttrCase<"ModS",11, "mods">; + +// Unary Expression Opcodes. +def PEO_CLog2 : I32EnumAttrCase<"CLog2", 12, "clog2">; + +// String manipulation Opcodes. +def PEO_StrConcat : I32EnumAttrCase<"StrConcat", 13, "str.concat">; + +def PEOAttr : I32EnumAttr<"PEO", "Parameter Expression Opcode", + [PEO_Add, PEO_Mul, PEO_And, PEO_Or, PEO_Xor, + PEO_Shl, PEO_ShrU, PEO_ShrS, + PEO_DivU, PEO_DivS, PEO_ModU, PEO_ModS, + PEO_CLog2, PEO_StrConcat]>; +} + +def ParamExprAttr : AttrDef { + let summary = "Parameter expression combining operands"; + let parameters = (ins "PEO":$opcode, + ArrayRefParameter<"::mlir::TypedAttr">:$operands, + AttributeSelfTypeParameter<"">:$type); + let mnemonic = "param.expr"; + + // Force all clients to go through our building logic so we can canonicalize + // during building. + let skipDefaultBuilders = 1; + + let extraClassDeclaration = [{ + /// Build a parameter expression. This automatically canonicalizes and + /// folds, so it may not necessarily return a ParamExprAttr. + static mlir::TypedAttr get(PEO opcode, + mlir::ArrayRef operands); + + /// Build a binary parameter expression for convenience. + static mlir::TypedAttr get(PEO opcode, mlir::TypedAttr lhs, + mlir::TypedAttr rhs) { + mlir::TypedAttr operands[] = { lhs, rhs }; + return get(opcode, operands); + } + }]; + + let hasCustomAssemblyFormat = 1; +} + +// An attribute to indicate an enumeration value. +def EnumFieldAttr : AttrDef { + let summary = "Enumeration field attribute"; + let description = [{ + This attribute represents a field of an enumeration. + + Examples: + ```mlir + #hw.enum.value> + ``` + }]; + let mnemonic = "enum.field"; + let parameters = (ins "::mlir::StringAttr":$field, "::mlir::TypeAttr":$type); + + // Force all clients to go through our custom builder so we can check + // whether the requested enum value is part of the provided enum type. + let skipDefaultBuilders = 1; + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Builds a new EnumFieldAttr of the provided value. + /// This will fail if the value is not a member of the provided enum type. + static EnumFieldAttr get(::mlir::Location loc, ::mlir::StringAttr value, mlir::Type type); + }]; +} + +#endif // CIRCT_DIALECT_HW_HWATTRIBUTES_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWAttributesNaming.td b/third_party/circt/include/circt/Dialect/HW/HWAttributesNaming.td new file mode 100644 index 000000000..e54de58c9 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWAttributesNaming.td @@ -0,0 +1,70 @@ +//===- HWAttributesNaming.td - Attributes for HW dialect ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines HW dialect attributes used in other dialects. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWATTRIBUTESNAMING +#define CIRCT_DIALECT_HW_HWATTRIBUTESNAMING + +include "circt/Dialect/HW/HWDialect.td" +include "mlir/IR/AttrTypeBase.td" + +def InnerRefAttr : AttrDef { + let summary = "Refer to a name inside a module"; + let description = [{ + This works like a symbol reference, but to a name inside a module. + }]; + let mnemonic = "innerNameRef"; + let parameters = (ins "::mlir::FlatSymbolRefAttr":$moduleRef, + "::mlir::StringAttr":$name); + + let builders = [ + AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$module, + "::mlir::StringAttr":$name), [{ + return $_get( + module.getContext(), mlir::FlatSymbolRefAttr::get(module), name); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + /// Get the InnerRefAttr for an operation and add the sym on it. + static InnerRefAttr getFromOperation(mlir::Operation *op, + mlir::StringAttr symName, + mlir::StringAttr moduleName); + + /// Return the name of the referenced module. + mlir::StringAttr getModule() const { return getModuleRef().getAttr(); } + }]; +} + +def GlobalRefAttr : AttrDef { + let summary = "Refer to a non-local symbol"; + let description = [{ + This works like a symbol reference, but to a global symbol with a possible + unique instance path. + }]; + let mnemonic = "globalNameRef"; + let parameters = (ins "::mlir::FlatSymbolRefAttr":$glblSym); + let builders = [ + AttrBuilderWithInferredContext<(ins "::circt::hw::GlobalRefOp":$ref),[{ + return get(ref.getContext(), SymbolRefAttr::get(ref)); + }]>, + ]; + + let assemblyFormat = "`<` $glblSym `>`"; + + let extraClassDeclaration = [{ + static constexpr char DialectAttrName[] = "circt.globalRef"; + }]; +} + +#endif // CIRCT_DIALECT_HW_HWATTRIBUTESNAMING diff --git a/third_party/circt/include/circt/Dialect/HW/HWDialect.h b/third_party/circt/include/circt/Dialect/HW/HWDialect.h new file mode 100644 index 000000000..f506b8a66 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWDialect.h @@ -0,0 +1,26 @@ +//===- HWDialect.h - HW dialect declaration ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines an HW MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWDIALECT_H +#define CIRCT_DIALECT_HW_HWDIALECT_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" + +// Pull in the dialect definition. +#include "circt/Dialect/HW/HWDialect.h.inc" + +// Pull in all enum type definitions and utility function declarations. +#include "circt/Dialect/HW/HWEnums.h.inc" + +#endif // CIRCT_DIALECT_HW_HWDIALECT_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWDialect.td b/third_party/circt/include/circt/Dialect/HW/HWDialect.td new file mode 100644 index 000000000..4d9a12965 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWDialect.td @@ -0,0 +1,49 @@ +//===- HWDialect.td - HW dialect definition ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This contains the HWDialect definition to be included in other files. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWDIALECT +#define CIRCT_DIALECT_HW_HWDIALECT + +include "mlir/IR/OpBase.td" + +def HWDialect : Dialect { + let name = "hw"; + let cppNamespace = "::circt::hw"; + + let summary = "Types and operations for the hardware dialect"; + let description = [{ + This dialect defines the `hw` dialect, which is intended to be a generic + representation of HW outside of a particular use-case. + }]; + + let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; + + // Opt-out of properties for now, must migrate by LLVM 19. #5273. + let usePropertiesForAttributes = 0; + + let extraClassDeclaration = [{ + /// Register all HW types. + void registerTypes(); + /// Register all HW attributes. + void registerAttributes(); + + Attribute parseAttribute(DialectAsmParser &p, Type type) const override; + void printAttribute(Attribute attr, DialectAsmPrinter &p) const override; + }]; +} + +// Base class for the operation in this dialect. +class HWOp traits = []> : + Op; + +#endif // CIRCT_DIALECT_HW_HWDIALECT diff --git a/third_party/circt/include/circt/Dialect/HW/HWInstanceGraph.h b/third_party/circt/include/circt/Dialect/HW/HWInstanceGraph.h new file mode 100644 index 000000000..f53eaa088 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWInstanceGraph.h @@ -0,0 +1,56 @@ +//===- InstanceGraph.h - Instance graph -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the HW InstanceGraph, which is similar to a CallGraph. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWINSTANCEGRAPH_H +#define CIRCT_DIALECT_HW_HWINSTANCEGRAPH_H + +#include "circt/Dialect/HW/HWOpInterfaces.h" +#include "circt/Support/InstanceGraph.h" + +namespace circt { +namespace hw { + +/// HW-specific instance graph with a virtual entry node linking to +/// all publicly visible modules. +class InstanceGraph : public igraph::InstanceGraph { + public: + InstanceGraph(Operation *operation); + + /// Return the entry node linking to all public modules. + igraph::InstanceGraphNode *getTopLevelNode() override { return &entry; } + + /// Adds a module, updating links to entry. + igraph::InstanceGraphNode *addHWModule(HWModuleLike module); + + /// Erases a module, updating links to entry. + void erase(igraph::InstanceGraphNode *node) override; + + private: + using igraph::InstanceGraph::addModule; + igraph::InstanceGraphNode entry; +}; + +} // namespace hw +} // namespace circt + +// Specialisation for the HW instance graph. +template <> +struct llvm::GraphTraits + : public llvm::GraphTraits {}; + +template <> +struct llvm::DOTGraphTraits + : public llvm::DOTGraphTraits { + using llvm::DOTGraphTraits::DOTGraphTraits; +}; + +#endif // CIRCT_DIALECT_HW_HWINSTANCEGRAPH_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWMiscOps.td b/third_party/circt/include/circt/Dialect/HW/HWMiscOps.td new file mode 100644 index 000000000..bd534a5b7 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWMiscOps.td @@ -0,0 +1,192 @@ +//===- HWMiscOps.td - Miscellaneous HW ops -----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This defines miscellaneous generic HW ops, like ConstantOp and BitcastOp. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWMISCOPS_TD +#define CIRCT_DIALECT_HW_HWMISCOPS_TD + +include "circt/Dialect/HW/HWAttributes.td" +include "circt/Dialect/HW/HWDialect.td" +include "circt/Dialect/HW/HWOpInterfaces.td" +include "circt/Dialect/HW/HWTypes.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def ConstantOp + : HWOp<"constant", [Pure, ConstantLike, FirstAttrDerivedResultType, + DeclareOpInterfaceMethods]> { + let summary = "Produce a constant value"; + let description = [{ + The constant operation produces a constant value of standard integer type + without a sign. + ``` + %result = hw.constant 42 : t1 + ``` + }]; + + let arguments = (ins APIntAttr:$value); + let results = (outs HWIntegerType:$result); + let hasCustomAssemblyFormat = 1; + + let builders = [ + /// Build a ConstantOp from an APInt, infering the result type from the + /// width of the APInt. + OpBuilder<(ins "const APInt &":$value)>, + + /// This builder allows construction of small signed integers like 0, 1, -1 + /// matching a specified MLIR IntegerType. This shouldn't be used for + /// general constant folding because it only works with values that can be + /// expressed in an int64_t. Use APInt's instead. + OpBuilder<(ins "Type":$type, "int64_t":$value)>, + + /// Build a ConstantOp from a prebuilt attribute. + OpBuilder<(ins "IntegerAttr":$value)> + ]; + let hasFolder = true; + let hasVerifier = 1; +} + +def WireOp : HWOp<"wire", [ + SameOperandsAndResultType, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Assign a name or symbol to an SSA edge"; + let description = [{ + An `hw.wire` is used to assign a human-readable name or a symbol for remote + references to an SSA edge. It takes a single operand and returns its value + unchanged as a result. The operation guarantees the following: + + - If the wire has a symbol, the value of its operand remains observable + under that symbol within the IR. + + - If the wire has a name, the name is treated as a hint. If the wire + persists until code generation the resulting wire will have this name, + with a potential suffix to ensure uniqueness. If the wire is canonicalized + away, its name is propagated to its input operand as a name hint. + + - The users of its result will always observe the operand through the + operation itself, meaning that optimizations cannot bypass the wire. This + ensures that if the wire's value is *forced*, for example through a + Verilog force statement, the forced value will affect all users of the + wire in the output. + + Example: + ``` + %1 = hw.wire %0 : i42 + %2 = hw.wire %0 sym @mySym : i42 + %3 = hw.wire %0 name "myWire" : i42 + %myWire = hw.wire %0 : i42 + ``` + }]; + + let arguments = (ins AnyType:$input, + OptionalAttr:$name, + OptionalAttr:$inner_sym); + let results = (outs AnyType:$result); + + let hasFolder = true; + let hasCanonicalizeMethod = 1; + let builders = [ + OpBuilder<(ins "mlir::Value":$input, + CArg<"const StringAttrOrRef &", "{}">:$name, + CArg<"hw::InnerSymAttr", "{}">:$innerSym), [{ + auto *context = odsBuilder.getContext(); + odsState.addOperands(input); + if (auto attr = name.get(context)) + odsState.addAttribute(getNameAttrName(odsState.name), attr); + if (innerSym) + odsState.addAttribute(getInnerSymAttrName(odsState.name), innerSym); + odsState.addTypes(input.getType()); + }]> + ]; + + let assemblyFormat = [{ + $input (`sym` $inner_sym^)? custom($name) attr-dict + `:` qualified(type($input)) + }]; +} + +def KnownBitWidthType : Type, + "Type wherein the bitwidth in hardware is known">; + +def BitcastOp: HWOp<"bitcast", [Pure]> { + let summary = [{ + Reinterpret one value to another value of the same size and + potentially different type. See the `hw` dialect rationale document for + more details. + }]; + + let arguments = (ins KnownBitWidthType:$input); + let results = (outs KnownBitWidthType:$result); + let hasCanonicalizeMethod = true; + let hasFolder = true; + let hasVerifier = 1; + + let assemblyFormat = "$input attr-dict `:` functional-type($input, $result)"; +} + +def ParamValueOp : HWOp<"param.value", + [FirstAttrDerivedResultType, Pure, + ConstantLike]> { + let summary = [{ + Return the value of a parameter expression as an SSA value that may be used + by other ops. + }]; + + let arguments = (ins AnyAttr:$value); + let results = (outs HWValueType:$result); + let assemblyFormat = "custom($value, qualified(type($result))) attr-dict"; + let hasVerifier = 1; + let hasFolder = true; +} + +def EnumConstantOp : HWOp<"enum.constant", [Pure, ConstantLike, + DeclareOpInterfaceMethods]> { + let summary = "Produce a constant enumeration value."; + let description = [{ + The enum.constant operation produces an enumeration value of the specified + enum value attribute. + ``` + %0 = hw.enum.constant A : !hw.enum + ``` + }]; + + let arguments = (ins EnumFieldAttr:$field); + let results = (outs EnumType:$result); + let hasCustomAssemblyFormat = 1; + let hasFolder = true; + let hasVerifier = true; + let builders = [ + OpBuilder<(ins "hw::EnumFieldAttr":$field)>, + ]; +} + +def EnumCmpOp : HWOp<"enum.cmp", [Pure]> { + let summary = "Compare two values of an enumeration"; + let description = [{ + This operation compares two values with the same canonical enumeration + type, returning 0 if they are different, and 1 if they are the same. + + Example: + ```mlir + %enumcmp = hw.enum.cmp %A, %B : !hw.enum, !hw.enum + ``` + }]; + let arguments = (ins EnumType:$lhs, EnumType:$rhs); + let results = (outs I1:$result); + let hasVerifier = true; + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) + }]; +} + +#endif // CIRCT_DIALECT_HW_HWMISCOPS_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWModuleGraph.h b/third_party/circt/include/circt/Dialect/HW/HWModuleGraph.h new file mode 100644 index 000000000..214fb7db2 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWModuleGraph.h @@ -0,0 +1,186 @@ +//===- HWModuleGraph.h - HWModule graph -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the HWModuleGraph. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWMODULEGRAPH_H +#define CIRCT_DIALECT_HW_HWMODULEGRAPH_H + +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWInstanceGraph.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/Seq/SeqOps.h" +#include "circt/Support/LLVM.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/iterator.h" +#include "llvm/Support/DOTGraphTraits.h" +#include "llvm/Support/GraphWriter.h" + +namespace circt { +namespace hw { +namespace detail { + +// Using declaration to avoid polluting global namespace with CIRCT-specific +// graph traits for mlir::Operation. +using HWOperation = mlir::Operation; + +} // namespace detail +} // namespace hw +} // namespace circt + +template <> +struct llvm::GraphTraits { + using NodeType = circt::hw::detail::HWOperation; + using NodeRef = NodeType *; + + using ChildIteratorType = mlir::Operation::user_iterator; + static NodeRef getEntryNode(NodeRef op) { return op; } + static ChildIteratorType child_begin(NodeRef op) { return op->user_begin(); } + static ChildIteratorType child_end(NodeRef op) { return op->user_end(); } +}; + +template <> +struct llvm::GraphTraits + : public llvm::GraphTraits { + using GraphType = circt::hw::HWModuleOp; + + static NodeRef getEntryNode(GraphType mod) { + return &mod.getBodyBlock()->front(); + } + + using nodes_iterator = pointer_iterator; + static nodes_iterator nodes_begin(GraphType mod) { + return nodes_iterator{mod.getBodyBlock()->begin()}; + } + static nodes_iterator nodes_end(GraphType mod) { + return nodes_iterator{mod.getBodyBlock()->end()}; + } +}; + +template <> +struct llvm::DOTGraphTraits + : public llvm::DefaultDOTGraphTraits { + using DefaultDOTGraphTraits::DefaultDOTGraphTraits; + + static std::string getNodeLabel(circt::hw::detail::HWOperation *node, + circt::hw::HWModuleOp) { + return llvm::TypeSwitch(node) + .Case([&](auto) { return "+"; }) + .Case([&](auto) { return "-"; }) + .Case([&](auto) { return "&"; }) + .Case([&](auto) { return "|"; }) + .Case([&](auto) { return "^"; }) + .Case([&](auto) { return "*"; }) + .Case([&](auto) { return "mux"; }) + .Case( + [&](auto) { return ">>"; }) + .Case([&](auto) { return "<<"; }) + .Case([&](auto op) { + switch (op.getPredicate()) { + case circt::comb::ICmpPredicate::eq: + case circt::comb::ICmpPredicate::ceq: + case circt::comb::ICmpPredicate::weq: + return "=="; + case circt::comb::ICmpPredicate::wne: + case circt::comb::ICmpPredicate::cne: + case circt::comb::ICmpPredicate::ne: + return "!="; + case circt::comb::ICmpPredicate::uge: + case circt::comb::ICmpPredicate::sge: + return ">="; + case circt::comb::ICmpPredicate::ugt: + case circt::comb::ICmpPredicate::sgt: + return ">"; + case circt::comb::ICmpPredicate::ule: + case circt::comb::ICmpPredicate::sle: + return "<="; + case circt::comb::ICmpPredicate::ult: + case circt::comb::ICmpPredicate::slt: + return "<"; + } + llvm_unreachable("unhandled ICmp predicate"); + }) + .Case( + [&](auto op) { return op.getName().str(); }) + .Case([&](auto op) { + llvm::SmallString<64> valueString; + op.getValue().toString(valueString, 10, false); + return valueString.str().str(); + }) + .Default([&](auto op) { return op->getName().getStringRef().str(); }); + } + + std::string getNodeAttributes(circt::hw::detail::HWOperation *node, + circt::hw::HWModuleOp) { + return llvm::TypeSwitch(node) + .Case( + [&](auto) { return "fillcolor=darkgoldenrod1,style=filled"; }) + .Case([&](auto) { + return "shape=invtrapezium,fillcolor=bisque,style=filled"; + }) + .Case( + [&](auto) { return "fillcolor=lightblue,style=filled"; }) + .Default([&](auto op) { + return llvm::TypeSwitch( + op->getDialect()) + .Case([&](auto) { + return "shape=oval,fillcolor=bisque,style=filled"; + }) + .template Case([&](auto) { + return "shape=folder,fillcolor=gainsboro,style=filled"; + }) + .Default([&](auto) { return ""; }); + }); + } + + static void addCustomGraphFeatures( + circt::hw::HWModuleOp mod, llvm::GraphWriter &g) { + // Add module input args. + auto &os = g.getOStream(); + os << "subgraph cluster_entry_args {\n"; + os << "label=\"Input arguments\";\n"; + auto iports = mod.getPortList(); + for (auto [info, arg] : + llvm::zip(iports.getInputs(), mod.getBodyBlock()->getArguments())) { + g.emitSimpleNode(reinterpret_cast(&arg), "", + info.getName().str()); + } + os << "}\n"; + for (auto [info, arg] : + llvm::zip(iports.getInputs(), mod.getBodyBlock()->getArguments())) { + for (auto *user : arg.getUsers()) { + g.emitEdge(reinterpret_cast(&arg), 0, user, -1, ""); + } + } + } + + template + static std::string getEdgeAttributes(circt::hw::detail::HWOperation *node, + Iterator it, circt::hw::HWModuleOp mod) { + mlir::OpOperand &operand = *it.getCurrent(); + mlir::Value v = operand.get(); + std::string str; + llvm::raw_string_ostream os(str); + auto verboseEdges = mod->getAttrOfType("dot_verboseEdges"); + if (verboseEdges.getValue()) { + os << "label=\"" << operand.getOperandNumber() << " (" << v.getType() + << ")\""; + } + + int64_t width = circt::hw::getBitWidth(v.getType()); + if (width > 1) os << " style=bold"; + + return os.str(); + } +}; + +#endif // CIRCT_DIALECT_HW_HWMODULEGRAPH_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.h b/third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.h new file mode 100644 index 000000000..bcff33a33 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.h @@ -0,0 +1,297 @@ +//===- HWOpInterfaces.h - Declare HW op interfaces --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the operation interfaces for the HW dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWOPINTERFACES_H +#define CIRCT_DIALECT_HW_HWOPINTERFACES_H + +#include "circt/Dialect/HW/HWTypes.h" +#include "circt/Dialect/HW/InnerSymbolTable.h" +#include "circt/Support/InstanceGraphInterface.h" +#include "circt/Support/LLVM.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/SymbolTable.h" + +namespace circt { +namespace hw { + +void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, + RewritePatternSet &patterns, + TypeConverter &converter); + +/// This holds the name, type, direction of a module's ports +struct PortInfo : public ModulePort { + /// This is the argument index or the result index depending on the direction. + /// "0" for an output means the first output, "0" for a in/inout means the + /// first argument. + size_t argNum = ~0U; + + /// The optional symbol for this port. + InnerSymAttr sym = {}; + DictionaryAttr attrs = {}; + LocationAttr loc = {}; + + StringRef getName() const { return name.getValue(); } + bool isInput() const { return dir == ModulePort::Direction::Input; } + bool isOutput() const { return dir == ModulePort::Direction::Output; } + bool isInOut() const { return dir == ModulePort::Direction::InOut; } + + /// Return a unique numeric identifier for this port. + ssize_t getId() const { return isOutput() ? argNum : (-1 - argNum); }; +}; + +raw_ostream &operator<<(raw_ostream &printer, PortInfo port); + +/// This holds a decoded list of input/inout and output ports for a module or +/// instance. +struct ModulePortInfo { + explicit ModulePortInfo(ArrayRef inputs, + ArrayRef outputs) { + ports.insert(ports.end(), inputs.begin(), inputs.end()); + ports.insert(ports.end(), outputs.begin(), outputs.end()); + sanitizeInOut(); + } + + explicit ModulePortInfo(ArrayRef mergedPorts) + : ports(mergedPorts.begin(), mergedPorts.end()) { + sanitizeInOut(); + } + + using iterator = SmallVector::iterator; + using const_iterator = SmallVector::const_iterator; + + iterator begin() { return ports.begin(); } + iterator end() { return ports.end(); } + const_iterator begin() const { return ports.begin(); } + const_iterator end() const { return ports.end(); } + + using PortDirectionRange = llvm::iterator_range< + llvm::filter_iterator>>; + + using ConstPortDirectionRange = llvm::iterator_range>>; + + PortDirectionRange getPortsOfDirection(bool input) { + std::function predicateFn; + if (input) { + predicateFn = [](const PortInfo &port) -> bool { + return port.dir == ModulePort::Direction::Input || + port.dir == ModulePort::Direction::InOut; + }; + } else { + predicateFn = [](const PortInfo &port) -> bool { + return port.dir == ModulePort::Direction::Output; + }; + } + return llvm::make_filter_range(ports, predicateFn); + } + + ConstPortDirectionRange getPortsOfDirection(bool input) const { + std::function predicateFn; + if (input) { + predicateFn = [](const PortInfo &port) -> bool { + return port.dir == ModulePort::Direction::Input || + port.dir == ModulePort::Direction::InOut; + }; + } else { + predicateFn = [](const PortInfo &port) -> bool { + return port.dir == ModulePort::Direction::Output; + }; + } + return llvm::make_filter_range(ports, predicateFn); + } + + PortDirectionRange getInputs() { return getPortsOfDirection(true); } + + PortDirectionRange getOutputs() { return getPortsOfDirection(false); } + + ConstPortDirectionRange getInputs() const { + return getPortsOfDirection(true); + } + + ConstPortDirectionRange getOutputs() const { + return getPortsOfDirection(false); + } + + size_t size() const { return ports.size(); } + size_t sizeInputs() const { + auto r = getInputs(); + return std::distance(r.begin(), r.end()); + } + size_t sizeOutputs() const { + auto r = getOutputs(); + return std::distance(r.begin(), r.end()); + } + + size_t portNumForInput(size_t idx) const { + size_t port = 0; + while (idx || ports[port].isOutput()) { + if (!ports[port].isOutput()) --idx; + ++port; + } + return port; + } + + size_t portNumForOutput(size_t idx) const { + size_t port = 0; + while (idx || !ports[port].isOutput()) { + if (ports[port].isOutput()) --idx; + ++port; + } + return port; + } + + PortInfo &at(size_t idx) { return ports[idx]; } + PortInfo &atInput(size_t idx) { return ports[portNumForInput(idx)]; } + PortInfo &atOutput(size_t idx) { return ports[portNumForOutput(idx)]; } + + const PortInfo &at(size_t idx) const { return ports[idx]; } + const PortInfo &atInput(size_t idx) const { + return ports[portNumForInput(idx)]; + } + const PortInfo &atOutput(size_t idx) const { + return ports[portNumForOutput(idx)]; + } + + void eraseInput(size_t idx) { + assert(idx < sizeInputs()); + ports.erase(ports.begin() + portNumForInput(idx)); + } + + private: + // convert input inout -> inout type + void sanitizeInOut() { + for (auto &p : ports) + if (auto inout = dyn_cast(p.type)) { + p.type = inout.getElementType(); + p.dir = ModulePort::Direction::InOut; + } + } + + /// This contains a list of all ports. Input first. + SmallVector ports; +}; + +// This provides capability for looking up port indices based on port names. +struct ModulePortLookupInfo { + FailureOr lookupPortIndex( + const llvm::DenseMap &portMap, + StringAttr name) const { + auto it = portMap.find(name); + if (it == portMap.end()) return failure(); + return it->second; + } + + public: + explicit ModulePortLookupInfo(MLIRContext *ctx, + const ModulePortInfo &portInfo) + : ctx(ctx) { + for (auto &in : portInfo.getInputs()) inputPortMap[in.name] = in.argNum; + + for (auto &out : portInfo.getOutputs()) + outputPortMap[out.name] = out.argNum; + } + + // Return the index of the input port with the specified name. + FailureOr getInputPortIndex(StringAttr name) const { + return lookupPortIndex(inputPortMap, name); + } + + // Return the index of the output port with the specified name. + FailureOr getOutputPortIndex(StringAttr name) const { + return lookupPortIndex(outputPortMap, name); + } + + FailureOr getInputPortIndex(StringRef name) const { + return getInputPortIndex(StringAttr::get(ctx, name)); + } + + FailureOr getOutputPortIndex(StringRef name) const { + return getOutputPortIndex(StringAttr::get(ctx, name)); + } + + private: + llvm::DenseMap inputPortMap; + llvm::DenseMap outputPortMap; + MLIRContext *ctx; +}; + +class InnerSymbolOpInterface; +/// Verification hook for verifying InnerSym Attribute. +LogicalResult verifyInnerSymAttr(InnerSymbolOpInterface op); + +namespace detail { +LogicalResult verifyInnerRefNamespace(Operation *op); +} // namespace detail + +/// Classify operations that are InnerRefNamespace-like, +/// until structure is in place to do this via Traits. +/// Useful for getParentOfType<>, or scheduling passes. +/// Prefer putting the trait on operations here or downstream. +struct InnerRefNamespaceLike { + /// Return if this operation is explicitly an IRN or appears compatible. + static bool classof(mlir::Operation *op); + /// Return if this operation is explicitly an IRN or appears compatible. + static bool classof(const mlir::RegisteredOperationName *opInfo); +}; + +} // namespace hw +} // namespace circt + +namespace mlir { +namespace OpTrait { + +/// This trait is for operations that define a scope for resolving InnerRef's, +/// and provides verification for InnerRef users (via InnerRefUserOpInterface). +template +class InnerRefNamespace : public TraitBase { + public: + static LogicalResult verifyRegionTrait(Operation *op) { + static_assert( + ConcreteType::template hasTrait<::mlir::OpTrait::SymbolTable>(), + "expected operation to be a SymbolTable"); + + if (op->getNumRegions() != 1) + return op->emitError("expected operation to have a single region"); + if (!op->getRegion(0).hasOneBlock()) + return op->emitError("expected operation to have a single block"); + + // Verify all InnerSymbolTable's and InnerRef users. + return ::circt::hw::detail::verifyInnerRefNamespace(op); + } +}; + +/// A trait for inner symbol table functionality on an operation. +template +class InnerSymbolTable : public TraitBase { + public: + static LogicalResult verifyRegionTrait(Operation *op) { + // Insist that ops with InnerSymbolTable's provide a Symbol, this is + // essential to how InnerRef's work. + static_assert( + ConcreteType::template hasTrait<::mlir::SymbolOpInterface::Trait>(), + "expected operation to define a Symbol"); + + // InnerSymbolTable's must be directly nested within an InnerRefNamespace. + auto *parent = op->getParentOp(); + if (!parent || !isa(parent)) + return op->emitError( + "InnerSymbolTable must have InnerRefNamespace parent"); + + return success(); + } +}; +} // namespace OpTrait +} // namespace mlir + +#include "circt/Dialect/HW/HWOpInterfaces.h.inc" + +#endif // CIRCT_DIALECT_HW_HWOPINTERFACES_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.td b/third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.td new file mode 100644 index 000000000..1d2b1789a --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWOpInterfaces.td @@ -0,0 +1,528 @@ +//===- HWOpInterfaces.td - Operation Interfaces ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the HW operation interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWOPINTERFACES +#define CIRCT_DIALECT_HW_HWOPINTERFACES + +include "mlir/IR/SymbolInterfaces.td" +include "mlir/IR/OpBase.td" +include "circt/Support/InstanceGraphInterface.td" + +def PortList : OpInterface<"PortList", []> { + let cppNamespace = "circt::hw"; + let description = "Operations which produce a unified port list representation"; + let methods = [ + InterfaceMethod<"Get port list", + "::circt::hw::ModulePortInfo", "getPortList", (ins)>, + + InterfaceMethod<"Get the port a specific input", + "size_t", "getPortIdForInputId", (ins "size_t":$idx)>, + + InterfaceMethod<"Get the port a specific output", + "size_t", "getPortIdForOutputId", (ins "size_t":$idx)>, + + InterfaceMethod<"Get the number of ports", + "size_t", "getNumPorts", (ins)>, + InterfaceMethod<"Get the number of input ports", + "size_t", "getNumInputPorts", (ins)>, + InterfaceMethod<"Get the number of output ports", + "size_t", "getNumOutputPorts", (ins)>, + ]; +} + +def SymboledPortList : OpInterface<"SymboledPortList", [PortList]> { + let cppNamespace = "circt::hw"; + let description = "Operations which produce a unified port list representation which includes port symbols"; + let methods = [ + InterfaceMethod<"Get a port symbol attribute", + "::circt::hw::InnerSymAttr", "getPortSymbolAttr", (ins "size_t":$portIndex)>, + ]; +} + + +def HWModuleLike : OpInterface<"HWModuleLike", [ + Symbol, SymboledPortList, InstanceGraphModuleOpInterface]> { + let cppNamespace = "circt::hw"; + let description = "Provide common module information."; + + let methods = [ + InterfaceMethod<"Get the module type", + "::circt::hw::ModuleType", "getHWModuleType", (ins)>, + + InterfaceMethod<"Get the port Attributes", + "SmallVector", "getAllPortAttrs", (ins)>, + + InterfaceMethod<"Set the port Attributes", + "void", "setAllPortAttrs", (ins "ArrayRef":$attrs)>, + + InterfaceMethod<"Remove the port Attributes", + "void", "removeAllPortAttrs", (ins)>, + + InterfaceMethod<"Get the port Locations", + "SmallVector", "getAllPortLocs", (ins)>, + + InterfaceMethod<"Set the port Locations", + "void", "setAllPortLocs", (ins "ArrayRef":$locs)>, + + InterfaceMethod<"Set the module type (and port names)", + "void", "setHWModuleType", (ins "::circt::hw::ModuleType":$type)>, + + InterfaceMethod<"Set the port names", + "void", "setAllPortNames", (ins "ArrayRef":$names)>, + + ]; + + let extraSharedClassDeclaration = [{ + /// Attribute name for port symbols. + static StringRef getPortSymbolAttrName() { + return "hw.exportPort"; + } + + /// Return the region containing the body of this function. + Region &getModuleBody() { return $_op->getRegion(0); } + Block *getBodyBlock() { + Region &body = getModuleBody(); + if (body.empty()) + return nullptr; + return &body.front(); + } + + /// Returns the entry block argument at the given index. + BlockArgument getArgumentForInput(unsigned idx) { + return $_op.getModuleBody().getArgument(idx); + } + + /// Returns the entry block argument for the given port. May be null. + BlockArgument getArgumentForPort(unsigned idx) { + return $_op.getModuleBody().getArgument($_op.getHWModuleType().getInputIdForPortId(idx)); + } + + SmallVector getPortTypes() { + return $_op.getHWModuleType().getPortTypes(); + } + + SmallVector getInputTypes() { + return $_op.getHWModuleType().getInputTypes(); + } + + SmallVector getOutputTypes() { + return $_op.getHWModuleType().getOutputTypes(); + } + + /// Return the set of names on input and inout ports + SmallVector getInputNamesStr() { + return $_op.getHWModuleType().getInputNamesStr(); + } + + /// Return the set of names on output ports + SmallVector getOutputNamesStr() { + return $_op.getHWModuleType().getOutputNamesStr(); + } + + /// Return the set of names on input and inout ports + SmallVector getInputNames() { + return $_op.getHWModuleType().getInputNames(); + } + + /// Return the set of names on output ports + SmallVector getOutputNames() { + return $_op.getHWModuleType().getOutputNames(); + } + + void setInputNames(ArrayRef names) { + SmallVector newNames(names.begin(), names.end()); + auto resNames = $_op.getOutputNames(); + newNames.append(resNames.begin(), resNames.end()); + $_op.setAllPortNames(newNames); + } + + void setOutputNames(ArrayRef names) { + SmallVector newNames = $_op.getInputNames(); + newNames.append(names.begin(), names.end()); + $_op.setAllPortNames(newNames); + } + + // Get the name for the specified input or inout port + StringRef getPortName(size_t idx) { + return $_op.getHWModuleType().getPortName(idx); + } + + // Get the name for the specified input or inout port + StringRef getInputName(size_t idx) { + return $_op.getHWModuleType().getInputName(idx); + } + + // Get the name for the specified output port + StringRef getOutputName(size_t idx) { + return $_op.getHWModuleType().getOutputName(idx); + } + + StringAttr getInputNameAttr(size_t idx) { + return $_op.getHWModuleType().getInputNameAttr(idx); + } + + StringAttr getOutputNameAttr(size_t idx) { + return $_op.getHWModuleType().getOutputNameAttr(idx); + } + + Attribute getPortAttrs(size_t idx) { + return $_op.getAllPortAttrs()[idx]; + } + + SmallVector getAllInputAttrs() { + auto attrs = $_op.getAllPortAttrs(); + SmallVector retval; + auto modType = $_op.getHWModuleType(); + for (unsigned x = 0, e = $_op.getNumInputPorts(); x < e; ++x) + retval.push_back(attrs[modType.getPortIdForInputId(x)]); + return retval; + } + + SmallVector getAllOutputAttrs() { + auto attrs = $_op.getAllPortAttrs(); + SmallVector retval; + auto modType = $_op.getHWModuleType(); + for (unsigned x = 0, e = $_op.getNumOutputPorts(); x < e; ++x) + retval.push_back(attrs[modType.getPortIdForOutputId(x)]); + return retval; + } + + void setAllInputAttrs(ArrayRef attrs) { + SmallVector retval(attrs.begin(), attrs.end()); + auto resAttrs = $_op.getAllOutputAttrs(); + retval.append(resAttrs.begin(), resAttrs.end()); + $_op.setAllPortAttrs(retval); + } + + void setAllOutputAttrs(ArrayRef attrs) { + SmallVector retval = $_op.getAllInputAttrs(); + retval.append(attrs.begin(), attrs.end()); + $_op.setAllPortAttrs(retval); + } + + Attribute getInputAttrs(size_t idx) { + return $_op.getAllPortAttrs()[$_op.getPortIdForInputId(idx)]; + } + + Attribute getOutputAttrs(size_t idx) { + return $_op.getAllPortAttrs()[$_op.getPortIdForOutputId(idx)]; + } + + void setPortAttrs(size_t idx, DictionaryAttr attr) { + auto attrs = $_op.getAllPortAttrs(); + attrs[idx] = attr; + $_op.setAllPortAttrs(attrs); + } + + void setPortAttr(size_t idx, StringAttr name, Attribute value) { + auto attrs = $_op.getAllPortAttrs(); + NamedAttrList pattr(cast(attrs[idx])); + Attribute oldValue; + if (!value) + oldValue = pattr.erase(name); + else + oldValue = pattr.set(name, value); + if (oldValue != value) { + attrs[idx] = pattr.getDictionary($_op.getContext()); + $_op.setAllPortAttrs(attrs); + } + } + + void setPortAttrs(StringAttr attrName, ArrayRef newAttrs) { + auto attrs = $_op.getAllPortAttrs(); + auto ctxt = $_op.getContext(); + assert(newAttrs.size() == attrs.size()); + for (size_t idx = 0, e = attrs.size(); idx != e; ++idx) { + NamedAttrList pattr(cast(attrs[idx])); + auto newAttr = newAttrs[idx]; + if (newAttr) + pattr.set(attrName, newAttr); + else + pattr.erase(attrName); + attrs[idx] = pattr.getDictionary(ctxt); + } + $_op.setAllPortAttrs(attrs); + } + + Location getPortLoc(size_t idx) { + return $_op.getAllPortLocs()[idx]; + } + + void setPortLoc(size_t idx, Location loc) { + auto locs = $_op.getAllPortLocs(); + locs[idx] = loc; + return $_op.setAllPortLocs(locs); + } + + Location getInputLoc(size_t idx) { + return $_op.getAllPortLocs()[$_op.getPortIdForInputId(idx)]; + } + + Location getOutputLoc(size_t idx) { + return $_op.getAllPortLocs()[$_op.getPortIdForOutputId(idx)]; + } + + SmallVector getInputLocs() { + auto locs = $_op.getAllPortLocs(); + SmallVector retval; + for (unsigned x = 0, e = $_op.getNumInputPorts(); x < e; ++x) + retval.push_back(locs[$_op.getPortIdForInputId(x)]); + return retval; + } + + ArrayAttr getInputLocsAttr() { + auto locs = $_op.getAllPortLocs(); + SmallVector retval; + for (unsigned x = 0, e = $_op.getNumInputPorts(); x < e; ++x) + retval.push_back(locs[$_op.getPortIdForInputId(x)]); + return ArrayAttr::get($_op->getContext(), retval); + } + + void setInputLocs(ArrayRef inAttrs) { + assert(inAttrs.size() == $_op.getNumInputPorts()); + auto outAttrs = getOutputLocs(); + SmallVector attrs; + attrs.append(inAttrs.begin(), inAttrs.end()); + attrs.append(outAttrs.begin(), outAttrs.end()); + $_op.setAllPortLocs(attrs); + } + + SmallVector getOutputLocs() { + auto locs = $_op.getAllPortLocs(); + SmallVector retval; + for (unsigned x = 0, e = $_op.getNumOutputPorts(); x < e; ++x) + retval.push_back(locs[$_op.getPortIdForOutputId(x)]); + return retval; + } + + ArrayAttr getOutputLocsAttr() { + auto locs = $_op.getAllPortLocs(); + SmallVector retval; + for (unsigned x = 0, e = $_op.getNumOutputPorts(); x < e; ++x) + retval.push_back(locs[$_op.getPortIdForOutputId(x)]); + return ArrayAttr::get($_op->getContext(), retval); + } + + void setOutputLocs(ArrayRef outAttrs) { + assert(outAttrs.size() == $_op.getNumOutputPorts()); + auto inAttrs = getInputLocs(); + SmallVector attrs; + attrs.append(inAttrs.begin(), inAttrs.end()); + attrs.append(outAttrs.begin(), outAttrs.end()); + $_op.setAllPortLocs(attrs); + } + }]; + + let verify = [{ + static_assert( + ConcreteOp::template hasTrait<::mlir::SymbolOpInterface::Trait>(), + "expected operation to be a symbol"); + return success(); + }]; +} + +def HWMutableModuleLike : OpInterface<"HWMutableModuleLike", [HWModuleLike]> { + let cppNamespace = "circt::hw"; + let description = "Provide methods to mutate a module."; + + let methods = [ + + InterfaceMethod<"Get a handle to a utility class which provides by-name lookup of port indices. The returned object does _not_ update if the module is mutated.", + "::circt::hw::ModulePortLookupInfo", "getPortLookupInfo", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return hw::ModulePortLookupInfo( + $_op->getContext(), + $_op.getPortList()); + }]>, + + /// Insert and remove input and output ports of this module. Does not modify + /// the block arguments of the module body. The insertion and removal + /// indices must be in ascending order. The indices refer to the port + /// positions before any insertion or removal occurs. Ports inserted at the + /// same index will appear in the module in the same order as they were + /// listed in the insertion arrays. + InterfaceMethod<"Insert and remove input and output ports", + "void", "modifyPorts", (ins + "ArrayRef>":$insertInputs, + "ArrayRef>":$insertOutputs, + "ArrayRef":$eraseInputs, "ArrayRef":$eraseOutputs), + /*methodBody=*/[{ + $_op.modifyPorts(insertInputs, insertOutputs, eraseInputs, eraseOutputs); + }]>, + + /// Insert ports into the module. Does not modify the block arguments of the + /// module body. + InterfaceMethod<"Insert ports into this module", + "void", "insertPorts", (ins + "ArrayRef>":$insertInputs, + "ArrayRef>":$insertOutputs), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + $_op.modifyPorts(insertInputs, insertOutputs, {}, {}); + }]>, + + /// Erase ports from the module. Does not modify the block arguments of the + /// module body. + InterfaceMethod<"Erase ports from this module", + "void", "erasePorts", (ins + "ArrayRef":$eraseInputs, + "ArrayRef":$eraseOutputs), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + $_op.modifyPorts({}, {}, eraseInputs, eraseOutputs); + }]>, + + /// Appends output ports to the module with the specified names and rewrites + /// the output op to return the associated values. + InterfaceMethod<"Append output values to this module", + "void", "appendOutputs", (ins + "ArrayRef>":$outputs)> + ]; +} + + +def HWInstanceLike : OpInterface<"HWInstanceLike", [ + PortList, InstanceGraphInstanceOpInterface]> { + let cppNamespace = "circt::hw"; + let description = "Provide common module information."; +} + +def InnerRefNamespace : NativeOpTrait<"InnerRefNamespace">; + +def InnerSymbol : OpInterface<"InnerSymbolOpInterface"> { + let description = [{ + This interface describes an operation that may define an + `inner_sym`. An `inner_sym` operation resides + in arbitrarily-nested regions of a region that defines a + `InnerSymbolTable`. + Inner Symbols are different from normal symbols due to + MLIR symbol table resolution rules. Specifically normal + symbols are resolved by first going up to the closest + parent symbol table and resolving from there (recursing + down for complex symbol paths). In HW and SV, modules + define a symbol in a circuit or std.module symbol table. + For instances to be able to resolve the modules they + instantiate, the symbol use in an instance must resolve + in the top-level symbol table. If a module were a + symbol table, instances resolving a symbol would start from + their own module, never seeing other modules (since + resolution would start in the parent module of the + instance and be unable to go to the global scope). + The second problem arises from nesting. Symbols defining + ops must be immediate children of a symbol table. HW + and SV operations which define a inner_sym are grandchildren, + at least, of a symbol table and may be much further nested. + Lastly, ports need to define inner_sym, something not allowed + by normal symbols. + + Any operation implementing an InnerSymbol may have the inner symbol be + optional and all methods should be robuse to the attribute not being + defined. + }]; + + let cppNamespace = "::circt::hw"; + let methods = [ + InterfaceMethod<"Returns the name of the top-level inner symbol defined by this operation, if present.", + "::mlir::StringAttr", "getInnerNameAttr", (ins), [{}], + /*defaultImplementation=*/[{ + if (auto attr = + this->getOperation()->template getAttrOfType( + circt::hw::InnerSymbolTable::getInnerSymbolAttrName())) + return attr.getSymName(); + return {}; + }] + >, + InterfaceMethod<"Returns the name of the top-level inner symbol defined by this operation, if present.", + "::std::optional<::mlir::StringRef>", "getInnerName", (ins), [{}], + /*defaultImplementation=*/[{ + auto attr = this->getInnerNameAttr(); + return attr ? ::std::optional(attr.getValue()) : ::std::nullopt; + }] + >, + InterfaceMethod<"Sets the name of the top-level inner symbol defined by this operation to the specified string, dropping any symbols on fields.", + "void", "setInnerSymbol", (ins "::mlir::StringAttr":$name), [{}], + /*defaultImplementation=*/[{ + this->getOperation()->setAttr( + InnerSymbolTable::getInnerSymbolAttrName(), hw::InnerSymAttr::get(name)); + }] + >, + InterfaceMethod<"Sets the inner symbols defined by this operation.", + "void", "setInnerSymbolAttr", (ins "::circt::hw::InnerSymAttr":$sym), [{}], + /*defaultImplementation=*/[{ + if (sym && !sym.empty()) + this->getOperation()->setAttr( + InnerSymbolTable::getInnerSymbolAttrName(), sym); + else + this->getOperation()->removeAttr(InnerSymbolTable::getInnerSymbolAttrName()); + }] + >, + InterfaceMethod<"Returns an InnerRef to this operation's top-level inner symbol, which must be present.", + "::circt::hw::InnerRefAttr", "getInnerRef", (ins), [{}], + /*defaultImplementation=*/[{ + auto *op = this->getOperation(); + return hw::InnerRefAttr::get( + SymbolTable::getSymbolName( + op->template getParentWithTrait()), + InnerSymbolTable::getInnerSymbol(op)); + }] + >, + InterfaceMethod<"Returns the InnerSymAttr representing all inner symbols defined by this operation.", + "::circt::hw::InnerSymAttr", "getInnerSymAttr", (ins), [{}], + /*defaultImplementation=*/[{ + return this->getOperation()->template getAttrOfType( + circt::hw::InnerSymbolTable::getInnerSymbolAttrName()); + }] + >, + // Ask an operation if per-field symbols are allowed. + // Defaults to indicating they're allowed iff there's a defined target result, + // but let operations answer this differently if for some reason that makes sense. + StaticInterfaceMethod<"Returns whether per-field symbols are supported for this operation type.", + "bool", "supportsPerFieldSymbols", (ins), [{}], /*defaultImplementation=*/[{ + return ConcreteOp::getTargetResultIndex().has_value(); + }]>, + StaticInterfaceMethod<"Returns the index of the result the innner symbol targets, if applicable. Per-field symbols are resolved into this.", + "std::optional", "getTargetResultIndex">, + InterfaceMethod<"Returns the result the innner symbol targets, if applicable. Per-field symbols are resolved into this.", + "OpResult", "getTargetResult", (ins), [{}], /*defaultImplementation=*/[{ + auto idx = ConcreteOp::getTargetResultIndex(); + if (!idx) + return {}; + return $_op->getResult(*idx); + }]>, + ]; + + let verify = [{ + return verifyInnerSymAttr(cast(op)); + }]; +} + +def InnerSymbolTable : NativeOpTrait<"InnerSymbolTable">; + +def InnerRefUserOpInterface : OpInterface<"InnerRefUserOpInterface"> { + let description = [{ + This interface describes an operation that may use a `InnerRef`. This + interface allows for users of inner symbols to hook into verification and + other inner symbol related utilities that are either costly or otherwise + disallowed within a traditional operation. + }]; + let cppNamespace = "::circt::hw"; + + let methods = [ + InterfaceMethod<"Verify the inner ref uses held by this operation.", + "::mlir::LogicalResult", "verifyInnerRefs", + (ins "::circt::hw::InnerRefNamespace&":$ns) + >, + ]; +} + +#endif diff --git a/third_party/circt/include/circt/Dialect/HW/HWOps.h b/third_party/circt/include/circt/Dialect/HW/HWOps.h new file mode 100644 index 000000000..7f52381c0 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWOps.h @@ -0,0 +1,140 @@ +//===- HWOps.h - Declare HW dialect operations ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the operation classes for the HW dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_OPS_H +#define CIRCT_DIALECT_HW_OPS_H + +#include "circt/Dialect/HW/HWDialect.h" +#include "circt/Dialect/HW/HWOpInterfaces.h" +#include "circt/Dialect/HW/HWTypes.h" +#include "circt/Support/BuilderUtils.h" +#include "llvm/ADT/StringExtras.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace circt { +namespace hw { + +class EnumFieldAttr; + +/// Flip a port direction. +ModulePort::Direction flip(ModulePort::Direction direction); + +/// TODO: Move all these functions to a hw::ModuleLike interface. + +/// Insert and remove ports of a module. The insertion and removal indices must +/// be in ascending order. The indices refer to the port positions before any +/// insertion or removal occurs. Ports inserted at the same index will appear in +/// the module in the same order as they were listed in the `insert*` array. +/// If 'body' is provided, additionally inserts/removes the corresponding +/// block arguments. +void modifyModulePorts(Operation *op, + ArrayRef> insertInputs, + ArrayRef> insertOutputs, + ArrayRef removeInputs, + ArrayRef removeOutputs, Block *body = nullptr); + +// Helpers for working with modules. + +/// Return true if isAnyModule or instance. +bool isAnyModuleOrInstance(Operation *module); + +/// Return the signature for the specified module as a function type. +FunctionType getModuleType(Operation *module); + +/// Returns the verilog module name attribute or symbol name of any module-like +/// operations. +StringAttr getVerilogModuleNameAttr(Operation *module); +inline StringRef getVerilogModuleName(Operation *module) { + return getVerilogModuleNameAttr(module).getValue(); +} + +// Index width should be exactly clog2 (size of array), or either 0 or 1 if the +// array is a singleton. +bool isValidIndexBitWidth(Value index, Value array); + +/// Return true if the specified operation is a combinational logic op. +bool isCombinational(Operation *op); + +/// Return true if the specified attribute tree is made up of nodes that are +/// valid in a parameter expression. +bool isValidParameterExpression(Attribute attr, Operation *module); + +/// Check parameter specified by `value` to see if it is valid within the scope +/// of the specified module `module`. If not, emit an error at the location of +/// `usingOp` and return failure, otherwise return success. +/// +/// If `disallowParamRefs` is true, then parameter references are not allowed. +LogicalResult checkParameterInContext(Attribute value, Operation *module, + Operation *usingOp, + bool disallowParamRefs = false); + +/// Check parameter specified by `value` to see if it is valid according to the +/// module's parameters. If not, emit an error to the diagnostic provided as an +/// argument to the lambda 'instanceError' and return failure, otherwise return +/// success. +/// +/// If `disallowParamRefs` is true, then parameter references are not allowed. +LogicalResult checkParameterInContext( + Attribute value, ArrayAttr moduleParameters, + const std::function)> + &instanceError, + bool disallowParamRefs = false); + +// Check whether an integer value is an offset from a base. +bool isOffset(Value base, Value index, uint64_t offset); + +// A class for providing access to the in- and output ports of a module through +// use of the HWModuleBuilder. +class HWModulePortAccessor { + public: + HWModulePortAccessor(Location loc, const ModulePortInfo &info, + Region &bodyRegion); + + // Returns the i'th/named input port of the module. + Value getInput(unsigned i); + Value getInput(StringRef name); + ValueRange getInputs() { return inputArgs; } + + // Assigns the i'th/named output port of the module. + void setOutput(unsigned i, Value v); + void setOutput(StringRef name, Value v); + + const ModulePortInfo &getPortList() const { return info; } + const llvm::SmallVector &getOutputOperands() const { + return outputOperands; + } + + private: + llvm::StringMap inputIdx, outputIdx; + llvm::SmallVector inputArgs; + llvm::SmallVector outputOperands; + ModulePortInfo info; +}; + +using HWModuleBuilder = + llvm::function_ref; + +} // namespace hw +} // namespace circt + +#define GET_OP_CLASSES +#include "circt/Dialect/HW/HW.h.inc" + +#endif // CIRCT_DIALECT_HW_OPS_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWPasses.h b/third_party/circt/include/circt/Dialect/HW/HWPasses.h new file mode 100644 index 000000000..acd357f66 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWPasses.h @@ -0,0 +1,38 @@ +//===- Passes.h - HW pass entry points --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWPASSES_H +#define CIRCT_DIALECT_HW_HWPASSES_H + +#include +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" + +namespace circt { +namespace hw { + +std::unique_ptr createPrintInstanceGraphPass(); +std::unique_ptr createHWSpecializePass(); +std::unique_ptr createPrintHWModuleGraphPass(); +std::unique_ptr createFlattenIOPass(); +std::unique_ptr createVerifyInnerRefNamespacePass(); + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "circt/Dialect/HW/Passes.h.inc" + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_HWPASSES_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWReductions.h b/third_party/circt/include/circt/Dialect/HW/HWReductions.h new file mode 100644 index 000000000..bd72e1508 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWReductions.h @@ -0,0 +1,29 @@ +//===- HWReductions.h - HW reduction interface declaration ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWREDUCTIONS_H +#define CIRCT_DIALECT_HW_HWREDUCTIONS_H + +#include "circt/Reduce/Reduction.h" + +namespace circt { +namespace hw { + +/// A dialect interface to provide reduction patterns to a reducer tool. +struct HWReducePatternDialectInterface : public ReducePatternDialectInterface { + using ReducePatternDialectInterface::ReducePatternDialectInterface; + void populateReducePatterns(circt::ReducePatternSet &patterns) const override; +}; + +/// Register the HW Reduction pattern dialect interface to the given registry. +void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry); + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_HWREDUCTIONS_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWStructure.td b/third_party/circt/include/circt/Dialect/HW/HWStructure.td new file mode 100644 index 000000000..c1b8ce705 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWStructure.td @@ -0,0 +1,720 @@ +//===- HWStructure.td - HW structure ops -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the MLIR ops for structure. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWSTRUCTURE_TD +#define CIRCT_DIALECT_HW_HWSTRUCTURE_TD + +include "circt/Dialect/HW/HWAttributes.td" +include "circt/Dialect/HW/HWDialect.td" +include "circt/Dialect/HW/HWOpInterfaces.td" +include "circt/Dialect/HW/HWTypes.td" +include "mlir/Interfaces/FunctionInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/RegionKindInterface.td" +include "mlir/IR/SymbolInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +/// Base class factoring out some of the additional class declarations common to +/// the module-like operations. +class HWModuleOpBase traits = []> : + HWOp, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Symbol, InnerSymbolTable, + OpAsmOpInterface, HasParent<"mlir::ModuleOp">]> { + /// Additional class declarations inside the module op. + code extraModuleClassDeclaration = ?; + + let extraClassDeclaration = extraModuleClassDeclaration # [{ + /// Insert and remove input and output ports of this module. Does not modify + /// the block arguments of the module body. The insertion and removal + /// indices must be in ascending order. The indices refer to the port + /// positions before any insertion or removal occurs. Ports inserted at the + /// same index will appear in the module in the same order as they were + /// listed in the insertion arrays. + void modifyPorts( + ArrayRef> insertInputs, + ArrayRef> insertOutputs, + ArrayRef eraseInputs, + ArrayRef eraseOutputs + ); + + void setPortSymbolAttr(size_t portIndex, ::circt::hw::InnerSymAttr sym); + + StringAttr getArgName(size_t index) { + return getHWModuleType().getInputNameAttr(index); + } + + }]; + + /// Additional class definitions inside the module op. + code extraModuleClassDefinition = [{}]; + + let extraClassDefinition = extraModuleClassDefinition # [{ + + ModuleType $cppClass::getHWModuleType() { + return getModuleType(); + } + + ::circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) { + DictionaryAttr portAttrs = cast_or_null(getPortAttrs(portIndex)); + if (portAttrs) + return portAttrs.getAs<::circt::hw::InnerSymAttr>( + getPortSymbolAttrName()); + return {}; + } + + void $cppClass::setPortSymbolAttr(size_t portIndex, ::circt::hw::InnerSymAttr sym) { + auto portSymAttr = StringAttr::get(getContext(), getPortSymbolAttrName()); + setPortAttr(portIndex, portSymAttr, sym); + } + + size_t $cppClass::getNumPorts() { + auto modty = getHWModuleType(); + return modty.getNumPorts(); + } + + size_t $cppClass::getNumInputPorts() { + auto modty = getHWModuleType(); + return modty.getNumInputs(); + } + + size_t $cppClass::getNumOutputPorts() { + auto modty = getHWModuleType(); + return modty.getNumOutputs(); + } + + size_t $cppClass::getPortIdForInputId(size_t idx) { + auto modty = getHWModuleType(); + return modty.getPortIdForInputId(idx); + } + + size_t $cppClass::getPortIdForOutputId(size_t idx) { + auto modty = getHWModuleType(); + return modty.getPortIdForOutputId(idx); + } + + ModulePortInfo $cppClass::getPortList() { + auto modTy = getHWModuleType(); + SmallVector inputs, outputs; + auto emptyDict = DictionaryAttr::get(getContext()); + auto argTypes = modTy.getInputTypes(); + auto argNames = modTy.getInputNames(); + auto argLocs = getArgLocs(); + auto argAttrs = getArgAttrs(); + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + auto type = argTypes[i]; + auto direction = ModulePort::Direction::Input; + if (auto inout = type.dyn_cast()) { + type = inout.getElementType(); + direction = ModulePort::Direction::InOut; + } + LocationAttr loc; + if (argLocs) + loc = argLocs[i].cast(); + DictionaryAttr attrs = emptyDict; + if (argAttrs && (*argAttrs)[i]) + attrs = cast((*argAttrs)[i]); + inputs.push_back({{cast(argNames[i]), type, direction}, + i, + extractSym(attrs), + attrs, + loc}); + } + + auto resultNames = modTy.getOutputNames(); + auto resultTypes = modTy.getOutputTypes(); + auto resultLocs = getResultLocs(); + auto resultAttrs = getResAttrs(); + for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) { + LocationAttr loc; + if (resultLocs) + loc = resultLocs[i].cast(); + DictionaryAttr attrs = emptyDict; + if (resultAttrs && (*resultAttrs)[i]) + attrs = cast((*resultAttrs)[i]); + outputs.push_back({{cast(resultNames[i]), resultTypes[i], + ModulePort::Direction::Output}, + i, + extractSym(attrs), + attrs, + loc}); + } + return ModulePortInfo(inputs, outputs); + } + + }]; + +} + +def HWTestModuleOp : HWOp<"testmodule", [Symbol, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + OpAsmOpInterface, + HasParent<"mlir::ModuleOp">, + IsolatedFromAbove, + SingleBlockImplicitTerminator<"OutputOp">]> { + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$module_type, + OptionalAttr:$port_attrs, + OptionalAttr:$port_locs, + ParamDeclArrayAttr:$parameters, + OptionalAttr:$comment); + let results = (outs); + let regions = (region SizedRegion<1>:$body); + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + void getAsmBlockArgumentNames(mlir::Region ®ion, + mlir::OpAsmSetValueNameFn setNameFn); + + Block *getBodyBlock() { return &getRegion().front(); } + }]; + +} + +def HWModuleOp : HWModuleOpBase<"module", + [IsolatedFromAbove, RegionKindInterface, + SingleBlockImplicitTerminator<"OutputOp">]>{ + let summary = "HW Module"; + let description = [{ + The "hw.module" operation represents a Verilog module, including a given + name, a list of ports, a list of parameters, and a body that represents the + connections within the module. + }]; + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$module_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + LocationArrayAttr:$argLocs, + LocationArrayAttr:$resultLocs, + ParamDeclArrayAttr:$parameters, + StrAttr:$comment); + let results = (outs); + let regions = (region SizedRegion<1>:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "StringAttr":$name, "ArrayRef":$ports, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"ArrayRef", "{}">:$attributes, + CArg<"StringAttr", "{}">:$comment)>, + OpBuilder<(ins "StringAttr":$name, "const ModulePortInfo &":$ports, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"ArrayRef", "{}">:$attributes, + CArg<"StringAttr", "{}">:$comment, + CArg<"bool", "true">:$shouldEnsureTerminator)>, + OpBuilder<(ins "StringAttr":$name, "const ModulePortInfo &":$ports, + "HWModuleBuilder":$modBuilder, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"ArrayRef", "{}">:$attributes, + CArg<"StringAttr", "{}">:$comment)> + ]; + + let extraModuleClassDeclaration = [{ + + // Implement RegionKindInterface. + static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph;} + + /// Append an input with a given name and type to the port list. + /// If the name is not unique, a unique name is created and returned. + std::pair + appendInput(const Twine &name, Type ty) { + return insertInput(getNumInputPorts(), name, ty); + } + + std::pair + appendInput(StringAttr name, Type ty) { + return insertInput(getNumInputPorts(), name.getValue(), ty); + } + + /// Prepend an input with a given name and type to the port list. + /// If the name is not unique, a unique name is created and returned. + std::pair + prependInput(const Twine &name, Type ty) { + return insertInput(0, name, ty); + } + + std::pair + prependInput(StringAttr name, Type ty) { + return insertInput(0, name.getValue(), ty); + } + + /// Insert an input with a given name and type into the port list. + /// The input is added at the specified index. + std::pair + insertInput(unsigned index, StringAttr name, Type ty); + + std::pair + insertInput(unsigned index, const Twine &name, Type ty) { + ::mlir::StringAttr nameAttr = ::mlir::StringAttr::get(getContext(), name); + return insertInput(index, nameAttr, ty); + } + + /// Append an output with a given name and type to the port list. + /// If the name is not unique, a unique name is created. + void appendOutput(StringAttr name, Value value) { + return insertOutputs(getNumOutputPorts(), {{name, value}}); + } + + void appendOutput(const Twine &name, Value value) { + ::mlir::StringAttr nameAttr = ::mlir::StringAttr::get(getContext(), name); + return insertOutputs(getNumOutputPorts(), {{nameAttr, value}}); + } + + /// Prepend an output with a given name and type to the port list. + /// If the name is not unique, a unique name is created. + void prependOutput(StringAttr name, Value value) { + return insertOutputs(0, {{name, value}}); + } + + void prependOutput(const Twine &name, Value value) { + ::mlir::StringAttr nameAttr = ::mlir::StringAttr::get(getContext(), name); + return insertOutputs(0, {{nameAttr, value}}); + } + + /// Inserts a list of output ports into the port list at a specific + /// location, shifting all subsequent ports. Rewrites the output op + /// to return the associated values. + void insertOutputs(unsigned index, + ArrayRef> outputs); + + // Get the module's symbolic name as StringAttr. + StringAttr getNameAttr() { + return (*this)->getAttrOfType( + ::mlir::SymbolTable::getSymbolAttrName()); + } + + // Get the module's symbolic name. + StringRef getName() { + return getNameAttr().getValue(); + } + void getAsmBlockArgumentNames(mlir::Region ®ion, + mlir::OpAsmSetValueNameFn setNameFn); + + /// Verifies the body of the function. + LogicalResult verifyBody(); + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def HWModuleExternOp : HWModuleOpBase<"module.extern"> { + let summary = "HW external Module"; + let description = [{ + The "hw.module.extern" operation represents an external reference to a + Verilog module, including a given name and a list of ports. + + The 'verilogName' attribute (when present) specifies the spelling of the + module name in Verilog we can use. TODO: This is a hack because we don't + have proper parameterization in the hw.dialect. We need a way to represent + parameterized types instead of just concrete types. + }]; + let arguments = (ins SymbolNameAttr:$sym_name, + TypeAttrOf:$module_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + LocationArrayAttr:$argLocs, + LocationArrayAttr:$resultLocs, + ParamDeclArrayAttr:$parameters, + OptionalAttr:$verilogName); + let results = (outs); + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "StringAttr":$name, "ArrayRef":$ports, + CArg<"StringRef", "StringRef()">:$verilogName, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "StringAttr":$name, "const ModulePortInfo &":$ports, + CArg<"StringRef", "StringRef()">:$verilogName, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"ArrayRef", "{}">:$attributes)> + ]; + + let extraModuleClassDeclaration = [{ + + /// Return the name to use for the Verilog module that we're referencing + /// here. This is typically the symbol, but can be overridden with the + /// verilogName attribute. + StringRef getVerilogModuleName() { + return getVerilogModuleNameAttr().getValue(); + } + + /// Return the name to use for the Verilog module that we're referencing + /// here. This is typically the symbol, but can be overridden with the + /// verilogName attribute. + StringAttr getVerilogModuleNameAttr(); + + // Get the module's symbolic name as StringAttr. + StringAttr getNameAttr() { + return (*this)->getAttrOfType( + ::mlir::SymbolTable::getSymbolAttrName()); + } + + // Get the module's symbolic name. + StringRef getName() { + return getNameAttr().getValue(); + } + + void getAsmBlockArgumentNames(mlir::Region ®ion, + mlir::OpAsmSetValueNameFn setNameFn); + + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def HWGeneratorSchemaOp : HWOp<"generator.schema", + [Symbol, HasParent<"mlir::ModuleOp">]> { + let summary = "HW Generator Schema declaration"; + let description = [{ + The "hw.generator.schema" operation declares a kind of generated module by + declaring the schema of meta-data required. + A generated module instance of a schema is independent of the external + method of producing it. It is assumed that for well known schema instances, + multiple external tools might exist which can process it. Generator nodes + list attributes required by hw.module.generated instances. + + For example: + generator.schema @MEMORY, "Simple-Memory", ["ports", "write_latency", "read_latency"] + module.generated @mymem, @MEMORY(ports) + -> (ports) {write_latency=1, read_latency=1, ports=["read","write"]} + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, StrAttr:$descriptor, + StrArrayAttr:$requiredAttrs); + let results = (outs); + let assemblyFormat = "$sym_name `,` $descriptor `,` $requiredAttrs attr-dict"; +} + +def HWModuleGeneratedOp : HWModuleOpBase<"module.generated", [ + DeclareOpInterfaceMethods, + IsolatedFromAbove]> { + let summary = "HW Generated Module"; + let description = [{ + The "hw.module.generated" operation represents a reference to an external + module that will be produced by some external process. + This represents the name and list of ports to be generated. + + The 'verilogName' attribute (when present) specifies the spelling of the + module name in Verilog we can use. See hw.module for an explanation. + }]; + let arguments = (ins SymbolNameAttr:$sym_name, + FlatSymbolRefAttr:$generatorKind, + TypeAttrOf:$module_type, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + LocationArrayAttr:$argLocs, + LocationArrayAttr:$resultLocs, + ParamDeclArrayAttr:$parameters, + OptionalAttr:$verilogName); + let results = (outs); + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "FlatSymbolRefAttr":$genKind, "StringAttr":$name, + "ArrayRef":$ports, + CArg<"StringRef", "StringRef()">:$verilogName, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "FlatSymbolRefAttr":$genKind, "StringAttr":$name, + "const ModulePortInfo &":$ports, + CArg<"StringRef", "StringRef()">:$verilogName, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"ArrayRef", "{}">:$attributes)> + ]; + + let extraModuleClassDeclaration = [{ + /// Return the name to use for the Verilog module that we're referencing + /// here. This is typically the symbol, but can be overridden with the + /// verilogName attribute. + StringRef getVerilogModuleName() { + return getVerilogModuleNameAttr().getValue(); + } + + /// Return the name to use for the Verilog module that we're referencing + /// here. This is typically the symbol, but can be overridden with the + /// verilogName attribute. + StringAttr getVerilogModuleNameAttr(); + + /// Lookup the generator kind for the symbol. This returns null on + /// invalid IR. + Operation *getGeneratorKindOp(); + + void getAsmBlockArgumentNames(mlir::Region ®ion, + mlir::OpAsmSetValueNameFn setNameFn); + + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def InstanceOp : HWOp<"instance", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Create an instance of a module"; + let description = [{ + This represents an instance of a module. The inputs and results are + the referenced module's inputs and outputs. The `argNames` and + `resultNames` attributes must match the referenced module. + + Any parameters in the "old" format (slated to be removed) are stored in the + `oldParameters` dictionary. + }]; + + let arguments = (ins StrAttr:$instanceName, + FlatSymbolRefAttr:$moduleName, + Variadic:$inputs, + StrArrayAttr:$argNames, StrArrayAttr:$resultNames, + ParamDeclArrayAttr:$parameters, + OptionalAttr:$inner_sym); + let results = (outs Variadic:$results); + + let builders = [ + /// Create a instance that refers to a known module. + OpBuilder<(ins "Operation*":$module, "StringAttr":$name, + "ArrayRef":$inputs, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"InnerSymAttr", "{}">:$innerSym)>, + /// Create a instance that refers to a known module. + OpBuilder<(ins "Operation*":$module, "StringRef":$name, + "ArrayRef":$inputs, + CArg<"ArrayAttr", "{}">:$parameters, + CArg<"InnerSymAttr", "{}">:$innerSym), [{ + build($_builder, $_state, module, $_builder.getStringAttr(name), inputs, + parameters, innerSym); + }]>, + ]; + + let extraClassDeclaration = [{ + /// Return the name of the specified input port or null if it cannot be + /// determined. + StringAttr getArgumentName(size_t i); + + /// Return the name of the specified result or null if it cannot be + /// determined. + StringAttr getResultName(size_t i); + + /// Change the name of the specified input port. + void setArgumentName(size_t i, StringAttr name); + + /// Change the name of the specified output port. + void setResultName(size_t i, StringAttr name); + + /// Change the names of all the input ports. + void setInputNames(ArrayAttr names) { + setArgNamesAttr(names); + } + + /// Change the names of all the result ports. + void setOutputNames(ArrayAttr names) { + setResultNamesAttr(names); + } + + /// Lookup the module or extmodule for the symbol. This returns null on + /// invalid IR. + Operation *getReferencedModule(const HWSymbolCache *cache); + Operation *getReferencedModule(SymbolTable& tbl); + Operation *getReferencedModuleSlow(); + + /// Return the values for the port in port order. + /// Note: The module ports may not be input, output ordered. This computes + /// the port index to instance result/input Value mapping. + void getValues(SmallVectorImpl &values, const ModulePortInfo &mpi); + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + /// An InstanceOp may optionally define a symbol. + bool isOptionalSymbol() { return true; } + + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def OutputOp : HWOp<"output", [Terminator, HasParent<"HWModuleOp, HWTestModuleOp">, + Pure, ReturnLike]> { + let summary = "HW termination operation"; + let description = [{ + "hw.output" marks the end of a region in the HW dialect and the values + to put on the output ports. + }]; + + let arguments = (ins Variadic:$outputs); + + let builders = [ + OpBuilder<(ins), "build($_builder, $_state, std::nullopt);"> + ]; + + let assemblyFormat = "attr-dict ($outputs^ `:` qualified(type($outputs)))?"; + + let hasVerifier = 1; +} + +def GlobalRefOp : HWOp<"globalRef", [ + DeclareOpInterfaceMethods, + IsolatedFromAbove, Symbol]> { + let summary = "A global reference to uniquely identify each" + "instance of an operation"; + let description = [{ + This works like a symbol reference to an operation by specifying the + instance path to uniquely identify it globally. + It can be used to attach per instance metadata (non-local attributes). + This also lets components of the path point to a common entity. + }]; + + let arguments = (ins SymbolNameAttr:$sym_name, NameRefArrayAttr:$namepath); + + let assemblyFormat = [{ $sym_name $namepath attr-dict }]; +} + +def HierPathOp : HWOp<"hierpath", + [IsolatedFromAbove, Symbol, + DeclareOpInterfaceMethods]> { + let summary = "Hierarchical path specification"; + let description = [{ + The "hw.hierpath" operation represents a path through the hierarchy. + This is used to specify namable things for use in other operations, for + example in verbatim substitution. Non-local annotations also use these. + }]; + let arguments = (ins SymbolNameAttr:$sym_name, NameRefArrayAttr:$namepath); + let results = (outs); + let hasCustomAssemblyFormat = 1; + let extraClassDeclaration = [{ + /// Drop the module from the namepath. If its a InnerNameRef, then drop + /// the Module-Instance pair, else drop the final module from the namepath. + /// Return true if any update is made. + bool dropModule(StringAttr moduleToDrop); + + /// Inline the module in the namepath. + /// Update the symbol name for the inlined module instance, by prepending + /// the symbol name of the instance at which the inling was done. + /// Return true if any update is made. + bool inlineModule(StringAttr moduleToDrop); + + /// Replace the oldMod module with newMod module in the namepath of the NLA. + /// Return true if any update is made. + bool updateModule(StringAttr oldMod, StringAttr newMod); + + /// Replace the oldMod module with newMod module in the namepath of the NLA. + /// Since the module is being updated, the symbols inside the module should + /// also be renamed. Use the rename Map to update the corresponding + /// inner_sym names in the namepath. Return true if any update is made. + bool updateModuleAndInnerRef(StringAttr oldMod, StringAttr newMod, + const llvm::DenseMap &innerSymRenameMap); + + /// Truncate the namepath for this NLA, at atMod module. + /// If includeMod is false, drop atMod and beyond, else include it and drop + /// everything after it. + /// Return true if any update is made. + bool truncateAtModule(StringAttr atMod, bool includeMod = true); + + /// Return just the module part of the namepath at a specific index. + StringAttr modPart(unsigned i); + + /// Return the root module. + StringAttr root(); + + /// Return just the reference part of the namepath at a specific index. + /// This will return an empty attribute if this is the leaf and the leaf is + /// a module. + StringAttr refPart(unsigned i); + + /// Return the leaf reference. This returns an empty attribute if the leaf + /// reference is a module. + StringAttr ref(); + + /// Return the leaf Module. + StringAttr leafMod(); + + /// Returns true, if the NLA path contains the module. + bool hasModule(StringAttr modName); + + /// Returns true, if the NLA path contains the InnerSym {modName, symName}. + bool hasInnerSym(StringAttr modName, StringAttr symName) const; + + /// Returns true if this NLA targets a module or instance of a module (as + /// opposed to an instance's port or something inside an instance). + bool isModule(); + + /// Returns true if this NLA targets something inside a module (as opposed + /// to a module or an instance of a module); + bool isComponent(); + }]; +} + +// Edge behavior for trigger blocks. Currently these map 1:1 to SV event +// control kinds. + +/// AtPosEdge triggers on a rise from 0 to 1/X/Z, or X/Z to 1. +def AtPosEdge: I32EnumAttrCase<"AtPosEdge", 0, "posedge">; +/// AtNegEdge triggers on a drop from 1 to 0/X/Z, or X/Z to 0. +def AtNegEdge: I32EnumAttrCase<"AtNegEdge", 1, "negedge">; +/// AtEdge(v) is syntactic sugar for "AtPosEdge(v) or AtNegEdge(v)". +def AtEdge : I32EnumAttrCase<"AtEdge", 2, "edge">; + +def EventControlAttr : I32EnumAttr<"EventControl", "edge control trigger", + [AtPosEdge, AtNegEdge, AtEdge]> { + let cppNamespace = "circt::hw"; +} + +def TriggeredOp : HWOp<"triggered", [ + IsolatedFromAbove, SingleBlock, NoTerminator]> { + let summary = "A procedural region with a trigger condition"; + let description = [{ + A procedural region that can be triggered by an event. The trigger + condition is a 1-bit value that is activated based on some event control + attribute. + The operation is `IsolatedFromAbove`, and thus requires values passed into + the trigger region to be explicitly passed in through the `inputs` list. + }]; + + let regions = (region SizedRegion<1>:$body); + let arguments = (ins EventControlAttr:$event, I1:$trigger, Variadic:$inputs); + let results = (outs); + + let assemblyFormat = [{ + $event $trigger `(` $inputs `)` `:` type($inputs) $body attr-dict + }]; + + let extraClassDeclaration = [{ + Block *getBodyBlock() { return &getBody().front(); } + + // Return the input arguments inside the trigger region. + ArrayRef getInnerInputs() { + return getBodyBlock()->getArguments(); + } + }]; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "EventControlAttr":$event, "Value":$trigger, "ValueRange":$inputs)> + ]; +} + +#endif // CIRCT_DIALECT_HW_HWSTRUCTURE_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWSymCache.h b/third_party/circt/include/circt/Dialect/HW/HWSymCache.h new file mode 100644 index 000000000..a8e645b55 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWSymCache.h @@ -0,0 +1,118 @@ +//===- HWSymCache.h - Declare Symbol Cache ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a Symbol Cache specialized for HW instances. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_SYMCACHE_H +#define CIRCT_DIALECT_HW_SYMCACHE_H + +#include "circt/Dialect/HW/HWAttributes.h" +#include "circt/Support/SymCache.h" + +namespace circt { +namespace hw { + +/// This stores lookup tables to make manipulating and working with the IR more +/// efficient. There are two phases to this object: the "building" phase in +/// which it is "write only" and then the "using" phase which is read-only (and +/// thus can be used by multiple threads). The "freeze" method transitions +/// between the two states. +class HWSymbolCache : public SymbolCacheBase { + public: + class Item { + public: + Item(mlir::Operation *op) : op(op), port(~0ULL) {} + Item(mlir::Operation *op, size_t port) : op(op), port(port) {} + bool hasPort() const { return port != ~0ULL; } + size_t getPort() const { return port; } + mlir::Operation *getOp() const { return op; } + + private: + mlir::Operation *op; + size_t port; + }; + + // Add inner names, which might be ports + void addDefinition(mlir::StringAttr modSymbol, mlir::StringAttr name, + mlir::Operation *op, size_t port = ~0ULL) { + auto key = InnerRefAttr::get(modSymbol, name); + symbolCache.try_emplace(key, op, port); + } + + void addDefinition(mlir::Attribute key, mlir::Operation *op) override { + assert(!isFrozen && "cannot mutate a frozen cache"); + symbolCache.try_emplace(key, op); + } + + // Pull in getDefinition(mlir::FlatSymbolRefAttr symbol) + using SymbolCacheBase::getDefinition; + mlir::Operation *getDefinition(mlir::Attribute attr) const override { + assert(isFrozen && "cannot read from this cache until it is frozen"); + auto it = symbolCache.find(attr); + if (it == symbolCache.end()) return nullptr; + assert(!it->second.hasPort() && "Module names should never be ports"); + return it->second.getOp(); + } + + HWSymbolCache::Item getInnerDefinition(mlir::StringAttr modSymbol, + mlir::StringAttr name) const { + return lookupInner(InnerRefAttr::get(modSymbol, name)); + } + + HWSymbolCache::Item getInnerDefinition(InnerRefAttr inner) const { + return lookupInner(inner); + } + + /// Mark the cache as frozen, which allows it to be shared across threads. + void freeze() { isFrozen = true; } + + private: + Item lookupInner(InnerRefAttr attr) const { + assert(isFrozen && "cannot read from this cache until it is frozen"); + auto it = symbolCache.find(attr); + return it == symbolCache.end() ? Item{nullptr, ~0ULL} : it->second; + } + + bool isFrozen = false; + + /// This stores a lookup table from symbol attribute to the item + /// that defines it. + llvm::DenseMap symbolCache; + + private: + // Iterator support. Map from Item's to their inner operations. + using Iterator = decltype(symbolCache)::iterator; + struct HwSymbolCacheIteratorImpl : public CacheIteratorImpl { + HwSymbolCacheIteratorImpl(Iterator it) : it(it) {} + CacheItem operator*() override { + return {it->getFirst(), it->getSecond().getOp()}; + } + void operator++() override { it++; } + bool operator==(CacheIteratorImpl *other) override { + return it == static_cast(other)->it; + } + Iterator it; + }; + + public: + SymbolCacheBase::Iterator begin() override { + return SymbolCacheBase::Iterator( + std::make_unique(symbolCache.begin())); + } + SymbolCacheBase::Iterator end() override { + return SymbolCacheBase::Iterator( + std::make_unique(symbolCache.end())); + } +}; + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_SYMCACHE_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWTypeDecls.td b/third_party/circt/include/circt/Dialect/HW/HWTypeDecls.td new file mode 100644 index 000000000..1c8758d29 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWTypeDecls.td @@ -0,0 +1,63 @@ +//===- HWTypeDecls.td - HW data type declaration ops and types ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Support for type declarations in the HW type system. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWTYPEDECLS_TD +#define CIRCT_DIALECT_HW_HWTYPEDECLS_TD + +include "circt/Dialect/HW/HWDialect.td" +include "mlir/IR/SymbolInterfaces.td" + +//===----------------------------------------------------------------------===// +// Declaration operations +//===----------------------------------------------------------------------===// + +def TypeScopeOp : HWOp<"type_scope", + [Symbol, SymbolTable, SingleBlock, NoTerminator, NoRegionArguments]> { + let summary = "Type declaration wrapper."; + let description = [{ + An operation whose one body block contains type declarations. This op + provides a scope for type declarations at the top level of an MLIR module. + It is a symbol that may be looked up within the module, as well as a symbol + table itself, so type declarations may be looked up. + }]; + + let regions = (region SizedRegion<1>:$body); + let arguments = (ins SymbolNameAttr:$sym_name); + + let assemblyFormat = "$sym_name $body attr-dict"; + + let extraClassDeclaration = [{ + Block *getBodyBlock() { return &getBody().front(); } + }]; +} + +def TypedeclOp : HWOp<"typedecl", [Symbol, HasParent<"TypeScopeOp">]> { + let summary = "Type declaration."; + let description = "Associate a symbolic name with a type."; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttr:$type, + OptionalAttr:$verilogName + ); + + let assemblyFormat = "$sym_name (`,` $verilogName^)? `:` $type attr-dict"; + + let extraClassDeclaration = [{ + StringRef getPreferredName(); + + // Returns the type alias type which this typedecl op defines. + Type getAliasType(); + }]; +} + +#endif // CIRCT_DIALECT_HW_HWTYPEDECLS_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.h b/third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.h new file mode 100644 index 000000000..708b43bcb --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.h @@ -0,0 +1,44 @@ +//===- HWTypeInterfaces.h - Declare HW type interfaces ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares type interfaces for the HW Dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWTYPEINTERFACES_H +#define CIRCT_DIALECT_HW_HWTYPEINTERFACES_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/Types.h" + +namespace circt { +namespace hw { +namespace FieldIdImpl { +uint64_t getMaxFieldID(Type); + +std::pair<::mlir::Type, uint64_t> getSubTypeByFieldID(Type, uint64_t fieldID); + +::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID); + +std::pair projectToChildFieldID(Type, uint64_t fieldID, + uint64_t index); + +std::pair getIndexAndSubfieldID(Type type, + uint64_t fieldID); + +uint64_t getFieldID(Type type, uint64_t index); + +uint64_t getIndexForFieldID(Type type, uint64_t fieldID); + +} // namespace FieldIdImpl +} // namespace hw +} // namespace circt + +#include "circt/Dialect/HW/HWTypeInterfaces.h.inc" + +#endif // CIRCT_DIALECT_HW_HWTYPEINTERFACES_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.td b/third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.td new file mode 100644 index 000000000..b71b6e5ed --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWTypeInterfaces.td @@ -0,0 +1,81 @@ +//===- HWTypeInterfaces.td - HW Type Interfaces ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the type interfaces of the HW dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWTYPEINTERFACES_TD +#define CIRCT_DIALECT_HW_HWTYPEINTERFACES_TD + +include "mlir/IR/OpBase.td" + +def FieldIDTypeInterface : TypeInterface<"FieldIDTypeInterface"> { + let cppNamespace = "circt::hw"; + let description = [{ + Common methods for types which can be indexed by a FieldID. + FieldID is a depth-first numbering of the elements of a type. For example: + ``` + struct a /* 0 */ { + int b; /* 1 */ + struct c /* 2 */ { + int d; /* 3 */ + } + } + + int e; /* 0 */ + ``` + }]; + + let methods = [ + InterfaceMethod<"Get the maximum field ID for this type", + "uint64_t", "getMaxFieldID">, + + InterfaceMethod<[{ + Get the sub-type of a type for a field ID, and the subfield's ID. Strip + off a single layer of this type and return the sub-type and a field ID + targeting the same field, but rebased on the sub-type. + + The resultant type *may* not be a FieldIDTypeInterface if the resulting + fieldID is zero. This means that leaf types may be ground without + implementing an interface. An empty aggregate will also appear as a + zero. + }], "std::pair<::mlir::Type, uint64_t>", + "getSubTypeByFieldID", (ins "uint64_t":$fieldID)>, + + InterfaceMethod<[{ + Returns the effective field id when treating the index field as the + root of the type. Essentially maps a fieldID to a fieldID after a + subfield op. Returns the new id and whether the id is in the given + child. + }], "std::pair", "projectToChildFieldID", + (ins "uint64_t":$fieldID, "uint64_t":$index)>, + + InterfaceMethod<[{ + Returns the index (e.g. struct or vector element) for a given FieldID. + This returns the containing index in the case that the fieldID points to a + child field of a field. + }], "uint64_t", "getIndexForFieldID", (ins "uint64_t":$fieldID)>, + + InterfaceMethod<[{ + Return the fieldID of a given index (e.g. struct or vector element). + Field IDs start at 1, and are assigned + to each field in a recursive depth-first walk of all + elements. A field ID of 0 is used to reference the type itself. + }], "uint64_t", "getFieldID", (ins "uint64_t":$fieldID)>, + + InterfaceMethod<[{ + Find the index of the element that contains the given fieldID. + As well, rebase the fieldID to the element. + }], "std::pair", "getIndexAndSubfieldID", + (ins "uint64_t":$fieldID)>, + + ]; +} + +#endif // CIRCT_DIALECT_HW_HWTYPEINTERFACES_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWTypes.h b/third_party/circt/include/circt/Dialect/HW/HWTypes.h new file mode 100644 index 000000000..6b2bd618c --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWTypes.h @@ -0,0 +1,165 @@ +//===- HWTypes.h - Types for the HW dialect ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Types for the HW dialect are mostly in tablegen. This file should contain +// C++ types used in MLIR type parameters. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_TYPES_H +#define CIRCT_DIALECT_HW_TYPES_H + +#include "circt/Dialect/HW/HWDialect.h" +#include "circt/Dialect/HW/HWTypeInterfaces.h" +#include "circt/Support/LLVM.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +namespace circt { +namespace hw { + +struct ModulePort { + enum Direction { Input, Output, InOut }; + mlir::StringAttr name; + mlir::Type type; + Direction dir; +}; + +class HWSymbolCache; +class ParamDeclAttr; +class TypedeclOp; +class ModuleType; + +namespace detail { + +ModuleType fnToMod(Operation *op, ArrayRef inputNames, + ArrayRef outputNames); +ModuleType fnToMod(FunctionType fn, ArrayRef inputNames, + ArrayRef outputNames); + +/// Struct defining a field. Used in structs. +struct FieldInfo { + mlir::StringAttr name; + mlir::Type type; +}; + +/// Struct defining a field with an offset. Used in unions. +struct OffsetFieldInfo { + StringAttr name; + Type type; + size_t offset; +}; +} // namespace detail +} // namespace hw +} // namespace circt + +#define GET_TYPEDEF_CLASSES +#include "circt/Dialect/HW/HWTypes.h.inc" + +namespace circt { +namespace hw { + +// Returns the canonical type of a HW type (inner type of a type alias). +mlir::Type getCanonicalType(mlir::Type type); + +/// Return true if the specified type is a value HW Integer type. This checks +/// that it is a signless standard dialect type. +bool isHWIntegerType(mlir::Type type); + +/// Return true if the specified type is a HW Enum type. +bool isHWEnumType(mlir::Type type); + +/// Return true if the specified type can be used as an HW value type, that is +/// the set of types that can be composed together to represent synthesized, +/// hardware but not marker types like InOutType or unknown types from other +/// dialects. +bool isHWValueType(mlir::Type type); + +/// Return the hardware bit width of a type. Does not reflect any encoding, +/// padding, or storage scheme, just the bit (and wire width) of a +/// statically-size type. Reflects the number of wires needed to transmit a +/// value of this type. Returns -1 if the type is not known or cannot be +/// statically computed. +int64_t getBitWidth(mlir::Type type); + +/// Return true if the specified type contains known marker types like +/// InOutType. Unlike isHWValueType, this is not conservative, it only returns +/// false on known InOut types, rather than any unknown types. +bool hasHWInOutType(mlir::Type type); + +template +bool type_isa(Type type) { + // First check if the type is the requested type. + if (type.isa()) return true; + + // Then check if it is a type alias wrapping the requested type. + if (auto alias = type.dyn_cast()) + return type_isa(alias.getInnerType()); + + return false; +} + +// type_isa for a nullable argument. +template +bool type_isa_and_nonnull(Type type) { // NOLINT(readability-identifier-naming) + if (!type) return false; + return type_isa(type); +} + +template +BaseTy type_cast(Type type) { + assert(type_isa(type) && "type must convert to requested type"); + + // If the type is the requested type, return it. + if (type.isa()) return type.cast(); + + // Otherwise, it must be a type alias wrapping the requested type. + return type_cast(type.cast().getInnerType()); +} + +template +BaseTy type_dyn_cast(Type type) { + if (!type_isa(type)) return BaseTy(); + + return type_cast(type); +} + +/// Utility type that wraps a type that may be one of several possible Types. +/// This is similar to std::variant but is implemented for mlir::Type, and it +/// understands how to handle type aliases. +template +class TypeVariant + : public ::mlir::Type::TypeBase, mlir::Type, + mlir::TypeStorage> { + using mlir::Type::TypeBase, mlir::Type, + mlir::TypeStorage>::Base::Base; + + public: + // Support LLVM isa/cast/dyn_cast to one of the possible types. + static bool classof(Type other) { return type_isa(other); } +}; + +template +class TypeAliasOr + : public ::mlir::Type::TypeBase, mlir::Type, + mlir::TypeStorage> { + using mlir::Type::TypeBase, mlir::Type, + mlir::TypeStorage>::Base::Base; + + public: + // Support LLVM isa/cast/dyn_cast to BaseTy. + static bool classof(Type other) { return type_isa(other); } + + // Support C++ implicit conversions to BaseTy. + operator BaseTy() const { return type_cast(*this); } +}; + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_TYPES_H diff --git a/third_party/circt/include/circt/Dialect/HW/HWTypes.td b/third_party/circt/include/circt/Dialect/HW/HWTypes.td new file mode 100644 index 000000000..3d13624f4 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWTypes.td @@ -0,0 +1,179 @@ +//===- HWTypes.td - HW data type definitions ---------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Basic data types for the HW dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWTYPES +#define CIRCT_DIALECT_HW_HWTYPES + +include "circt/Dialect/HW/HWDialect.td" +include "mlir/IR/AttrTypeBase.td" + +//===----------------------------------------------------------------------===// +// Type predicates +//===----------------------------------------------------------------------===// + +// Type constraint that indicates that an operand/result may only be a valid, +// known, non-directional type. +def HWIntegerType : DialectType, + "a signless integer bitvector", + "::circt::hw::TypeVariant<::mlir::IntegerType, ::circt::hw::IntType>">; + +// Type constraint that indicates that an operand/result may only be a valid, +// known, non-directional type. +def HWValueType : DialectType, "a known primitive element">; + +// Type constraint that indicates that an operand/result may only be a valid +// non-directional type. +def HWNonInOutType : DialectType, "a type without inout">; + +def InOutType : DialectType($_self)">, + "InOutType", "InOutType">; + +// A handle to refer to circt::hw::ArrayType in ODS. +def ArrayType : DialectType($_self)">, + "an ArrayType", "::circt::hw::TypeAliasOr">; + +// A handle to refer to circt::hw::StructType in ODS. +def StructType : DialectType($_self)">, + "a StructType", "::circt::hw::TypeAliasOr">; + +// A handle to refer to circt::hw::UnionType in ODS. +def UnionType : DialectType($_self)">, + "a UnionType", "::circt::hw::TypeAliasOr">; + +// A handle to refer to circt::hw::EnumType in ODS. +def EnumType : DialectType($_self)">, + "a EnumType", "::circt::hw::TypeAliasOr">; + +def HWAggregateType : DialectType($_self)}]>, + "an ArrayType or StructType", + [{::circt::hw::TypeVariant< + ::circt::hw::ArrayType, + ::circt::hw::UnpackedArrayType, + ::circt::hw::StructType>}]>; + +// A handle to refer to circt::hw::ModuleType in ODS. +def ModuleType : DialectType($_self)">, + "a module", "::circt::hw::ModuleType">; + +// A handle to refer to circt::hw::StringType in ODS. +def HWStringType : + DialectType($_self)">, + "a HW string", "::circt::hw::StringType">, + BuildableType<"::circt::hw::StringType::get($_builder.getContext())">; + +/// A flat symbol reference or a reference to a name within a module. +def NameRefAttr : Attr< + CPred<"$_self.isa<::mlir::FlatSymbolRefAttr, ::circt::hw::InnerRefAttr>()">, + "name reference attribute">{ + let returnType = "::mlir::Attribute"; + let convertFromStorage = "$_self"; + let valueType = NoneType; +} + +// Like a FlatSymbolRefArrayAttr, but can also refer to names inside modules. +def NameRefArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getArrayAttr($0)"; +} + +def InnerSymProperties : AttrDef { + let mnemonic = "innerSymProps"; + let parameters = (ins + "::mlir::StringAttr":$name, + DefaultValuedParameter<"uint64_t", "0">:$fieldID, + DefaultValuedParameter<"::mlir::StringAttr", "public">:$sym_visibility + ); + let builders = [ + AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$sym),[{ + return get(sym.getContext(), sym, 0, + mlir::StringAttr::get(sym.getContext(), "public") ); + }]> + ]; + let hasCustomAssemblyFormat = 1; + // The assembly format is as follows: + // "`<` `@` $name `,` $fieldID `,` $sym_visibility `>`"; + let genVerifyDecl = 1; +} + + +def InnerSymAttr : AttrDef { + let summary = "Inner symbol definition"; + let description = [{ + Defines the properties of an inner_sym attribute. It specifies the symbol + name and symbol visibility for each field ID. For any ground types, + there are no subfields and the field ID is 0. For aggregate types, a + unique field ID is assigned to each field by visiting them in a + depth-first pre-order. The custom assembly format ensures that for ground + types, only `@` is printed. + }]; + let mnemonic = "innerSym"; + let parameters = (ins ArrayRefParameter<"InnerSymPropertiesAttr">:$props); + let builders = [ + AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$sym),[{ + return get(sym.getContext(), + {InnerSymPropertiesAttr::get(sym.getContext(), sym, 0, + mlir::StringAttr::get(sym.getContext(), "public"))}); + }]>, + // Create an empty array, represents an invalid InnerSym. + AttrBuilder<(ins),[{ + return get($_ctxt, {}); + }]> + ]; + let extraClassDeclaration = [{ + /// Get the inner sym name for fieldID, if it exists. + mlir::StringAttr getSymIfExists(uint64_t fieldID) const; + + /// Get the inner sym name for fieldID=0, if it exists. + mlir::StringAttr getSymName() const { return getSymIfExists(0); } + + /// Get the number of inner symbols defined. + size_t size() const { return getProps().size(); } + + /// Check if this is an empty array, no sym names stored. + bool empty() const { return getProps().empty(); } + + /// Return an InnerSymAttr with the inner symbol for the specified fieldID removed. + InnerSymAttr erase(uint64_t fieldID) const; + + using iterator = mlir::ArrayRef::iterator; + /// Iterator begin for all the InnerSymProperties. + iterator begin() const { return getProps().begin(); } + + /// Iterator end for all the InnerSymProperties. + iterator end() const { return getProps().end(); } + + /// Invoke the func, for all sym names. Return success(), + /// if the callback function never returns failure(). + mlir::LogicalResult walkSymbols(llvm::function_ref< + mlir::LogicalResult (::mlir::StringAttr)>) const; + }]; + + let hasCustomAssemblyFormat = 1; + // Example format: + // firrtl.wire sym [<@x,1,private>, <@w,2,public>, <@syh,4,public>] +} + +#endif // CIRCT_DIALECT_HW_HWTYPES diff --git a/third_party/circt/include/circt/Dialect/HW/HWTypesImpl.td b/third_party/circt/include/circt/Dialect/HW/HWTypesImpl.td new file mode 100644 index 000000000..ae0206354 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWTypesImpl.td @@ -0,0 +1,280 @@ +//===- HWTypesImpl.td - HW data type definitions -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Basic data type implementations for the HW dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWTYPESIMPL_TD +#define CIRCT_DIALECT_HW_HWTYPESIMPL_TD + +include "circt/Dialect/HW/HWDialect.td" +include "circt/Dialect/HW/HWTypeInterfaces.td" +include "mlir/IR/AttrTypeBase.td" + +// Base class for other typedefs. Provides dialact-specific defaults. +class HWType traits = []> + : TypeDef { } + +//===----------------------------------------------------------------------===// +// Type declarations +//===----------------------------------------------------------------------===// + +// A parameterized integer type. Declares the hw::IntType in C++. +def IntTypeImpl : HWType<"Int"> { + let summary = "parameterized-width integer"; + let description = [{ + Parameterized integer types are equivalent to the MLIR standard integer + type: it is signless, and may be any width integer. This type represents + the case when the width is a parameter in the HW dialect sense. + }]; + + let mnemonic = "int"; + let parameters = (ins "::mlir::TypedAttr":$width); + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; + + let extraClassDeclaration = [{ + /// Get an integer type for the specified width. Note that this may return + /// a builtin integer type if the width is a known-constant value. + static Type get(::mlir::TypedAttr width); + }]; +} + +// A simple fixed size array. Declares the hw::ArrayType in C++. +def ArrayTypeImpl : HWType<"Array", [DeclareTypeInterfaceMethods]> { + let summary = "fixed-sized array"; + let description = [{ + Fixed sized HW arrays are roughly similar to C arrays. On the wire (vs. + in a memory), arrays are always packed. Memory layout is not defined as + it does not need to be since in silicon there is not implicit memory + sharing. + }]; + + let mnemonic = "array"; + let parameters = (ins "::mlir::Type":$elementType, "::mlir::Attribute":$sizeAttr); + let genVerifyDecl = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + static ArrayType get(Type elementType, size_t size) { + auto *ctx = elementType.getContext(); + auto intType = ::mlir::IntegerType::get(ctx, 64); + auto sizeAttr = ::mlir::IntegerAttr::get(intType, size); + return ArrayType::get(ctx, elementType, sizeAttr); + } + size_t getNumElements() const; + }]; +} + +// An 'unpacked' array of fixed size. +def UnpackedArrayType : HWType<"UnpackedArray", [DeclareTypeInterfaceMethods]> { + let summary = "SystemVerilog 'unpacked' fixed-sized array"; + let description = [{ + Unpacked arrays are a more flexible array representation than packed arrays, + and are typically used to model memories. See SystemVerilog Spec 7.4.2. + }]; + + let mnemonic = "uarray"; + let parameters = (ins "::mlir::Type":$elementType, "::mlir::Attribute":$sizeAttr); + let genVerifyDecl = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + static UnpackedArrayType get(Type elementType, size_t size) { + auto *ctx = elementType.getContext(); + auto intType = ::mlir::IntegerType::get(ctx, 64); + auto sizeAttr = ::mlir::IntegerAttr::get(intType, size); + return UnpackedArrayType::get(ctx, elementType, sizeAttr); + } + size_t getNumElements() const; + }]; +} + +def InOutTypeImpl : HWType<"InOut"> { + let summary = "inout type"; + let description = [{ + InOut type is used for model operations and values that have "connection" + semantics, instead of typical dataflow behavior. This is used for wires + and inout ports in Verilog. + }]; + + let mnemonic = "inout"; + let parameters = (ins "::mlir::Type":$elementType); + let genVerifyDecl = 1; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + static InOutType get(Type elementType) { + return get(elementType.getContext(), elementType); + } + }]; +} + +// A packed struct. Declares the hw::StructType in C++. +def StructTypeImpl : HWType<"Struct", [DeclareTypeInterfaceMethods]> { + let summary = "HW struct type"; + let description = [{ + Represents a structure of name, value pairs. + !hw.struct + }]; + let mnemonic = "struct"; + + let hasCustomAssemblyFormat = 1; + + let parameters = ( + ins ArrayRefParameter< + "::circt::hw::StructType::FieldInfo", "struct fields">: $elements + ); + + let extraClassDeclaration = [{ + using FieldInfo = ::circt::hw::detail::FieldInfo; + mlir::Type getFieldType(mlir::StringRef fieldName); + void getInnerTypes(mlir::SmallVectorImpl&); + std::optional getFieldIndex(mlir::StringRef fieldName); + std::optional getFieldIndex(mlir::StringAttr fieldName); + }]; +} + +// An enum type. Declares the hw::EnumType in C++. +def EnumTypeImpl : HWType<"Enum"> { + let summary = "HW Enum type"; + let description = [{ + Represents an enumeration of values. Enums are interpreted as integers with + a synthesis-defined encoding. + !hw.enum + }]; + let mnemonic = "enum"; + let parameters = ( + ins "mlir::ArrayAttr":$fields + ); + + let extraClassDeclaration = [{ + /// Returns true if the requested field is part of this enum + bool contains(mlir::StringRef field); + + /// Returns the number of bits used by the enum + size_t getBitWidth(); + + /// Returns the index of the requested field, or a nullopt if the field is + // not part of this enum. + std::optional indexOf(mlir::StringRef field); + }]; + + let hasCustomAssemblyFormat = 1; +} + +// An untagged union. Declares the hw::UnionType in C++. +def UnionTypeImpl : HWType<"Union"> { + let summary = "An untagged union of types"; + let parameters = ( + ins ArrayRefParameter< + "::circt::hw::UnionType::FieldInfo", "union fields">: $elements + ); + let mnemonic = "union"; + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + using FieldInfo = ::circt::hw::detail::OffsetFieldInfo; + + FieldInfo getFieldInfo(::mlir::StringRef fieldName); + + ::mlir::Type getFieldType(::mlir::StringRef fieldName); + }]; +} + +def TypeAliasType : HWType<"TypeAlias"> { + let summary = "An symbolic reference to a type declaration"; + let description = [{ + A TypeAlias is parameterized by a SymbolRefAttr, which points to a + TypedeclOp. The root reference should refer to a TypeScope within the same + outer ModuleOp, and the leaf reference should refer to a type within that + TypeScope. A TypeAlias is further parameterized by the inner type, which is + needed to be known at the time the type is parsed. + + Upon construction, a TypeAlias stores the symbol reference and type, and + canonicalizes the type to resolve any nested type aliases. The canonical + type is also cached to avoid recomputing it when needed. + }]; + + let mnemonic = "typealias"; + + let parameters = (ins + "mlir::SymbolRefAttr":$ref, + "mlir::Type":$innerType, + "mlir::Type":$canonicalType + ); + + let hasCustomAssemblyFormat = 1; + + let builders = [ + TypeBuilderWithInferredContext<(ins + "mlir::SymbolRefAttr":$ref, "mlir::Type":$innerType)> + ]; + + let extraClassDeclaration = [{ + /// Return the Typedecl referenced by this TypeAlias, given the module to + /// look in. This returns null when the IR is malformed. + TypedeclOp getTypeDecl(const HWSymbolCache &cache); + }]; +} + +def ModuleTypeImpl : HWType<"Module"> { + let summary = "Module Type"; + let description = [{ + Module types have ports. + }]; + let parameters = (ins ArrayRefParameter<"::circt::hw::ModulePort", "port list">:$ports); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + let mnemonic = "modty"; + + let extraClassDeclaration = [{ + // Many of these are transitional and will be removed when modules and instances + // have moved over to this type. + size_t getNumPorts(); + size_t getNumInputs(); + size_t getNumOutputs(); + SmallVector getPortTypes(); + SmallVector getInputTypes(); + SmallVector getOutputTypes(); + Type getPortType(size_t); + Type getInputType(size_t); + Type getOutputType(size_t); + SmallVector getInputNamesStr(); + SmallVector getOutputNamesStr(); + SmallVector getInputNames(); + SmallVector getOutputNames(); + StringAttr getPortNameAttr(size_t); + StringRef getPortName(size_t); + StringAttr getInputNameAttr(size_t); + StringRef getInputName(size_t); + StringAttr getOutputNameAttr(size_t); + StringRef getOutputName(size_t); + FunctionType getFuncType(); + bool isOutput(size_t); + size_t getInputIdForPortId(size_t); + size_t getOutputIdForPortId(size_t); + size_t getPortIdForInputId(size_t); + size_t getPortIdForOutputId(size_t); + }]; +} + +def HWStringTypeImpl : HWType<"String"> { + let summary = "String type"; + let description = "Defines a string type for the hw-centric dialects"; + let mnemonic = "string"; +} + +#endif // CIRCT_DIALECT_HW_HWTYPESIMPL_TD diff --git a/third_party/circt/include/circt/Dialect/HW/HWVisitors.h b/third_party/circt/include/circt/Dialect/HW/HWVisitors.h new file mode 100644 index 000000000..ffda925cb --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/HWVisitors.h @@ -0,0 +1,140 @@ +//===- HWVisitors.h - HW Dialect Visitors ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines visitors that make it easier to work with HW IR. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_HWVISITORS_H +#define CIRCT_DIALECT_HW_HWVISITORS_H + +#include "circt/Dialect/HW/HWOps.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace circt { +namespace hw { + +/// This helps visit TypeOp nodes. +template +class TypeOpVisitor { + public: + ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args) { + auto *thisCast = static_cast(this); + return TypeSwitch(op) + .template Case([&](auto expr) -> ResultType { + return thisCast->visitTypeOp(expr, args...); + }) + .Default([&](auto expr) -> ResultType { + return thisCast->visitInvalidTypeOp(op, args...); + }); + } + + /// This callback is invoked on any non-expression operations. + ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args) { + op->emitOpError("unknown HW combinational node"); + abort(); + } + + /// This callback is invoked on any combinational operations that are not + /// handled by the concrete visitor. + ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args) { + return ResultType(); + } + +#define HANDLE(OPTYPE, OPKIND) \ + ResultType visitTypeOp(OPTYPE op, ExtraArgs... args) { \ + return static_cast(this)->visit##OPKIND##TypeOp(op, \ + args...); \ + } + + HANDLE(ConstantOp, Unhandled); + HANDLE(AggregateConstantOp, Unhandled); + HANDLE(BitcastOp, Unhandled); + HANDLE(ParamValueOp, Unhandled); + HANDLE(StructCreateOp, Unhandled); + HANDLE(StructExtractOp, Unhandled); + HANDLE(StructInjectOp, Unhandled); + HANDLE(UnionCreateOp, Unhandled); + HANDLE(UnionExtractOp, Unhandled); + HANDLE(ArraySliceOp, Unhandled); + HANDLE(ArrayGetOp, Unhandled); + HANDLE(ArrayCreateOp, Unhandled); + HANDLE(ArrayConcatOp, Unhandled); + HANDLE(EnumCmpOp, Unhandled); + HANDLE(EnumConstantOp, Unhandled); +#undef HANDLE +}; + +/// This helps visit TypeOp nodes. +template +class StmtVisitor { + public: + ResultType dispatchStmtVisitor(Operation *op, ExtraArgs... args) { + auto *thisCast = static_cast(this); + return TypeSwitch(op) + .template Case( + [&](auto expr) -> ResultType { + return thisCast->visitStmt(expr, args...); + }) + .Default([&](auto expr) -> ResultType { + return thisCast->visitInvalidStmt(op, args...); + }); + } + + /// This callback is invoked on any non-expression operations. + ResultType visitInvalidStmt(Operation *op, ExtraArgs... args) { + op->emitOpError("unknown hw statement"); + abort(); + } + + /// This callback is invoked on any combinational operations that are not + /// handled by the concrete visitor. + ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args) { + return ResultType(); + } + + /// This fallback is invoked on any binary node that isn't explicitly handled. + /// The default implementation delegates to the 'unhandled' fallback. + ResultType visitBinaryTypeOp(Operation *op, ExtraArgs... args) { + return static_cast(this)->visitUnhandledTypeOp(op, args...); + } + + ResultType visitUnaryTypeOp(Operation *op, ExtraArgs... args) { + return static_cast(this)->visitUnhandledTypeOp(op, args...); + } + +#define HANDLE(OPTYPE, OPKIND) \ + ResultType visitStmt(OPTYPE op, ExtraArgs... args) { \ + return static_cast(this)->visit##OPKIND##Stmt(op, \ + args...); \ + } + + // Basic nodes. + HANDLE(OutputOp, Unhandled); + HANDLE(InstanceOp, Unhandled); + HANDLE(TypeScopeOp, Unhandled); + HANDLE(TypedeclOp, Unhandled); +#undef HANDLE +}; + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_HWVISITORS_H diff --git a/third_party/circt/include/circt/Dialect/HW/InnerSymbolNamespace.h b/third_party/circt/include/circt/Dialect/HW/InnerSymbolNamespace.h new file mode 100644 index 000000000..6da97b899 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/InnerSymbolNamespace.h @@ -0,0 +1,51 @@ +//===- InnerSymbolNamespace.h - Inner Symbol Table Namespace ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the InnerSymbolNamespace, which tracks the names +// used by inner symbols within an InnerSymbolTable. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_INNERSYMBOLNAMESPACE_H +#define CIRCT_DIALECT_HW_INNERSYMBOLNAMESPACE_H + +#include "circt/Dialect/HW/InnerSymbolTable.h" +#include "circt/Support/Namespace.h" + +namespace circt { +namespace hw { + +struct InnerSymbolNamespace : Namespace { + InnerSymbolNamespace() = default; + InnerSymbolNamespace(Operation *module) { add(module); } + + /// Populate the namespace from a module-like operation. This namespace will + /// be composed of the `inner_sym`s of the module's ports and declarations. + void add(Operation *module) { + hw::InnerSymbolTable::walkSymbols( + module, [&](StringAttr name, const InnerSymTarget &target) { + nextIndex.insert({name.getValue(), 0}); + }); + } +}; + +struct InnerSymbolNamespaceCollection { + InnerSymbolNamespace &get(Operation *op) { + return collection.try_emplace(op, op).first->second; + } + + InnerSymbolNamespace &operator[](Operation *op) { return get(op); } + + private: + DenseMap collection; +}; + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_INNERSYMBOLNAMESPACE_H diff --git a/third_party/circt/include/circt/Dialect/HW/InnerSymbolTable.h b/third_party/circt/include/circt/Dialect/HW/InnerSymbolTable.h new file mode 100644 index 000000000..fd48930ca --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/InnerSymbolTable.h @@ -0,0 +1,264 @@ +//===- InnerSymbolTable.h - Inner Symbol Table -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the InnerSymbolTable and related classes, used for +// managing and tracking "inner symbols". +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_INNERSYMBOLTABLE_H +#define CIRCT_DIALECT_HW_INNERSYMBOLTABLE_H + +#include "circt/Dialect/HW/HWAttributes.h" +#include "circt/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/SymbolTable.h" + +namespace circt { +namespace hw { + +/// The target of an inner symbol, the entity the symbol is a handle for. +class InnerSymTarget { + public: + /// Default constructor, invalid. + InnerSymTarget() { assert(!*this); } + + /// Target an operation. + explicit InnerSymTarget(Operation *op) : InnerSymTarget(op, 0) {} + + /// Target an operation and a field (=0 means the op itself). + InnerSymTarget(Operation *op, size_t fieldID) + : op(op), portIdx(invalidPort), fieldID(fieldID) {} + + /// Target a port, and optionally a field (=0 means the port itself). + InnerSymTarget(size_t portIdx, Operation *op, size_t fieldID = 0) + : op(op), portIdx(portIdx), fieldID(fieldID) {} + + InnerSymTarget(const InnerSymTarget &) = default; + InnerSymTarget(InnerSymTarget &&) = default; + + // Accessors: + + /// Return the target's fieldID. + auto getField() const { return fieldID; } + + /// Return the target's base operation. For ports, this is the module. + Operation *getOp() const { return op; } + + /// Return the target's port, if valid. Check "isPort()". + auto getPort() const { + assert(isPort()); + return portIdx; + } + + // Classification: + + /// Return if this targets a field (nonzero fieldID). + bool isField() const { return fieldID != 0; } + + /// Return if this targets a port. + bool isPort() const { return portIdx != invalidPort; } + + /// Returns if this targets an operation only (not port or field). + bool isOpOnly() const { return !isPort() && !isField(); } + + /// Return a target to the specified field within the given base. + /// FieldID is relative to the specified base target. + static InnerSymTarget getTargetForSubfield(const InnerSymTarget &base, + size_t fieldID) { + if (base.isPort()) + return InnerSymTarget(base.portIdx, base.op, base.fieldID + fieldID); + return InnerSymTarget(base.op, base.fieldID + fieldID); + } + + private: + auto asTuple() const { return std::tie(op, portIdx, fieldID); } + Operation *op = nullptr; + size_t portIdx = 0; + size_t fieldID = 0; + static constexpr size_t invalidPort = ~size_t{0}; + + public: + // Operators are defined below. + + // Comparison operators: + bool operator==(const InnerSymTarget &rhs) const { + return asTuple() == rhs.asTuple(); + } + + // Assignment operators: + InnerSymTarget &operator=(InnerSymTarget &&) = default; + InnerSymTarget &operator=(const InnerSymTarget &) = default; + + /// Check if this target is valid. + operator bool() const { return op; } +}; + +/// A table of inner symbols and their resolutions. +class InnerSymbolTable { + public: + /// Build an inner symbol table for the given operation. The operation must + /// have the InnerSymbolTable trait. + explicit InnerSymbolTable(Operation *op); + + /// Non-copyable + InnerSymbolTable(const InnerSymbolTable &) = delete; + InnerSymbolTable &operator=(const InnerSymbolTable &) = delete; + + // Moveable + InnerSymbolTable(InnerSymbolTable &&) = default; + InnerSymbolTable &operator=(InnerSymbolTable &&) = default; + + /// Look up a symbol with the specified name, returning empty InnerSymTarget + /// if no such name exists. Names never include the @ on them. + InnerSymTarget lookup(StringRef name) const; + InnerSymTarget lookup(StringAttr name) const; + + /// Look up a symbol with the specified name, returning null if no such + /// name exists or doesn't target just an operation. + Operation *lookupOp(StringRef name) const; + template + T lookupOp(StringRef name) const { + return dyn_cast_or_null(lookupOp(name)); + } + + /// Look up a symbol with the specified name, returning null if no such + /// name exists or doesn't target just an operation. + Operation *lookupOp(StringAttr name) const; + template + T lookupOp(StringAttr name) const { + return dyn_cast_or_null(lookupOp(name)); + } + + /// Get InnerSymbol for an operation. + static StringAttr getInnerSymbol(Operation *op); + + /// Get InnerSymbol for a target. + static StringAttr getInnerSymbol(const InnerSymTarget &target); + + /// Return the name of the attribute used for inner symbol names. + static StringRef getInnerSymbolAttrName() { return "inner_sym"; } + + /// Construct an InnerSymbolTable, checking for verification failure. + /// Emits diagnostics describing encountered issues. + static FailureOr get(Operation *op); + + using InnerSymCallbackFn = + llvm::function_ref; + + /// Walk the given IST operation and invoke the callback for all encountered + /// inner symbols. + /// This variant allows callbacks that return LogicalResult OR void, + /// and wraps the underlying implementation. + template > + static RetTy walkSymbols(Operation *op, FuncTy &&callback) { + if constexpr (std::is_void_v) + return (void)walkSymbols( + op, InnerSymCallbackFn( + [&](StringAttr name, const InnerSymTarget &target) { + std::invoke(std::forward(callback), name, target); + return success(); + })); + else + return walkSymbols( + op, InnerSymCallbackFn([&](StringAttr name, + const InnerSymTarget &target) { + return std::invoke(std::forward(callback), name, target); + })); + } + + /// Walk the given IST operation and invoke the callback for all encountered + /// inner symbols. + /// This variant is the underlying implementation. + /// If callback returns failure, the walk is aborted and failure is returned. + /// A successful walk with no failures returns success. + static LogicalResult walkSymbols(Operation *op, InnerSymCallbackFn callback); + + private: + using TableTy = DenseMap; + /// Construct an inner symbol table for the given operation, + /// with pre-populated table contents. + explicit InnerSymbolTable(Operation *op, TableTy &&table) + : innerSymTblOp(op), symbolTable(table){}; + + /// This is the operation this table is constructed for, which must have the + /// InnerSymbolTable trait. + Operation *innerSymTblOp; + + /// This maps inner symbol names to their targets. + TableTy symbolTable; +}; + +/// This class represents a collection of InnerSymbolTable's. +class InnerSymbolTableCollection { + public: + /// Get or create the InnerSymbolTable for the specified operation. + InnerSymbolTable &getInnerSymbolTable(Operation *op); + + /// Populate tables in parallel for all InnerSymbolTable operations in the + /// given InnerRefNamespace operation, verifying each and returning + /// the verification result. + LogicalResult populateAndVerifyTables(Operation *innerRefNSOp); + + explicit InnerSymbolTableCollection() = default; + explicit InnerSymbolTableCollection(Operation *innerRefNSOp) { + // Caller is not interested in verification, no way to report it upwards. + auto result = populateAndVerifyTables(innerRefNSOp); + (void)result; + assert(succeeded(result)); + } + InnerSymbolTableCollection(const InnerSymbolTableCollection &) = delete; + InnerSymbolTableCollection &operator=(const InnerSymbolTableCollection &) = + delete; + + private: + /// This maps Operations to their InnnerSymbolTable's. + DenseMap> symbolTables; +}; + +/// This class represents the namespace in which InnerRef's can be resolved. +struct InnerRefNamespace { + SymbolTable &symTable; + InnerSymbolTableCollection &innerSymTables; + + /// Resolve the InnerRef to its target within this namespace, returning empty + /// target if no such name exists. + InnerSymTarget lookup(hw::InnerRefAttr inner); + + /// Resolve the InnerRef to its target within this namespace, returning + /// empty target if no such name exists or it's not an operation. + /// Template type can be used to limit results to specified op type. + Operation *lookupOp(hw::InnerRefAttr inner); + template + T lookupOp(hw::InnerRefAttr inner) { + return dyn_cast_or_null(lookupOp(inner)); + } +}; + +/// Printing InnerSymTarget's. +template +OS &operator<<(OS &os, const InnerSymTarget &target) { + if (!target) return os << ""; + + if (target.isField()) os << "field " << target.getField() << " of "; + + if (target.isPort()) + os << "port " << target.getPort() << " on @" + << SymbolTable::getSymbolName(target.getOp()).getValue() << ""; + else + os << "op " << *target.getOp() << ""; + + return os; +} + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_FIRRTL_INNERSYMBOLTABLE_H diff --git a/third_party/circt/include/circt/Dialect/HW/InstanceImplementation.h b/third_party/circt/include/circt/Dialect/HW/InstanceImplementation.h new file mode 100644 index 000000000..c262a2029 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/InstanceImplementation.h @@ -0,0 +1,104 @@ +//===- InstanceImplementation.h - Instance-like Op utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utility functions for implementing instance-like +// operations, in particular, parsing, and printing common to instance-like +// operations. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_INSTANCEIMPLEMENTATION_H +#define CIRCT_DIALECT_HW_INSTANCEIMPLEMENTATION_H + +#include + +#include "circt/Support/LLVM.h" + +namespace circt { +namespace hw { +// Forward declarations. +class HWSymbolCache; + +namespace instance_like_impl { + +/// Whenever the nested function returns true, a note referring to the +/// referenced module is attached to the error. +using EmitErrorFn = + std::function)>; + +/// Return a pointer to the referenced module operation. +Operation *getReferencedModule(const HWSymbolCache *cache, + Operation *instanceOp, + mlir::FlatSymbolRefAttr moduleName); + +/// Verify that the instance refers to a valid HW module. +LogicalResult verifyReferencedModule(Operation *instanceOp, + SymbolTableCollection &symbolTable, + mlir::FlatSymbolRefAttr moduleName, + Operation *&module); + +/// Stores a resolved version of each type in @param types wherein any parameter +/// reference has been evaluated based on the set of provided @param parameters +/// in @param resolvedTypes +LogicalResult resolveParametricTypes(Location loc, ArrayAttr parameters, + ArrayRef types, + SmallVectorImpl &resolvedTypes, + const EmitErrorFn &emitError); + +/// Verify that the list of inputs of the instance and the module match in terms +/// of length, names, and types. +LogicalResult verifyInputs(ArrayAttr argNames, ArrayAttr moduleArgNames, + TypeRange inputTypes, + ArrayRef moduleInputTypes, + const EmitErrorFn &emitError); + +/// Verify that the list of outputs of the instance and the module match in +/// terms of length, names, and types. +LogicalResult verifyOutputs(ArrayAttr resultNames, ArrayAttr moduleResultNames, + TypeRange resultTypes, + ArrayRef moduleResultTypes, + const EmitErrorFn &emitError); + +/// Verify that the parameter lists of the instance and the module match in +/// terms of length, names, and types. +LogicalResult verifyParameters(ArrayAttr parameters, ArrayAttr moduleParameters, + const EmitErrorFn &emitError); + +/// Combines verifyReferencedModule, verifyInputs, verifyOutputs, and +/// verifyParameters. It is only allowed to call this function when the instance +/// refers to a HW module. The @param parameters attribute may be null in which +/// case not parameters are verified. +LogicalResult verifyInstanceOfHWModule( + Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, + TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, + ArrayAttr parameters, SymbolTableCollection &symbolTable); + +/// Check that all the parameter values specified to the instance are +/// structurally valid. +LogicalResult verifyParameterStructure(ArrayAttr parameters, + ArrayAttr moduleParameters, + const EmitErrorFn &emitError); + +/// Return the name at the specified index of the ArrayAttr or null if it cannot +/// be determined. +StringAttr getName(ArrayAttr names, size_t idx); + +/// Change the name at the specified index of the @param oldNames ArrayAttr to +/// @param name +ArrayAttr updateName(ArrayAttr oldNames, size_t i, StringAttr name); + +/// Suggest a name for each result value based on the saved result names +/// attribute. +void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, + ArrayAttr resultNames, ValueRange results); + +} // namespace instance_like_impl +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_INSTANCEIMPLEMENTATION_H diff --git a/third_party/circt/include/circt/Dialect/HW/ModuleImplementation.h b/third_party/circt/include/circt/Dialect/HW/ModuleImplementation.h new file mode 100644 index 000000000..2d5954558 --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/ModuleImplementation.h @@ -0,0 +1,56 @@ +//===- ModuleImplementation.h - Module-like Op utilities --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utility functions for implementing module-like +// operations, in particular, parsing, and printing common to module-like +// operations. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_MODULEIMPLEMENTATION_H +#define CIRCT_DIALECT_HW_MODULEIMPLEMENTATION_H + +#include "circt/Dialect/HW/HWTypes.h" +#include "circt/Support/LLVM.h" +#include "mlir/IR/DialectImplementation.h" + +namespace circt { +namespace hw { + +namespace module_like_impl { + +struct PortParse : OpAsmParser::Argument { + ModulePort::Direction direction; +}; + +/// This is a variant of mlir::parseFunctionSignature that allows names on +/// result arguments. +ParseResult parseModuleFunctionSignature( + OpAsmParser &parser, bool &isVariadic, + SmallVectorImpl &args, + SmallVectorImpl &argNames, SmallVectorImpl &argLocs, + SmallVectorImpl &resultNames, + SmallVectorImpl &resultAttrs, + SmallVectorImpl &resultLocs, TypeAttr &type); + +/// Print a module signature with named results. +void printModuleSignature(OpAsmPrinter &p, Operation *op, + ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes, bool &needArgNamesAttr); + +/// New Style parsing +ParseResult parseModuleSignature(OpAsmParser &parser, + SmallVectorImpl &args, + TypeAttr &modType); +void printModuleSignatureNew(OpAsmPrinter &p, Operation *op); + +} // namespace module_like_impl +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_MODULEIMPLEMENTATION_H diff --git a/third_party/circt/include/circt/Dialect/HW/Passes.td b/third_party/circt/include/circt/Dialect/HW/Passes.td new file mode 100644 index 000000000..15d95000f --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/Passes.td @@ -0,0 +1,61 @@ +//===-- Passes.td - HW pass definition file ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the passes that work on the HW dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_PASSES_TD +#define CIRCT_DIALECT_HW_PASSES_TD + +include "mlir/Pass/PassBase.td" + +def PrintInstanceGraph : Pass<"hw-print-instance-graph", "mlir::ModuleOp"> { + let summary = "Print a DOT graph of the module hierarchy."; + let constructor = "circt::hw::createPrintInstanceGraphPass()"; +} + +def PrintHWModuleGraph : Pass<"hw-print-module-graph", "mlir::ModuleOp"> { + let summary = "Print a DOT graph of the HWModule's within a top-level module."; + let constructor = "circt::hw::createPrintHWModuleGraphPass()"; + + let options = [ + Option<"verboseEdges", "verbose-edges", "bool", "false", + "Print information on SSA edges (types, operand #, ...)">, + ]; +} + +def FlattenIO : Pass<"hw-flatten-io", "mlir::ModuleOp"> { + let summary = "Flattens hw::Structure typed in- and output ports."; + let constructor = "circt::hw::createFlattenIOPass()"; + + let options = [ + Option<"recursive", "recursive", "bool", "false", + "Recursively flatten nested structs.">, + ]; +} + +def HWSpecialize : Pass<"hw-specialize", "mlir::ModuleOp"> { + let summary = "Specializes instances of parametric hw.modules"; + let constructor = "circt::hw::createHWSpecializePass()"; + let description = [{ + Any `hw.instance` operation instantiating a parametric `hw.module` will + trigger a specialization procedure which resolves all parametric types and + values within the module based on the set of provided parameters to the + `hw.instance` operation. This specialized module is created as a new + `hw.module` and the referring `hw.instance` operation is rewritten to + instantiate the newly specialized module. + }]; +} + +def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> { + let summary = "Verify InnerRefNamespaceLike operations, if not self-verifying."; + let constructor = "circt::hw::createVerifyInnerRefNamespacePass()"; +} + +#endif // CIRCT_DIALECT_HW_PASSES_TD diff --git a/third_party/circt/include/circt/Dialect/HW/PortConverter.h b/third_party/circt/include/circt/Dialect/HW/PortConverter.h new file mode 100644 index 000000000..8849c8b6d --- /dev/null +++ b/third_party/circt/include/circt/Dialect/HW/PortConverter.h @@ -0,0 +1,182 @@ +//===- PortConverter.h - Module I/O rewriting utility -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The PortConverter is a utility class for rewriting arguments of a +// HWMutableModuleLike operation. +// It is intended to be a generic utility that can facilitate replacement of +// a given module in- or output to an arbitrary set of new inputs and outputs +// (i.e. 1 port -> N in, M out ports). Typical usecases is where an in (or +// output) of a module represents some higher-level abstraction that will be +// implemented by a set of lower-level in- and outputs ports + supporting +// operations within a module. It also attempts to do so in an optimal way, by +// e.g. being able to collect multiple port modifications of a module, and +// perform them all at once. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_DIALECT_HW_PORTCONVERTER_H +#define CIRCT_DIALECT_HW_PORTCONVERTER_H + +#include "circt/Dialect/HW/HWInstanceGraph.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Support/BackedgeBuilder.h" +#include "circt/Support/LLVM.h" + +namespace circt { +namespace hw { + +class PortConversionBuilder; +class PortConversion; + +class PortConverterImpl { + public: + /// Run port conversion. + LogicalResult run(); + Block *getBody() const { return body; } + hw::HWMutableModuleLike getModule() const { return mod; } + + /// These two methods take care of allocating new ports in the correct place + /// based on the position of 'origPort'. The new port is based on the original + /// name and suffix. The specification for the new port is given by `newPort` + /// and is recorded internally. Any changes to 'newPort' after calling this + /// will not be reflected in the modules new port list. Will also add the new + /// input to the block arguments of the body of the module. + Value createNewInput(hw::PortInfo origPort, const Twine &suffix, Type type, + hw::PortInfo &newPort); + /// Same as above. 'output' is the value fed into the new port and is required + /// if 'body' is non-null. Important note: cannot be a backedge which gets + /// replaced since this isn't attached to an op until later in the pass. + void createNewOutput(hw::PortInfo origPort, const Twine &suffix, Type type, + Value output, hw::PortInfo &newPort); + + protected: + PortConverterImpl(igraph::InstanceGraphNode *moduleNode); + + std::unique_ptr ssb; + + private: + /// Updates an instance of the module. This is called after the module has + /// been updated. It will update the instance to match the new port + void updateInstance(hw::InstanceOp); + + // If the module has a block and it wants to be modified, this'll be + // non-null. + Block *body = nullptr; + + igraph::InstanceGraphNode *moduleNode; + hw::HWMutableModuleLike mod; + OpBuilder b; + + // Keep around a reference to the specific port conversion classes to + // facilitate updating the instance ops. Indexed by the original port + // location. + SmallVector> loweredInputs; + SmallVector> loweredOutputs; + + // Tracking information to modify the module. Populated by the + // 'createNew(Input|Output)' methods. Will be cleared once port changes have + // materialized. Default length is 0 to save memory in case we'll be keeping + // this around for later use. + SmallVector, 0> newInputs; + SmallVector, 0> newOutputs; + + // Maintain a handle to the terminator of the body, if any. This will get + // continuously updated during port conversion whenever a new output is added + // to the module. + Operation *terminator = nullptr; +}; + +/// Base class for the port conversion of a particular port. Abstracts the +/// details of a particular port conversion from the port layout. Subclasses +/// keep around port mapping information to use when updating instances. +class PortConversion { + public: + PortConversion(PortConverterImpl &converter, hw::PortInfo origPort) + : converter(converter), body(converter.getBody()), origPort(origPort) {} + virtual ~PortConversion() = default; + + // An optional initialization step that can be overridden by subclasses. + // This allows subclasses to perform a failable post-construction + // initialization step. + virtual LogicalResult init() { return success(); } + + // Lower the specified port into a wire-level signaling protocol. The two + // virtual methods 'build*Signals' should be overridden by subclasses. They + // should use the 'create*' methods in 'PortConverter' to create the + // necessary ports. + void lowerPort() { + if (origPort.dir == hw::ModulePort::Direction::Output) + buildOutputSignals(); + else + buildInputSignals(); + } + + /// Update an instance port to the new port information. + virtual void mapInputSignals(OpBuilder &b, Operation *inst, Value instValue, + SmallVectorImpl &newOperands, + ArrayRef newResults) = 0; + virtual void mapOutputSignals(OpBuilder &b, Operation *inst, Value instValue, + SmallVectorImpl &newOperands, + ArrayRef newResults) = 0; + + MLIRContext *getContext() { return getModule()->getContext(); } + bool isUntouched() const { return isUntouchedFlag; } + + protected: + // Build the input and output signals for the port. This pertains to modifying + // the module itself. + virtual void buildInputSignals() = 0; + virtual void buildOutputSignals() = 0; + + PortConverterImpl &converter; + Block *body; + hw::PortInfo origPort; + + hw::HWMutableModuleLike getModule() { return converter.getModule(); } + + // We don't need full LLVM-style RTTI support for PortConversion (would + // require some mechanism of registering user-provided PortConversion-derived + // classes), we only need to dynamically tell whether any given PortConversion + // is the UntouchedPortConversion. + bool isUntouchedFlag = false; +}; // namespace hw + +// A PortConversionBuilder will, given an input type, build the appropriate +// port conversion for that type. +class PortConversionBuilder { + public: + PortConversionBuilder(PortConverterImpl &converter) : converter(converter) {} + virtual ~PortConversionBuilder() = default; + + // Builds the appropriate port conversion for the port. Users should + // override this method with their own llvm::TypeSwitch-based dispatch code, + // and by default call this method when no port conversion applies. + virtual FailureOr> build(hw::PortInfo port); + + PortConverterImpl &converter; +}; + +// A PortConverter wraps a single HWMutableModuleLike operation, and is +// initialized from an instance graph node. The port converter is templated +// on a PortConversionBuilder, which is used to build the appropriate +// port conversion for each port type. +template +class PortConverter : public PortConverterImpl { + public: + template + PortConverter(hw::InstanceGraph &graph, hw::HWMutableModuleLike mod, + Args &&...args) + : PortConverterImpl(graph.lookup(cast(*mod))) { + ssb = std::make_unique(*this, args...); + } +}; + +} // namespace hw +} // namespace circt + +#endif // CIRCT_DIALECT_HW_PORTCONVERTER_H diff --git a/third_party/circt/include/circt/Support/APInt.h b/third_party/circt/include/circt/Support/APInt.h new file mode 100644 index 000000000..e27f66590 --- /dev/null +++ b/third_party/circt/include/circt/Support/APInt.h @@ -0,0 +1,30 @@ +//===- APInt.h - CIRCT Lowering Options -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for working around limitations of upstream LLVM APInts. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_APINT_H +#define CIRCT_SUPPORT_APINT_H + +#include "circt/Support/LLVM.h" + +namespace circt { + +/// A safe version of APInt::sext that will NOT assert on zero-width +/// signed APSInts. Instead of asserting, this will zero extend. +APInt sextZeroWidth(APInt value, unsigned width); + +/// A safe version of APSInt::extOrTrunc that will NOT assert on zero-width +/// signed APSInts. Instead of asserting, this will zero extend. +APSInt extOrTruncZeroWidth(APSInt value, unsigned width); + +} // namespace circt + +#endif // CIRCT_SUPPORT_APINT_H diff --git a/third_party/circt/include/circt/Support/BUILD b/third_party/circt/include/circt/Support/BUILD new file mode 100644 index 000000000..8a270d274 --- /dev/null +++ b/third_party/circt/include/circt/Support/BUILD @@ -0,0 +1,52 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + glob([ + "*.h", + ]), +) + +td_library( + name = "td_files", + srcs = glob([ + "*.td", + ]), + includes = ["include"], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "interfaces_inc_gen", + strip_include_prefix = "/third_party/circt/include", + tbl_outs = [ + ( + [ + "-gen-op-interface-decls", + ], + "InstanceGraphInterface.h.inc", + ), + ( + [ + "-gen-op-interface-defs", + ], + "InstanceGraphInterface.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "InstanceGraphInterface.td", + deps = [ + ":td_files", + ], +) diff --git a/third_party/circt/include/circt/Support/BackedgeBuilder.h b/third_party/circt/include/circt/Support/BackedgeBuilder.h new file mode 100644 index 000000000..5049ae315 --- /dev/null +++ b/third_party/circt/include/circt/Support/BackedgeBuilder.h @@ -0,0 +1,100 @@ +//===- BackedgeBuilder.h - Support for building backedges -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Backedges are operations/values which have to exist as operands before +// they are produced in a result. Since it isn't clear how to build backedges +// in MLIR, these helper classes set up a canonical way to do so. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_BACKEDGEBUILDER_H +#define CIRCT_SUPPORT_BACKEDGEBUILDER_H + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Value.h" + +namespace mlir { +class OpBuilder; +class PatternRewriter; +class Operation; +} // namespace mlir + +namespace circt { + +class Backedge; + +/// Instantiate one of these and use it to build typed backedges. Backedges +/// which get used as operands must be assigned to with the actual value before +/// this class is destructed, usually at the end of a scope. It will check that +/// invariant then erase all the backedge ops during destruction. +/// +/// Example use: +/// ``` +/// circt::BackedgeBuilder back(rewriter, loc); +/// circt::Backedge ready = back.get(rewriter.getI1Type()); +/// // Use `ready` as a `Value`. +/// auto addOp = rewriter.create(loc, ready); +/// // When the actual value is available, +/// ready.set(anotherOp.getResult(0)); +/// ``` +class BackedgeBuilder { + friend class Backedge; + + public: + /// To build a backedge op and manipulate it, we need a `PatternRewriter` and + /// a `Location`. Store them during construct of this instance and use them + /// when building. + BackedgeBuilder(mlir::OpBuilder &builder, mlir::Location loc); + BackedgeBuilder(mlir::PatternRewriter &rewriter, mlir::Location loc); + ~BackedgeBuilder(); + + /// Create a typed backedge. If no location is provided, the one passed to the + /// constructor will be used. + Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc = {}); + + /// Clear the backedges, erasing any remaining cursor ops. Returns `failure` + /// and emits diagnostic messages if a backedge is still active. + mlir::LogicalResult clearOrEmitError(); + + /// Abandon the backedges, suppressing any diagnostics if they are still + /// active upon destruction of the backedge builder. Also, any currently + /// existing cursor ops will be abandoned. + void abandon(); + + private: + mlir::OpBuilder &builder; + mlir::PatternRewriter *rewriter; + mlir::Location loc; + llvm::SmallVector edges; +}; + +/// `Backedge` is a wrapper class around a `Value`. When assigned another +/// `Value`, it replaces all uses of itself with the new `Value` then become a +/// wrapper around the new `Value`. +class Backedge { + friend class BackedgeBuilder; + + /// `Backedge` is constructed exclusively by `BackedgeBuilder`. + Backedge(mlir::Operation *op); + + public: + Backedge() {} + + explicit operator bool() const { return !!value; } + operator mlir::Value() const { return value; } + void setValue(mlir::Value); + + private: + mlir::Value value; + bool set = false; +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_BACKEDGEBUILDER_H diff --git a/third_party/circt/include/circt/Support/BuilderUtils.h b/third_party/circt/include/circt/Support/BuilderUtils.h new file mode 100644 index 000000000..e42e61114 --- /dev/null +++ b/third_party/circt/include/circt/Support/BuilderUtils.h @@ -0,0 +1,45 @@ +//===- BuilderUtils.h - Operation builder utilities -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_BUILDERUTILS_H +#define CIRCT_SUPPORT_BUILDERUTILS_H + +#include "circt/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace circt { + +/// A helper union that can represent a `StringAttr`, `StringRef`, or `Twine`. +/// It is intended to be used as arguments to an op's `build` function. This +/// allows a single builder to accept any flavor value for a string attribute. +/// The `get` function can then be used to obtain a `StringAttr` from any of the +/// possible variants `StringAttrOrRef` can take. +class StringAttrOrRef { + using Value = llvm::PointerUnion; + Value value; + + public: + StringAttrOrRef() : value() {} + StringAttrOrRef(StringAttr attr) : value(attr) {} + StringAttrOrRef(const StringRef &str) + : value(const_cast(&str)) {} + StringAttrOrRef(const Twine &twine) : value(const_cast(&twine)) {} + + /// Return the represented string as a `StringAttr`. + StringAttr get(MLIRContext *context) const { + return TypeSwitch(value) + .Case([&](auto value) { return value; }) + .Case( + [&](auto value) { return StringAttr::get(context, *value); }) + .Default([](auto) { return StringAttr{}; }); + } +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_BUILDERUTILS_H diff --git a/third_party/circt/include/circt/Support/CMakeLists.txt b/third_party/circt/include/circt/Support/CMakeLists.txt new file mode 100644 index 000000000..e63fd5422 --- /dev/null +++ b/third_party/circt/include/circt/Support/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_interface(InstanceGraphInterface) diff --git a/third_party/circt/include/circt/Support/ConversionPatterns.h b/third_party/circt/include/circt/Support/ConversionPatterns.h new file mode 100644 index 000000000..28daf72d8 --- /dev/null +++ b/third_party/circt/include/circt/Support/ConversionPatterns.h @@ -0,0 +1,36 @@ +//===- ConversionPatterns.h - Common Conversion patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_CONVERSIONPATTERNS_H +#define CIRCT_SUPPORT_CONVERSIONPATTERNS_H + +#include "circt/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace circt { + +/// Generic pattern which replaces an operation by one of the same operation +/// name, but with converted attributes, operands, and result types to eliminate +/// illegal types. Uses generic builders based on OperationState to make sure +/// that this pattern can apply to _any_ operation. +/// +/// Useful when a conversion can be entirely defined by a TypeConverter. +struct TypeConversionPattern : public mlir::ConversionPattern { + public: + TypeConversionPattern(TypeConverter &converter, MLIRContext *context) + : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {} + using ConversionPattern::ConversionPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_CONVERSIONPATTERNS_H diff --git a/third_party/circt/include/circt/Support/CustomDirectiveImpl.h b/third_party/circt/include/circt/Support/CustomDirectiveImpl.h new file mode 100644 index 000000000..8cce9f320 --- /dev/null +++ b/third_party/circt/include/circt/Support/CustomDirectiveImpl.h @@ -0,0 +1,91 @@ +//===- CustomDirectiveImpl.h - Custom TableGen directives -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides common custom directives for table-gen assembly formats. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_CUSTOMDIRECTIVEIMPL_H +#define CIRCT_SUPPORT_CUSTOMDIRECTIVEIMPL_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" + +namespace circt { + +//===----------------------------------------------------------------------===// +// ImplicitSSAName Custom Directive +//===----------------------------------------------------------------------===// + +/// Parse an implicit SSA name string attribute. If the name is not provided in +/// the input text, its value is inferred from the SSA name of the operation's +/// first result. +/// +/// implicit-name ::= (`name` str-attr)? +ParseResult parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr); + +/// Parse an attribute dictionary and ensure that it contains a `name` field by +/// inferring its value from the SSA name of the operation's first result if +/// necessary. +ParseResult parseImplicitSSAName(OpAsmParser &parser, NamedAttrList &attrs); + +/// Ensure that `attrs` contains a `name` attribute by inferring its value from +/// the SSA name of the operation's first result if necessary. Returns true if a +/// name was inferred, false if `attrs` already contained a `name`. +bool inferImplicitSSAName(OpAsmParser &parser, NamedAttrList &attrs); + +/// Print an implicit SSA name string attribute. If the given string attribute +/// does not match the SSA name of the operation's first result, the name is +/// explicitly printed. Prints a leading space in front of `name` if any name is +/// present. +/// +/// implicit-name ::= (`name` str-attr)? +void printImplicitSSAName(OpAsmPrinter &p, Operation *op, StringAttr attr); + +/// Print an attribute dictionary and elide the `name` field if its value +/// matches the SSA name of the operation's first result. +void printImplicitSSAName(OpAsmPrinter &p, Operation *op, DictionaryAttr attrs, + ArrayRef extraElides = {}); + +/// Check if the `name` attribute in `attrs` matches the SSA name of the +/// operation's first result. If it does, add `name` to `elides`. This is +/// helpful during printing of attribute dictionaries in order to determine if +/// the inclusion of the `name` field would be redundant. +void elideImplicitSSAName(OpAsmPrinter &printer, Operation *op, + DictionaryAttr attrs, + SmallVectorImpl &elides); + +/// Print/parse binary operands type only when types are different. +/// optional-bin-op-types := type($lhs) (, type($rhs))? +void printOptionalBinaryOpTypes(OpAsmPrinter &p, Operation *op, Type lhs, + Type rhs); +ParseResult parseOptionalBinaryOpTypes(OpAsmParser &parser, Type &lhs, + Type &rhs); + +//===----------------------------------------------------------------------===// +// KeywordBool Custom Directive +//===----------------------------------------------------------------------===// + +/// Parse a boolean as one of two keywords. The `trueKeyword` will result in a +/// true boolean; the `falseKeyword` will result in a false boolean. +/// +/// labeled-bool ::= (true-label | false-label) +ParseResult parseKeywordBool(OpAsmParser &parser, BoolAttr &attr, + StringRef trueKeyword, StringRef falseKeyword); + +/// Print a boolean as one of two keywords. If the boolean is true, the +/// `trueKeyword` is used; if it is false, the `falseKeyword` is used. +/// +/// labeled-bool ::= (true-label | false-label) +void printKeywordBool(OpAsmPrinter &printer, Operation *op, BoolAttr attr, + StringRef trueKeyword, StringRef falseKeyword); + +} // namespace circt + +#endif // CIRCT_SUPPORT_CUSTOMDIRECTIVEIMPL_H diff --git a/third_party/circt/include/circt/Support/FieldRef.h b/third_party/circt/include/circt/Support/FieldRef.h new file mode 100644 index 000000000..02b5f1664 --- /dev/null +++ b/third_party/circt/include/circt/Support/FieldRef.h @@ -0,0 +1,116 @@ +//===- FieldRef.h - Field References ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines FieldRefs and helpers for them. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_FIELDREF_H +#define CIRCT_SUPPORT_FIELDREF_H + +#include "circt/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "mlir/IR/Value.h" + +namespace circt { + +/// This class represents a reference to a specific field or element of an +/// aggregate value. Typically, the user will assign a unique field ID to each +/// field in an aggregate type by visiting them in a depth-first pre-order. +/// +/// This can be used as the key in a hashtable to store field specific +/// information. +class FieldRef { + public: + /// Get a null FieldRef. + FieldRef() {} + + /// Get a FieldRef location for the specified value. + FieldRef(Value value, unsigned id) : value(value), id(id) {} + + /// Get the Value which created this location. + Value getValue() const { return value; } + + /// Get the operation which defines this field. If the field is a block + /// argument it will return the operation which owns the block. + Operation *getDefiningOp() const; + + /// Get the operation which defines this field and cast it to the OpTy. + /// Returns null if the defining operation is of a different type. + template + OpTy getDefiningOp() const { + return llvm::dyn_cast(getDefiningOp()); + } + + template + bool isa() const { + auto *op = getDefiningOp(); + assert(op && "isa<> used on a null type."); + return ::llvm::isa(op); + } + + /// Get the field ID of this FieldRef, which is a unique identifier mapped to + /// a specific field in a bundle. + unsigned getFieldID() const { return id; } + + /// Get a reference to a subfield. + FieldRef getSubField(unsigned subFieldID) const { + return FieldRef(value, id + subFieldID); + } + + /// Get the location associated with the value of this field ref. + Location getLoc() const { return getValue().getLoc(); } + + bool operator==(const FieldRef &other) const { + return value == other.value && id == other.id; + } + + bool operator<(const FieldRef &other) const { + if (value.getImpl() < other.value.getImpl()) return true; + if (value.getImpl() > other.value.getImpl()) return false; + return id < other.id; + } + + operator bool() const { return bool(value); } + + private: + /// A pointer to the value which created this. + Value value; + + /// A unique field ID. + unsigned id = 0; +}; + +/// Get a hash code for a FieldRef. +inline ::llvm::hash_code hash_value(const FieldRef &fieldRef) { + return llvm::hash_combine(fieldRef.getValue(), fieldRef.getFieldID()); +} + +} // namespace circt + +namespace llvm { +/// Allow using FieldRef with DenseMaps. This hash is based on the Value +/// identity and field ID. +template <> +struct DenseMapInfo { + static inline circt::FieldRef getEmptyKey() { + return circt::FieldRef(DenseMapInfo::getEmptyKey(), 0); + } + static inline circt::FieldRef getTombstoneKey() { + return circt::FieldRef(DenseMapInfo::getTombstoneKey(), 0); + } + static unsigned getHashValue(const circt::FieldRef &val) { + return circt::hash_value(val); + } + static bool isEqual(const circt::FieldRef &lhs, const circt::FieldRef &rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +#endif // CIRCT_SUPPORT_FIELDREF_H diff --git a/third_party/circt/include/circt/Support/FoldUtils.h b/third_party/circt/include/circt/Support/FoldUtils.h new file mode 100644 index 000000000..056593cba --- /dev/null +++ b/third_party/circt/include/circt/Support/FoldUtils.h @@ -0,0 +1,38 @@ +//===- FoldUtils.h - Common folder and canonicalizer utilities --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_FOLDUTILS_H +#define CIRCT_SUPPORT_FOLDUTILS_H + +#include "llvm/ADT/APInt.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace circt { + +/// Determine the integer value of a constant operand. +static inline std::optional getConstantInt(Attribute operand) { + if (!operand) return {}; + if (auto attr = dyn_cast(operand)) return attr.getValue(); + return {}; +} + +/// Determine whether a constant operand is a zero value. +static inline bool isConstantZero(Attribute operand) { + if (auto cst = getConstantInt(operand)) return cst->isZero(); + return false; +} + +/// Determine whether a constant operand is a one value. +static inline bool isConstantOne(Attribute operand) { + if (auto cst = getConstantInt(operand)) return cst->isOne(); + return false; +} + +} // namespace circt + +#endif // CIRCT_SUPPORT_FOLDUTILS_H diff --git a/third_party/circt/include/circt/Support/InstanceGraph.h b/third_party/circt/include/circt/Support/InstanceGraph.h new file mode 100644 index 000000000..860f92f78 --- /dev/null +++ b/third_party/circt/include/circt/Support/InstanceGraph.h @@ -0,0 +1,453 @@ +//===- InstanceGraph.h - Instance graph -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a generic instance graph for module- and instance-likes. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_INSTANCEGRAPH_H +#define CIRCT_SUPPORT_INSTANCEGRAPH_H + +#include "circt/Support/LLVM.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/iterator.h" +#include "llvm/Support/DOTGraphTraits.h" +#include "mlir/IR/OpDefinition.h" + +/// The InstanceGraph op interface, see InstanceGraphInterface.td for more +/// details. +#include "circt/Support/InstanceGraphInterface.h" + +namespace circt { +namespace igraph { + +namespace detail { +/// This just maps a iterator of references to an iterator of addresses. +template +struct AddressIterator + : public llvm::mapped_iterator { + // This using statement is to get around a bug in MSVC. Without it, it + // tries to look up "It" as a member type of the parent class. + using Iterator = It; + static typename Iterator::value_type *addrOf( + typename Iterator::value_type &v) noexcept { + return std::addressof(v); + } + /* implicit */ AddressIterator(Iterator iterator) + : llvm::mapped_iterator(iterator, + addrOf) {} +}; +} // namespace detail + +class InstanceGraphNode; + +/// This is an edge in the InstanceGraph. This tracks a specific instantiation +/// of a module. +class InstanceRecord + : public llvm::ilist_node_with_parent { + public: + /// Get the instance-like op that this is tracking. + template + auto getInstance() { + if constexpr (std::is_same::value) + return instance; + return dyn_cast_or_null(instance.getOperation()); + } + + /// Get the module where the instantiation lives. + InstanceGraphNode *getParent() const { return parent; } + + /// Get the module which the instance-like is instantiating. + InstanceGraphNode *getTarget() const { return target; } + + /// Erase this instance record, removing it from the parent module and the + /// target's use-list. + void erase(); + + private: + friend class InstanceGraph; + friend class InstanceGraphNode; + + InstanceRecord(InstanceGraphNode *parent, InstanceOpInterface instance, + InstanceGraphNode *target) + : parent(parent), instance(instance), target(target) {} + InstanceRecord(const InstanceRecord &) = delete; + + /// This is the module where the instantiation lives. + InstanceGraphNode *parent; + + /// The InstanceLike that this is tracking. + InstanceOpInterface instance; + + /// This is the module which the instance-like is instantiating. + InstanceGraphNode *target; + /// Intrusive linked list for other uses. + InstanceRecord *nextUse = nullptr; + InstanceRecord *prevUse = nullptr; +}; + +/// This is a Node in the InstanceGraph. Each node represents a Module in a +/// Circuit. Both external modules and regular modules can be represented by +/// this class. It is possible to efficiently iterate all modules instantiated +/// by this module, as well as all instantiations of this module. +class InstanceGraphNode : public llvm::ilist_node { + using InstanceList = llvm::iplist; + + public: + InstanceGraphNode() : module(nullptr) {} + + /// Get the module that this node is tracking. + template + auto getModule() { + if constexpr (std::is_same::value) + return module; + return cast(module.getOperation()); + } + + /// Iterate the instance records in this module. + using iterator = detail::AddressIterator; + iterator begin() { return instances.begin(); } + iterator end() { return instances.end(); } + + /// Return true if there are no more instances of this module. + bool noUses() { return !firstUse; } + + /// Return true if this module has exactly one use. + bool hasOneUse() { return llvm::hasSingleElement(uses()); } + + /// Get the number of direct instantiations of this module. + size_t getNumUses() { return std::distance(usesBegin(), usesEnd()); } + + /// Iterator for module uses. + struct UseIterator + : public llvm::iterator_facade_base< + UseIterator, std::forward_iterator_tag, InstanceRecord *> { + UseIterator() : current(nullptr) {} + UseIterator(InstanceGraphNode *node) : current(node->firstUse) {} + InstanceRecord *operator*() const { return current; } + using llvm::iterator_facade_base::operator++; + UseIterator &operator++() { + assert(current && "incrementing past end"); + current = current->nextUse; + return *this; + } + bool operator==(const UseIterator &other) const { + return current == other.current; + } + + private: + InstanceRecord *current; + }; + + /// Iterate the instance records which instantiate this module. + UseIterator usesBegin() { return {this}; } + UseIterator usesEnd() { return {}; } + llvm::iterator_range uses() { + return llvm::make_range(usesBegin(), usesEnd()); + } + + /// Record a new instance op in the body of this module. Returns a newly + /// allocated InstanceRecord which will be owned by this node. + InstanceRecord *addInstance(InstanceOpInterface instance, + InstanceGraphNode *target); + + private: + friend class InstanceRecord; + + InstanceGraphNode(const InstanceGraphNode &) = delete; + + /// Record that a module instantiates this module. + void recordUse(InstanceRecord *record); + + /// The module. + ModuleOpInterface module; + + /// List of instance operations in this module. This member owns the + /// InstanceRecords, which may be pointed to by other InstanceGraphNode's use + /// lists. + InstanceList instances; + + /// List of instances which instantiate this module. + InstanceRecord *firstUse = nullptr; + + // Provide access to the constructor. + friend class InstanceGraph; +}; + +/// This graph tracks modules and where they are instantiated. This is intended +/// to be used as a cached analysis on circuits. This class can be used +/// to walk the modules efficiently in a bottom-up or top-down order. +/// +/// To use this class, retrieve a cached copy from the analysis manager: +/// auto &instanceGraph = getAnalysis(getOperation()); +class InstanceGraph { + /// This is the list of InstanceGraphNodes in the graph. + using NodeList = llvm::iplist; + + public: + /// Create a new module graph of a circuit. Must be called on the parent + /// operation of ModuleOpInterface ops. + InstanceGraph(Operation *parent); + InstanceGraph(const InstanceGraph &) = delete; + virtual ~InstanceGraph() = default; + + /// Look up an InstanceGraphNode for a module. + InstanceGraphNode *lookup(ModuleOpInterface op); + + /// Lookup an module by name. + InstanceGraphNode *lookup(StringAttr name); + + /// Lookup an InstanceGraphNode for a module. + InstanceGraphNode *operator[](ModuleOpInterface op) { return lookup(op); } + + /// Look up the referenced module from an InstanceOp. This will use a + /// hashtable lookup to find the module, where + /// InstanceOp.getReferencedModule() will be a linear search through the IR. + template + auto getReferencedModule(InstanceOpInterface op) { + return cast(getReferencedModuleImpl(op).getOperation()); + } + + /// Check if child is instantiated by a parent. + bool isAncestor(ModuleOpInterface child, ModuleOpInterface parent); + + /// Get the node corresponding to the top-level module of a circuit. + virtual InstanceGraphNode *getTopLevelNode() { return nullptr; } + + /// Get the nodes corresponding to the inferred top-level modules of a + /// circuit. + FailureOr> getInferredTopLevelNodes(); + + /// Return the parent under which all nodes are nested. + Operation *getParent() { return parent; } + + /// Returns pointer to member of operation list. + static NodeList InstanceGraph::*getSublistAccess(Operation *) { + return &InstanceGraph::nodes; + } + + /// Iterate through all modules. + using iterator = detail::AddressIterator; + iterator begin() { return nodes.begin(); } + iterator end() { return nodes.end(); } + + //===------------------------------------------------------------------------- + // Methods to keep an InstanceGraph up to date. + // + // These methods are not thread safe. Make sure that modifications are + // properly synchronized or performed in a serial context. When the + // InstanceGraph is used as an analysis, this is only safe when the pass is + // on a CircuitOp or a ModuleOp. + + /// Add a newly created module to the instance graph. + virtual InstanceGraphNode *addModule(ModuleOpInterface module); + + /// Remove this module from the instance graph. This will also remove all + /// InstanceRecords in this module. All instances of this module must have + /// been removed from the graph. + virtual void erase(InstanceGraphNode *node); + + /// Replaces an instance of a module with another instance. The target module + /// of both InstanceOps must be the same. + virtual void replaceInstance(InstanceOpInterface inst, + InstanceOpInterface newInst); + + protected: + ModuleOpInterface getReferencedModuleImpl(InstanceOpInterface op); + + /// Get the node corresponding to the module. If the node has does not exist + /// yet, it will be created. + InstanceGraphNode *getOrAddNode(StringAttr name); + + /// The node under which all modules are nested. + Operation *parent; + + /// The storage for graph nodes, with deterministic iteration. + NodeList nodes; + + /// This maps each operation to its graph node. + llvm::DenseMap nodeMap; + + /// A caching of the inferred top level module(s). + llvm::SmallVector inferredTopLevelNodes; +}; + +struct InstancePathCache; + +/** + * An instance path composed of a series of instances. + */ +class InstancePath final { + public: + InstancePath() = default; + + InstanceOpInterface top() const { + assert(!empty() && "instance path is empty"); + return path[0]; + } + + InstanceOpInterface leaf() const { + assert(!empty() && "instance path is empty"); + return path.back(); + } + + InstancePath dropFront() const { return InstancePath(path.drop_front()); } + + InstanceOpInterface operator[](size_t idx) const { return path[idx]; } + ArrayRef::iterator begin() const { return path.begin(); } + ArrayRef::iterator end() const { return path.end(); } + size_t size() const { return path.size(); } + bool empty() const { return path.empty(); } + + /// Print the path to any stream-like object. + template + void print(T &into) const { + into << "$root"; + for (auto inst : path) + into << "/" << inst.getInstanceName() << ":" + << inst.getReferencedModuleName(); + } + + private: + // Only the path cache is allowed to create paths. + friend struct InstancePathCache; + InstancePath(ArrayRef path) : path(path) {} + + ArrayRef path; +}; + +template +static T &operator<<(T &os, const InstancePath &path) { + return path.print(os); +} + +/// A data structure that caches and provides absolute paths to module instances +/// in the IR. +struct InstancePathCache { + /// The instance graph of the IR. + InstanceGraph &instanceGraph; + + explicit InstancePathCache(InstanceGraph &instanceGraph) + : instanceGraph(instanceGraph) {} + ArrayRef getAbsolutePaths(ModuleOpInterface op); + + /// Replace an InstanceOp. This is required to keep the cache updated. + void replaceInstance(InstanceOpInterface oldOp, InstanceOpInterface newOp); + + /// Append an instance to a path. + InstancePath appendInstance(InstancePath path, InstanceOpInterface inst); + + /// Prepend an instance to a path. + InstancePath prependInstance(InstanceOpInterface inst, InstancePath path); + + private: + /// An allocator for individual instance paths and entire path lists. + llvm::BumpPtrAllocator allocator; + + /// Cached absolute instance paths. + DenseMap> absolutePathsCache; +}; + +} // namespace igraph +} // namespace circt + +// Graph traits for modules. +template <> +struct llvm::GraphTraits { + using NodeType = circt::igraph::InstanceGraphNode; + using NodeRef = NodeType *; + + // Helper for getting the module referenced by the instance op. + static NodeRef getChild(const circt::igraph::InstanceRecord *record) { + return record->getTarget(); + } + + using ChildIteratorType = + llvm::mapped_iterator; + + static NodeRef getEntryNode(NodeRef node) { return node; } + static ChildIteratorType child_begin(NodeRef node) { + return {node->begin(), &getChild}; + } + static ChildIteratorType child_end(NodeRef node) { + return {node->end(), &getChild}; + } +}; + +// Provide graph traits for iterating the modules in inverse order. +template <> +struct llvm::GraphTraits> { + using NodeType = circt::igraph::InstanceGraphNode; + using NodeRef = NodeType *; + + // Helper for getting the module containing the instance op. + static NodeRef getParent(const circt::igraph::InstanceRecord *record) { + return record->getParent(); + } + + using ChildIteratorType = + llvm::mapped_iterator; + + static NodeRef getEntryNode(Inverse inverse) { + return inverse.Graph; + } + static ChildIteratorType child_begin(NodeRef node) { + return {node->usesBegin(), &getParent}; + } + static ChildIteratorType child_end(NodeRef node) { + return {node->usesEnd(), &getParent}; + } +}; + +// Graph traits for the common instance graph. +template <> +struct llvm::GraphTraits + : public llvm::GraphTraits { + using nodes_iterator = circt::igraph::InstanceGraph::iterator; + + static NodeRef getEntryNode(circt::igraph::InstanceGraph *graph) { + return graph->getTopLevelNode(); + } + // NOLINTNEXTLINE(readability-identifier-naming) + static nodes_iterator nodes_begin(circt::igraph::InstanceGraph *graph) { + return graph->begin(); + } + // NOLINTNEXTLINE(readability-identifier-naming) + static nodes_iterator nodes_end(circt::igraph::InstanceGraph *graph) { + return graph->end(); + } +}; + +// Graph traits for DOT labeling. +template <> +struct llvm::DOTGraphTraits + : public llvm::DefaultDOTGraphTraits { + using DefaultDOTGraphTraits::DefaultDOTGraphTraits; + + static std::string getNodeLabel(circt::igraph::InstanceGraphNode *node, + circt::igraph::InstanceGraph *) { + // The name of the graph node is the module name. + return node->getModule().getModuleName().str(); + } + + template + static std::string getEdgeAttributes( + const circt::igraph::InstanceGraphNode *node, Iterator it, + circt::igraph::InstanceGraph *) { + // Set an edge label that is the name of the instance. + auto *instanceRecord = *it.getCurrent(); + auto instanceOp = instanceRecord->getInstance(); + return ("label=" + instanceOp.getInstanceName()).str(); + } +}; + +#endif // CIRCT_SUPPORT_INSTANCEGRAPH_H diff --git a/third_party/circt/include/circt/Support/InstanceGraphInterface.h b/third_party/circt/include/circt/Support/InstanceGraphInterface.h new file mode 100644 index 000000000..7063bc868 --- /dev/null +++ b/third_party/circt/include/circt/Support/InstanceGraphInterface.h @@ -0,0 +1,23 @@ +//===- InstanceGraphInterface.h - Instance graph interface ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares stuff related to the instance graph interface. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_INSTANCEGRAPHINTERFACE_H +#define CIRCT_SUPPORT_INSTANCEGRAPHINTERFACE_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/OpDefinition.h" + +/// The InstanceGraph op interface, see InstanceGraphInterface.td for more +/// details. +#include "circt/Support/InstanceGraphInterface.h.inc" + +#endif // CIRCT_SUPPORT_INSTANCEGRAPHINTERFACE_H diff --git a/third_party/circt/include/circt/Support/InstanceGraphInterface.td b/third_party/circt/include/circt/Support/InstanceGraphInterface.td new file mode 100644 index 000000000..261dd1570 --- /dev/null +++ b/third_party/circt/include/circt/Support/InstanceGraphInterface.td @@ -0,0 +1,74 @@ +//===- InstanceGraphInterface.td - Interface for instance graphs --------*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains interfaces and other utilities for interacting with the +// generic CIRCT instance graph. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_INSTANCEGRAPH_INSTANCEGRAPHINTERFACE_TD +#define CIRCT_SUPPORT_INSTANCEGRAPH_INSTANCEGRAPHINTERFACE_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/OpBase.td" + +def InstanceGraphInstanceOpInterface : OpInterface<"InstanceOpInterface"> { + let description = [{ + This interface provides hooks for an instance-like operation. + }]; + let cppNamespace = "::circt::igraph"; + + let methods = [ + InterfaceMethod<"Get the name of the instance", + "::llvm::StringRef", "getInstanceName", (ins)>, + + InterfaceMethod<"Get the name of the instance", + "::mlir::StringAttr", "getInstanceNameAttr", (ins)>, + + InterfaceMethod<"Get the name of the instantiated module", + "::llvm::StringRef", "getReferencedModuleName", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return $_op.getModuleName(); }]>, + + InterfaceMethod<"Get the name of the instantiated module", + "::mlir::StringAttr", "getReferencedModuleNameAttr", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return $_op.getModuleNameAttr().getAttr(); }]>, + + InterfaceMethod<[{ + Get the referenced module (slow, unsafe). This function directly accesses + the parent operation to lookup a symbol, which is unsafe in many contexts. + }], + "::mlir::Operation *", "getReferencedModuleSlow", (ins)>, + + InterfaceMethod<"Get the referenced module via a symbol table.", + "::mlir::Operation *", "getReferencedModule", (ins "SymbolTable&":$symtbl)>, + ]; +} + +def InstanceGraphModuleOpInterface : OpInterface<"ModuleOpInterface"> { + let description = [{ + This interface provides hooks for a module-like operation. + }]; + let cppNamespace = "::circt::igraph"; + + let methods = [ + InterfaceMethod<"Get the module name", + "::llvm::StringRef", "getModuleName", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return $_op.getModuleNameAttr().getValue(); }]>, + + InterfaceMethod<"Get the module name", + "::mlir::StringAttr", "getModuleNameAttr", (ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return $_op.getNameAttr(); }]>, + ]; + +} + +#endif // CIRCT_SUPPORT_INSTANCEGRAPH_INSTANCEGRAPHINTERFACE_TD diff --git a/third_party/circt/include/circt/Support/JSON.h b/third_party/circt/include/circt/Support/JSON.h new file mode 100644 index 000000000..1b76bdc18 --- /dev/null +++ b/third_party/circt/include/circt/Support/JSON.h @@ -0,0 +1,30 @@ +//===- Json.h - Json Utilities ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for JSON-to-Attribute conversion. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_JSON_H +#define CIRCT_SUPPORT_JSON_H + +#include "circt/Support/LLVM.h" +#include "llvm/Support/JSON.h" + +namespace circt { + +/// Convert a simple attribute to JSON. +LogicalResult convertAttributeToJSON(llvm::json::OStream &json, Attribute attr); + +/// Convert arbitrary JSON to an MLIR Attribute. +Attribute convertJSONToAttribute(MLIRContext *context, llvm::json::Value &value, + llvm::json::Path p); + +} // namespace circt + +#endif // CIRCT_SUPPORT_JSON_H diff --git a/third_party/circt/include/circt/Support/LLVM.h b/third_party/circt/include/circt/Support/LLVM.h new file mode 100644 index 000000000..4ab50feee --- /dev/null +++ b/third_party/circt/include/circt/Support/LLVM.h @@ -0,0 +1,293 @@ +//===- LLVM.h - Import and forward declare core LLVM types ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file forward declares and imports various common LLVM and MLIR datatypes +// that we want to use unqualified. +// +// Note that most of these are forward declared and then imported into the circt +// namespace with using decls, rather than being #included. This is because we +// want clients to explicitly #include the files they need. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_LLVM_H +#define CIRCT_SUPPORT_LLVM_H + +// MLIR includes a lot of forward declarations of LLVM types, use them. +#include "mlir/Support/LLVM.h" + +// Can not forward declare inline functions with default arguments, so we +// include the header directly. +#include "mlir/Support/LogicalResult.h" + +// Import classes from the `mlir` namespace into the `circt` namespace. All of +// the following classes have been already forward declared and imported from +// `llvm` in to the `mlir` namespace. For classes with default template +// arguments, MLIR does not import the type directly, it creates a templated +// using statement. This is due to the limitiation that only one declaration of +// a type can have default arguments. For those types, it is important to import +// the MLIR version, and not the LLVM version. To keep things simple, all +// classes here should be imported from the `mlir` namespace, not the `llvm` +// namespace. +namespace circt { +using mlir::APFloat; // NOLINT(misc-unused-using-decls) +using mlir::APInt; // NOLINT(misc-unused-using-decls) +using mlir::APSInt; // NOLINT(misc-unused-using-decls) +using mlir::ArrayRef; // NOLINT(misc-unused-using-decls) +using mlir::BitVector; // NOLINT(misc-unused-using-decls) +using mlir::cast; // NOLINT(misc-unused-using-decls) +using mlir::cast_or_null; // NOLINT(misc-unused-using-decls) +using mlir::DenseMap; // NOLINT(misc-unused-using-decls) +using mlir::DenseMapInfo; // NOLINT(misc-unused-using-decls) +using mlir::DenseSet; // NOLINT(misc-unused-using-decls) +using mlir::dyn_cast; // NOLINT(misc-unused-using-decls) +using mlir::dyn_cast_or_null; // NOLINT(misc-unused-using-decls) +using mlir::function_ref; // NOLINT(misc-unused-using-decls) +using mlir::isa; // NOLINT(misc-unused-using-decls) +using mlir::isa_and_nonnull; // NOLINT(misc-unused-using-decls) +using mlir::iterator_range; // NOLINT(misc-unused-using-decls) +using mlir::MutableArrayRef; // NOLINT(misc-unused-using-decls) +using mlir::PointerUnion; // NOLINT(misc-unused-using-decls) +using mlir::raw_ostream; // NOLINT(misc-unused-using-decls) +using mlir::SetVector; // NOLINT(misc-unused-using-decls) +using mlir::SmallPtrSet; // NOLINT(misc-unused-using-decls) +using mlir::SmallPtrSetImpl; // NOLINT(misc-unused-using-decls) +using mlir::SmallString; // NOLINT(misc-unused-using-decls) +using mlir::SmallVector; // NOLINT(misc-unused-using-decls) +using mlir::SmallVectorImpl; // NOLINT(misc-unused-using-decls) +using mlir::StringLiteral; // NOLINT(misc-unused-using-decls) +using mlir::StringRef; // NOLINT(misc-unused-using-decls) +using mlir::StringSet; // NOLINT(misc-unused-using-decls) +using mlir::TinyPtrVector; // NOLINT(misc-unused-using-decls) +using mlir::Twine; // NOLINT(misc-unused-using-decls) +using mlir::TypeSwitch; // NOLINT(misc-unused-using-decls) +} // namespace circt + +// Forward declarations of LLVM classes to be imported in to the circt +// namespace. +namespace llvm { +template +class SmallDenseMap; +template +class SmallSet; +} // namespace llvm + +// Import things we want into our namespace. +namespace circt { +using llvm::SmallDenseMap; // NOLINT(misc-unused-using-decls) +using llvm::SmallSet; // NOLINT(misc-unused-using-decls) +} // namespace circt + +// Forward declarations of classes to be imported in to the circt namespace. +namespace mlir { +class ArrayAttr; +class AsmParser; +class AsmPrinter; +class Attribute; +class Block; +class TypedAttr; +class IRMapping; +class BlockArgument; +class BoolAttr; +class Builder; +class NamedAttrList; +class ConversionPattern; +class ConversionPatternRewriter; +class ConversionTarget; +class DenseElementsAttr; +class Diagnostic; +class Dialect; +class DialectAsmParser; +class DialectAsmPrinter; +class DictionaryAttr; +class DistinctAttr; +class ElementsAttr; +class FileLineColLoc; +class FlatSymbolRefAttr; +class FloatAttr; +class FunctionType; +class FusedLoc; +class ImplicitLocOpBuilder; +class IndexType; +class InFlightDiagnostic; +class IntegerAttr; +class IntegerType; +class Location; +class LocationAttr; +class MemRefType; +class MLIRContext; +class ModuleOp; +class MutableOperandRange; +class NamedAttribute; +class NamedAttrList; +class NoneType; +class OpAsmDialectInterface; +class OpAsmParser; +class OpAsmPrinter; +class OpaqueProperties; +class OpBuilder; +class OperandRange; +class Operation; +class OpFoldResult; +class OpOperand; +class OpResult; +template +class OwningOpRef; +class ParseResult; +class Pass; +class PatternRewriter; +class Region; +class RewritePatternSet; +class ShapedType; +class SplatElementsAttr; +class StringAttr; +class SymbolRefAttr; +class SymbolTable; +class SymbolTableCollection; +class TupleType; +class Type; +class TypeAttr; +class TypeConverter; +class TypeID; +class TypeRange; +class TypeStorage; +class UnitAttr; +class UnknownLoc; +class Value; +class ValueRange; +class VectorType; +class WalkResult; +enum class RegionKind; +struct CallInterfaceCallable; +struct LogicalResult; +struct OperationState; +class OperationName; + +namespace affine { +struct MemRefAccess; +} // namespace affine + +template +class FailureOr; +template +class OpConversionPattern; +template +class OperationPass; +template +struct OpRewritePattern; + +using DefaultTypeStorage = TypeStorage; +using OpAsmSetValueNameFn = function_ref; + +namespace OpTrait {} + +} // namespace mlir + +// Import things we want into our namespace. +namespace circt { +// clang-tidy removes following using directives incorrectly. So force +// clang-tidy to ignore them. +// TODO: It is better to use `NOLINTBEGIN/END` comments to disable clang-tidy +// than adding `NOLINT` to every line. `NOLINTBEGIN/END` will supported from +// clang-tidy-14. +using mlir::ArrayAttr; // NOLINT(misc-unused-using-decls) +using mlir::AsmParser; // NOLINT(misc-unused-using-decls) +using mlir::AsmPrinter; // NOLINT(misc-unused-using-decls) +using mlir::Attribute; // NOLINT(misc-unused-using-decls) +using mlir::Block; // NOLINT(misc-unused-using-decls) +using mlir::BlockArgument; // NOLINT(misc-unused-using-decls) +using mlir::BoolAttr; // NOLINT(misc-unused-using-decls) +using mlir::Builder; // NOLINT(misc-unused-using-decls) +using mlir::CallInterfaceCallable; // NOLINT(misc-unused-using-decls) +using mlir::ConversionPattern; // NOLINT(misc-unused-using-decls) +using mlir::ConversionPatternRewriter; // NOLINT(misc-unused-using-decls) +using mlir::ConversionTarget; // NOLINT(misc-unused-using-decls) +using mlir::DefaultTypeStorage; // NOLINT(misc-unused-using-decls) +using mlir::DenseElementsAttr; // NOLINT(misc-unused-using-decls) +using mlir::Diagnostic; // NOLINT(misc-unused-using-decls) +using mlir::Dialect; // NOLINT(misc-unused-using-decls) +using mlir::DialectAsmParser; // NOLINT(misc-unused-using-decls) +using mlir::DialectAsmPrinter; // NOLINT(misc-unused-using-decls) +using mlir::DictionaryAttr; // NOLINT(misc-unused-using-decls) +using mlir::DistinctAttr; // NOLINT(misc-unused-using-decls) +using mlir::ElementsAttr; // NOLINT(misc-unused-using-decls) +using mlir::failed; // NOLINT(misc-unused-using-decls) +using mlir::failure; // NOLINT(misc-unused-using-decls) +using mlir::FailureOr; // NOLINT(misc-unused-using-decls) +using mlir::FileLineColLoc; // NOLINT(misc-unused-using-decls) +using mlir::FlatSymbolRefAttr; // NOLINT(misc-unused-using-decls) +using mlir::FloatAttr; // NOLINT(misc-unused-using-decls) +using mlir::FunctionType; // NOLINT(misc-unused-using-decls) +using mlir::FusedLoc; // NOLINT(misc-unused-using-decls) +using mlir::ImplicitLocOpBuilder; // NOLINT(misc-unused-using-decls) +using mlir::IndexType; // NOLINT(misc-unused-using-decls) +using mlir::InFlightDiagnostic; // NOLINT(misc-unused-using-decls) +using mlir::IntegerAttr; // NOLINT(misc-unused-using-decls) +using mlir::IntegerType; // NOLINT(misc-unused-using-decls) +using mlir::IRMapping; // NOLINT(misc-unused-using-decls) +using mlir::Location; // NOLINT(misc-unused-using-decls) +using mlir::LocationAttr; // NOLINT(misc-unused-using-decls) +using mlir::LogicalResult; // NOLINT(misc-unused-using-decls) +using mlir::MemRefType; // NOLINT(misc-unused-using-decls) +using mlir::MLIRContext; // NOLINT(misc-unused-using-decls) +using mlir::ModuleOp; // NOLINT(misc-unused-using-decls) +using mlir::MutableOperandRange; // NOLINT(misc-unused-using-decls) +using mlir::NamedAttribute; // NOLINT(misc-unused-using-decls) +using mlir::NamedAttrList; // NOLINT(misc-unused-using-decls) +using mlir::NoneType; // NOLINT(misc-unused-using-decls) +using mlir::OpaqueProperties; // NOLINT(misc-unused-using-decls) +using mlir::OpAsmDialectInterface; // NOLINT(misc-unused-using-decls) +using mlir::OpAsmParser; // NOLINT(misc-unused-using-decls) +using mlir::OpAsmPrinter; // NOLINT(misc-unused-using-decls) +using mlir::OpAsmSetValueNameFn; // NOLINT(misc-unused-using-decls) +using mlir::OpBuilder; // NOLINT(misc-unused-using-decls) +using mlir::OpConversionPattern; // NOLINT(misc-unused-using-decls) +using mlir::OperandRange; // NOLINT(misc-unused-using-decls) +using mlir::Operation; // NOLINT(misc-unused-using-decls) +using mlir::OperationName; // NOLINT(misc-unused-using-decls) +using mlir::OperationPass; // NOLINT(misc-unused-using-decls) +using mlir::OperationState; // NOLINT(misc-unused-using-decls) +using mlir::OpFoldResult; // NOLINT(misc-unused-using-decls) +using mlir::OpOperand; // NOLINT(misc-unused-using-decls) +using mlir::OpResult; // NOLINT(misc-unused-using-decls) +using mlir::OpRewritePattern; // NOLINT(misc-unused-using-decls) +using mlir::OwningOpRef; // NOLINT(misc-unused-using-decls) +using mlir::ParseResult; // NOLINT(misc-unused-using-decls) +using mlir::Pass; // NOLINT(misc-unused-using-decls) +using mlir::PatternRewriter; // NOLINT(misc-unused-using-decls) +using mlir::Region; // NOLINT(misc-unused-using-decls) +using mlir::RegionKind; // NOLINT(misc-unused-using-decls) +using mlir::RewritePatternSet; // NOLINT(misc-unused-using-decls) +using mlir::ShapedType; // NOLINT(misc-unused-using-decls) +using mlir::SplatElementsAttr; // NOLINT(misc-unused-using-decls) +using mlir::StringAttr; // NOLINT(misc-unused-using-decls) +using mlir::succeeded; // NOLINT(misc-unused-using-decls) +using mlir::success; // NOLINT(misc-unused-using-decls) +using mlir::SymbolRefAttr; // NOLINT(misc-unused-using-decls) +using mlir::SymbolTable; // NOLINT(misc-unused-using-decls) +using mlir::SymbolTableCollection; // NOLINT(misc-unused-using-decls) +using mlir::TupleType; // NOLINT(misc-unused-using-decls) +using mlir::Type; // NOLINT(misc-unused-using-decls) +using mlir::TypeAttr; // NOLINT(misc-unused-using-decls) +using mlir::TypeConverter; // NOLINT(misc-unused-using-decls) +using mlir::TypedAttr; // NOLINT(misc-unused-using-decls) +using mlir::TypeID; // NOLINT(misc-unused-using-decls) +using mlir::TypeRange; // NOLINT(misc-unused-using-decls) +using mlir::TypeStorage; // NOLINT(misc-unused-using-decls) +using mlir::UnitAttr; // NOLINT(misc-unused-using-decls) +using mlir::UnknownLoc; // NOLINT(misc-unused-using-decls) +using mlir::Value; // NOLINT(misc-unused-using-decls) +using mlir::ValueRange; // NOLINT(misc-unused-using-decls) +using mlir::VectorType; // NOLINT(misc-unused-using-decls) +using mlir::WalkResult; // NOLINT(misc-unused-using-decls) +using mlir::affine::MemRefAccess; // NOLINT(misc-unused-using-decls) +namespace OpTrait = mlir::OpTrait; +} // namespace circt + +#endif // CIRCT_SUPPORT_LLVM_H diff --git a/third_party/circt/include/circt/Support/LoweringOptions.h b/third_party/circt/include/circt/Support/LoweringOptions.h new file mode 100644 index 000000000..9c81237fd --- /dev/null +++ b/third_party/circt/include/circt/Support/LoweringOptions.h @@ -0,0 +1,169 @@ +//===- LoweringOptions.h - CIRCT Lowering Options ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Options for controlling the lowering process and verilog exporting. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_LOWERINGOPTIONS_H +#define CIRCT_SUPPORT_LOWERINGOPTIONS_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "mlir/IR/BuiltinAttributes.h" + +namespace mlir { +class ModuleOp; +} + +namespace circt { + +/// Options which control the emission from CIRCT to Verilog. +struct LoweringOptions { + /// Error callback type used to indicate errors parsing the options string. + using ErrorHandlerT = llvm::function_ref; + + /// Create a LoweringOptions with the default values. + LoweringOptions() = default; + + /// Create a LoweringOptions and read in options from a string, + /// overriding only the set options in the string. + LoweringOptions(llvm::StringRef options, ErrorHandlerT errorHandler); + + /// Create a LoweringOptions with values loaded from an MLIR ModuleOp. This + /// loads a string attribute with the key `circt.loweringOptions`. If there is + /// an error parsing the attribute this will print an error using the + /// ModuleOp. + LoweringOptions(mlir::ModuleOp module); + + /// Return the value of the `circt.loweringOptions` in the specified module + /// if present, or a null attribute if not. + static mlir::StringAttr getAttributeFrom(mlir::ModuleOp module); + + /// Read in options from a string, overriding only the set options in the + /// string. + void parse(llvm::StringRef options, ErrorHandlerT callback); + + /// Returns a string representation of the options. + std::string toString() const; + + /// Write the verilog emitter options to a module's attributes. + void setAsAttribute(mlir::ModuleOp module); + + /// Load any emitter options from the module. If there is an error validating + /// the attribute, this will print an error using the ModuleOp. + void parseFromAttribute(mlir::ModuleOp module); + + /// If true, emits `sv.alwayscomb` as Verilog `always @(*)` statements. + /// Otherwise, print them as `always_comb`. + bool noAlwaysComb = false; + + /// If true, expressions are allowed in the sensitivity list of `always` + /// statements, otherwise they are forced to be simple wires. Some EDA + /// tools rely on these being simple wires. + bool allowExprInEventControl = false; + + /// If true, eliminate packed arrays for tools that don't support them (e.g. + /// Yosys). + bool disallowPackedArrays = false; + + /// If true, eliminate packed struct assignments in favor of a wire + + /// assignments to the individual fields. + bool disallowPackedStructAssignments = false; + + /// If true, do not emit SystemVerilog locally scoped "automatic" or logic + /// declarations - emit top level wire and reg's instead. + bool disallowLocalVariables = false; + + /// If true, verification statements like `assert`, `assume`, and `cover` will + /// always be emitted with a label. If the statement has no label in the IR, a + /// generic one will be created. Some EDA tools require verification + /// statements to be labeled. + bool enforceVerifLabels = false; + + /// This is the maximum number of terms in an expression before that + /// expression spills a wire. + enum { DEFAULT_TERM_LIMIT = 256 }; + unsigned maximumNumberOfTermsPerExpression = DEFAULT_TERM_LIMIT; + + /// This is the target width of lines in an emitted Verilog source file in + /// columns. + enum { DEFAULT_LINE_LENGTH = 90 }; + unsigned emittedLineLength = DEFAULT_LINE_LENGTH; + + /// Add an explicit bitcast for avoiding bitwidth mismatch LINT errors. + bool explicitBitcast = false; + + /// If true, replicated ops are emitted to a header file. + bool emitReplicatedOpsToHeader = false; + + /// This option controls emitted location information style. + enum LocationInfoStyle { + Plain, // Default. + WrapInAtSquareBracket, // Wrap location info in @[..]. + None, // No location info comment. + } locationInfoStyle = Plain; + + /// If true, every port is declared separately + /// (each includes direction and type (e.g., `input [3:0]`)). + /// When false (default), ports share declarations when possible. + bool disallowPortDeclSharing = false; + + /// Print debug info. + bool printDebugInfo = false; + + /// If true, every mux expression is spilled to a wire. + bool disallowMuxInlining = false; + + /// This controls extra wire spilling performed in PrepareForEmission to + /// improve readablitiy and debuggability. + enum WireSpillingHeuristic : unsigned { + SpillLargeTermsWithNamehints = 1, // Spill wires for expressions with + // namehints if the term size is greater + // than `wireSpillingNamehintTermLimit`. + }; + + unsigned wireSpillingHeuristicSet = 0; + + bool isWireSpillingHeuristicEnabled(WireSpillingHeuristic heurisic) const { + return static_cast(wireSpillingHeuristicSet & heurisic); + } + + enum { DEFAULT_NAMEHINT_TERM_LIMIT = 3 }; + unsigned wireSpillingNamehintTermLimit = DEFAULT_NAMEHINT_TERM_LIMIT; + + /// If true, every expression passed to an instance port is driven by a wire. + /// Some lint tools dislike expressions being inlined into input ports so this + /// option avoids such warnings. + bool disallowExpressionInliningInPorts = false; + + /// If true, every expression used as an array index is driven by a wire, and + /// the wire is marked as `(* keep = "true" *)`. Certain versions of Vivado + /// produce incorrect synthesis results for certain arithmetic ops inlined + /// into the array index. + bool mitigateVivadoArrayIndexConstPropBug = false; + + /// If true, emit `wire` in port lists rather than nothing. Used in cases + /// where `default_nettype is not set to wire. + bool emitWireInPorts = false; + + /// If true, emit a comment wherever an instance wasn't printed, because + /// it's emitted elsewhere as a bind. + bool emitBindComments = false; + + /// If true, do not emit a version comment at the top of each verilog file. + bool omitVersionComment = false; + + /// If true, then unique names that collide with keywords case insensitively. + /// This is used to avoid stricter lint warnings which, e.g., treat "REG" as a + /// Verilog keyword. + bool caseInsensitiveKeywords = false; +}; +} // namespace circt + +#endif // CIRCT_SUPPORT_LOWERINGOPTIONS_H diff --git a/third_party/circt/include/circt/Support/LoweringOptionsParser.h b/third_party/circt/include/circt/Support/LoweringOptionsParser.h new file mode 100644 index 000000000..1a7244582 --- /dev/null +++ b/third_party/circt/include/circt/Support/LoweringOptionsParser.h @@ -0,0 +1,57 @@ +//===- LoweringOptionsParser.h - CIRCT Lowering Option Parser ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parser for lowering options. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_LOWERINGOPTIONSPARSER_H +#define CIRCT_SUPPORT_LOWERINGOPTIONSPARSER_H + +#include "circt/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" + +namespace circt { + +/// Commandline parser for LoweringOptions. Delegates to the parser +/// defined by LoweringOptions. +struct LoweringOptionsParser : public llvm::cl::parser { + LoweringOptionsParser(llvm::cl::Option &option) + : llvm::cl::parser(option) {} + + bool parse(llvm::cl::Option &option, StringRef argName, StringRef argValue, + LoweringOptions &value) { + bool failed = false; + value.parse(argValue, [&](Twine error) { failed = option.error(error); }); + return failed; + } +}; + +struct LoweringOptionsOption + : llvm::cl::opt { + LoweringOptionsOption(llvm::cl::OptionCategory &cat) + : llvm::cl::opt{ + "lowering-options", + llvm::cl::desc( + "Style options. Valid flags include: " + "noAlwaysComb, exprInEventControl, disallowPackedArrays, " + "disallowLocalVariables, verifLabels, emittedLineLength=, " + "maximumNumberOfTermsPerExpression=, " + "explicitBitcast, emitReplicatedOpsToHeader, " + "locationInfoStyle={plain,wrapInAtSquareBracket,none}, " + "disallowPortDeclSharing, printDebugInfo, " + "disallowExpressionInliningInPorts, disallowMuxInlining, " + "emitWireInPort, emitBindComments, omitVersionComment, " + "caseInsensitiveKeywords"), + llvm::cl::cat(cat), llvm::cl::value_desc("option")} {} +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_LOWERINGOPTIONSPARSER_H diff --git a/third_party/circt/include/circt/Support/Namespace.h b/third_party/circt/include/circt/Support/Namespace.h new file mode 100644 index 000000000..c22e2b9e1 --- /dev/null +++ b/third_party/circt/include/circt/Support/Namespace.h @@ -0,0 +1,134 @@ +//===- Namespace.h - Utilities for generating names -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities for generating new names that do not conflict +// with existing names. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_NAMESPACE_H +#define CIRCT_SUPPORT_NAMESPACE_H + +#include "circt/Support/LLVM.h" +#include "circt/Support/SymCache.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/Twine.h" + +namespace circt { + +/// A namespace that is used to store existing names and generate new names in +/// some scope within the IR. This exists to work around limitations of +/// SymbolTables. This acts as a base class providing facilities common to all +/// namespaces implementations. +class Namespace { + public: + Namespace() {} + Namespace(const Namespace &other) = default; + Namespace(Namespace &&other) : nextIndex(std::move(other.nextIndex)) {} + + Namespace &operator=(const Namespace &other) = default; + Namespace &operator=(Namespace &&other) { + nextIndex = std::move(other.nextIndex); + return *this; + } + + /// SymbolCache initializer; initialize from every key that is convertible to + /// a StringAttr in the SymbolCache. + void add(SymbolCache &symCache) { + for (auto &&[attr, _] : symCache) + if (auto strAttr = attr.dyn_cast()) + nextIndex.insert({strAttr.getValue(), 0}); + } + + /// Empty the namespace. + void clear() { nextIndex.clear(); } + + /// Return a unique name, derived from the input `name`, and add the new name + /// to the internal namespace. There are two possible outcomes for the + /// returned name: + /// + /// 1. The original name is returned. + /// 2. The name is given a `_` suffix where `` is a number starting from + /// `0` and incrementing by one each time (`_0`, ...). + StringRef newName(const Twine &name) { + // Special case the situation where there is no name collision to avoid + // messing with the SmallString allocation below. + llvm::SmallString<64> tryName; + auto inserted = nextIndex.insert({name.toStringRef(tryName), 0}); + if (inserted.second) return inserted.first->getKey(); + + // Try different suffixes until we get a collision-free one. + if (tryName.empty()) + name.toVector(tryName); // toStringRef may leave tryName unfilled + + // Indexes less than nextIndex[tryName] are lready used, so skip them. + // Indexes larger than nextIndex[tryName] may be used in another name. + size_t &i = nextIndex[tryName]; + tryName.push_back('_'); + size_t baseLength = tryName.size(); + do { + tryName.resize(baseLength); + Twine(i++).toVector(tryName); // append integer to tryName + inserted = nextIndex.insert({tryName, 0}); + } while (!inserted.second); + + return inserted.first->getKey(); + } + + /// Return a unique name, derived from the input `name` and ensure the + /// returned name has the input `suffix`. Also add the new name to the + /// internal namespace. + /// There are two possible outcomes for the returned name: + /// 1. The original name + `_` is returned. + /// 2. The name is given a suffix `__` where `` is a number + /// starting from `0` and incrementing by one each time. + StringRef newName(const Twine &name, const Twine &suffix) { + // Special case the situation where there is no name collision to avoid + // messing with the SmallString allocation below. + llvm::SmallString<64> tryName; + auto inserted = nextIndex.insert( + {name.concat("_").concat(suffix).toStringRef(tryName), 0}); + if (inserted.second) return inserted.first->getKey(); + + // Try different suffixes until we get a collision-free one. + tryName.clear(); + name.toVector(tryName); // toStringRef may leave tryName unfilled + tryName.push_back('_'); + size_t baseLength = tryName.size(); + + // Get the initial number to start from. Since `:` is not a valid character + // in a verilog identifier, we use it separate the name and suffix. + // Next number for name+suffix is stored with key `name_:suffix`. + tryName.push_back(':'); + suffix.toVector(tryName); + + // Indexes less than nextIndex[tryName] are already used, so skip them. + // Indexes larger than nextIndex[tryName] may be used in another name. + size_t &i = nextIndex[tryName]; + do { + tryName.resize(baseLength); + Twine(i++).toVector(tryName); // append integer to tryName + tryName.push_back('_'); + suffix.toVector(tryName); + inserted = nextIndex.insert({tryName, 0}); + } while (!inserted.second); + + return inserted.first->getKey(); + } + + protected: + // The "next index" that will be tried when trying to unique a string within a + // namespace. It follows that all values less than the "next index" value are + // already used. + llvm::StringMap nextIndex; +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_NAMESPACE_H diff --git a/third_party/circt/include/circt/Support/ParsingUtils.h b/third_party/circt/include/circt/Support/ParsingUtils.h new file mode 100644 index 000000000..476ca8905 --- /dev/null +++ b/third_party/circt/include/circt/Support/ParsingUtils.h @@ -0,0 +1,57 @@ +//===- ParsingUtils.h - CIRCT parsing common functions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities to help with parsing. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_PARSINGUTILS_H +#define CIRCT_SUPPORT_PARSINGUTILS_H + +#include "circt/Support/LLVM.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpImplementation.h" + +namespace circt { +namespace parsing_util { + +/// Get a name from an SSA value string, if said value name is not a +/// number. +static inline StringAttr getNameFromSSA(MLIRContext *context, StringRef name) { + if (!name.empty()) { + // Ignore numeric names like %42 + assert(name.size() > 1 && name[0] == '%' && "Unknown MLIR name"); + if (isdigit(name[1])) + name = StringRef(); + else + name = name.drop_front(); + } + return StringAttr::get(context, name); +} + +//===----------------------------------------------------------------------===// +// Initializer lists +//===----------------------------------------------------------------------===// + +/// Parses an initializer. +/// An initializer list is a list of operands, types and names on the format: +/// (%arg = %input : type, ...) +ParseResult parseInitializerList( + mlir::OpAsmParser &parser, + llvm::SmallVector &inputArguments, + llvm::SmallVector &inputOperands, + llvm::SmallVector &inputTypes, ArrayAttr &inputNames); + +// Prints an initializer list. +void printInitializerList(OpAsmPrinter &p, ValueRange ins, + ArrayRef args); + +} // namespace parsing_util +} // namespace circt + +#endif // CIRCT_SUPPORT_PARSINGUTILS_H diff --git a/third_party/circt/include/circt/Support/Passes.h b/third_party/circt/include/circt/Support/Passes.h new file mode 100644 index 000000000..1de582dd6 --- /dev/null +++ b/third_party/circt/include/circt/Support/Passes.h @@ -0,0 +1,64 @@ +//===- Passes.h - Helpers for pipeline instrumentation ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_PASSES_H +#define CIRCT_SUPPORT_PASSES_H + +#include "circt/Support/LLVM.h" +#include "llvm/Support/Chrono.h" +#include "llvm/Support/Format.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassInstrumentation.h" + +namespace circt { +// This class prints logs before and after of pass executions when its pass +// operation is in `LoggedOpTypes`. Note that `runBeforePass` and `runAfterPass` +// are not thread safe so `LoggedOpTypes` must be a set of operations whose +// passes are ran sequentially (e.g. mlir::ModuleOp, firrtl::CircuitOp). +template +class VerbosePassInstrumentation : public mlir::PassInstrumentation { + // This stores start time points of passes. + using TimePoint = llvm::sys::TimePoint<>; + llvm::SmallVector timePoints; + int level = 0; + const char *toolName; + + public: + VerbosePassInstrumentation(const char *toolName) : toolName(toolName){}; + void runBeforePass(Pass *pass, Operation *op) override { + if (isa(op)) { + timePoints.push_back(TimePoint::clock::now()); + auto &os = llvm::errs(); + os << llvm::format("[%s] ", toolName); + os.indent(2 * level++); + os << "Running \""; + pass->printAsTextualPipeline(llvm::errs()); + os << "\"\n"; + } + } + + void runAfterPass(Pass *pass, Operation *op) override { + using namespace std::chrono; + if (isa(op)) { + auto &os = llvm::errs(); + auto elapsed = duration(TimePoint::clock::now() - + timePoints.pop_back_val()) / + seconds(1); + os << llvm::format("[%s] ", toolName); + os.indent(2 * --level); + os << "-- Done in " << llvm::format("%.3f", elapsed) << " sec\n"; + } + } +}; + +/// Create a simple canonicalizer pass. +std::unique_ptr createSimpleCanonicalizerPass(); + +} // namespace circt + +#endif // CIRCT_SUPPORT_PASSES_H diff --git a/third_party/circt/include/circt/Support/Path.h b/third_party/circt/include/circt/Support/Path.h new file mode 100644 index 000000000..ddbe424a8 --- /dev/null +++ b/third_party/circt/include/circt/Support/Path.h @@ -0,0 +1,30 @@ +//===- Path.h - Path Utilities ----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for file system path handling, supplementing the ones from +// llvm::sys::path. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_PATH_H +#define CIRCT_SUPPORT_PATH_H + +#include "circt/Support/LLVM.h" + +namespace circt { + +/// Append a path to an existing path, replacing it if the other path is +/// absolute. This mimicks the behaviour of `foo/bar` and `/foo/bar` being used +/// in a working directory `/home`, resulting in `/home/foo/bar` and `/foo/bar`, +/// respectively. +void appendPossiblyAbsolutePath(llvm::SmallVectorImpl &base, + const llvm::Twine &suffix); + +} // namespace circt + +#endif // CIRCT_SUPPORT_PATH_H diff --git a/third_party/circt/include/circt/Support/PrettyPrinter.h b/third_party/circt/include/circt/Support/PrettyPrinter.h new file mode 100644 index 000000000..c801ecb72 --- /dev/null +++ b/third_party/circt/include/circt/Support/PrettyPrinter.h @@ -0,0 +1,325 @@ +//===- PrettyPrinter.h - Pretty printing ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This implements a pretty-printer. +// "PrettyPrinting", Derek C. Oppen, 1980. +// https://dx.doi.org/10.1145/357114.357115 +// +// This was selected as it is linear in number of tokens O(n) and requires +// memory O(linewidth). +// +// See PrettyPrinter.cpp for more information. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_PRETTYPRINTER_H +#define CIRCT_SUPPORT_PRETTYPRINTER_H + +#include +#include +#include + +#include "circt/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/SaveAndRestore.h" + +namespace circt { +namespace pretty { + +//===----------------------------------------------------------------------===// +// Tokens +//===----------------------------------------------------------------------===// + +/// Style of breaking within a group: +/// - Consistent: all fits or all breaks. +/// - Inconsistent: best fit, break where needed. +/// - Never: force no breaking including nested groups. +enum class Breaks { Consistent, Inconsistent, Never }; + +/// Style of indent when starting a group: +/// - Visual: offset is relative to current column. +/// - Block: offset is relative to current base indentation. +enum class IndentStyle { Visual, Block }; + +class Token { + public: + enum class Kind { String, Break, Begin, End, Callback }; + + struct TokenInfo { + Kind kind; // Common initial sequence. + }; + struct StringInfo : public TokenInfo { + uint32_t len; + const char *str; + }; + struct BreakInfo : public TokenInfo { + uint32_t spaces; // How many spaces to emit when not broken. + int32_t offset; // How many spaces to emit when broken. + bool neverbreak; // If set, behaves like break except this always 'fits'. + }; + struct BeginInfo : public TokenInfo { + int32_t offset; // Adjust base indentation by this amount. + Breaks breaks; + IndentStyle style; + }; + struct EndInfo : public TokenInfo { + // Nothing + }; + // This can be used to associate a callback with the print event on the + // tokens stream. Since tokens strictly follow FIFO order on a queue, each + // CallbackToken can uniquely identify the data it is associated with, when it + // is added and popped from the queue. + struct CallbackInfo : public TokenInfo {}; + + private: + union { + TokenInfo info; + StringInfo stringInfo; + BreakInfo breakInfo; + BeginInfo beginInfo; + EndInfo endInfo; + CallbackInfo callbackInfo; + } data; + + protected: + template + static auto &getInfoImpl(T &t) { + if constexpr (k == Kind::String) return t.data.stringInfo; + if constexpr (k == Kind::Break) return t.data.breakInfo; + if constexpr (k == Kind::Begin) return t.data.beginInfo; + if constexpr (k == Kind::End) return t.data.endInfo; + if constexpr (k == Kind::Callback) return t.data.callbackInfo; + llvm_unreachable("unhandled token kind"); + } + + Token(Kind k) { data.info.kind = k; } + + public: + Kind getKind() const { return data.info.kind; } +}; + +/// Helper class to CRTP-derive common functions. +template +struct TokenBase : public Token { + static bool classof(const Token *t) { return t->getKind() == DerivedKind; } + + protected: + TokenBase() : Token(DerivedKind) {} + + using InfoType = std::remove_reference_t), Token &>>; + + InfoType &getInfoMut() { return Token::getInfoImpl(*this); } + + const InfoType &getInfo() const { + return Token::getInfoImpl(*this); + } + + template + void initialize(Args &&...args) { + getInfoMut() = InfoType{{DerivedKind}, args...}; + } +}; + +/// Token types. + +struct StringToken : public TokenBase { + StringToken(llvm::StringRef text) { + assert(text.size() == (uint32_t)text.size()); + initialize((uint32_t)text.size(), text.data()); + } + StringRef text() const { return StringRef(getInfo().str, getInfo().len); } +}; + +struct BreakToken : public TokenBase { + BreakToken(uint32_t spaces = 1, int32_t offset = 0, bool neverbreak = false) { + initialize(spaces, offset, neverbreak); + } + uint32_t spaces() const { return getInfo().spaces; } + int32_t offset() const { return getInfo().offset; } + bool neverbreak() const { return getInfo().neverbreak; } +}; + +struct BeginToken : public TokenBase { + BeginToken(int32_t offset = 2, Breaks breaks = Breaks::Inconsistent, + IndentStyle style = IndentStyle::Visual) { + initialize(offset, breaks, style); + } + int32_t offset() const { return getInfo().offset; } + Breaks breaks() const { return getInfo().breaks; } + IndentStyle style() const { return getInfo().style; } +}; + +struct EndToken : public TokenBase {}; + +struct CallbackToken : public TokenBase { + CallbackToken() = default; +}; + +//===----------------------------------------------------------------------===// +// PrettyPrinter +//===----------------------------------------------------------------------===// + +class PrettyPrinter { + public: + /// Listener to Token storage events. + struct Listener { + virtual ~Listener(); + /// No tokens referencing external memory are present. + virtual void clear(){}; + /// Listener for print event. + virtual void print(){}; + }; + + /// PrettyPrinter for specified stream. + /// - margin: line width. + /// - baseIndent: always indent at least this much (starting 'indent' value). + /// - currentColumn: current column, used to calculate space remaining. + /// - maxStartingIndent: max column indentation starts at, must be >= margin. + PrettyPrinter(llvm::raw_ostream &os, uint32_t margin, uint32_t baseIndent = 0, + uint32_t currentColumn = 0, + uint32_t maxStartingIndent = kInfinity / 4, + Listener *listener = nullptr) + : space(margin - std::max(currentColumn, baseIndent)), + defaultFrame{baseIndent, PrintBreaks::Inconsistent}, + indent(baseIndent), + margin(margin), + maxStartingIndent(std::max(maxStartingIndent, margin)), + os(os), + listener(listener) { + assert(maxStartingIndent < kInfinity / 2); + assert(maxStartingIndent > baseIndent); + assert(margin > currentColumn); + // Ensure first print advances to at least baseIndent. + pendingIndentation = + baseIndent > currentColumn ? baseIndent - currentColumn : 0; + } + ~PrettyPrinter() { eof(); } + + /// Add token for printing. In Oppen, this is "scan". + void add(Token t); + + /// Add a range of tokens. + template + void addTokens(R &&tokens) { + // Don't invoke listener until range processed, we own it now. + { + llvm::SaveAndRestore save(donotClear, true); + for (Token &t : tokens) add(t); + } + // Invoke it now if appropriate. + if (scanStack.empty()) clear(); + } + + void eof(); + + void setListener(Listener *newListener) { listener = newListener; }; + auto *getListener() const { return listener; } + + static constexpr uint32_t kInfinity = (1U << 15) - 1; + + private: + /// Format token with tracked size. + struct FormattedToken { + Token token; /// underlying token + int32_t size; /// calculate size when positive. + }; + + /// Breaking style for a printStack entry. + /// This is "Breaks" values with extra for "Fits". + /// Breaks::Never is "AlwaysFits" here. + enum class PrintBreaks { Consistent, Inconsistent, AlwaysFits, Fits }; + + /// Printing information for active scope, stored in printStack. + struct PrintEntry { + uint32_t offset; + PrintBreaks breaks; + }; + + /// Print out tokens we know sizes for, and drop from token buffer. + void advanceLeft(); + + /// Break encountered, set sizes of begin/breaks in scanStack we now know. + void checkStack(); + + /// Check if there's enough tokens to hit width, if so print. + /// If scan size is wider than line, it's infinity. + void checkStream(); + + /// Print a token, maintaining printStack for context. + void print(const FormattedToken &f); + + /// Clear token buffer, scanStack must be empty. + void clear(); + + /// Reset leftTotal and tokenOffset, rebase size data and scanStack indices. + void rebaseIfNeeded(); + + /// Get current printing frame. + auto &getPrintFrame() { + return printStack.empty() ? defaultFrame : printStack.back(); + } + + /// Characters left on this line. + int32_t space; + + /// Sizes: printed, enqueued + int32_t leftTotal; + int32_t rightTotal; + + /// Unprinted tokens, combination of 'token' and 'size' in Oppen. + std::deque tokens; + /// index of first token, for resolving scanStack entries. + uint32_t tokenOffset = 0; + + /// Stack of begin/break tokens, adjust by tokenOffset to index into tokens. + std::deque scanStack; + + /// Stack of printing contexts (indentation + breaking behavior). + SmallVector printStack; + + /// Printing context when stack is empty. + const PrintEntry defaultFrame; + + /// Number of "AlwaysFits" on print stack. + uint32_t alwaysFits = 0; + + /// Current indentation level + uint32_t indent; + + /// Whitespace to print before next, tracked to avoid trailing whitespace. + uint32_t pendingIndentation; + + /// Target line width. + const uint32_t margin; + + /// Maximum starting indentation level (default=kInfinity/4). + /// Useful to continue indentation past margin while still providing a limit + /// to avoid pathological output and for consumption by tools with limits. + const uint32_t maxStartingIndent; + + /// Output stream. + llvm::raw_ostream &os; + + /// Hook for Token storage events. + Listener *listener = nullptr; + + /// Flag to identify a state when the clear cannot be called. + bool donotClear = false; + + /// Threshold for walking scan state and "rebasing" totals/offsets. + static constexpr decltype(leftTotal) rebaseThreshold = + 1UL << (std::numeric_limits::digits - 3); +}; + +} // end namespace pretty +} // end namespace circt + +#endif // CIRCT_SUPPORT_PRETTYPRINTER_H diff --git a/third_party/circt/include/circt/Support/PrettyPrinterHelpers.h b/third_party/circt/include/circt/Support/PrettyPrinterHelpers.h new file mode 100644 index 000000000..5cceff009 --- /dev/null +++ b/third_party/circt/include/circt/Support/PrettyPrinterHelpers.h @@ -0,0 +1,374 @@ +//===- PrettyPrinterHelpers.h - Pretty printing helpers -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper classes for using PrettyPrinter. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_PRETTYPRINTERHELPERS_H +#define CIRCT_SUPPORT_PRETTYPRINTERHELPERS_H + +#include + +#include "circt/Support/PrettyPrinter.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/StringSaver.h" +#include "llvm/Support/raw_ostream.h" + +namespace circt { +namespace pretty { + +//===----------------------------------------------------------------------===// +// PrettyPrinter standin that buffers tokens until flushed. +//===----------------------------------------------------------------------===// + +/// Buffer tokens for clients that need to adjust things. +struct BufferingPP { + using BufferVec = SmallVectorImpl; + BufferVec &tokens; + bool hasEOF = false; + + BufferingPP(BufferVec &tokens) : tokens(tokens) {} + + void add(Token t) { + assert(!hasEOF); + tokens.push_back(t); + } + + /// Add a range of tokens. + template + void addTokens(R &&newTokens) { + assert(!hasEOF); + llvm::append_range(tokens, newTokens); + } + + /// Buffer a final EOF, no tokens allowed after this. + void eof() { + assert(!hasEOF); + hasEOF = true; + } + + /// Flush buffered tokens to the specified pretty printer. + /// Emit the EOF is one was added. + void flush(PrettyPrinter &pp) { + pp.addTokens(tokens); + tokens.clear(); + if (hasEOF) { + pp.eof(); + hasEOF = false; + } + } +}; + +//===----------------------------------------------------------------------===// +// Convenience Token builders. +//===----------------------------------------------------------------------===// + +namespace detail { +void emitNBSP(unsigned n, llvm::function_ref add); +} // end namespace detail + +/// Add convenience methods for generating pretty-printing tokens. +template +class TokenBuilder { + PPTy &pp; + + public: + TokenBuilder(PPTy &pp) : pp(pp) {} + + //===- Add tokens -------------------------------------------------------===// + + /// Add new token. + template + typename std::enable_if_t> add(Args &&...args) { + pp.add(T(std::forward(args)...)); + } + void addToken(Token t) { pp.add(t); } + + /// End of a stream. + void eof() { pp.eof(); } + + //===- Strings ----------------------------------------------------------===// + + /// Add a literal (with external storage). + void literal(StringRef str) { add(str); } + + /// Add a non-breaking space. + void nbsp() { literal(" "); } + + /// Add multiple non-breaking spaces as a single token. + void nbsp(unsigned n) { + detail::emitNBSP(n, [&](Token t) { addToken(t); }); + } + + //===- Breaks -----------------------------------------------------------===// + + /// Add a 'neverbreak' break. Always 'fits'. + void neverbreak() { add(0, 0, true); } + + /// Add a newline (break too wide to fit, always breaks). + void newline() { add(PrettyPrinter::kInfinity); } + + /// Add breakable spaces. + void spaces(uint32_t n) { add(n); } + + /// Add a breakable space. + void space() { spaces(1); } + + /// Add a break that is zero-wide if not broken. + void zerobreak() { add(0); } + + //===- Groups -----------------------------------------------------------===// + + /// Start a IndentStyle::Block group with specified offset. + void bbox(int32_t offset = 0, Breaks breaks = Breaks::Consistent) { + add(offset, breaks, IndentStyle::Block); + } + + /// Start a consistent group with specified offset. + void cbox(int32_t offset = 0, IndentStyle style = IndentStyle::Visual) { + add(offset, Breaks::Consistent, style); + } + + /// Start an inconsistent group with specified offset. + void ibox(int32_t offset = 0, IndentStyle style = IndentStyle::Visual) { + add(offset, Breaks::Inconsistent, style); + } + + /// Start a group that cannot break, including nested groups. + /// Use sparingly. + void neverbox() { add(0, Breaks::Never); } + + /// End a group. + void end() { add(); } +}; + +/// PrettyPrinter::Listener that saves strings while live. +/// Once they're no longer referenced, memory is reset. +/// Allows differentiating between strings to save and external strings. +class TokenStringSaver : public PrettyPrinter::Listener { + llvm::BumpPtrAllocator alloc; + llvm::StringSaver strings; + + public: + TokenStringSaver() : strings(alloc) {} + + /// Add string, save in storage. + [[nodiscard]] StringRef save(StringRef str) { return strings.save(str); } + + /// PrettyPrinter::Listener::clear -- indicates no external refs. + void clear() override; +}; + +/// Note: Callable class must implement a callable with signature: +/// void (Data) +template +class PrintEventAndStorageListener : public TokenStringSaver { + /// List of all the unique data associated with each callback token. + /// The fact that tokens on a stream can never be printed out of order, + /// ensures that CallbackTokens are always added and invoked in FIFO order, + /// hence no need to record an index into the Data list. + std::queue dataQ; + /// The storage for the callback, as a function object. + CallableTy &callable; + + public: + PrintEventAndStorageListener(CallableTy &c) : callable(c) {} + + /// PrettyPrinter::Listener::print -- indicates all the preceding tokens on + /// the stream have been printed. + /// This is invoked when the CallbackToken is printed. + void print() override { + std::invoke(callable, dataQ.front()); + dataQ.pop(); + } + /// Get a token with the obj data. + CallbackToken getToken(DataTy obj) { + // Insert data onto the list. + dataQ.push(obj); + return CallbackToken(); + } +}; + +//===----------------------------------------------------------------------===// +// Streaming support. +//===----------------------------------------------------------------------===// + +/// Send one of these to TokenStream to add the corresponding token. +/// See TokenBuilder for details of each. +enum class PP { + bbox2, + cbox0, + cbox2, + end, + eof, + ibox0, + ibox2, + nbsp, + neverbox, + neverbreak, + newline, + space, + zerobreak, +}; + +/// String wrapper to indicate string has external storage. +struct PPExtString { + StringRef str; + explicit PPExtString(StringRef str) : str(str) {} +}; + +/// String wrapper to indicate string needs to be saved. +struct PPSaveString { + StringRef str; + explicit PPSaveString(StringRef str) : str(str) {} +}; + +/// Wrap a PrettyPrinter with TokenBuilder features as well as operator<<'s. +/// String behavior: +/// Strings streamed as `const char *` are assumed to have external storage, +/// and StringRef's are saved until no longer needed. +/// Use PPExtString() and PPSaveString() wrappers to specify/override behavior. +template +class TokenStream : public TokenBuilder { + using Base = TokenBuilder; + TokenStringSaver &saver; + + public: + /// Create a TokenStream using the specified PrettyPrinter and StringSaver + /// storage. Strings are saved in `saver`, which is generally the listener in + /// the PrettyPrinter, but may not be (e.g., using BufferingPP). + TokenStream(PPTy &pp, TokenStringSaver &saver) : Base(pp), saver(saver) {} + + /// Add a string literal (external storage). + TokenStream &operator<<(const char *s) { + Base::literal(s); + return *this; + } + /// Add a string token (saved to storage). + TokenStream &operator<<(StringRef s) { + Base::template add(saver.save(s)); + return *this; + } + + /// String has external storage. + TokenStream &operator<<(const PPExtString &str) { + Base::literal(str.str); + return *this; + } + + /// String must be saved. + TokenStream &operator<<(const PPSaveString &str) { + Base::template add(saver.save(str.str)); + return *this; + } + + /// Convenience for inline streaming of builder methods. + TokenStream &operator<<(PP s) { + switch (s) { + case PP::bbox2: + Base::bbox(2); + break; + case PP::cbox0: + Base::cbox(0); + break; + case PP::cbox2: + Base::cbox(2); + break; + case PP::end: + Base::end(); + break; + case PP::eof: + Base::eof(); + break; + case PP::ibox0: + Base::ibox(0); + break; + case PP::ibox2: + Base::ibox(2); + break; + case PP::nbsp: + Base::nbsp(); + break; + case PP::neverbox: + Base::neverbox(); + break; + case PP::neverbreak: + Base::neverbreak(); + break; + case PP::newline: + Base::newline(); + break; + case PP::space: + Base::space(); + break; + case PP::zerobreak: + Base::zerobreak(); + break; + } + return *this; + } + + /// Stream support for user-created Token's. + TokenStream &operator<<(Token t) { + Base::addToken(t); + return *this; + } + + /// General-purpose "format this" helper, for types not supported by + /// operator<< yet. + template + TokenStream &addAsString(T &&t) { + invokeWithStringOS([&](auto &os) { os << std::forward(t); }); + return *this; + } + + /// Helper to invoke code with a llvm::raw_ostream argument for compatibility. + /// All data is gathered into a single string token. + template + auto invokeWithStringOS(Callable &&c) { + SmallString ss; + llvm::raw_svector_ostream ssos(ss); + auto flush = llvm::make_scope_exit([&]() { + if (!ss.empty()) *this << ss; + }); + return std::invoke(std::forward(c), ssos); + } + + /// Write escaped versions of the string, saved in storage. + TokenStream &writeEscaped(StringRef str, bool useHexEscapes = false) { + return writeQuotedEscaped(str, useHexEscapes, "", ""); + } + TokenStream &writeQuotedEscaped(StringRef str, bool useHexEscapes = false, + StringRef left = "\"", + StringRef right = "\"") { + // Add as a single StringToken. + invokeWithStringOS([&](auto &os) { + os << left; + os.write_escaped(str, useHexEscapes); + os << right; + }); + return *this; + } + + /// Open a box, invoke the lambda, and close it after. + template + auto scopedBox(T &&t, Callable &&c, Token close = EndToken()) { + *this << std::forward(t); + auto done = llvm::make_scope_exit([&]() { *this << close; }); + return std::invoke(std::forward(c)); + } +}; + +} // end namespace pretty +} // end namespace circt + +#endif // CIRCT_SUPPORT_PRETTYPRINTERHELPERS_H diff --git a/third_party/circt/include/circt/Support/SymCache.h b/third_party/circt/include/circt/Support/SymCache.h new file mode 100644 index 000000000..10cc62ba7 --- /dev/null +++ b/third_party/circt/include/circt/Support/SymCache.h @@ -0,0 +1,132 @@ +//===- SymCache.h - Declare Symbol Cache ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a Symbol Cache. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_SYMCACHE_H +#define CIRCT_SUPPORT_SYMCACHE_H + +#include "llvm/ADT/iterator.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/SymbolTable.h" + +namespace circt { + +/// Base symbol cache class to allow for cache lookup through a pointer to some +/// abstract cache. A symbol cache stores lookup tables to make manipulating and +/// working with the IR more efficient. +class SymbolCacheBase { + public: + virtual ~SymbolCacheBase(); + + /// Defines 'op' as associated with the 'symbol' in the cache. + virtual void addDefinition(mlir::Attribute symbol, mlir::Operation *op) = 0; + + /// Adds the symbol-defining 'op' to the cache. + void addSymbol(mlir::SymbolOpInterface op) { + addDefinition(op.getNameAttr(), op); + } + + /// Populate the symbol cache with all symbol-defining operations within the + /// 'top' operation. + void addDefinitions(mlir::Operation *top); + + /// Lookup a definition for 'symbol' in the cache. + virtual mlir::Operation *getDefinition(mlir::Attribute symbol) const = 0; + + /// Lookup a definition for 'symbol' in the cache. + mlir::Operation *getDefinition(mlir::FlatSymbolRefAttr symbol) const { + return getDefinition(symbol.getAttr()); + } + + /// Iterator support through a pointer to some abstract cache. + /// The implementing cache must provide an iterator that carries values on the + /// form of . + using CacheItem = std::pair; + struct CacheIteratorImpl { + virtual ~CacheIteratorImpl() {} + virtual void operator++() = 0; + virtual CacheItem operator*() = 0; + virtual bool operator==(CacheIteratorImpl *other) = 0; + }; + + struct Iterator + : public llvm::iterator_facade_base { + Iterator(std::unique_ptr &&impl) + : impl(std::move(impl)) {} + CacheItem operator*() const { return **impl; } + using llvm::iterator_facade_base::operator++; + bool operator==(const Iterator &other) const { + return *impl == other.impl.get(); + } + void operator++() { impl->operator++(); } + + private: + std::unique_ptr impl; + }; + virtual Iterator begin() = 0; + virtual Iterator end() = 0; +}; + +/// Default symbol cache implementation; stores associations between names +/// (StringAttr's) to mlir::Operation's. +/// Adding/getting definitions from the symbol cache is not +/// thread safe. If this is required, synchronizing cache acccess should be +/// ensured by the caller. +class SymbolCache : public SymbolCacheBase { + public: + /// In the building phase, add symbols. + void addDefinition(mlir::Attribute key, mlir::Operation *op) override { + symbolCache.try_emplace(key, op); + } + + // Pull in getDefinition(mlir::FlatSymbolRefAttr symbol) + using SymbolCacheBase::getDefinition; + mlir::Operation *getDefinition(mlir::Attribute attr) const override { + auto it = symbolCache.find(attr); + if (it == symbolCache.end()) return nullptr; + return it->second; + } + + protected: + /// This stores a lookup table from symbol attribute to the operation + /// that defines it. + llvm::DenseMap symbolCache; + + private: + /// Iterator support: A simple mapping between decltype(symbolCache)::iterator + /// to SymbolCacheBase::Iterator. + using Iterator = decltype(symbolCache)::iterator; + struct SymbolCacheIteratorImpl : public CacheIteratorImpl { + SymbolCacheIteratorImpl(Iterator it) : it(it) {} + CacheItem operator*() override { return {it->getFirst(), it->getSecond()}; } + void operator++() override { it++; } + bool operator==(CacheIteratorImpl *other) override { + return it == static_cast(other)->it; + } + Iterator it; + }; + + public: + SymbolCacheBase::Iterator begin() override { + return SymbolCacheBase::Iterator( + std::make_unique(symbolCache.begin())); + } + SymbolCacheBase::Iterator end() override { + return SymbolCacheBase::Iterator( + std::make_unique(symbolCache.end())); + } +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_SYMCACHE_H diff --git a/third_party/circt/include/circt/Support/ValueMapper.h b/third_party/circt/include/circt/Support/ValueMapper.h new file mode 100644 index 000000000..ffcd5346e --- /dev/null +++ b/third_party/circt/include/circt/Support/ValueMapper.h @@ -0,0 +1,63 @@ +//===- ValueMapper.h - Support for mapping SSA values -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides support for mapping SSA values between two domains. +// Provided a BackedgeBuilder, the ValueMapper supports mappings between +// GraphRegions, creating Backedges in cases of 'get'ing mapped values which are +// yet to be 'set'. +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_SUPPORT_VALUEMAPPER_H +#define CIRCT_SUPPORT_VALUEMAPPER_H + +#include +#include + +#include "circt/Support/BackedgeBuilder.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" + +namespace circt { + +/// The ValueMapper class facilitates the definition and connection of SSA +/// def-use chains between two location - a 'from' location (defining +/// use-def chains) and a 'to' location (where new operations are created based +/// on the 'from' location).´ +class ValueMapper { + public: + using TypeTransformer = llvm::function_ref; + static mlir::Type identity(mlir::Type t) { return t; }; + explicit ValueMapper(BackedgeBuilder *bb = nullptr) : bb(bb) {} + + // Get the mapped value of value 'from'. If no mapping has been registered, a + // new backedge is created. The type of the mapped value may optionally be + // modified through the 'typeTransformer'. + mlir::Value get(mlir::Value from, + TypeTransformer typeTransformer = ValueMapper::identity); + llvm::SmallVector get( + mlir::ValueRange from, + TypeTransformer typeTransformer = ValueMapper::identity); + + // Set the mapped value of 'from' to 'to'. If 'from' is already mapped to a + // backedge, replaces that backedge with 'to'. If 'replace' is not set, and a + // (non-backedge) mapping already exists, an assert is thrown. + void set(mlir::Value from, mlir::Value to, bool replace = false); + void set(mlir::ValueRange from, mlir::ValueRange to, bool replace = false); + + private: + BackedgeBuilder *bb = nullptr; + llvm::DenseMap> mapping; +}; + +} // namespace circt + +#endif // CIRCT_SUPPORT_VALUEMAPPER_H diff --git a/third_party/circt/include/circt/Support/Version.h b/third_party/circt/include/circt/Support/Version.h new file mode 100644 index 000000000..206f0688e --- /dev/null +++ b/third_party/circt/include/circt/Support/Version.h @@ -0,0 +1,16 @@ +#ifndef CIRCT_SUPPORT_VERSION_H +#define CIRCT_SUPPORT_VERSION_H + +#include + +namespace circt { +const char *getCirctVersion(); +const char *getCirctVersionComment(); + +/// A generic bug report message for CIRCT-related projects +constexpr const char *circtBugReportMsg = + "PLEASE submit a bug report to https://github.com/llvm/circt and include " + "the crash backtrace.\n"; +} // namespace circt + +#endif // CIRCT_SUPPORT_VERSION_H diff --git a/third_party/circt/lib/Dialect/Comb/BUILD b/third_party/circt/lib/Dialect/Comb/BUILD new file mode 100644 index 000000000..441a1decc --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/BUILD @@ -0,0 +1,23 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = glob([ + "*.cpp", + ]), + deps = [ + "@heir//third_party/circt/include/circt/Dialect/Comb:dialect_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/Comb:enum_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/Comb:headers", + "@heir//third_party/circt/include/circt/Dialect/Comb:ops_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/Comb:type_inc_gen", + "@heir//third_party/circt/lib/Dialect/HW:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) diff --git a/third_party/circt/lib/Dialect/Comb/CMakeLists.txt b/third_party/circt/lib/Dialect/Comb/CMakeLists.txt new file mode 100644 index 000000000..8034eff9b --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/CMakeLists.txt @@ -0,0 +1,26 @@ +add_circt_dialect_library(CIRCTComb + CombFolds.cpp + CombOps.cpp + CombAnalysis.cpp + CombDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Comb + + DEPENDS + CIRCTHW + MLIRCombIncGen + MLIRCombEnumsIncGen + + LINK_COMPONENTS + Support + + LINK_LIBS PUBLIC + CIRCTHW + MLIRIR + MLIRInferTypeOpInterface + ) + +add_dependencies(circt-headers MLIRCombIncGen MLIRCombEnumsIncGen) + +add_subdirectory(Transforms) diff --git a/third_party/circt/lib/Dialect/Comb/CombAnalysis.cpp b/third_party/circt/lib/Dialect/Comb/CombAnalysis.cpp new file mode 100644 index 000000000..c2c96fea5 --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/CombAnalysis.cpp @@ -0,0 +1,87 @@ +//===- CombAnalysis.cpp - Analysis Helpers for Comb+HW operations ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "llvm/Support/KnownBits.h" + +using namespace circt; +using namespace comb; + +/// Given an integer SSA value, check to see if we know anything about the +/// result of the computation. For example, we know that "and with a constant" +/// always returns zeros for the zero bits in a constant. +/// +/// Expression trees can be very large, so we need ot make sure to cap our +/// recursion, this is controlled by `depth`. +static KnownBits computeKnownBits(Value v, unsigned depth) { + Operation *op = v.getDefiningOp(); + if (!op || depth == 5) return KnownBits(v.getType().getIntOrFloatBitWidth()); + + // A constant has all bits known! + if (auto constant = dyn_cast(op)) + return KnownBits::makeConstant(constant.getValue()); + + // `concat(x, y, z)` has whatever is known about the operands concat'd. + if (auto concatOp = dyn_cast(op)) { + auto result = computeKnownBits(concatOp.getOperand(0), depth + 1); + for (size_t i = 1, e = concatOp.getNumOperands(); i != e; ++i) { + auto otherBits = computeKnownBits(concatOp.getOperand(i), depth + 1); + unsigned width = otherBits.getBitWidth(); + unsigned newWidth = result.getBitWidth() + width; + result.Zero = + (result.Zero.zext(newWidth) << width) | otherBits.Zero.zext(newWidth); + result.One = + (result.One.zext(newWidth) << width) | otherBits.One.zext(newWidth); + } + return result; + } + + // `and(x, y, z)` has whatever is known about the operands intersected. + if (auto andOp = dyn_cast(op)) { + auto result = computeKnownBits(andOp.getOperand(0), depth + 1); + for (size_t i = 1, e = andOp.getNumOperands(); i != e; ++i) + result &= computeKnownBits(andOp.getOperand(i), depth + 1); + return result; + } + + // `or(x, y, z)` has whatever is known about the operands unioned. + if (auto orOp = dyn_cast(op)) { + auto result = computeKnownBits(orOp.getOperand(0), depth + 1); + for (size_t i = 1, e = orOp.getNumOperands(); i != e; ++i) + result |= computeKnownBits(orOp.getOperand(i), depth + 1); + return result; + } + + // `xor(x, cst)` inverts known bits and passes through unmodified ones. + if (auto xorOp = dyn_cast(op)) { + auto result = computeKnownBits(xorOp.getOperand(0), depth + 1); + for (size_t i = 1, e = xorOp.getNumOperands(); i != e; ++i) { + // If we don't know anything, we don't need to evaluate more subexprs. + if (result.isUnknown()) return result; + result ^= computeKnownBits(xorOp.getOperand(i), depth + 1); + } + return result; + } + + // `mux(cond, x, y)` is the intersection of the known bits of `x` and `y`. + if (auto muxOp = dyn_cast(op)) { + auto lhs = computeKnownBits(muxOp.getTrueValue(), depth + 1); + auto rhs = computeKnownBits(muxOp.getFalseValue(), depth + 1); + return lhs.intersectWith(rhs); + } + + return KnownBits(v.getType().getIntOrFloatBitWidth()); +} + +/// Given an integer SSA value, check to see if we know anything about the +/// result of the computation. For example, we know that "and with a +/// constant" always returns zeros for the zero bits in a constant. +KnownBits comb::computeKnownBits(Value value) { + return ::computeKnownBits(value, 0); +} diff --git a/third_party/circt/lib/Dialect/Comb/CombDialect.cpp b/third_party/circt/lib/Dialect/Comb/CombDialect.cpp new file mode 100644 index 000000000..0d6211998 --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/CombDialect.cpp @@ -0,0 +1,62 @@ +//===- CombDialect.cpp - Implement the Comb dialect -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Comb dialect. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombDialect.h" + +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace circt; +using namespace comb; + +//===----------------------------------------------------------------------===// +// Dialect specification. +//===----------------------------------------------------------------------===// + +void CombDialect::initialize() { + // Register operations. + addOperations< +#define GET_OP_LIST +#include "circt/Dialect/Comb/Comb.cpp.inc" + >(); +} + +/// Registered hook to materialize a single constant operation from a given +/// attribute value with the desired resultant type. This method should use +/// the provided builder to create the operation without changing the +/// insertion position. The generated operation is expected to be constant +/// like, i.e. single result, zero operands, non side-effecting, etc. On +/// success, this hook should return the value generated to represent the +/// constant value. Otherwise, it should return null on failure. +Operation *CombDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + // Integer constants. + if (auto intType = type.dyn_cast()) + if (auto attrValue = value.dyn_cast()) + return builder.create(loc, type, attrValue); + + // Parameter expressions materialize into hw.param.value. + auto parentOp = builder.getBlock()->getParentOp(); + auto curModule = dyn_cast(parentOp); + if (!curModule) curModule = parentOp->getParentOfType(); + if (curModule && isValidParameterExpression(value, curModule)) + return builder.create(loc, type, value); + + return nullptr; +} + +// Provide implementations for the enums we use. +#include "circt/Dialect/Comb/CombDialect.cpp.inc" +#include "circt/Dialect/Comb/CombEnums.cpp.inc" diff --git a/third_party/circt/lib/Dialect/Comb/CombFolds.cpp b/third_party/circt/lib/Dialect/Comb/CombFolds.cpp new file mode 100644 index 000000000..ded65c1f6 --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/CombFolds.cpp @@ -0,0 +1,3026 @@ +//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWAttributes.h" +#include "circt/Dialect/HW/HWOps.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/KnownBits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace circt; +using namespace comb; +using namespace matchers; + +/// Create a new instance of a generic operation that only has value operands, +/// and has a single result value whose type matches the first operand. +/// +/// This should not be used to create instances of ops with attributes or with +/// more complicated type signatures. +static Value createGenericOp(Location loc, OperationName name, + ArrayRef operands, OpBuilder &builder) { + OperationState state(loc, name); + state.addOperands(operands); + state.addTypes(operands[0].getType()); + return builder.create(state)->getResult(0); +} + +static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) { + return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()), + value); +} + +/// Flatten concat and mux operands into a vector. +static void getConcatOperands(Value v, SmallVectorImpl &result) { + if (auto concat = v.getDefiningOp()) { + for (auto op : concat.getOperands()) getConcatOperands(op, result); + } else if (auto repl = v.getDefiningOp()) { + for (size_t i = 0, e = repl.getMultiple(); i != e; ++i) + getConcatOperands(repl.getOperand(), result); + } else { + result.push_back(v); + } +} + +/// A wrapper of `PatternRewriter::replaceOp` to propagate "sv.namehint" +/// attribute. If a replaced op has a "sv.namehint" attribute, this function +/// propagates the name to the new value. +static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, + Value newValue) { + if (auto *newOp = newValue.getDefiningOp()) { + auto name = op->getAttrOfType("sv.namehint"); + if (name && !newOp->hasAttr("sv.namehint")) + rewriter.updateRootInPlace(newOp, + [&] { newOp->setAttr("sv.namehint", name); }); + } + rewriter.replaceOp(op, newValue); +} + +/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate +/// "sv.namehint" attribute. If a replaced op has a "sv.namehint" attribute, +/// this function propagates the name to the new value. +template +static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, + Operation *op, Args &&...args) { + auto name = op->getAttrOfType("sv.namehint"); + auto newOp = + rewriter.replaceOpWithNewOp(op, std::forward(args)...); + if (name && !newOp->hasAttr("sv.namehint")) + rewriter.updateRootInPlace(newOp, + [&] { newOp->setAttr("sv.namehint", name); }); + + return newOp; +} + +// Return true if the op has SV attributes. Note that we cannot use a helper +// function `hasSVAttributes` defined under SV dialect because of a cyclic +// dependency. +static bool hasSVAttributes(Operation *op) { + return op->hasAttr("sv.attributes"); +} + +namespace { +template +struct ComplementMatcher { + SubType lhs; + ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {} + bool match(Operation *op) { + auto xorOp = dyn_cast(op); + return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0)); + } +}; +} // end anonymous namespace + +template +static inline ComplementMatcher m_Complement(const SubType &subExpr) { + return ComplementMatcher(subExpr); +} + +/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined +/// as an Op. Returns true if successful, and false otherwise. +/// +/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true +/// +static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) { + auto inputs = op->getOperands(); + + for (size_t i = 0, size = inputs.size(); i != size; ++i) { + Operation *flattenOp = inputs[i].getDefiningOp(); + if (!flattenOp || flattenOp->getName() != op->getName()) continue; + + // Check for loops + if (flattenOp == op) continue; + + // Don't duplicate logic when it has multiple uses. + if (!inputs[i].hasOneUse()) { + // We can fold a multi-use binary operation into this one if this allows a + // constant to fold though. For example, fold + // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2) + // since the constants will both fold and we end up with the equiv cost. + // + // We don't do this for add/mul because the hardware won't be shared + // between the two ops if duplicated. + if (flattenOp->getNumOperands() != 2 || !isa(op) || + !flattenOp->getOperand(1).getDefiningOp() || + !inputs.back().getDefiningOp()) + continue; + } + + // Otherwise, flatten away. + auto flattenOpInputs = flattenOp->getOperands(); + + SmallVector newOperands; + newOperands.reserve(size + flattenOpInputs.size()); + + auto flattenOpIndex = inputs.begin() + i; + newOperands.append(inputs.begin(), flattenOpIndex); + newOperands.append(flattenOpInputs.begin(), flattenOpInputs.end()); + newOperands.append(flattenOpIndex + 1, inputs.end()); + + Value result = + createGenericOp(op->getLoc(), op->getName(), newOperands, rewriter); + + // If the original operation and flatten operand have bin flags, propagte + // the flag to new one. + if (op->hasAttrOfType("twoState") && + flattenOp->hasAttrOfType("twoState")) + result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr()); + + replaceOpAndCopyName(rewriter, op, result); + return true; + } + return false; +} + +// Given a range of uses of an operation, find the lowest and highest bits +// inclusive that are ever referenced. The range of uses must not be empty. +static std::pair getLowestBitAndHighestBitRequired( + Operation *op, bool narrowTrailingBits, size_t originalOpWidth) { + auto users = op->getUsers(); + assert(!users.empty() && + "getLowestBitAndHighestBitRequired cannot operate on " + "a empty list of uses."); + + // when we don't want to narrowTrailingBits (namely in arithmetic + // operations), forcing lowestBitRequired = 0 + size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0; + size_t highestBitRequired = 0; + + for (auto *user : users) { + if (auto extractOp = dyn_cast(user)) { + size_t lowBit = extractOp.getLowBit(); + size_t highBit = + extractOp.getType().cast().getWidth() + lowBit - 1; + highestBitRequired = std::max(highestBitRequired, highBit); + lowestBitRequired = std::min(lowestBitRequired, lowBit); + continue; + } + + highestBitRequired = originalOpWidth - 1; + lowestBitRequired = 0; + break; + } + + return {lowestBitRequired, highestBitRequired}; +} + +template +static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, + PatternRewriter &rewriter) { + IntegerType opType = + op.getResult().getType().template dyn_cast(); + if (!opType) return false; + + auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits, + opType.getWidth()); + if (range.second + 1 == opType.getWidth() && range.first == 0) return false; + + SmallVector args; + auto newType = rewriter.getIntegerType(range.second - range.first + 1); + for (auto inop : op.getOperands()) { + // deal with muxes here + if (inop.getType() != op.getType()) + args.push_back(inop); + else + args.push_back(rewriter.createOrFold(inop.getLoc(), newType, + inop, range.first)); + } + Value newop = rewriter.createOrFold(op.getLoc(), newType, args); + newop.getDefiningOp()->setDialectAttrs(op->getDialectAttrs()); + if (range.first) + newop = rewriter.createOrFold( + op.getLoc(), newop, + rewriter.create(op.getLoc(), + APInt::getZero(range.first))); + if (range.second + 1 < opType.getWidth()) + newop = rewriter.createOrFold( + op.getLoc(), + rewriter.create( + op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)), + newop); + rewriter.replaceOp(op, newop); + return true; +} + +//===----------------------------------------------------------------------===// +// Unary Operations +//===----------------------------------------------------------------------===// + +OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) { + // Replicate one time -> noop. + if (getType().cast().getWidth() == + getInput().getType().getIntOrFloatBitWidth()) + return getInput(); + + // Constant fold. + if (auto input = adaptor.getInput().dyn_cast_or_null()) { + if (input.getValue().getBitWidth() == 1) { + if (input.getValue().isZero()) + return getIntAttr( + APInt::getZero(getType().cast().getWidth()), + getContext()); + return getIntAttr( + APInt::getAllOnes(getType().cast().getWidth()), + getContext()); + } + + APInt result = APInt::getZeroWidth(); + for (auto i = getMultiple(); i != 0; --i) + result = result.concat(input.getValue()); + return getIntAttr(result, getContext()); + } + + return {}; +} + +OpFoldResult ParityOp::fold(FoldAdaptor adaptor) { + // Constant fold. + if (auto input = adaptor.getInput().dyn_cast_or_null()) + return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext()); + + return {}; +} + +//===----------------------------------------------------------------------===// +// Binary Operations +//===----------------------------------------------------------------------===// + +/// Performs constant folding `calculate` with element-wise behavior on the two +/// attributes in `operands` and returns the result if possible. +static Attribute constFoldBinaryOp(ArrayRef operands, + hw::PEO paramOpcode) { + assert(operands.size() == 2 && "binary op takes two operands"); + if (!operands[0] || !operands[1]) return {}; + + // Fold constants with ParamExprAttr::get which handles simple constants as + // well as parameter expressions. + return hw::ParamExprAttr::get(paramOpcode, operands[0].cast(), + operands[1].cast()); +} + +OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { + if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { + unsigned shift = rhs.getValue().getZExtValue(); + unsigned width = getType().getIntOrFloatBitWidth(); + if (shift == 0) return getOperand(0); + if (width <= shift) return getIntAttr(APInt::getZero(width), getContext()); + } + + return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl); +} + +LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) { + // ShlOp(x, cst) -> Concat(Extract(x), zeros) + APInt value; + if (!matchPattern(op.getRhs(), m_ConstantInt(&value))) return failure(); + + unsigned width = op.getLhs().getType().cast().getWidth(); + unsigned shift = value.getZExtValue(); + + // This case is handled by fold. + if (width <= shift || shift == 0) return failure(); + + auto zeros = + rewriter.create(op.getLoc(), APInt::getZero(shift)); + + // Remove the high bits which would be removed by the Shl. + auto extract = + rewriter.create(op.getLoc(), op.getLhs(), 0, width - shift); + + replaceOpWithNewOpAndCopyName(rewriter, op, extract, zeros); + return success(); +} + +OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { + if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { + unsigned shift = rhs.getValue().getZExtValue(); + if (shift == 0) return getOperand(0); + + unsigned width = getType().getIntOrFloatBitWidth(); + if (width <= shift) return getIntAttr(APInt::getZero(width), getContext()); + } + return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU); +} + +LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) { + // ShrUOp(x, cst) -> Concat(zeros, Extract(x)) + APInt value; + if (!matchPattern(op.getRhs(), m_ConstantInt(&value))) return failure(); + + unsigned width = op.getLhs().getType().cast().getWidth(); + unsigned shift = value.getZExtValue(); + + // This case is handled by fold. + if (width <= shift || shift == 0) return failure(); + + auto zeros = + rewriter.create(op.getLoc(), APInt::getZero(shift)); + + // Remove the low bits which would be removed by the Shr. + auto extract = rewriter.create(op.getLoc(), op.getLhs(), shift, + width - shift); + + replaceOpWithNewOpAndCopyName(rewriter, op, zeros, extract); + return success(); +} + +OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { + if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { + if (rhs.getValue().getZExtValue() == 0) return getOperand(0); + } + return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS); +} + +LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) { + // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x)) + APInt value; + if (!matchPattern(op.getRhs(), m_ConstantInt(&value))) return failure(); + + unsigned width = op.getLhs().getType().cast().getWidth(); + unsigned shift = value.getZExtValue(); + + auto topbit = + rewriter.createOrFold(op.getLoc(), op.getLhs(), width - 1, 1); + auto sext = rewriter.createOrFold(op.getLoc(), topbit, shift); + + if (width <= shift) { + replaceOpAndCopyName(rewriter, op, {sext}); + return success(); + } + + auto extract = rewriter.create(op.getLoc(), op.getLhs(), shift, + width - shift); + + replaceOpWithNewOpAndCopyName(rewriter, op, sext, extract); + return success(); +} + +//===----------------------------------------------------------------------===// +// Other Operations +//===----------------------------------------------------------------------===// + +OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { + // If we are extracting the entire input, then return it. + if (getInput().getType() == getType()) return getInput(); + + // Constant fold. + if (auto input = adaptor.getInput().dyn_cast_or_null()) { + unsigned dstWidth = getType().cast().getWidth(); + return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth), + getContext()); + } + return {}; +} + +// Transforms extract(lo, cat(a, b, c, d, e)) into +// cat(extract(lo1, b), c, extract(lo2, d)). +// innerCat must be the argument of the provided ExtractOp. +static LogicalResult extractConcatToConcatExtract(ExtractOp op, + ConcatOp innerCat, + PatternRewriter &rewriter) { + auto reversedConcatArgs = llvm::reverse(innerCat.getInputs()); + size_t beginOfFirstRelevantElement = 0; + auto it = reversedConcatArgs.begin(); + size_t lowBit = op.getLowBit(); + + // This loop finds the first concatArg that is covered by the ExtractOp + for (; it != reversedConcatArgs.end(); it++) { + assert(beginOfFirstRelevantElement <= lowBit && + "incorrectly moved past an element that lowBit has coverage over"); + auto operand = *it; + + size_t operandWidth = operand.getType().getIntOrFloatBitWidth(); + if (lowBit < beginOfFirstRelevantElement + operandWidth) { + // A bit other than the first bit will be used in this element. + // ...... ........ ... + // ^---lowBit + // ^---beginOfFirstRelevantElement + // + // Edge-case close to the end of the range. + // ...... ........ ... + // ^---(position + operandWidth) + // ^---lowBit + // ^---beginOfFirstRelevantElement + // + // Edge-case close to the beginning of the rang + // ...... ........ ... + // ^---lowBit + // ^---beginOfFirstRelevantElement + // + break; + } + + // extraction discards this element. + // ...... ........ ... + // | ^---lowBit + // ^---beginOfFirstRelevantElement + beginOfFirstRelevantElement += operandWidth; + } + assert(it != reversedConcatArgs.end() && + "incorrectly failed to find an element which contains coverage of " + "lowBit"); + + SmallVector reverseConcatArgs; + size_t widthRemaining = op.getType().cast().getWidth(); + size_t extractLo = lowBit - beginOfFirstRelevantElement; + + // Transform individual arguments of innerCat(..., a, b, c,) into + // [ extract(a), b, extract(c) ], skipping an extract operation where + // possible. + for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) { + auto concatArg = *it; + size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth(); + size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo); + + if (widthToConsume == operandWidth && extractLo == 0) { + reverseConcatArgs.push_back(concatArg); + } else { + auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume); + reverseConcatArgs.push_back( + rewriter.create(op.getLoc(), resultType, *it, extractLo)); + } + + widthRemaining -= widthToConsume; + + // Beyond the first element, all elements are extracted from position 0. + extractLo = 0; + } + + if (reverseConcatArgs.size() == 1) { + replaceOpAndCopyName(rewriter, op, reverseConcatArgs[0]); + } else { + replaceOpWithNewOpAndCopyName( + rewriter, op, SmallVector(llvm::reverse(reverseConcatArgs))); + } + return success(); +} + +// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c). +static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, + PatternRewriter &rewriter) { + auto extractResultWidth = op.getType().cast().getWidth(); + auto replicateEltWidth = + replicate.getOperand().getType().getIntOrFloatBitWidth(); + + // If the extract starts at the base of an element and is an even multiple, + // we can replace the extract with a smaller replicate. + if (op.getLowBit() % replicateEltWidth == 0 && + extractResultWidth % replicateEltWidth == 0) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + replicate.getOperand()); + return true; + } + + // If the extract is completely contained in one element, extract from the + // element. + if (op.getLowBit() % replicateEltWidth + extractResultWidth <= + replicateEltWidth) { + replaceOpWithNewOpAndCopyName( + rewriter, op, op.getType(), replicate.getOperand(), + op.getLowBit() % replicateEltWidth); + return true; + } + + // We don't currently handle the case of extracting from non-whole elements, + // e.g. `extract (replicate 2-bit-thing, N), 1`. + return false; +} + +LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) { + auto *inputOp = op.getInput().getDefiningOp(); + + // This turns out to be incredibly expensive. Disable until performance is + // addressed. +#if 0 + // If the extracted bits are all known, then return the result. + auto knownBits = computeKnownBits(op.getInput()) + .extractBits(op.getType().cast().getWidth(), + op.getLowBit()); + if (knownBits.isConstant()) { + replaceOpWithNewOpAndCopyName(rewriter, op, + knownBits.getConstant()); + return success(); + } +#endif + + // extract(olo, extract(ilo, x)) = extract(olo + ilo, x) + if (auto innerExtract = dyn_cast_or_null(inputOp)) { + replaceOpWithNewOpAndCopyName( + rewriter, op, op.getType(), innerExtract.getInput(), + innerExtract.getLowBit() + op.getLowBit()); + return success(); + } + + // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d)) + if (auto innerCat = dyn_cast_or_null(inputOp)) + return extractConcatToConcatExtract(op, innerCat, rewriter); + + // extract(lo, replicate(a)) + if (auto replicate = dyn_cast_or_null(inputOp)) + if (extractFromReplicate(op, replicate, rewriter)) return success(); + + // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the + // and/or/xor are not modifying the extracted bits. + if (inputOp && inputOp->getNumOperands() == 2 && + isa(inputOp)) { + if (auto cstRHS = inputOp->getOperand(1).getDefiningOp()) { + auto extractedCst = cstRHS.getValue().extractBits( + op.getType().cast().getWidth(), op.getLowBit()); + if (isa(inputOp) && extractedCst.isZero()) { + replaceOpWithNewOpAndCopyName( + rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit()); + return success(); + } + + // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one + // extract to represent the result. Turning it into a pile of extracts is + // always fine by our cost model, but we don't want to explode things into + // a ton of bits because it will bloat the IR and generated Verilog. + if (isa(inputOp)) { + // For our cost model, we only do this if the bit pattern is a + // contiguous series of ones. + unsigned lz = extractedCst.countLeadingZeros(); + unsigned tz = extractedCst.countTrailingZeros(); + unsigned pop = extractedCst.popcount(); + if (extractedCst.getBitWidth() - lz - tz == pop) { + auto resultTy = rewriter.getIntegerType(pop); + SmallVector resultElts; + if (lz) + resultElts.push_back(rewriter.create( + op.getLoc(), APInt::getZero(lz))); + resultElts.push_back(rewriter.createOrFold( + op.getLoc(), resultTy, inputOp->getOperand(0), + op.getLowBit() + tz)); + if (tz) + resultElts.push_back(rewriter.create( + op.getLoc(), APInt::getZero(tz))); + replaceOpWithNewOpAndCopyName(rewriter, op, resultElts); + return success(); + } + } + } + } + + // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is + // extracted. + if (op.getType().cast().getWidth() == 1 && inputOp) + if (auto shlOp = dyn_cast(inputOp)) + if (auto lhsCst = shlOp.getOperand(0).getDefiningOp()) + if (lhsCst.getValue().isOne()) { + auto newCst = rewriter.create( + shlOp.getLoc(), + APInt(lhsCst.getValue().getBitWidth(), op.getLowBit())); + replaceOpWithNewOpAndCopyName(rewriter, op, ICmpPredicate::eq, + shlOp->getOperand(1), newCst, + false); + return success(); + } + + return failure(); +} + +//===----------------------------------------------------------------------===// +// Associative Variadic operations +//===----------------------------------------------------------------------===// + +// Reduce all operands to a single value (either integer constant or parameter +// expression) if all the operands are constants. +static Attribute constFoldAssociativeOp(ArrayRef operands, + hw::PEO paramOpcode) { + assert(operands.size() > 1 && "caller should handle one-operand case"); + // We can only fold anything in the case where all operands are known to be + // constants. Check the least common one first for an early out. + if (!operands[1] || !operands[0]) return {}; + + // This will fold to a simple constant if all operands are constant. + if (llvm::all_of(operands.drop_front(2), + [&](Attribute in) { return !!in; })) { + SmallVector typedOperands; + typedOperands.reserve(operands.size()); + for (auto operand : operands) { + if (auto typedOperand = operand.dyn_cast()) + typedOperands.push_back(typedOperand); + else + break; + } + if (typedOperands.size() == operands.size()) + return hw::ParamExprAttr::get(paramOpcode, typedOperands); + } + + return {}; +} + +/// When we find a logical operation (and, or, xor) with a constant e.g. +/// `X & 42`, we want to push the constant into the computation of X if it leads +/// to simplification. +/// +/// This function handles the case where the logical operation has a concat +/// operand. We check to see if we can simplify the concat, e.g. when it has +/// constant operands. +/// +/// This returns true when a simplification happens. +static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, + size_t concatIdx, const APInt &cst, + PatternRewriter &rewriter) { + auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp(); + assert((isa(logicalOp) && concatOp)); + + // Check to see if any operands can be simplified by pushing the logical op + // into all parts of the concat. + bool canSimplify = + llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool { + auto *operandOp = operand.getDefiningOp(); + if (!operandOp) return false; + + // If the concat has a constant operand then we can transform this. + if (isa(operandOp)) return true; + // If the concat has the same logical operation and that operation has + // a constant operation than we can fold it into that suboperation. + return operandOp->getName() == logicalOp->getName() && + operandOp->hasOneUse() && operandOp->getNumOperands() != 0 && + operandOp->getOperands().back().getDefiningOp(); + }); + + if (!canSimplify) return false; + + // Create a new instance of the logical operation. We have to do this the + // hard way since we're generic across a family of different ops. + auto createLogicalOp = [&](ArrayRef operands) -> Value { + return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands, + rewriter); + }; + + // Ok, let's do the transformation. We do this by slicing up the constant + // for each unit of the concat and duplicate the operation into the + // sub-operand. + SmallVector newConcatOperands; + newConcatOperands.reserve(concatOp->getNumOperands()); + + // Work from MSB to LSB. + size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth(); + for (Value operand : concatOp->getOperands()) { + size_t operandWidth = operand.getType().getIntOrFloatBitWidth(); + nextOperandBit -= operandWidth; + // Take a slice of the constant. + auto eltCst = rewriter.create( + logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth)); + + newConcatOperands.push_back(createLogicalOp({operand, eltCst})); + } + + // Create the concat, and the rest of the logical op if we need it. + Value newResult = + rewriter.create(concatOp.getLoc(), newConcatOperands); + + // If we had a variadic logical op on the top level, then recreate it with the + // new concat and without the constant operand. + if (logicalOp->getNumOperands() > 2) { + auto origOperands = logicalOp->getOperands(); + SmallVector operands; + // Take any stuff before the concat. + operands.append(origOperands.begin(), origOperands.begin() + concatIdx); + // Take any stuff after the concat but before the constant. + operands.append(origOperands.begin() + concatIdx + 1, + origOperands.begin() + (origOperands.size() - 1)); + // Include the new concat. + operands.push_back(newResult); + newResult = createLogicalOp(operands); + } + + replaceOpAndCopyName(rewriter, logicalOp, newResult); + return true; +} + +OpFoldResult AndOp::fold(FoldAdaptor adaptor) { + APInt value = APInt::getAllOnes(getType().cast().getWidth()); + + auto inputs = adaptor.getInputs(); + + // and(x, 01, 10) -> 00 -- annulment. + for (auto operand : inputs) { + if (!operand) continue; + value &= operand.cast().getValue(); + if (value.isZero()) return getIntAttr(value, getContext()); + } + + // and(x, -1) -> x. + if (inputs.size() == 2 && inputs[1] && + inputs[1].cast().getValue().isAllOnes()) + return getInputs()[0]; + + // and(x, x, x) -> x. This also handles and(x) -> x. + if (llvm::all_of(getInputs(), + [&](auto in) { return in == this->getInputs()[0]; })) + return getInputs()[0]; + + // and(..., x, ..., ~x, ...) -> 0 + for (Value arg : getInputs()) { + Value subExpr; + if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) { + for (Value arg2 : getInputs()) + if (arg2 == subExpr) + return getIntAttr( + APInt::getZero(getType().cast().getWidth()), + getContext()); + } + } + + // Constant fold + return constFoldAssociativeOp(inputs, hw::PEO::And); +} + +/// Returns a single common operand that all inputs of the operation `op` can +/// be traced back to, or an empty `Value` if no such operand exists. +/// +/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a` +/// (assuming the bit-width of `a` is `n`). +template +static Value getCommonOperand(Op op) { + if (!op.getType().isInteger(1)) return Value(); + + auto inputs = op.getInputs(); + size_t size = inputs.size(); + + auto sourceOp = inputs[0].template getDefiningOp(); + if (!sourceOp) return Value(); + Value source = sourceOp.getOperand(); + + // Fast path: the input size is not equal to the width of the source. + if (size != source.getType().getIntOrFloatBitWidth()) return Value(); + + // Tracks the bits that were encountered. + llvm::BitVector bits(size); + bits.set(sourceOp.getLowBit()); + + for (size_t i = 1; i != size; ++i) { + auto extractOp = inputs[i].template getDefiningOp(); + if (!extractOp || extractOp.getOperand() != source) return Value(); + bits.set(extractOp.getLowBit()); + } + + return bits.all() ? source : Value(); +} + +/// Canonicalize an idempotent operation `op` so that only one input of any kind +/// occurs. +/// +/// Example: `and(x, y, x, z)` -> `and(x, y, z)` +template +static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) { + auto inputs = op.getInputs(); + llvm::SmallSetVector uniqueInputs; + + for (const auto input : inputs) uniqueInputs.insert(input); + + if (uniqueInputs.size() < inputs.size()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + uniqueInputs.getArrayRef()); + return true; + } + + return false; +} + +LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { + auto inputs = op.getInputs(); + auto size = inputs.size(); + assert(size > 1 && "expected 2 or more operands, `fold` should handle this"); + + // and(..., x, ..., x) -> and(..., x, ...) -- idempotent + // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above. + if (size > 2 && canonicalizeIdempotentInputs(op, rewriter)) return success(); + + // Patterns for and with a constant on RHS. + APInt value; + if (matchPattern(inputs.back(), m_ConstantInt(&value))) { + // and(..., '1) -> and(...) -- identity + if (value.isAllOnes()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + inputs.drop_back(), false); + return success(); + } + + // TODO: Combine multiple constants together even if they aren't at the + // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant + // folding + APInt value2; + if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { + auto cst = rewriter.create(op.getLoc(), value & value2); + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + newOperands.push_back(cst); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands, false); + return success(); + } + + // Handle 'and' with a single bit constant on the RHS. + if (size == 2 && value.isPowerOf2()) { + // If the LHS is a replicate from a single bit, we can 'concat' it + // into place. e.g.: + // `replicate(x) & 4` -> `concat(zeros, x, zeros)` + // TODO: Generalize this for non-single-bit operands. + if (auto replicate = inputs[0].getDefiningOp()) { + auto replicateOperand = replicate.getOperand(); + if (replicateOperand.getType().isInteger(1)) { + unsigned resultWidth = op.getType().getIntOrFloatBitWidth(); + auto trailingZeros = value.countTrailingZeros(); + + // Don't add zero bit constants unnecessarily. + SmallVector concatOperands; + if (trailingZeros != resultWidth - 1) { + auto highZeros = rewriter.create( + op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1)); + concatOperands.push_back(highZeros); + } + concatOperands.push_back(replicateOperand); + if (trailingZeros != 0) { + auto lowZeros = rewriter.create( + op.getLoc(), APInt::getZero(trailingZeros)); + concatOperands.push_back(lowZeros); + } + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + concatOperands); + return success(); + } + } + } + + // If this is an and from an extract op, try shrinking the extract. + if (auto extractOp = inputs[0].getDefiningOp()) { + if (size == 2 && + // We can shrink it if the mask has leading or trailing zeros. + (value.countLeadingZeros() || value.countTrailingZeros())) { + unsigned lz = value.countLeadingZeros(); + unsigned tz = value.countTrailingZeros(); + + // Start by extracting the smaller number of bits. + auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz); + Value smallElt = rewriter.createOrFold( + extractOp.getLoc(), smallTy, extractOp->getOperand(0), + extractOp.getLowBit() + tz); + // Apply the 'and' mask if needed. + APInt smallMask = value.extractBits(smallTy.getWidth(), tz); + if (!smallMask.isAllOnes()) { + auto loc = inputs.back().getLoc(); + smallElt = rewriter.createOrFold( + loc, smallElt, rewriter.create(loc, smallMask), + false); + } + + // The final replacement will be a concat of the leading/trailing zeros + // along with the smaller extracted value. + SmallVector resultElts; + if (lz) + resultElts.push_back( + rewriter.create(op.getLoc(), APInt::getZero(lz))); + resultElts.push_back(smallElt); + if (tz) + resultElts.push_back( + rewriter.create(op.getLoc(), APInt::getZero(tz))); + replaceOpWithNewOpAndCopyName(rewriter, op, resultElts); + return success(); + } + } + + // and(concat(x, cst1), a, b, c, cst2) + // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')). + // We do this for even more multi-use concats since they are "just wiring". + for (size_t i = 0; i < size - 1; ++i) { + if (auto concat = inputs[i].getDefiningOp()) + if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter)) + return success(); + } + } + + // and(x, and(...)) -> and(x, ...) -- flatten + if (tryFlatteningOperands(op, rewriter)) return success(); + + // extracts only of and(...) -> and(extract()...) + if (narrowOperationWidth(op, true, rewriter)) return success(); + + // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1) + if (auto source = getCommonOperand(op)) { + auto cmpAgainst = + rewriter.create(op.getLoc(), APInt::getAllOnes(size)); + replaceOpWithNewOpAndCopyName(rewriter, op, ICmpPredicate::eq, + source, cmpAgainst); + return success(); + } + + /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement + return failure(); +} + +OpFoldResult OrOp::fold(FoldAdaptor adaptor) { + auto value = APInt::getZero(getType().cast().getWidth()); + auto inputs = adaptor.getInputs(); + // or(x, 10, 01) -> 11 + for (auto operand : inputs) { + if (!operand) continue; + value |= operand.cast().getValue(); + if (value.isAllOnes()) return getIntAttr(value, getContext()); + } + + // or(x, 0) -> x + if (inputs.size() == 2 && inputs[1] && + inputs[1].cast().getValue().isZero()) + return getInputs()[0]; + + // or(x, x, x) -> x. This also handles or(x) -> x + if (llvm::all_of(getInputs(), + [&](auto in) { return in == this->getInputs()[0]; })) + return getInputs()[0]; + + // or(..., x, ..., ~x, ...) -> -1 + for (Value arg : getInputs()) { + Value subExpr; + if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) { + for (Value arg2 : getInputs()) + if (arg2 == subExpr) + return getIntAttr( + APInt::getAllOnes(getType().cast().getWidth()), + getContext()); + } + } + + // Constant fold + return constFoldAssociativeOp(inputs, hw::PEO::Or); +} + +/// Simplify concat ops in an or op when a constant operand is present in either +/// concat. +/// +/// This will invert an or(concat, concat) into concat(or, or, ...), which can +/// often be further simplified due to the smaller or ops being easier to fold. +/// +/// For example: +/// +/// or(..., concat(x, 0), concat(0, y)) +/// ==> or(..., concat(x, 0, y)), when x and y don't overlap. +/// +/// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1)) +/// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)), +/// or(extract(cst1, 3..1), extract(cst2, 2..0)), +/// or(extract(cst1, 0..0), y: i1)) +static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, + size_t concatIdx2, + PatternRewriter &rewriter) { + assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2"); + + auto inputs = op.getInputs(); + auto concat1 = inputs[concatIdx1].getDefiningOp(); + auto concat2 = inputs[concatIdx2].getDefiningOp(); + + assert(concat1 && concat2 && "expected indexes to point to ConcatOps"); + + // We can simplify as long as a constant is present in either concat. + bool hasConstantOp1 = + llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool { + return operand.getDefiningOp(); + }); + if (!hasConstantOp1) { + bool hasConstantOp2 = + llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool { + return operand.getDefiningOp(); + }); + if (!hasConstantOp2) return false; + } + + SmallVector newConcatOperands; + + // Simultaneously iterate over the operands of both concat ops, from MSB to + // LSB, pushing out or's of overlapping ranges of the operands. When operands + // span different bit ranges, we extract only the maximum overlap. + auto operands1 = concat1->getOperands(); + auto operands2 = concat2->getOperands(); + // Number of bits already consumed from operands 1 and 2, respectively. + unsigned consumedWidth1 = 0; + unsigned consumedWidth2 = 0; + for (auto it1 = operands1.begin(), end1 = operands1.end(), + it2 = operands2.begin(), end2 = operands2.end(); + it1 != end1 && it2 != end2;) { + auto operand1 = *it1; + auto operand2 = *it2; + + unsigned remainingWidth1 = + hw::getBitWidth(operand1.getType()) - consumedWidth1; + unsigned remainingWidth2 = + hw::getBitWidth(operand2.getType()) - consumedWidth2; + unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2); + auto narrowedType = rewriter.getIntegerType(widthToConsume); + + auto extract1 = rewriter.createOrFold( + op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume); + auto extract2 = rewriter.createOrFold( + op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume); + + newConcatOperands.push_back( + rewriter.createOrFold(op.getLoc(), extract1, extract2, false)); + + consumedWidth1 += widthToConsume; + consumedWidth2 += widthToConsume; + + if (widthToConsume == remainingWidth1) { + ++it1; + consumedWidth1 = 0; + } + if (widthToConsume == remainingWidth2) { + ++it2; + consumedWidth2 = 0; + } + } + + ConcatOp newOp = rewriter.create(op.getLoc(), newConcatOperands); + + // Copy the old operands except for concatIdx1 and concatIdx2, and append the + // new ConcatOp to the end. + SmallVector newOrOperands; + newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1); + newOrOperands.append(inputs.begin() + concatIdx1 + 1, + inputs.begin() + concatIdx2); + newOrOperands.append(inputs.begin() + concatIdx2 + 1, + inputs.begin() + inputs.size()); + newOrOperands.push_back(newOp); + + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOrOperands); + return true; +} + +LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { + auto inputs = op.getInputs(); + auto size = inputs.size(); + assert(size > 1 && "expected 2 or more operands"); + + // or(..., x, ..., x, ...) -> or(..., x) -- idempotent + // Trivial or(x), or(x, x) cases are handled by [OrOp::fold]. + if (size > 2 && canonicalizeIdempotentInputs(op, rewriter)) return success(); + + // Patterns for and with a constant on RHS. + APInt value; + if (matchPattern(inputs.back(), m_ConstantInt(&value))) { + // or(..., '0) -> or(...) -- identity + if (value.isZero()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + inputs.drop_back()); + return success(); + } + + // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding + APInt value2; + if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { + auto cst = rewriter.create(op.getLoc(), value | value2); + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + newOperands.push_back(cst); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands); + return success(); + } + + // or(concat(x, cst1), a, b, c, cst2) + // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')). + // We do this for even more multi-use concats since they are "just wiring". + for (size_t i = 0; i < size - 1; ++i) { + if (auto concat = inputs[i].getDefiningOp()) + if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter)) + return success(); + } + } + + // or(x, or(...)) -> or(x, ...) -- flatten + if (tryFlatteningOperands(op, rewriter)) return success(); + + // or(..., concat(x, cst1), concat(cst2, y) + // ==> or(..., concat(x, cst3, y)), when x and y don't overlap. + for (size_t i = 0; i < size - 1; ++i) { + if (auto concat = inputs[i].getDefiningOp()) + for (size_t j = i + 1; j < size; ++j) + if (auto concat = inputs[j].getDefiningOp()) + if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter)) + return success(); + } + + // extracts only of or(...) -> or(extract()...) + if (narrowOperationWidth(op, true, rewriter)) return success(); + + // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0) + if (auto source = getCommonOperand(op)) { + auto cmpAgainst = + rewriter.create(op.getLoc(), APInt::getZero(size)); + replaceOpWithNewOpAndCopyName(rewriter, op, ICmpPredicate::ne, + source, cmpAgainst); + return success(); + } + + // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2, + // .., c_n), a, 0) + if (auto firstMux = op.getOperand(0).getDefiningOp()) { + APInt value; + if (op.getTwoState() && firstMux.getTwoState() && + matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) && + value.isZero()) { + SmallVector conditions{firstMux.getCond()}; + auto check = [&](Value v) { + auto mux = v.getDefiningOp(); + if (!mux) return false; + conditions.push_back(mux.getCond()); + return mux.getTwoState() && + firstMux.getTrueValue() == mux.getTrueValue() && + firstMux.getFalseValue() == mux.getFalseValue(); + }; + if (llvm::all_of(op.getOperands().drop_front(), check)) { + auto cond = rewriter.create(op.getLoc(), conditions, true); + replaceOpWithNewOpAndCopyName( + rewriter, op, cond, firstMux.getTrueValue(), + firstMux.getFalseValue(), true); + return success(); + } + } + } + + /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement + return failure(); +} + +OpFoldResult XorOp::fold(FoldAdaptor adaptor) { + auto size = getInputs().size(); + auto inputs = adaptor.getInputs(); + + // xor(x) -> x -- noop + if (size == 1) return getInputs()[0]; + + // xor(x, x) -> 0 -- idempotent + if (size == 2 && getInputs()[0] == getInputs()[1]) + return IntegerAttr::get(getType(), 0); + + // xor(x, 0) -> x + if (inputs.size() == 2 && inputs[1] && + inputs[1].cast().getValue().isZero()) + return getInputs()[0]; + + // xor(xor(x,1),1) -> x + // but not self loop + if (isBinaryNot()) { + Value subExpr; + if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) && + subExpr != getResult()) + return subExpr; + } + + // Constant fold + return constFoldAssociativeOp(inputs, hw::PEO::Xor); +} + +// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user. +static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, + PatternRewriter &rewriter) { + auto icmp = op.getOperand(icmpOperand).getDefiningOp(); + auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate()); + + Value result = + rewriter.create(icmp.getLoc(), negatedPred, icmp.getOperand(0), + icmp.getOperand(1), icmp.getTwoState()); + + // If the xor had other operands, rebuild it. + if (op.getNumOperands() > 2) { + SmallVector newOperands(op.getOperands()); + newOperands.pop_back(); + newOperands.erase(newOperands.begin() + icmpOperand); + newOperands.push_back(result); + result = rewriter.create(op.getLoc(), newOperands, op.getTwoState()); + } + + replaceOpAndCopyName(rewriter, op, result); +} + +LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) { + auto inputs = op.getInputs(); + auto size = inputs.size(); + assert(size > 1 && "expected 2 or more operands"); + + // xor(..., x, x) -> xor (...) -- idempotent + if (inputs[size - 1] == inputs[size - 2]) { + assert(size > 2 && + "expected idempotent case for 2 elements handled already."); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + inputs.drop_back(/*n=*/2), false); + return success(); + } + + // Patterns for xor with a constant on RHS. + APInt value; + if (matchPattern(inputs.back(), m_ConstantInt(&value))) { + // xor(..., 0) -> xor(...) -- identity + if (value.isZero()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + inputs.drop_back(), false); + return success(); + } + + // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2. + APInt value2; + if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { + auto cst = rewriter.create(op.getLoc(), value ^ value2); + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + newOperands.push_back(cst); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands, false); + return success(); + } + + bool isSingleBit = value.getBitWidth() == 1; + + // Check for subexpressions that we can simplify. + for (size_t i = 0; i < size - 1; ++i) { + Value operand = inputs[i]; + + // xor(concat(x, cst1), a, b, c, cst2) + // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')). + // We do this for even more multi-use concats since they are "just + // wiring". + if (auto concat = operand.getDefiningOp()) + if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter)) + return success(); + + // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user. + if (isSingleBit && operand.hasOneUse()) { + assert(value == 1 && "single bit constant has to be one if not zero"); + if (auto icmp = operand.getDefiningOp()) + return canonicalizeXorIcmpTrue(op, i, rewriter), success(); + } + } + } + + // xor(x, xor(...)) -> xor(x, ...) -- flatten + if (tryFlatteningOperands(op, rewriter)) return success(); + + // extracts only of xor(...) -> xor(extract()...) + if (narrowOperationWidth(op, true, rewriter)) return success(); + + // xor(a[0], a[1], ..., a[n]) -> parity(a) + if (auto source = getCommonOperand(op)) { + replaceOpWithNewOpAndCopyName(rewriter, op, source); + return success(); + } + + return failure(); +} + +OpFoldResult SubOp::fold(FoldAdaptor adaptor) { + // sub(x - x) -> 0 + if (getRhs() == getLhs()) + return getIntAttr( + APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()), + getContext()); + + if (adaptor.getRhs()) { + // If both are constants, we can unconditionally fold. + if (adaptor.getLhs()) { + // Constant fold (c1 - c2) => (c1 + -1*c2). + auto negOne = getIntAttr( + APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()), + getContext()); + auto rhsNeg = hw::ParamExprAttr::get( + hw::PEO::Mul, adaptor.getRhs().cast(), negOne); + return hw::ParamExprAttr::get(hw::PEO::Add, + adaptor.getLhs().cast(), rhsNeg); + } + + // sub(x - 0) -> x + if (auto rhsC = adaptor.getRhs().dyn_cast()) { + if (rhsC.getValue().isZero()) return getLhs(); + } + } + + return {}; +} + +LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) { + // sub(x, cst) -> add(x, -cst) + APInt value; + if (matchPattern(op.getRhs(), m_ConstantInt(&value))) { + auto negCst = rewriter.create(op.getLoc(), -value); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getLhs(), negCst, + false); + return success(); + } + + // extracts only of sub(...) -> sub(extract()...) + if (narrowOperationWidth(op, false, rewriter)) return success(); + + return failure(); +} + +OpFoldResult AddOp::fold(FoldAdaptor adaptor) { + auto size = getInputs().size(); + + // add(x) -> x -- noop + if (size == 1u) return getInputs()[0]; + + // Constant fold constant operands. + return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add); +} + +LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { + auto inputs = op.getInputs(); + auto size = inputs.size(); + assert(size > 1 && "expected 2 or more operands"); + + APInt value, value2; + + // add(..., 0) -> add(...) -- identity + if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + inputs.drop_back(), false); + return success(); + } + + // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding + if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) && + matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { + auto cst = rewriter.create(op.getLoc(), value + value2); + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + newOperands.push_back(cst); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands, false); + return success(); + } + + // add(..., x, x) -> add(..., shl(x, 1)) + if (inputs[size - 1] == inputs[size - 2]) { + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + + auto one = rewriter.create(op.getLoc(), op.getType(), 1); + auto shiftLeftOp = + rewriter.create(op.getLoc(), inputs.back(), one, false); + + newOperands.push_back(shiftLeftOp); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands, false); + return success(); + } + + auto shlOp = inputs[size - 1].getDefiningOp(); + // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1)) + if (shlOp && shlOp.getLhs() == inputs[size - 2] && + matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) { + APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false); + auto rhs = + rewriter.create(op.getLoc(), (one << value) + one); + + std::array factors = {shlOp.getLhs(), rhs}; + auto mulOp = rewriter.create(op.getLoc(), factors, false); + + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + newOperands.push_back(mulOp); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands, false); + return success(); + } + + auto mulOp = inputs[size - 1].getDefiningOp(); + // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1)) + if (mulOp && mulOp.getInputs().size() == 2 && + mulOp.getInputs()[0] == inputs[size - 2] && + matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) { + APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false); + auto rhs = rewriter.create(op.getLoc(), value + one); + std::array factors = {mulOp.getInputs()[0], rhs}; + auto newMulOp = rewriter.create(op.getLoc(), factors, false); + + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + newOperands.push_back(newMulOp); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands, false); + return success(); + } + + // add(x, add(...)) -> add(x, ...) -- flatten + if (tryFlatteningOperands(op, rewriter)) return success(); + + // extracts only of add(...) -> add(extract()...) + if (narrowOperationWidth(op, false, rewriter)) return success(); + + // add(add(x, c1), c2) -> add(x, c1 + c2) + auto addOp = inputs[0].getDefiningOp(); + if (addOp && addOp.getInputs().size() == 2 && + matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) && + inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) { + auto rhs = rewriter.create(op.getLoc(), value + value2); + replaceOpWithNewOpAndCopyName( + rewriter, op, op.getType(), ArrayRef{addOp.getInputs()[0], rhs}, + /*twoState=*/op.getTwoState() && addOp.getTwoState()); + return success(); + } + + return failure(); +} + +OpFoldResult MulOp::fold(FoldAdaptor adaptor) { + auto size = getInputs().size(); + auto inputs = adaptor.getInputs(); + + // mul(x) -> x -- noop + if (size == 1u) return getInputs()[0]; + + auto width = getType().cast().getWidth(); + APInt value(/*numBits=*/width, 1, /*isSigned=*/false); + + // mul(x, 0, 1) -> 0 -- annulment + for (auto operand : inputs) { + if (!operand) continue; + value *= operand.cast().getValue(); + if (value.isZero()) return getIntAttr(value, getContext()); + } + + // Constant fold + return constFoldAssociativeOp(inputs, hw::PEO::Mul); +} + +LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { + auto inputs = op.getInputs(); + auto size = inputs.size(); + assert(size > 1 && "expected 2 or more operands"); + + APInt value, value2; + + // mul(x, c) -> shl(x, log2(c)), where c is a power of two. + if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) && + value.isPowerOf2()) { + auto shift = rewriter.create(op.getLoc(), op.getType(), + value.exactLogBase2()); + auto shlOp = + rewriter.create(op.getLoc(), inputs[0], shift, false); + + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + ArrayRef(shlOp), false); + return success(); + } + + // mul(..., 1) -> mul(...) -- identity + if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + inputs.drop_back()); + return success(); + } + + // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding + if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) && + matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { + auto cst = rewriter.create(op.getLoc(), value * value2); + SmallVector newOperands(inputs.drop_back(/*n=*/2)); + newOperands.push_back(cst); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands); + return success(); + } + + // mul(a, mul(...)) -> mul(a, ...) -- flatten + if (tryFlatteningOperands(op, rewriter)) return success(); + + // extracts only of mul(...) -> mul(extract()...) + if (narrowOperationWidth(op, false, rewriter)) return success(); + + return failure(); +} + +template +static OpFoldResult foldDiv(Op op, ArrayRef constants) { + if (auto rhsValue = constants[1].dyn_cast_or_null()) { + // divu(x, 1) -> x, divs(x, 1) -> x + if (rhsValue.getValue() == 1) return op.getLhs(); + + // If the divisor is zero, do not fold for now. + if (rhsValue.getValue().isZero()) return {}; + } + + return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU); +} + +OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { + return foldDiv(*this, adaptor.getOperands()); +} + +OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { + return foldDiv(*this, adaptor.getOperands()); +} + +template +static OpFoldResult foldMod(Op op, ArrayRef constants) { + if (auto rhsValue = constants[1].dyn_cast_or_null()) { + // modu(x, 1) -> 0, mods(x, 1) -> 0 + if (rhsValue.getValue() == 1) + return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()), + op.getContext()); + + // If the divisor is zero, do not fold for now. + if (rhsValue.getValue().isZero()) return {}; + } + + if (auto lhsValue = constants[0].dyn_cast_or_null()) { + // modu(0, x) -> 0, mods(0, x) -> 0 + if (lhsValue.getValue().isZero()) + return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()), + op.getContext()); + } + + return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU); +} + +OpFoldResult ModUOp::fold(FoldAdaptor adaptor) { + return foldMod(*this, adaptor.getOperands()); +} + +OpFoldResult ModSOp::fold(FoldAdaptor adaptor) { + return foldMod(*this, adaptor.getOperands()); +} +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +// Constant folding +OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { + if (getNumOperands() == 1) return getOperand(0); + + // If all the operands are constant, we can fold. + for (auto attr : adaptor.getInputs()) + if (!attr || !attr.isa()) return {}; + + // If we got here, we can constant fold. + unsigned resultWidth = getType().getIntOrFloatBitWidth(); + APInt result(resultWidth, 0); + + unsigned nextInsertion = resultWidth; + // Insert each chunk into the result. + for (auto attr : adaptor.getInputs()) { + auto chunk = attr.cast().getValue(); + nextInsertion -= chunk.getBitWidth(); + result.insertBits(chunk, nextInsertion); + } + + return getIntAttr(result, getContext()); +} + +LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) { + auto inputs = op.getInputs(); + auto size = inputs.size(); + assert(size > 1 && "expected 2 or more operands"); + + // This function is used when we flatten neighboring operands of a + // (variadic) concat into a new vesion of the concat. first/last indices + // are inclusive. + auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex, + ValueRange replacements) -> LogicalResult { + SmallVector newOperands; + newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex); + newOperands.append(replacements.begin(), replacements.end()); + newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end()); + if (newOperands.size() == 1) + replaceOpAndCopyName(rewriter, op, newOperands[0]); + else + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + newOperands); + return success(); + }; + + Value commonOperand = inputs[0]; + for (size_t i = 0; i != size; ++i) { + // Check to see if all operands are the same. + if (inputs[i] != commonOperand) commonOperand = Value(); + + // If an operand to the concat is itself a concat, then we can fold them + // together. + if (auto subConcat = inputs[i].getDefiningOp()) + return flattenConcat(i, i, subConcat->getOperands()); + + // Check for canonicalization due to neighboring operands. + if (i != 0) { + // Merge neighboring constants. + if (auto cst = inputs[i].getDefiningOp()) { + if (auto prevCst = inputs[i - 1].getDefiningOp()) { + unsigned prevWidth = prevCst.getValue().getBitWidth(); + unsigned thisWidth = cst.getValue().getBitWidth(); + auto resultCst = cst.getValue().zext(prevWidth + thisWidth); + resultCst |= prevCst.getValue().zext(prevWidth + thisWidth) + << thisWidth; + Value replacement = + rewriter.create(op.getLoc(), resultCst); + return flattenConcat(i - 1, i, replacement); + } + } + + // If the two operands are the same, turn them into a replicate. + if (inputs[i] == inputs[i - 1]) { + Value replacement = + rewriter.createOrFold(op.getLoc(), inputs[i], 2); + return flattenConcat(i - 1, i, replacement); + } + + // If this input is a replicate, see if we can fold it with the previous + // one. + if (auto repl = inputs[i].getDefiningOp()) { + // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ... + if (repl.getOperand() == inputs[i - 1]) { + Value replacement = rewriter.createOrFold( + op.getLoc(), repl.getOperand(), repl.getMultiple() + 1); + return flattenConcat(i - 1, i, replacement); + } + // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ... + if (auto prevRepl = inputs[i - 1].getDefiningOp()) { + if (prevRepl.getOperand() == repl.getOperand()) { + Value replacement = rewriter.createOrFold( + op.getLoc(), repl.getOperand(), + repl.getMultiple() + prevRepl.getMultiple()); + return flattenConcat(i - 1, i, replacement); + } + } + } + + // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ... + if (auto repl = inputs[i - 1].getDefiningOp()) { + if (repl.getOperand() == inputs[i]) { + Value replacement = rewriter.createOrFold( + op.getLoc(), inputs[i], repl.getMultiple() + 1); + return flattenConcat(i - 1, i, replacement); + } + } + + // Merge neighboring extracts of neighboring inputs, e.g. + // {A[3], A[2]} -> A[3:2] + if (auto extract = inputs[i].getDefiningOp()) { + if (auto prevExtract = inputs[i - 1].getDefiningOp()) { + if (extract.getInput() == prevExtract.getInput()) { + auto thisWidth = extract.getType().cast().getWidth(); + if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) { + auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth(); + auto resType = rewriter.getIntegerType(thisWidth + prevWidth); + Value replacement = rewriter.create( + op.getLoc(), resType, extract.getInput(), + extract.getLowBit()); + return flattenConcat(i - 1, i, replacement); + } + } + } + } + // Merge neighboring array extracts of neighboring inputs, e.g. + // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2]) + + // This represents a slice of an array. + struct ArraySlice { + Value input; + Value index; + size_t width; + static std::optional get(Value value) { + assert(value.getType().isa() && "expected integer type"); + if (auto arrayGet = value.getDefiningOp()) + return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1}; + // array slice op is wrapped with bitcast. + if (auto bitcast = value.getDefiningOp()) + if (auto arraySlice = + bitcast.getInput().getDefiningOp()) + return ArraySlice{ + arraySlice.getInput(), arraySlice.getLowIndex(), + hw::type_cast(arraySlice.getType()) + .getNumElements()}; + return std::nullopt; + } + }; + if (auto extractOpt = ArraySlice::get(inputs[i])) { + if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) { + // Check that two array slices are mergable. + if (prevExtractOpt->index.getType() == extractOpt->index.getType() && + prevExtractOpt->input == extractOpt->input && + hw::isOffset(extractOpt->index, prevExtractOpt->index, + extractOpt->width)) { + auto resType = hw::ArrayType::get( + hw::type_cast(prevExtractOpt->input.getType()) + .getElementType(), + extractOpt->width + prevExtractOpt->width); + auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType)); + Value replacement = rewriter.create( + op.getLoc(), resIntType, + rewriter.create(op.getLoc(), resType, + prevExtractOpt->input, + extractOpt->index)); + return flattenConcat(i - 1, i, replacement); + } + } + } + } + } + + // If all operands were the same, then this is a replicate. + if (commonOperand) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + commonOperand); + return success(); + } + + return failure(); +} + +//===----------------------------------------------------------------------===// +// MuxOp +//===----------------------------------------------------------------------===// + +OpFoldResult MuxOp::fold(FoldAdaptor adaptor) { + // mux (c, b, b) -> b + if (getTrueValue() == getFalseValue()) return getTrueValue(); + + // mux(0, a, b) -> b + // mux(1, a, b) -> a + if (auto pred = adaptor.getCond().dyn_cast_or_null()) { + if (pred.getValue().isZero()) return getFalseValue(); + return getTrueValue(); + } + + // mux(cond, 1, 0) -> cond + if (auto tv = adaptor.getTrueValue().dyn_cast_or_null()) + if (auto fv = adaptor.getFalseValue().dyn_cast_or_null()) + if (tv.getValue().isOne() && fv.getValue().isZero() && + hw::getBitWidth(getType()) == 1) + return getCond(); + + return {}; +} + +/// Check to see if the condition to the specified mux is an equality +/// comparison `indexValue` and one or more constants. If so, put the +/// constants in the constants vector and return true, otherwise return false. +/// +/// This is part of foldMuxChain. +/// +static bool getMuxChainCondConstant( + Value cond, Value indexValue, bool isInverted, + std::function constantFn) { + // Handle `idx == 42` and `idx != 42`. + if (auto cmp = cond.getDefiningOp()) { + // TODO: We could handle things like "x < 2" as two entries. + auto requiredPredicate = + (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne); + if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) { + if (auto cst = cmp.getRhs().getDefiningOp()) { + constantFn(cst); + return true; + } + } + return false; + } + + // Handle mux(`idx == 1 || idx == 3`, value, muxchain). + if (auto orOp = cond.getDefiningOp()) { + if (!isInverted) return false; + for (auto operand : orOp.getOperands()) + if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn)) + return false; + return true; + } + + // Handle mux(`idx != 1 && idx != 3`, muxchain, value). + if (auto andOp = cond.getDefiningOp()) { + if (isInverted) return false; + for (auto operand : andOp.getOperands()) + if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn)) + return false; + return true; + } + + return false; +} + +/// Given a mux, check to see if the "on true" value (or "on false" value if +/// isFalseSide=true) is a mux tree with the same condition. This allows us +/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into +/// `array_get (array_create(A, B, C), VAL)` which is far more compact and +/// allows synthesis tools to do more interesting optimizations. +/// +/// This returns false if we cannot form the mux tree (or do not want to) and +/// returns true if the mux was replaced. +static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, + PatternRewriter &rewriter) { + // Get the index value being compared. Later we check to see if it is + // compared to a constant with the right predicate. + auto rootCmp = rootMux.getCond().getDefiningOp(); + if (!rootCmp) return false; + Value indexValue = rootCmp.getLhs(); + + // Return the value to use if the equality match succeeds. + auto getCaseValue = [&](MuxOp mux) -> Value { + return mux.getOperand(1 + unsigned(!isFalseSide)); + }; + + // Return the value to use if the equality match fails. This is the next + // mux in the sequence or the "otherwise" value. + auto getTreeValue = [&](MuxOp mux) -> Value { + return mux.getOperand(1 + unsigned(isFalseSide)); + }; + + // Start scanning the mux tree to see what we've got. Keep track of the + // constant comparison value and the SSA value to use when equal to it. + SmallVector locationsFound; + SmallVector, 4> valuesFound; + + /// Extract constants and values into `valuesFound` and return true if this is + /// part of the mux tree, otherwise return false. + auto collectConstantValues = [&](MuxOp mux) -> bool { + return getMuxChainCondConstant( + mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) { + valuesFound.push_back({cst, getCaseValue(mux)}); + locationsFound.push_back(mux.getCond().getLoc()); + locationsFound.push_back(mux->getLoc()); + }); + }; + + // Make sure the root is a correct comparison with a constant. + if (!collectConstantValues(rootMux)) return false; + + // Make sure that we're not looking at the intermediate node in a mux tree. + if (rootMux->hasOneUse()) { + if (auto userMux = dyn_cast(*rootMux->user_begin())) { + if (getTreeValue(userMux) == rootMux.getResult() && + getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide, + [&](hw::ConstantOp cst) {})) + return false; + } + } + + // Scan up the tree linearly. + auto nextTreeValue = getTreeValue(rootMux); + while (1) { + auto nextMux = nextTreeValue.getDefiningOp(); + if (!nextMux || !nextMux->hasOneUse()) break; + if (!collectConstantValues(nextMux)) break; + nextTreeValue = getTreeValue(nextMux); + } + + // We need to have more than three values to create an array. This is an + // arbitrary threshold which is saying that one or two muxes together is ok, + // but three should be folded. + if (valuesFound.size() < 3) return false; + + // If the array is greater that 9 bits, it will take over 512 elements and + // it will be too large for a single expression. + auto indexWidth = indexValue.getType().cast().getWidth(); + if (indexWidth >= 9) return false; + + // Next we need to see if the values are dense-ish. We don't want to have + // a tremendous number of replicated entries in the array. Some sparsity is + // ok though, so we require the table to be at least 5/8 utilized. + uint64_t tableSize = 1ULL << indexWidth; + if (valuesFound.size() < (tableSize * 5) / 8) + return false; // Not dense enough. + + // Ok, we're going to do the transformation, start by building the table + // filled with the "otherwise" value. + SmallVector table(tableSize, nextTreeValue); + + // Fill in entries in the table from the leaf to the root of the expression. + // This ensures that any duplicate matches end up with the ultimate value, + // which is the one closer to the root. + for (auto &elt : llvm::reverse(valuesFound)) { + uint64_t idx = elt.first.getValue().getZExtValue(); + assert(idx < table.size() && "constant should be same bitwidth as index"); + table[idx] = elt.second; + } + + // The hw.array_create operation has the operand list in unintuitive order + // with a[0] stored as the last element, not the first. + std::reverse(table.begin(), table.end()); + + // Build the array_create and the array_get. + auto fusedLoc = rewriter.getFusedLoc(locationsFound); + auto array = rewriter.create(fusedLoc, table); + replaceOpWithNewOpAndCopyName(rewriter, rootMux, array, + indexValue); + return true; +} + +/// Given a fully associative variadic operation like (a+b+c+d), break the +/// expression into two parts, one without the specified operand (e.g. +/// `tmp = a+b+d`) and one that combines that into the full expression (e.g. +/// `tmp+c`), and return the inner expression. +/// +/// NOTE: This mutates the operation in place if it only has a single user, +/// which assumes that user will be removed. +/// +static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, + size_t operandNo, + PatternRewriter &rewriter) { + assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops"); + assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #"); + + // If this expression already has two operands (the common case) no splitting + // is necessary. + if (fullyAssoc->getNumOperands() == 2) + return fullyAssoc->getOperand(operandNo ^ 1); + + // If the operation has a single use, mutate it in place. + if (fullyAssoc->hasOneUse()) { + fullyAssoc->eraseOperand(operandNo); + return fullyAssoc->getResult(0); + } + + // Form the new operation with the operands that remain. + SmallVector operands; + operands.append(fullyAssoc->getOperands().begin(), + fullyAssoc->getOperands().begin() + operandNo); + operands.append(fullyAssoc->getOperands().begin() + operandNo + 1, + fullyAssoc->getOperands().end()); + Value opWithoutExcluded = createGenericOp( + fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter); + Value excluded = fullyAssoc->getOperand(operandNo); + + Value fullResult = + createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(), + ArrayRef{opWithoutExcluded, excluded}, rewriter); + replaceOpAndCopyName(rewriter, fullyAssoc, fullResult); + return opWithoutExcluded; +} + +/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and +/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand +/// is true. Return true on successful transformation, false if not. +/// +/// These are various forms of "predicated ops" that can be handled with a +/// replicate/and combination. +static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, + PatternRewriter &rewriter) { + // Check to see the operand in question is an operation. If it is a port, + // we can't simplify it. + Operation *subExpr = + (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp(); + if (!subExpr || subExpr->getNumOperands() < 2) return false; + + // If this isn't an operation we can handle, don't spend energy on it. + if (!isa(subExpr)) return false; + + // Check to see if the common value occurs in the operand list for the + // subexpression op. If so, then we can simplify it. + Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue(); + size_t opNo = 0, e = subExpr->getNumOperands(); + while (opNo != e && subExpr->getOperand(opNo) != commonValue) ++opNo; + if (opNo == e) return false; + + // If we got a hit, then go ahead and simplify it! + Value cond = op.getCond(); + + // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)` + // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)` + // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)` + // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)` + if (auto subMux = dyn_cast(subExpr)) { + Value otherValue; + Value subCond = subMux.getCond(); + + // Invert th subCond if needed and dig out the 'b' value. + if (subMux.getTrueValue() == commonValue) + otherValue = subMux.getFalseValue(); + else if (subMux.getFalseValue() == commonValue) { + otherValue = subMux.getTrueValue(); + subCond = createOrFoldNot(op.getLoc(), subCond, rewriter); + } else { + // We can't fold `mux(cond, a, mux(a, x, y))`. + return false; + } + + // Invert the outer cond if needed, and combine the mux conditions. + if (!isTrueOperand) cond = createOrFoldNot(op.getLoc(), cond, rewriter); + cond = rewriter.createOrFold(op.getLoc(), cond, subCond, false); + replaceOpWithNewOpAndCopyName(rewriter, op, cond, commonValue, + otherValue, op.getTwoState()); + return true; + } + + // Invert the condition if needed. Or/Xor invert when dealing with + // TrueOperand, And inverts for False operand. + bool isaAndOp = isa(subExpr); + if (isTrueOperand ^ isaAndOp) + cond = createOrFoldNot(op.getLoc(), cond, rewriter); + + auto extendedCond = + rewriter.createOrFold(op.getLoc(), op.getType(), cond); + + // Cache this information before subExpr is erased by extraction below. + bool isaXorOp = isa(subExpr); + bool isaOrOp = isa(subExpr); + + // Handle the fully associative ops, start by pulling out the subexpression + // from a many operand version of the op. + auto restOfAssoc = + extractOperandFromFullyAssociative(subExpr, opNo, rewriter); + + // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a` + // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a` + if (isaOrOp || isaXorOp) { + auto masked = rewriter.createOrFold(op.getLoc(), extendedCond, + restOfAssoc, false); + if (isaXorOp) + replaceOpWithNewOpAndCopyName(rewriter, op, masked, commonValue, + false); + else + replaceOpWithNewOpAndCopyName(rewriter, op, masked, commonValue, + false); + return true; + } + + // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a` + assert(isaAndOp && "unexpected operation here"); + auto masked = rewriter.createOrFold(op.getLoc(), extendedCond, + restOfAssoc, false); + replaceOpWithNewOpAndCopyName(rewriter, op, masked, commonValue, + false); + return true; +} + +/// This function is invoke when we find a mux with true/false operations that +/// have the same opcode. Check to see if we can strength reduce the mux by +/// applying it to less data by applying this transformation: +/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))` +static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, + Operation *falseOp, + PatternRewriter &rewriter) { + // Right now we only apply to concat. + // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice + if (!isa(trueOp)) return false; + + // Decode the operands, looking through recursive concats and replicates. + SmallVector trueOperands, falseOperands; + getConcatOperands(trueOp->getResult(0), trueOperands); + getConcatOperands(falseOp->getResult(0), falseOperands); + + size_t numTrueOperands = trueOperands.size(); + size_t numFalseOperands = falseOperands.size(); + + if (!numTrueOperands || !numFalseOperands || + (trueOperands.front() != falseOperands.front() && + trueOperands.back() != falseOperands.back())) + return false; + + // Pull all leading shared operands out into their own op if any are common. + if (trueOperands.front() == falseOperands.front()) { + SmallVector operands; + size_t i; + for (i = 0; i < numTrueOperands; ++i) { + Value trueOperand = trueOperands[i]; + if (trueOperand == falseOperands[i]) + operands.push_back(trueOperand); + else + break; + } + if (i == numTrueOperands) { + // Selecting between distinct, but lexically identical, concats. + replaceOpAndCopyName(rewriter, mux, trueOp->getResult(0)); + return true; + } + + Value sharedMSB; + if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); })) + sharedMSB = rewriter.createOrFold( + mux->getLoc(), operands.front(), operands.size()); + else + sharedMSB = rewriter.createOrFold(mux->getLoc(), operands); + operands.clear(); + + // Get a concat of the LSB's on each side. + operands.append(trueOperands.begin() + i, trueOperands.end()); + Value trueLSB = rewriter.createOrFold(trueOp->getLoc(), operands); + operands.clear(); + operands.append(falseOperands.begin() + i, falseOperands.end()); + Value falseLSB = + rewriter.createOrFold(falseOp->getLoc(), operands); + // Merge the LSBs with a new mux and concat the MSB with the LSB to be + // done. + Value lsb = rewriter.createOrFold( + mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState()); + replaceOpWithNewOpAndCopyName(rewriter, mux, sharedMSB, lsb); + return true; + } + + // If trailing operands match, try to commonize them. + if (trueOperands.back() == falseOperands.back()) { + SmallVector operands; + size_t i; + for (i = 0;; ++i) { + Value trueOperand = trueOperands[numTrueOperands - i - 1]; + if (trueOperand == falseOperands[numFalseOperands - i - 1]) + operands.push_back(trueOperand); + else + break; + } + std::reverse(operands.begin(), operands.end()); + Value sharedLSB = rewriter.createOrFold(mux->getLoc(), operands); + operands.clear(); + + // Get a concat of the MSB's on each side. + operands.append(trueOperands.begin(), trueOperands.end() - i); + Value trueMSB = rewriter.createOrFold(trueOp->getLoc(), operands); + operands.clear(); + operands.append(falseOperands.begin(), falseOperands.end() - i); + Value falseMSB = + rewriter.createOrFold(falseOp->getLoc(), operands); + // Merge the MSBs with a new mux and concat the MSB with the LSB to be done. + Value msb = rewriter.createOrFold( + mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState()); + replaceOpWithNewOpAndCopyName(rewriter, mux, msb, sharedLSB); + return true; + } + + return false; +} + +// If both arguments of the mux are arrays with the same elements, sink the +// mux and return a uniform array initializing all elements to it. +static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) { + auto trueVec = op.getTrueValue().getDefiningOp(); + auto falseVec = op.getFalseValue().getDefiningOp(); + if (!trueVec || !falseVec) return false; + if (!trueVec.isUniform() || !falseVec.isUniform()) return false; + + auto mux = rewriter.create( + op.getLoc(), op.getCond(), trueVec.getUniformElement(), + falseVec.getUniformElement(), op.getTwoState()); + + SmallVector values(trueVec.getInputs().size(), mux); + rewriter.replaceOpWithNewOp(op, values); + return true; +} + +namespace { +struct MuxRewriter : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MuxOp op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, + PatternRewriter &rewriter) const { + // If the op has a SV attribute, don't optimize it. + if (hasSVAttributes(op)) return failure(); + APInt value; + + if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) { + if (value.getBitWidth() == 1) { + // mux(a, 0, b) -> and(~a, b) for single-bit values. + if (value.isZero()) { + auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter); + replaceOpWithNewOpAndCopyName(rewriter, op, notCond, + op.getFalseValue(), false); + return success(); + } + + // mux(a, 1, b) -> or(a, b) for single-bit values. + replaceOpWithNewOpAndCopyName(rewriter, op, op.getCond(), + op.getFalseValue(), false); + return success(); + } + + // Check for mux of two constants. There are many ways to simplify them. + APInt value2; + if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) { + // When both inputs are constants and differ by only one bit, we can + // simplify by splitting the mux into up to three contiguous chunks: one + // for the differing bit and up to two for the bits that are the same. + // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0) + APInt xorValue = value ^ value2; + if (xorValue.isPowerOf2()) { + unsigned leadingZeros = xorValue.countLeadingZeros(); + unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1; + SmallVector operands; + + // Concat operands go from MSB to LSB, so we handle chunks in reverse + // order of bit indexes. + // For the chunks that are identical (i.e. correspond to 0s in + // xorValue), we can extract directly from either input value, and we + // arbitrarily pick the trueValue(). + + if (leadingZeros > 0) + operands.push_back(rewriter.createOrFold( + op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros)); + + // Handle the differing bit, which should simplify into either cond or + // ~cond. + auto v1 = rewriter.createOrFold( + op.getLoc(), op.getTrueValue(), trailingZeros, 1); + auto v2 = rewriter.createOrFold( + op.getLoc(), op.getFalseValue(), trailingZeros, 1); + operands.push_back(rewriter.createOrFold( + op.getLoc(), op.getCond(), v1, v2, false)); + + if (trailingZeros > 0) + operands.push_back(rewriter.createOrFold( + op.getLoc(), op.getTrueValue(), 0, trailingZeros)); + + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + operands); + return success(); + } + + // If the true value is all ones and the false is all zeros then we have a + // replicate pattern. + if (value.isAllOnes() && value2.isZero()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), + op.getCond()); + return success(); + } + } + } + + if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) && + value.getBitWidth() == 1) { + // mux(a, b, 0) -> and(a, b) for single-bit values. + if (value.isZero()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getCond(), + op.getTrueValue(), false); + return success(); + } + + // mux(a, b, 1) -> or(~a, b) for single-bit values. + // falseValue() is known to be a single-bit 1, which we can use for + // the 1 in the representation of ~ using xor. + auto notCond = rewriter.createOrFold(op.getLoc(), op.getCond(), + op.getFalseValue(), false); + replaceOpWithNewOpAndCopyName(rewriter, op, notCond, + op.getTrueValue(), false); + return success(); + } + + // mux(!a, b, c) -> mux(a, c, b) + Value subExpr; + Operation *condOp = op.getCond().getDefiningOp(); + if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) && + op.getTwoState()) { + replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), subExpr, + op.getFalseValue(), op.getTrueValue(), + true); + return success(); + } + + // Same but with Demorgan's law. + // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x) + // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x) + if (condOp && condOp->hasOneUse()) { + SmallVector invertedOperands; + + /// Scan all the operands to see if they are complemented. If so, build a + /// vector of them and return true, otherwise return false. + auto getInvertedOperands = [&]() -> bool { + for (Value operand : condOp->getOperands()) { + if (matchPattern(operand, m_Complement(m_Any(&subExpr)))) + invertedOperands.push_back(subExpr); + else + return false; + } + return true; + }; + + if (isa(condOp) && getInvertedOperands()) { + auto newOr = + rewriter.createOrFold(op.getLoc(), invertedOperands, false); + replaceOpWithNewOpAndCopyName(rewriter, op, newOr, + op.getFalseValue(), + op.getTrueValue(), op.getTwoState()); + return success(); + } + if (isa(condOp) && getInvertedOperands()) { + auto newAnd = + rewriter.createOrFold(op.getLoc(), invertedOperands, false); + replaceOpWithNewOpAndCopyName(rewriter, op, newAnd, + op.getFalseValue(), + op.getTrueValue(), op.getTwoState()); + return success(); + } + } + + if (auto falseMux = + dyn_cast_or_null(op.getFalseValue().getDefiningOp())) { + // mux(selector, x, mux(selector, y, z) = mux(selector, x, z) + if (op.getCond() == falseMux.getCond()) { + replaceOpWithNewOpAndCopyName( + rewriter, op, op.getCond(), op.getTrueValue(), + falseMux.getFalseValue(), op.getTwoStateAttr()); + return success(); + } + + // Check to see if we can fold a mux tree into an array_create/get pair. + if (foldMuxChain(op, /*isFalse*/ true, rewriter)) return success(); + } + + if (auto trueMux = + dyn_cast_or_null(op.getTrueValue().getDefiningOp())) { + // mux(selector, mux(selector, a, b), c) = mux(selector, a, c) + if (op.getCond() == trueMux.getCond()) { + replaceOpWithNewOpAndCopyName( + rewriter, op, op.getCond(), trueMux.getTrueValue(), + op.getFalseValue(), op.getTwoStateAttr()); + return success(); + } + + // Check to see if we can fold a mux tree into an array_create/get pair. + if (foldMuxChain(op, /*isFalseSide*/ false, rewriter)) return success(); + } + + // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c)) + if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), + falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); + trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && + trueMux.getTrueValue() == falseMux.getTrueValue()) { + auto subMux = rewriter.create( + rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}), + op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue()); + replaceOpWithNewOpAndCopyName(rewriter, op, trueMux.getCond(), + trueMux.getTrueValue(), subMux, + op.getTwoStateAttr()); + return success(); + } + + // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b) + if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), + falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); + trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && + trueMux.getFalseValue() == falseMux.getFalseValue()) { + auto subMux = rewriter.create( + rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}), + op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue()); + replaceOpWithNewOpAndCopyName(rewriter, op, trueMux.getCond(), + subMux, trueMux.getFalseValue(), + op.getTwoStateAttr()); + return success(); + } + + // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b) + if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), + falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); + trueMux && falseMux && + trueMux.getTrueValue() == falseMux.getTrueValue() && + trueMux.getFalseValue() == falseMux.getFalseValue()) { + auto subMux = rewriter.create( + rewriter.getFusedLoc( + {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}), + op.getCond(), trueMux.getCond(), falseMux.getCond()); + replaceOpWithNewOpAndCopyName( + rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(), + op.getTwoStateAttr()); + return success(); + } + + // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a + if (foldCommonMuxValue(op, false, rewriter)) return success(); + // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a + if (foldCommonMuxValue(op, true, rewriter)) return success(); + + // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))` + if (Operation *trueOp = op.getTrueValue().getDefiningOp()) + if (Operation *falseOp = op.getFalseValue().getDefiningOp()) + if (trueOp->getName() == falseOp->getName()) + if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter)) + return success(); + + // extracts only of mux(...) -> mux(extract()...) + if (narrowOperationWidth(op, true, rewriter)) return success(); + + // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2)) + if (foldMuxOfUniformArrays(op, rewriter)) return success(); + + return failure(); +} + +static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) { + // Do not fold uniform or singleton arrays to avoid duplicating muxes. + if (op.getInputs().empty() || op.isUniform()) return false; + auto inputs = op.getInputs(); + if (inputs.size() <= 1) return false; + + // Check the operands to the array create. Ensure all of them are the + // same op with the same number of operands. + auto first = inputs[0].getDefiningOp(); + if (!first || hasSVAttributes(first)) return false; + + // Check whether all operands are muxes with the same condition. + for (size_t i = 1, n = inputs.size(); i < n; ++i) { + auto input = inputs[i].getDefiningOp(); + if (!input || first.getCond() != input.getCond()) return false; + } + + // Collect the true and the false branches into arrays. + SmallVector trues{first.getTrueValue()}; + SmallVector falses{first.getFalseValue()}; + SmallVector locs{first->getLoc()}; + bool isTwoState = true; + for (size_t i = 1, n = inputs.size(); i < n; ++i) { + auto input = inputs[i].getDefiningOp(); + trues.push_back(input.getTrueValue()); + falses.push_back(input.getFalseValue()); + locs.push_back(input->getLoc()); + if (!input.getTwoState()) isTwoState = false; + } + + // Define the location of the array create as the aggregate of all muxes. + auto loc = FusedLoc::get(op.getContext(), locs); + + // Replace the create with an aggregate operation. Push the create op + // into the operands of the aggregate operation. + auto arrayTy = op.getType(); + auto trueValues = rewriter.create(loc, arrayTy, trues); + auto falseValues = rewriter.create(loc, arrayTy, falses); + rewriter.replaceOpWithNewOp(op, arrayTy, first.getCond(), + trueValues, falseValues, isTwoState); + return true; +} + +struct ArrayRewriter : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(hw::ArrayCreateOp op, + PatternRewriter &rewriter) const override { + if (foldArrayOfMuxes(op, rewriter)) return success(); + return failure(); + } +}; + +} // namespace + +void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// ICmpOp +//===----------------------------------------------------------------------===// + +// Calculate the result of a comparison when the LHS and RHS are both +// constants. +static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, + const APInt &rhs) { + switch (predicate) { + case ICmpPredicate::eq: + return lhs.eq(rhs); + case ICmpPredicate::ne: + return lhs.ne(rhs); + case ICmpPredicate::slt: + return lhs.slt(rhs); + case ICmpPredicate::sle: + return lhs.sle(rhs); + case ICmpPredicate::sgt: + return lhs.sgt(rhs); + case ICmpPredicate::sge: + return lhs.sge(rhs); + case ICmpPredicate::ult: + return lhs.ult(rhs); + case ICmpPredicate::ule: + return lhs.ule(rhs); + case ICmpPredicate::ugt: + return lhs.ugt(rhs); + case ICmpPredicate::uge: + return lhs.uge(rhs); + case ICmpPredicate::ceq: + return lhs.eq(rhs); + case ICmpPredicate::cne: + return lhs.ne(rhs); + case ICmpPredicate::weq: + return lhs.eq(rhs); + case ICmpPredicate::wne: + return lhs.ne(rhs); + } + llvm_unreachable("unknown comparison predicate"); +} + +// Returns the result of applying the predicate when the LHS and RHS are the +// exact same value. +static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) { + switch (predicate) { + case ICmpPredicate::eq: + case ICmpPredicate::sle: + case ICmpPredicate::sge: + case ICmpPredicate::ule: + case ICmpPredicate::uge: + case ICmpPredicate::ceq: + case ICmpPredicate::weq: + return true; + case ICmpPredicate::ne: + case ICmpPredicate::slt: + case ICmpPredicate::sgt: + case ICmpPredicate::ult: + case ICmpPredicate::ugt: + case ICmpPredicate::cne: + case ICmpPredicate::wne: + return false; + } + llvm_unreachable("unknown comparison predicate"); +} + +OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) { + // gt a, a -> false + // gte a, a -> true + if (getLhs() == getRhs()) { + auto val = applyCmpPredicateToEqualOperands(getPredicate()); + return IntegerAttr::get(getType(), val); + } + + // gt 1, 2 -> false + if (auto lhs = adaptor.getLhs().dyn_cast_or_null()) { + if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { + auto val = + applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); + return IntegerAttr::get(getType(), val); + } + } + return {}; +} + +// Given a range of operands, computes the number of matching prefix and +// suffix elements. This does not perform cross-element matching. +template +static size_t computeCommonPrefixLength(const Range &a, const Range &b) { + size_t commonPrefixLength = 0; + auto ia = a.begin(); + auto ib = b.begin(); + + for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) { + if (*ia != *ib) { + break; + } + } + + return commonPrefixLength; +} + +static size_t getTotalWidth(ArrayRef operands) { + size_t totalWidth = 0; + for (auto operand : operands) { + // getIntOrFloatBitWidth should never raise, since all arguments to + // ConcatOp are integers. + ssize_t width = operand.getType().getIntOrFloatBitWidth(); + assert(width >= 0); + totalWidth += width; + } + return totalWidth; +} + +/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise +/// comparison on common prefix and suffixes. Returns success() if a rewriting +/// happens. This handles both concat and replicate. +static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, + Operation *rhs, + PatternRewriter &rewriter) { + // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and + // all elements have non-zero length. Both these invariants are verified + // by the ConcatOp verifier. + SmallVector lhsOperands, rhsOperands; + getConcatOperands(lhs->getResult(0), lhsOperands); + getConcatOperands(rhs->getResult(0), rhsOperands); + ArrayRef lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands; + + auto formCatOrReplicate = [&](Location loc, + ArrayRef operands) -> Value { + assert(!operands.empty()); + Value sameElement = operands[0]; + for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i) + if (sameElement != operands[i]) sameElement = Value(); + if (sameElement) + return rewriter.createOrFold(loc, sameElement, + operands.size()); + return rewriter.createOrFold(loc, operands); + }; + + auto replaceWith = [&](ICmpPredicate predicate, Value lhs, + Value rhs) -> LogicalResult { + replaceOpWithNewOpAndCopyName(rewriter, op, predicate, lhs, rhs, + op.getTwoState()); + return success(); + }; + + size_t commonPrefixLength = + computeCommonPrefixLength(lhsOperands, rhsOperands); + if (commonPrefixLength == lhsOperands.size()) { + // cat(a, b, c) == cat(a, b, c) -> 1 + bool result = applyCmpPredicateToEqualOperands(op.getPredicate()); + replaceOpWithNewOpAndCopyName(rewriter, op, + APInt(1, result)); + return success(); + } + + size_t commonSuffixLength = computeCommonPrefixLength( + llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef)); + + size_t commonPrefixTotalWidth = + getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength)); + size_t commonSuffixTotalWidth = + getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength)); + auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength) + .drop_back(commonSuffixLength); + auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength) + .drop_back(commonSuffixLength); + + auto replaceWithoutReplicatingSignBit = [&]() { + auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly); + auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly); + return replaceWith(op.getPredicate(), newLhs, newRhs); + }; + + auto replaceWithReplicatingSignBit = [&]() { + auto firstNonEmptyValue = lhsOperands[0]; + auto firstNonEmptyElemWidth = + firstNonEmptyValue.getType().getIntOrFloatBitWidth(); + Value signBit = rewriter.createOrFold( + op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1); + + auto newLhs = rewriter.create(lhs->getLoc(), signBit, lhsOnly); + auto newRhs = rewriter.create(rhs->getLoc(), signBit, rhsOnly); + return replaceWith(op.getPredicate(), newLhs, newRhs); + }; + + if (ICmpOp::isPredicateSigned(op.getPredicate())) { + // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y)) + if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0) + return replaceWithoutReplicatingSignBit(); + + // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x), + // cat(sgn(b), ..y)) Note that we cannot perform this optimization if + // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign + // bit. Doing the rewrite can result in an infinite loop. + if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0) + return replaceWithReplicatingSignBit(); + + } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) { + // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y)) + return replaceWithoutReplicatingSignBit(); + } + + return failure(); +} + +/// Given an equality comparison with a constant value and some operand that has +/// known bits, simplify the comparison to check only the unknown bits of the +/// input. +/// +/// One simple example of this is that `concat(0, stuff) == 0` can be simplified +/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to +/// `extract x[1:0] == 0` +static void combineEqualityICmpWithKnownBitsAndConstant( + ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, + PatternRewriter &rewriter) { + // If any of the known bits disagree with any of the comparison bits, then + // we can constant fold this comparison right away. + APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One; + if ((bitsKnown & rhsCst) != bitAnalysis.One) { + // If we discover a mismatch then we know an "eq" comparison is false + // and a "ne" comparison is true! + bool result = cmpOp.getPredicate() == ICmpPredicate::ne; + replaceOpWithNewOpAndCopyName(rewriter, cmpOp, + APInt(1, result)); + return; + } + + // Check to see if we can prove the result entirely of the comparison (in + // which we bail out early), otherwise build a list of values to concat and a + // smaller constant to compare against. + SmallVector newConcatOperands; + auto newConstant = APInt::getZeroWidth(); + + // Ok, some (maybe all) bits are known and some others may be unknown. + // Extract out segments of the operand and compare against the + // corresponding bits. + unsigned knownMSB = bitsKnown.countLeadingOnes(); + + Value operand = cmpOp.getLhs(); + + // Ok, some bits are known but others are not. Extract out sequences of + // bits that are unknown and compare just those bits. We work from MSB to + // LSB. + while (knownMSB != bitsKnown.getBitWidth()) { + // Drop any high bits that are known. + if (knownMSB) + bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB); + + // Find the span of unknown bits, and extract it. + unsigned unknownBits = bitsKnown.countLeadingZeros(); + unsigned lowBit = bitsKnown.getBitWidth() - unknownBits; + auto spanOperand = rewriter.createOrFold( + operand.getLoc(), operand, /*lowBit=*/lowBit, + /*bitWidth=*/unknownBits); + auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits); + + // Add this info to the concat we're generating. + newConcatOperands.push_back(spanOperand); + // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values + // newConstant = newConstant.concat(spanConstant); + if (newConstant.getBitWidth() != 0) + newConstant = newConstant.concat(spanConstant); + else + newConstant = spanConstant; + + // Drop the unknown bits in prep for the next chunk. + unsigned newWidth = bitsKnown.getBitWidth() - unknownBits; + bitsKnown = bitsKnown.trunc(newWidth); + knownMSB = bitsKnown.countLeadingOnes(); + } + + // If all the operands to the concat are foldable then we have an identity + // situation where all the sub-elements equal each other. This implies that + // the overall result is foldable. + if (newConcatOperands.empty()) { + bool result = cmpOp.getPredicate() == ICmpPredicate::eq; + replaceOpWithNewOpAndCopyName(rewriter, cmpOp, + APInt(1, result)); + return; + } + + // If we have a single operand remaining, use it, otherwise form a concat. + Value concatResult = + rewriter.createOrFold(operand.getLoc(), newConcatOperands); + + // Form the comparison against the smaller constant. + auto newConstantOp = rewriter.create( + cmpOp.getOperand(1).getLoc(), newConstant); + + replaceOpWithNewOpAndCopyName(rewriter, cmpOp, cmpOp.getPredicate(), + concatResult, newConstantOp, + cmpOp.getTwoState()); +} + +// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2). +static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, + const APInt &rhs, + PatternRewriter &rewriter) { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(xorOp); + + auto xorRHS = xorOp.getOperands().back().getDefiningOp(); + auto newRHS = rewriter.create(xorRHS->getLoc(), + xorRHS.getValue() ^ rhs); + Value newLHS; + switch (xorOp.getNumOperands()) { + case 1: + // This isn't common but is defined so we need to handle it. + newLHS = rewriter.create( + xorOp.getLoc(), APInt::getZero(rhs.getBitWidth())); + break; + case 2: + // The binary case is the most common. + newLHS = xorOp.getOperand(0); + break; + default: + // The general case forces us to form a new xor with the remaining + // operands. + SmallVector newOperands(xorOp.getOperands()); + newOperands.pop_back(); + newLHS = rewriter.create(xorOp.getLoc(), newOperands, false); + break; + } + + bool xorMultipleUses = !xorOp->hasOneUse(); + + // If the xor has multiple uses (not just the compare, then we need/want to + // replace them as well. + if (xorMultipleUses) + replaceOpWithNewOpAndCopyName(rewriter, xorOp, newLHS, xorRHS, + false); + + // Replace the comparison. + rewriter.restoreInsertionPoint(ip); + replaceOpWithNewOpAndCopyName(rewriter, cmpOp, cmpOp.getPredicate(), + newLHS, newRHS, false); +} + +LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) { + APInt lhs, rhs; + + // icmp 1, x -> icmp x, 1 + if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) { + assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) && + "Should be folded"); + replaceOpWithNewOpAndCopyName( + rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()), + op.getRhs(), op.getLhs(), op.getTwoState()); + return success(); + } + + // Canonicalize with RHS constant + if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) { + auto getConstant = [&](APInt constant) -> Value { + return rewriter.create(op.getLoc(), std::move(constant)); + }; + + auto replaceWith = [&](ICmpPredicate predicate, Value lhs, + Value rhs) -> LogicalResult { + replaceOpWithNewOpAndCopyName(rewriter, op, predicate, lhs, rhs, + op.getTwoState()); + return success(); + }; + + auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult { + replaceOpWithNewOpAndCopyName(rewriter, op, + APInt(1, constant)); + return success(); + }; + + switch (op.getPredicate()) { + case ICmpPredicate::slt: + // x < max -> x != max + if (rhs.isMaxSignedValue()) + return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); + // x < min -> false + if (rhs.isMinSignedValue()) return replaceWithConstantI1(0); + // x < min+1 -> x == min + if ((rhs - 1).isMinSignedValue()) + return replaceWith(ICmpPredicate::eq, op.getLhs(), + getConstant(rhs - 1)); + break; + case ICmpPredicate::sgt: + // x > min -> x != min + if (rhs.isMinSignedValue()) + return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); + // x > max -> false + if (rhs.isMaxSignedValue()) return replaceWithConstantI1(0); + // x > max-1 -> x == max + if ((rhs + 1).isMaxSignedValue()) + return replaceWith(ICmpPredicate::eq, op.getLhs(), + getConstant(rhs + 1)); + break; + case ICmpPredicate::ult: + // x < max -> x != max + if (rhs.isAllOnes()) + return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); + // x < min -> false + if (rhs.isZero()) return replaceWithConstantI1(0); + // x < min+1 -> x == min + if ((rhs - 1).isZero()) + return replaceWith(ICmpPredicate::eq, op.getLhs(), + getConstant(rhs - 1)); + + // x < 0xE0 -> extract(x, 5..7) != 0b111 + if (rhs.countLeadingOnes() + rhs.countTrailingZeros() == + rhs.getBitWidth()) { + auto numOnes = rhs.countLeadingOnes(); + auto smaller = rewriter.create( + op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes); + return replaceWith(ICmpPredicate::ne, smaller, + getConstant(APInt::getAllOnes(numOnes))); + } + + break; + case ICmpPredicate::ugt: + // x > min -> x != min + if (rhs.isZero()) + return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); + // x > max -> false + if (rhs.isAllOnes()) return replaceWithConstantI1(0); + // x > max-1 -> x == max + if ((rhs + 1).isAllOnes()) + return replaceWith(ICmpPredicate::eq, op.getLhs(), + getConstant(rhs + 1)); + + // x > 0x07 -> extract(x, 3..7) != 0b00000 + if ((rhs + 1).isPowerOf2()) { + auto numOnes = rhs.countTrailingOnes(); + auto newWidth = rhs.getBitWidth() - numOnes; + auto smaller = rewriter.create(op.getLoc(), op.getLhs(), + numOnes, newWidth); + return replaceWith(ICmpPredicate::ne, smaller, + getConstant(APInt::getZero(newWidth))); + } + + break; + case ICmpPredicate::sle: + // x <= max -> true + if (rhs.isMaxSignedValue()) return replaceWithConstantI1(1); + // x <= c -> x < (c+1) + return replaceWith(ICmpPredicate::slt, op.getLhs(), + getConstant(rhs + 1)); + case ICmpPredicate::sge: + // x >= min -> true + if (rhs.isMinSignedValue()) return replaceWithConstantI1(1); + // x >= c -> x > (c-1) + return replaceWith(ICmpPredicate::sgt, op.getLhs(), + getConstant(rhs - 1)); + case ICmpPredicate::ule: + // x <= max -> true + if (rhs.isAllOnes()) return replaceWithConstantI1(1); + // x <= c -> x < (c+1) + return replaceWith(ICmpPredicate::ult, op.getLhs(), + getConstant(rhs + 1)); + case ICmpPredicate::uge: + // x >= min -> true + if (rhs.isZero()) return replaceWithConstantI1(1); + // x >= c -> x > (c-1) + return replaceWith(ICmpPredicate::ugt, op.getLhs(), + getConstant(rhs - 1)); + case ICmpPredicate::eq: + if (rhs.getBitWidth() == 1) { + if (rhs.isZero()) { + // x == 0 -> x ^ 1 + replaceOpWithNewOpAndCopyName(rewriter, op, op.getLhs(), + getConstant(APInt(1, 1)), + op.getTwoState()); + return success(); + } + if (rhs.isAllOnes()) { + // x == 1 -> x + replaceOpAndCopyName(rewriter, op, op.getLhs()); + return success(); + } + } + break; + case ICmpPredicate::ne: + if (rhs.getBitWidth() == 1) { + if (rhs.isZero()) { + // x != 0 -> x + replaceOpAndCopyName(rewriter, op, op.getLhs()); + return success(); + } + if (rhs.isAllOnes()) { + // x != 1 -> x ^ 1 + replaceOpWithNewOpAndCopyName(rewriter, op, op.getLhs(), + getConstant(APInt(1, 1)), + op.getTwoState()); + return success(); + } + } + break; + case ICmpPredicate::ceq: + case ICmpPredicate::cne: + case ICmpPredicate::weq: + case ICmpPredicate::wne: + break; + } + + // We have some specific optimizations for comparison with a constant that + // are only supported for equality comparisons. + if (op.getPredicate() == ICmpPredicate::eq || + op.getPredicate() == ICmpPredicate::ne) { + // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts + // with a smaller constant. We only support equality comparisons for + // this. + auto knownBits = computeKnownBits(op.getLhs()); + if (!knownBits.isUnknown()) + return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs, + rewriter), + success(); + + // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), + // cst1^cst2). + if (auto xorOp = op.getLhs().getDefiningOp()) + if (xorOp.getOperands().back().getDefiningOp()) + return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter), + success(); + + // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or + // all one. + if (auto replicateOp = op.getLhs().getDefiningOp()) + if (rhs.isAllOnes() || rhs.isZero()) { + auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth(); + auto cst = rewriter.create( + op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width) + : APInt::getZero(width)); + replaceOpWithNewOpAndCopyName(rewriter, op, op.getPredicate(), + replicateOp.getInput(), cst, + op.getTwoState()); + return success(); + } + } + } + + // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a, + // b), cat(c, d)). contains special handling for sign bit in signed + // compressions. + if (Operation *opLHS = op.getLhs().getDefiningOp()) + if (Operation *opRHS = op.getRhs().getDefiningOp()) + if (isa(opLHS) && + isa(opRHS)) { + if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter))) + return success(); + } + + return failure(); +} diff --git a/third_party/circt/lib/Dialect/Comb/CombOps.cpp b/third_party/circt/lib/Dialect/Comb/CombOps.cpp new file mode 100644 index 000000000..597aab6f4 --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/CombOps.cpp @@ -0,0 +1,302 @@ +//===- CombOps.cpp - Implement the Comb operations ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements combinational ops. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Comb/CombOps.h" + +#include "circt/Dialect/HW/HWOps.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" + +using namespace circt; +using namespace comb; + +/// Create a sign extension operation from a value of integer type to an equal +/// or larger integer type. +Value comb::createOrFoldSExt(Location loc, Value value, Type destTy, + OpBuilder &builder) { + IntegerType valueType = value.getType().dyn_cast(); + assert(valueType && destTy.isa() && + valueType.getWidth() <= destTy.getIntOrFloatBitWidth() && + valueType.getWidth() != 0 && "invalid sext operands"); + // If already the right size, we are done. + if (valueType == destTy) return value; + + // sext is concat with a replicate of the sign bits and the bottom part. + auto signBit = + builder.createOrFold(loc, value, valueType.getWidth() - 1, 1); + auto signBits = builder.createOrFold( + loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth()); + return builder.createOrFold(loc, signBits, value); +} + +Value comb::createOrFoldSExt(Value value, Type destTy, + ImplicitLocOpBuilder &builder) { + return createOrFoldSExt(builder.getLoc(), value, destTy, builder); +} + +Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder, + bool twoState) { + auto allOnes = builder.create(loc, value.getType(), -1); + return builder.createOrFold(loc, value, allOnes, twoState); +} + +Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, + bool twoState) { + return createOrFoldNot(builder.getLoc(), value, builder, twoState); +} + +//===----------------------------------------------------------------------===// +// ICmpOp +//===----------------------------------------------------------------------===// + +ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) { + switch (predicate) { + case ICmpPredicate::eq: + return ICmpPredicate::eq; + case ICmpPredicate::ne: + return ICmpPredicate::ne; + case ICmpPredicate::slt: + return ICmpPredicate::sgt; + case ICmpPredicate::sle: + return ICmpPredicate::sge; + case ICmpPredicate::sgt: + return ICmpPredicate::slt; + case ICmpPredicate::sge: + return ICmpPredicate::sle; + case ICmpPredicate::ult: + return ICmpPredicate::ugt; + case ICmpPredicate::ule: + return ICmpPredicate::uge; + case ICmpPredicate::ugt: + return ICmpPredicate::ult; + case ICmpPredicate::uge: + return ICmpPredicate::ule; + case ICmpPredicate::ceq: + return ICmpPredicate::ceq; + case ICmpPredicate::cne: + return ICmpPredicate::cne; + case ICmpPredicate::weq: + return ICmpPredicate::weq; + case ICmpPredicate::wne: + return ICmpPredicate::wne; + } + llvm_unreachable("unknown comparison predicate"); +} + +bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) { + switch (predicate) { + case ICmpPredicate::ult: + case ICmpPredicate::ugt: + case ICmpPredicate::ule: + case ICmpPredicate::uge: + case ICmpPredicate::ne: + case ICmpPredicate::eq: + case ICmpPredicate::cne: + case ICmpPredicate::ceq: + case ICmpPredicate::wne: + case ICmpPredicate::weq: + return false; + case ICmpPredicate::slt: + case ICmpPredicate::sgt: + case ICmpPredicate::sle: + case ICmpPredicate::sge: + return true; + } + llvm_unreachable("unknown comparison predicate"); +} + +/// Returns the predicate for a logically negated comparison, e.g. mapping +/// EQ => NE and SLE => SGT. +ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) { + switch (predicate) { + case ICmpPredicate::eq: + return ICmpPredicate::ne; + case ICmpPredicate::ne: + return ICmpPredicate::eq; + case ICmpPredicate::slt: + return ICmpPredicate::sge; + case ICmpPredicate::sle: + return ICmpPredicate::sgt; + case ICmpPredicate::sgt: + return ICmpPredicate::sle; + case ICmpPredicate::sge: + return ICmpPredicate::slt; + case ICmpPredicate::ult: + return ICmpPredicate::uge; + case ICmpPredicate::ule: + return ICmpPredicate::ugt; + case ICmpPredicate::ugt: + return ICmpPredicate::ule; + case ICmpPredicate::uge: + return ICmpPredicate::ult; + case ICmpPredicate::ceq: + return ICmpPredicate::cne; + case ICmpPredicate::cne: + return ICmpPredicate::ceq; + case ICmpPredicate::weq: + return ICmpPredicate::wne; + case ICmpPredicate::wne: + return ICmpPredicate::weq; + } + llvm_unreachable("unknown comparison predicate"); +} + +/// Return true if this is an equality test with -1, which is a "reduction +/// and" operation in Verilog. +bool ICmpOp::isEqualAllOnes() { + if (getPredicate() != ICmpPredicate::eq) return false; + + if (auto op1 = + dyn_cast_or_null(getOperand(1).getDefiningOp())) + return op1.getValue().isAllOnes(); + return false; +} + +/// Return true if this is a not equal test with 0, which is a "reduction +/// or" operation in Verilog. +bool ICmpOp::isNotEqualZero() { + if (getPredicate() != ICmpPredicate::ne) return false; + + if (auto op1 = + dyn_cast_or_null(getOperand(1).getDefiningOp())) + return op1.getValue().isZero(); + return false; +} + +//===----------------------------------------------------------------------===// +// Unary Operations +//===----------------------------------------------------------------------===// + +LogicalResult ReplicateOp::verify() { + // The source must be equal or smaller than the dest type, and an even + // multiple of it. Both are already known to be signless integers. + auto srcWidth = getOperand().getType().cast().getWidth(); + auto dstWidth = getType().cast().getWidth(); + if (srcWidth == 0) + return emitOpError("replicate does not take zero bit integer"); + + if (srcWidth > dstWidth) + return emitOpError("replicate cannot shrink bitwidth of operand"), + failure(); + + if (dstWidth % srcWidth) + return emitOpError("replicate must produce integer multiple of operand"), + failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Variadic operations +//===----------------------------------------------------------------------===// + +static LogicalResult verifyUTBinOp(Operation *op) { + if (op->getOperands().empty()) + return op->emitOpError("requires 1 or more args"); + return success(); +} + +LogicalResult AddOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult MulOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult AndOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult OrOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult XorOp::verify() { return verifyUTBinOp(*this); } + +/// Return true if this is a two operand xor with an all ones constant as its +/// RHS operand. +bool XorOp::isBinaryNot() { + if (getNumOperands() != 2) return false; + if (auto cst = getOperand(1).getDefiningOp()) + if (cst.getValue().isAllOnes()) return true; + return false; +} + +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +static unsigned getTotalWidth(ValueRange inputs) { + unsigned resultWidth = 0; + for (auto input : inputs) { + resultWidth += input.getType().cast().getWidth(); + } + return resultWidth; +} + +LogicalResult ConcatOp::verify() { + unsigned tyWidth = getType().cast().getWidth(); + unsigned operandsTotalWidth = getTotalWidth(getInputs()); + if (tyWidth != operandsTotalWidth) + return emitOpError( + "ConcatOp requires operands total width to " + "match type width. operands " + "totalWidth is") + << operandsTotalWidth << ", but concatOp type width is " << tyWidth; + + return success(); +} + +void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd, + ValueRange tl) { + result.addOperands(ValueRange{hd}); + result.addOperands(tl); + unsigned hdWidth = hd.getType().cast().getWidth(); + result.addTypes(builder.getIntegerType(getTotalWidth(tl) + hdWidth)); +} + +LogicalResult ConcatOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attrs, mlir::OpaqueProperties properties, + mlir::RegionRange regions, SmallVectorImpl &results) { + unsigned resultWidth = getTotalWidth(operands); + results.push_back(IntegerType::get(context, resultWidth)); + return success(); +} + +//===----------------------------------------------------------------------===// +// Other Operations +//===----------------------------------------------------------------------===// + +LogicalResult ExtractOp::verify() { + unsigned srcWidth = getInput().getType().cast().getWidth(); + unsigned dstWidth = getType().cast().getWidth(); + if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth) + return emitOpError("from bit too large for input"), failure(); + + return success(); +} + +LogicalResult TruthTableOp::verify() { + size_t numInputs = getInputs().size(); + if (numInputs >= sizeof(size_t) * 8) + return emitOpError("Truth tables support a maximum of ") + << sizeof(size_t) * 8 - 1 << " inputs on your platform"; + + ArrayAttr table = getLookupTable(); + if (table.size() != (1ull << numInputs)) + return emitOpError("Expected lookup table of 2^n length"); + return success(); +} + +//===----------------------------------------------------------------------===// +// TableGen generated logic. +//===----------------------------------------------------------------------===// + +// Provide the autogenerated implementation guts for the Op classes. +#define GET_OP_CLASSES +#include "circt/Dialect/Comb/Comb.cpp.inc" diff --git a/third_party/circt/lib/Dialect/Comb/Transforms/CMakeLists.txt b/third_party/circt/lib/Dialect/Comb/Transforms/CMakeLists.txt new file mode 100644 index 000000000..9043ccf5d --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/Transforms/CMakeLists.txt @@ -0,0 +1,15 @@ +add_circt_dialect_library(CIRCTCombTransforms + LowerComb.cpp + + DEPENDS + CIRCTCombTransformsIncGen + + LINK_LIBS PUBLIC + CIRCTHW + CIRCTSV + CIRCTComb + CIRCTSupport + MLIRIR + MLIRPass + MLIRTransformUtils +) diff --git a/third_party/circt/lib/Dialect/Comb/Transforms/LowerComb.cpp b/third_party/circt/lib/Dialect/Comb/Transforms/LowerComb.cpp new file mode 100644 index 000000000..a4114b922 --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/Transforms/LowerComb.cpp @@ -0,0 +1,88 @@ +//===- LowerComb.cpp - Lower some ops in comb -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/Comb/CombPasses.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace circt; +using namespace circt::comb; + +namespace circt { +namespace comb { +#define GEN_PASS_DEF_LOWERCOMB +#include "circt/Dialect/Comb/Passes.h.inc" +} // namespace comb +} // namespace circt + +namespace { +/// Lower truth tables to mux trees. +struct TruthTableToMuxTree : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + private: + /// Get a mux tree for `inputs` corresponding to the given truth table. Do + /// this recursively by dividing the table in half for each input. + // NOLINTNEXTLINE(misc-no-recursion) + Value getMux(Location loc, OpBuilder &b, Value t, Value f, + ArrayRef table, Operation::operand_range inputs) const { + assert(table.size() == (1ull << inputs.size())); + if (table.size() == 1) return table.front() ? t : f; + + size_t half = table.size() / 2; + Value if1 = + getMux(loc, b, t, f, table.drop_front(half), inputs.drop_front()); + Value if0 = + getMux(loc, b, t, f, table.drop_back(half), inputs.drop_front()); + return b.create(loc, inputs.front(), if1, if0, false); + } + + public: + LogicalResult matchAndRewrite(TruthTableOp op, OpAdaptor adaptor, + ConversionPatternRewriter &b) const override { + Location loc = op.getLoc(); + SmallVector table( + llvm::map_range(op.getLookupTableAttr().getAsValueRange(), + [](const APInt &a) { return !a.isZero(); })); + Value t = b.create(loc, b.getIntegerAttr(b.getI1Type(), 1)); + Value f = b.create(loc, b.getIntegerAttr(b.getI1Type(), 0)); + + Value tree = getMux(loc, b, t, f, table, op.getInputs()); + b.updateRootInPlace(tree.getDefiningOp(), [&]() { + tree.getDefiningOp()->setDialectAttrs(op->getDialectAttrs()); + }); + b.replaceOp(op, tree); + return success(); + } +}; +} // namespace + +namespace { +class LowerCombPass : public impl::LowerCombBase { + public: + using LowerCombBase::LowerCombBase; + + void runOnOperation() override; +}; +} // namespace + +void LowerCombPass::runOnOperation() { + ModuleOp module = getOperation(); + + ConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); + target.addIllegalOp(); + + patterns.add(patterns.getContext()); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/third_party/circt/lib/Dialect/Comb/Transforms/PassDetails.h b/third_party/circt/lib/Dialect/Comb/Transforms/PassDetails.h new file mode 100644 index 000000000..53c7b4b31 --- /dev/null +++ b/third_party/circt/lib/Dialect/Comb/Transforms/PassDetails.h @@ -0,0 +1,27 @@ +//===- PassDetails.h - Comb pass class details ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// clang-tidy seems to expect the absolute path in the header guard on some +// systems, so just disable it. +// NOLINTNEXTLINE(llvm-header-guard) +#ifndef DIALECT_COMB_TRANSFORMS_PASSDETAILS_H +#define DIALECT_COMB_TRANSFORMS_PASSDETAILS_H + +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Pass/Pass.h" + +namespace circt { +namespace comb { + +#define GEN_PASS_CLASSES +#include "circt/Dialect/Comb/Passes.h.inc" + +} // namespace comb +} // namespace circt + +#endif // DIALECT_COMB_TRANSFORMS_PASSDETAILS_H diff --git a/third_party/circt/lib/Dialect/HW/BUILD b/third_party/circt/lib/Dialect/HW/BUILD new file mode 100644 index 000000000..05fdfea84 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/BUILD @@ -0,0 +1,35 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = glob( + [ + "*.cpp", + ], + exclude = [ + "HWReductions.cpp", + ], + ), + hdrs = [ + "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", + ], + deps = [ + "@heir//third_party/circt/include/circt/Dialect/Comb:headers", + "@heir//third_party/circt/include/circt/Dialect/HW:attributes_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:dialect_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:enum_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:headers", + "@heir//third_party/circt/include/circt/Dialect/HW:op_interfaces_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:ops_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:type_interfaces_inc_gen", + "@heir//third_party/circt/include/circt/Dialect/HW:types_inc_gen", + "@heir//third_party/circt/lib/Support", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + ], +) diff --git a/third_party/circt/lib/Dialect/HW/CMakeLists.txt b/third_party/circt/lib/Dialect/HW/CMakeLists.txt new file mode 100644 index 000000000..83cc1559a --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/CMakeLists.txt @@ -0,0 +1,52 @@ +set(CIRCT_HW_Sources + ConversionPatterns.cpp + CustomDirectiveImpl.cpp + HWAttributes.cpp + HWDialect.cpp + HWInstanceGraph.cpp + HWModuleOpInterface.cpp + HWOpInterfaces.cpp + HWOps.cpp + HWTypeInterfaces.cpp + HWTypes.cpp + InstanceImplementation.cpp + ModuleImplementation.cpp + InnerSymbolTable.cpp + PortConverter.cpp +) + +set(LLVM_OPTIONAL_SOURCES + ${CIRCT_HW_Sources} + HWReductions.cpp +) + +add_circt_dialect_library(CIRCTHW + ${CIRCT_HW_Sources} + + ADDITIONAL_HEADER_DIRS + ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/HW + + DEPENDS + MLIRHWIncGen + MLIRHWAttrIncGen + MLIRHWEnumsIncGen + + LINK_COMPONENTS + Support + + LINK_LIBS PUBLIC + CIRCTSupport + MLIRIR + MLIRInferTypeOpInterface +) + +add_circt_library(CIRCTHWReductions + HWReductions.cpp + + LINK_LIBS PUBLIC + CIRCTReduceLib + CIRCTHW + MLIRIR +) + +add_subdirectory(Transforms) diff --git a/third_party/circt/lib/Dialect/HW/ConversionPatterns.cpp b/third_party/circt/lib/Dialect/HW/ConversionPatterns.cpp new file mode 100644 index 000000000..6ae68b213 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/ConversionPatterns.cpp @@ -0,0 +1,103 @@ +//===- ConversionPatterns.cpp - Common Conversion patterns ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/ConversionPatterns.h" + +#include "circt/Dialect/HW/HWTypes.h" + +using namespace circt; + +// Converts a function type wrt. the given type converter. +static FunctionType convertFunctionType(const TypeConverter &typeConverter, + FunctionType type) { + // Convert the original function types. + llvm::SmallVector res, arg; + llvm::transform(type.getResults(), std::back_inserter(res), + [&](Type t) { return typeConverter.convertType(t); }); + llvm::transform(type.getInputs(), std::back_inserter(arg), + [&](Type t) { return typeConverter.convertType(t); }); + + return FunctionType::get(type.getContext(), arg, res); +} + +// Converts a function type wrt. the given type converter. +static hw::ModuleType convertModuleType(const TypeConverter &typeConverter, + hw::ModuleType type) { + // Convert the original function types. + SmallVector ports(type.getPorts()); + for (auto &p : ports) p.type = typeConverter.convertType(p.type); + return hw::ModuleType::get(type.getContext(), ports); +} + +LogicalResult TypeConversionPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // Convert the TypeAttrs. + llvm::SmallVector newAttrs; + newAttrs.reserve(op->getAttrs().size()); + for (auto attr : op->getAttrs()) { + if (auto typeAttr = attr.getValue().dyn_cast()) { + auto innerType = typeAttr.getValue(); + // TypeConvert::convertType doesn't handle function types, so we need to + // handle them manually. + if (auto funcType = innerType.dyn_cast()) + innerType = convertFunctionType(*getTypeConverter(), funcType); + else if (auto modType = innerType.dyn_cast()) + innerType = convertModuleType(*getTypeConverter(), modType); + else + innerType = getTypeConverter()->convertType(innerType); + newAttrs.emplace_back(attr.getName(), TypeAttr::get(innerType)); + } else { + newAttrs.push_back(attr); + } + } + + // Convert the result types. + llvm::SmallVector newResults; + if (failed( + getTypeConverter()->convertTypes(op->getResultTypes(), newResults))) + return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed"); + + // Build the state for the edited clone. + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResults, newAttrs, op->getSuccessors()); + for (size_t i = 0, e = op->getNumRegions(); i < e; ++i) state.addRegion(); + + // Must create the op before running any modifications on the regions so that + // we don't crash with '-debug' and so we have something to 'root update'. + Operation *newOp = rewriter.create(state); + + // Move the regions over, converting the signatures as we go. + rewriter.startRootUpdate(newOp); + for (size_t i = 0, e = op->getNumRegions(); i < e; ++i) { + Region ®ion = op->getRegion(i); + Region *newRegion = &newOp->getRegion(i); + + // TypeConverter::SignatureConversion drops argument locations, so we need + // to manually copy them over (a verifier in e.g. HWModule checks this). + llvm::SmallVector argLocs; + for (auto arg : region.getArguments()) argLocs.push_back(arg.getLoc()); + + // Move the region and convert the region args. + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + TypeConverter::SignatureConversion result(newRegion->getNumArguments()); + if (failed(getTypeConverter()->convertSignatureArgs( + newRegion->getArgumentTypes(), result))) + return rewriter.notifyMatchFailure(op->getLoc(), + "type conversion failed"); + rewriter.applySignatureConversion(newRegion, result, getTypeConverter()); + + // Apply the argument locations. + for (auto [arg, loc] : llvm::zip(newRegion->getArguments(), argLocs)) + arg.setLoc(loc); + } + rewriter.finalizeRootUpdate(newOp); + + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} diff --git a/third_party/circt/lib/Dialect/HW/CustomDirectiveImpl.cpp b/third_party/circt/lib/Dialect/HW/CustomDirectiveImpl.cpp new file mode 100644 index 000000000..3e1b0eac3 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/CustomDirectiveImpl.cpp @@ -0,0 +1,136 @@ +//===- CustomDirectiveImpl.cpp - Custom directive definitions -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/CustomDirectiveImpl.h" + +#include "circt/Dialect/HW/HWAttributes.h" + +using namespace circt; +using namespace circt::hw; + +ParseResult circt::parseInputPortList( + OpAsmParser &parser, + SmallVectorImpl &inputs, + SmallVectorImpl &inputTypes, ArrayAttr &inputNames) { + SmallVector argNames; + auto parseInputPort = [&]() -> ParseResult { + std::string portName; + if (parser.parseKeywordOrString(&portName)) return failure(); + argNames.push_back(StringAttr::get(parser.getContext(), portName)); + inputs.push_back({}); + inputTypes.push_back({}); + return failure(parser.parseColon() || parser.parseOperand(inputs.back()) || + parser.parseColon() || parser.parseType(inputTypes.back())); + }; + llvm::SMLoc inputsOperandsLoc; + if (parser.getCurrentLocation(&inputsOperandsLoc) || + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, + parseInputPort)) + return failure(); + + inputNames = ArrayAttr::get(parser.getContext(), argNames); + + return success(); +} + +void circt::printInputPortList(OpAsmPrinter &p, Operation *op, + OperandRange inputs, TypeRange inputTypes, + ArrayAttr inputNames) { + p << "("; + llvm::interleaveComma(llvm::zip(inputs, inputNames), p, + [&](std::tuple input) { + Value val = std::get<0>(input); + p.printKeywordOrString( + std::get<1>(input).cast().getValue()); + p << ": " << val << ": " << val.getType(); + }); + p << ")"; +} + +ParseResult circt::parseOutputPortList(OpAsmParser &parser, + SmallVectorImpl &resultTypes, + ArrayAttr &resultNames) { + SmallVector names; + auto parseResultPort = [&]() -> ParseResult { + std::string portName; + if (parser.parseKeywordOrString(&portName)) return failure(); + names.push_back(StringAttr::get(parser.getContext(), portName)); + resultTypes.push_back({}); + return parser.parseColonType(resultTypes.back()); + }; + llvm::SMLoc inputsOperandsLoc; + if (parser.getCurrentLocation(&inputsOperandsLoc) || + parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, + parseResultPort)) + return failure(); + + resultNames = ArrayAttr::get(parser.getContext(), names); + + return success(); +} + +void circt::printOutputPortList(OpAsmPrinter &p, Operation *op, + TypeRange resultTypes, ArrayAttr resultNames) { + p << "("; + llvm::interleaveComma( + llvm::zip(resultTypes, resultNames), p, + [&](std::tuple result) { + p.printKeywordOrString( + std::get<1>(result).cast().getValue()); + p << ": " << std::get<0>(result); + }); + p << ")"; +} + +ParseResult circt::parseOptionalParameterList(OpAsmParser &parser, + ArrayAttr ¶meters) { + SmallVector params; + + auto parseParameter = [&]() { + std::string name; + Type type; + Attribute value; + + if (parser.parseKeywordOrString(&name) || parser.parseColonType(type)) + return failure(); + + // Parse the default value if present. + if (succeeded(parser.parseOptionalEqual())) { + if (parser.parseAttribute(value, type)) return failure(); + } + + auto &builder = parser.getBuilder(); + params.push_back(ParamDeclAttr::get( + builder.getContext(), builder.getStringAttr(name), type, value)); + return success(); + }; + + if (failed(parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::OptionalLessGreater, parseParameter))) + return failure(); + + parameters = ArrayAttr::get(parser.getContext(), params); + + return success(); +} + +void circt::printOptionalParameterList(OpAsmPrinter &p, Operation *op, + ArrayAttr parameters) { + if (parameters.empty()) return; + + p << '<'; + llvm::interleaveComma(parameters, p, [&](Attribute param) { + auto paramAttr = param.cast(); + p << paramAttr.getName().getValue() << ": " << paramAttr.getType(); + if (auto value = paramAttr.getValue()) { + p << " = "; + p.printAttributeWithoutType(value); + } + }); + p << '>'; +} diff --git a/third_party/circt/lib/Dialect/HW/HWAttributes.cpp b/third_party/circt/lib/Dialect/HW/HWAttributes.cpp new file mode 100644 index 000000000..b30448cb8 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWAttributes.cpp @@ -0,0 +1,1032 @@ +//===- HWAttributes.cpp - Implement HW attributes -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWAttributes.h" + +#include "circt/Dialect/HW/HWDialect.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWTypes.h" +#include "circt/Support/LLVM.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectImplementation.h" + +using namespace circt; +using namespace circt::hw; +using mlir::TypedAttr; + +// Internal method used for .mlir file parsing, defined below. +static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p, + Type type); + +//===----------------------------------------------------------------------===// +// ODS Boilerplate +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "circt/Dialect/HW/HWAttributes.cpp.inc" + +void HWDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "circt/Dialect/HW/HWAttributes.cpp.inc" + >(); +} + +Attribute HWDialect::parseAttribute(DialectAsmParser &p, Type type) const { + StringRef attrName; + Attribute attr; + auto parseResult = generatedAttributeParser(p, &attrName, type, attr); + if (parseResult.has_value()) return attr; + + // Parse "#hw.param.expr.add" as ParamExprAttr. + if (attrName.startswith(ParamExprAttr::getMnemonic())) { + auto string = attrName.drop_front(ParamExprAttr::getMnemonic().size()); + if (string.front() == '.') + return parseParamExprWithOpcode(string.drop_front(), p, type); + } + + p.emitError(p.getNameLoc(), "Unexpected hw attribute '" + attrName + "'"); + return {}; +} + +void HWDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const { + if (succeeded(generatedAttributePrinter(attr, p))) return; + llvm_unreachable("Unexpected attribute"); +} + +//===----------------------------------------------------------------------===// +// OutputFileAttr +//===----------------------------------------------------------------------===// + +static std::string canonicalizeFilename(const Twine &directory, + const Twine &filename) { + // Convert the filename to a native style path. + SmallString<128> nativeFilename; + llvm::sys::path::native(filename, nativeFilename); + + // If the filename is an absolute path, ignore the directory. + // e.g. `directory/` + `/etc/filename` -> `/etc/filename`. + if (llvm::sys::path::is_absolute(nativeFilename)) + return std::string(nativeFilename); + + // Convert the directory to a native style path. + SmallString<128> nativeDirectory; + llvm::sys::path::native(directory, nativeDirectory); + + // If the filename component is empty, then ensure that the path ends in a + // separator and return it. + // e.g. `directory` + `` -> `directory/`. + auto separator = llvm::sys::path::get_separator(); + if (nativeFilename.empty() && !nativeDirectory.endswith(separator)) { + nativeDirectory += separator; + return std::string(nativeDirectory); + } + + // Append the directory and filename together. + // e.g. `/tmp/` + `out/filename` -> `/tmp/out/filename`. + SmallString<128> fullPath; + llvm::sys::path::append(fullPath, nativeDirectory, nativeFilename); + return std::string(fullPath); +} + +OutputFileAttr OutputFileAttr::getFromFilename(MLIRContext *context, + const Twine &filename, + bool excludeFromFileList, + bool includeReplicatedOps) { + return OutputFileAttr::getFromDirectoryAndFilename( + context, "", filename, excludeFromFileList, includeReplicatedOps); +} + +OutputFileAttr OutputFileAttr::getFromDirectoryAndFilename( + MLIRContext *context, const Twine &directory, const Twine &filename, + bool excludeFromFileList, bool includeReplicatedOps) { + auto canonicalized = canonicalizeFilename(directory, filename); + return OutputFileAttr::get(StringAttr::get(context, canonicalized), + BoolAttr::get(context, excludeFromFileList), + BoolAttr::get(context, includeReplicatedOps)); +} + +OutputFileAttr OutputFileAttr::getAsDirectory(MLIRContext *context, + const Twine &directory, + bool excludeFromFileList, + bool includeReplicatedOps) { + return getFromDirectoryAndFilename(context, directory, "", + excludeFromFileList, includeReplicatedOps); +} + +bool OutputFileAttr::isDirectory() { + return getFilename().getValue().endswith(llvm::sys::path::get_separator()); +} + +/// Option ::= 'excludeFromFileList' | 'includeReplicatedOp' +/// OutputFileAttr ::= 'output_file<' directory ',' name (',' Option)* '>' +Attribute OutputFileAttr::parse(AsmParser &p, Type type) { + StringAttr filename; + if (p.parseLess() || p.parseAttribute(filename)) + return Attribute(); + + // Parse the additional keyword attributes. Its easier to let people specify + // these more than once than to detect the problem and do something about it. + bool excludeFromFileList = false; + bool includeReplicatedOps = false; + while (true) { + if (p.parseOptionalComma()) break; + if (!p.parseOptionalKeyword("excludeFromFileList")) + excludeFromFileList = true; + else if (!p.parseKeyword("includeReplicatedOps", + "or 'excludeFromFileList'")) + includeReplicatedOps = true; + else + return Attribute(); + } + + if (p.parseGreater()) return Attribute(); + + return OutputFileAttr::getFromFilename(p.getContext(), filename.getValue(), + excludeFromFileList, + includeReplicatedOps); +} + +void OutputFileAttr::print(AsmPrinter &p) const { + p << "<" << getFilename(); + if (getExcludeFromFilelist().getValue()) p << ", excludeFromFileList"; + if (getIncludeReplicatedOps().getValue()) p << ", includeReplicatedOps"; + p << ">"; +} + +//===----------------------------------------------------------------------===// +// FileListAttr +//===----------------------------------------------------------------------===// + +FileListAttr FileListAttr::getFromFilename(MLIRContext *context, + const Twine &filename) { + auto canonicalized = canonicalizeFilename("", filename); + return FileListAttr::get(StringAttr::get(context, canonicalized)); +} + +//===----------------------------------------------------------------------===// +// EnumFieldAttr +//===----------------------------------------------------------------------===// + +Attribute EnumFieldAttr::parse(AsmParser &p, Type) { + StringRef field; + Type type; + if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() || + p.parseType(type) || p.parseGreater()) + return Attribute(); + return EnumFieldAttr::get(p.getEncodedSourceLoc(p.getCurrentLocation()), + StringAttr::get(p.getContext(), field), type); +} + +void EnumFieldAttr::print(AsmPrinter &p) const { + p << "<" << getField().getValue() << ", "; + p.printType(getType().getValue()); + p << ">"; +} + +EnumFieldAttr EnumFieldAttr::get(Location loc, StringAttr value, + mlir::Type type) { + if (!hw::isHWEnumType(type)) emitError(loc) << "expected enum type"; + + // Check whether the provided value is a member of the enum type. + EnumType enumType = getCanonicalType(type).cast(); + if (!enumType.contains(value.getValue())) { + emitError(loc) << "enum value '" << value.getValue() + << "' is not a member of enum type " << enumType; + return nullptr; + } + + return Base::get(value.getContext(), value, TypeAttr::get(type)); +} + +//===----------------------------------------------------------------------===// +// InnerRefAttr +//===----------------------------------------------------------------------===// + +Attribute InnerRefAttr::parse(AsmParser &p, Type type) { + SymbolRefAttr attr; + if (p.parseLess() || p.parseAttribute(attr) || + p.parseGreater()) + return Attribute(); + if (attr.getNestedReferences().size() != 1) return Attribute(); + return InnerRefAttr::get(attr.getRootReference(), attr.getLeafReference()); +} + +void InnerRefAttr::print(AsmPrinter &p) const { + p << "<"; + p.printSymbolName(getModule().getValue()); + p << "::"; + p.printSymbolName(getName().getValue()); + p << ">"; +} + +//===----------------------------------------------------------------------===// +// InnerSymAttr and InnerSymPropertiesAttr +//===----------------------------------------------------------------------===// + +Attribute InnerSymPropertiesAttr::parse(AsmParser &parser, Type type) { + StringAttr name; + NamedAttrList dummyList; + int64_t fieldId = 0; + if (parser.parseLess() || parser.parseSymbolName(name, "name", dummyList) || + parser.parseComma() || parser.parseInteger(fieldId) || + parser.parseComma()) + return Attribute(); + + StringRef visibility; + auto loc = parser.getCurrentLocation(); + if (parser.parseOptionalKeyword(&visibility, + {"public", "private", "nested"})) { + parser.emitError(loc, "expected 'public', 'private', or 'nested'"); + return Attribute(); + } + auto visibilityAttr = parser.getBuilder().getStringAttr(visibility); + + if (parser.parseGreater()) return Attribute(); + + return parser.getChecked(parser.getContext(), name, + fieldId, visibilityAttr); +} + +void InnerSymPropertiesAttr::print(AsmPrinter &odsPrinter) const { + odsPrinter << "<@" << getName().getValue() << "," << getFieldID() << "," + << getSymVisibility().getValue() << ">"; +} + +LogicalResult InnerSymPropertiesAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + ::mlir::StringAttr name, uint64_t fieldID, + ::mlir::StringAttr symVisibility) { + if (!name || name.getValue().empty()) + return emitError() << "inner symbol cannot have empty name"; + return success(); +} + +StringAttr InnerSymAttr::getSymIfExists(uint64_t fieldId) const { + const auto *it = + llvm::find_if(getImpl()->props, [&](const InnerSymPropertiesAttr &p) { + return p.getFieldID() == fieldId; + }); + if (it != getProps().end()) return it->getName(); + return {}; +} + +InnerSymAttr InnerSymAttr::erase(uint64_t fieldID) const { + SmallVector syms(getProps()); + const auto *it = llvm::find_if(syms, [fieldID](InnerSymPropertiesAttr p) { + return p.getFieldID() == fieldID; + }); + assert(it != syms.end()); + syms.erase(it); + return InnerSymAttr::get(getContext(), syms); +} + +LogicalResult InnerSymAttr::walkSymbols( + llvm::function_ref callback) const { + for (auto p : getImpl()->props) + if (callback(p.getName()).failed()) return failure(); + return success(); +} + +Attribute InnerSymAttr::parse(AsmParser &parser, Type type) { + StringAttr sym; + NamedAttrList dummyList; + SmallVector names; + if (!parser.parseOptionalSymbolName(sym, "dummy", dummyList)) { + auto prop = parser.getChecked( + parser.getContext(), sym, 0, + StringAttr::get(parser.getContext(), "public")); + if (!prop) return {}; + names.push_back(prop); + } else if (parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Square, [&]() -> ParseResult { + InnerSymPropertiesAttr prop; + if (parser.parseCustomAttributeWithFallback( + prop, mlir::Type{}, "dummy", dummyList)) + return failure(); + + names.push_back(prop); + + return success(); + })) + return Attribute(); + + std::sort(names.begin(), names.end(), + [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) { + return a.getFieldID() < b.getFieldID(); + }); + + return InnerSymAttr::get(parser.getContext(), names); +} + +void InnerSymAttr::print(AsmPrinter &odsPrinter) const { + auto props = getProps(); + if (props.size() == 1 && + props[0].getSymVisibility().getValue().equals("public") && + props[0].getFieldID() == 0) { + odsPrinter << "@" << props[0].getName().getValue(); + return; + } + auto names = props.vec(); + + std::sort(names.begin(), names.end(), + [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) { + return a.getFieldID() < b.getFieldID(); + }); + odsPrinter << "["; + llvm::interleaveComma(names, odsPrinter, [&](InnerSymPropertiesAttr attr) { + attr.print(odsPrinter); + }); + odsPrinter << "]"; +} + +//===----------------------------------------------------------------------===// +// ParamDeclAttr +//===----------------------------------------------------------------------===// + +Attribute ParamDeclAttr::parse(AsmParser &p, Type trailing) { + std::string name; + Type type; + Attribute value; + // < "FOO" : i32 > : i32 + // < "FOO" : i32 = 0 > : i32 + // < "FOO" : none > + if (p.parseLess() || p.parseString(&name) || p.parseColonType(type)) + return Attribute(); + + if (succeeded(p.parseOptionalEqual())) { + if (p.parseAttribute(value, type)) return Attribute(); + } + + if (p.parseGreater()) return Attribute(); + + if (value) + return ParamDeclAttr::get(p.getContext(), + p.getBuilder().getStringAttr(name), type, value); + return ParamDeclAttr::get(name, type); +} + +void ParamDeclAttr::print(AsmPrinter &p) const { + p << "<" << getName() << ": " << getType(); + if (getValue()) { + p << " = "; + p.printAttributeWithoutType(getValue()); + } + p << ">"; +} + +//===----------------------------------------------------------------------===// +// ParamDeclRefAttr +//===----------------------------------------------------------------------===// + +Attribute ParamDeclRefAttr::parse(AsmParser &p, Type type) { + StringAttr name; + if (p.parseLess() || p.parseAttribute(name) || p.parseGreater() || + (!type && (p.parseColon() || p.parseType(type)))) + return Attribute(); + + return ParamDeclRefAttr::get(name, type); +} + +void ParamDeclRefAttr::print(AsmPrinter &p) const { + p << "<" << getName() << ">"; +} + +//===----------------------------------------------------------------------===// +// ParamVerbatimAttr +//===----------------------------------------------------------------------===// + +Attribute ParamVerbatimAttr::parse(AsmParser &p, Type type) { + StringAttr text; + if (p.parseLess() || p.parseAttribute(text) || p.parseGreater() || + (!type && (p.parseColon() || p.parseType(type)))) + return Attribute(); + + return ParamVerbatimAttr::get(p.getContext(), text, type); +} + +void ParamVerbatimAttr::print(AsmPrinter &p) const { + p << "<" << getValue() << ">"; +} + +//===----------------------------------------------------------------------===// +// ParamExprAttr +//===----------------------------------------------------------------------===// + +/// Given a binary function, if the two operands are known constant integers, +/// use the specified fold function to compute the result. +static TypedAttr foldBinaryOp( + ArrayRef operands, + llvm::function_ref calculate) { + assert(operands.size() == 2 && "binary operator always has two operands"); + if (auto lhs = operands[0].dyn_cast()) + if (auto rhs = operands[1].dyn_cast()) + return IntegerAttr::get(lhs.getType(), + calculate(lhs.getValue(), rhs.getValue())); + return {}; +} + +/// Given a unary function, if the operand is a known constant integer, +/// use the specified fold function to compute the result. +static TypedAttr foldUnaryOp( + ArrayRef operands, + llvm::function_ref calculate) { + assert(operands.size() == 1 && "unary operator always has one operand"); + if (auto intAttr = operands[0].dyn_cast()) + return IntegerAttr::get(intAttr.getType(), calculate(intAttr.getValue())); + return {}; +} + +/// If the specified attribute is a ParamExprAttr with the specified opcode, +/// return it. Otherwise return null. +static ParamExprAttr dyn_castPE(PEO opcode, Attribute value) { + if (auto expr = value.dyn_cast()) + if (expr.getOpcode() == opcode) return expr; + return {}; +} + +/// This implements a < comparison for two operands to an associative operation +/// imposing an ordering upon them. +/// +/// The ordering provided puts more complex things to the start of the list, +/// from left to right: +/// expressions :: verbatims :: decl.refs :: constant +/// +static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs) { + // Simplify the code below - we never have to care about exactly equal values. + if (lhs == rhs) return false; + + // All expressions are "less than" a constant, since they appear on the right. + if (rhs.isa()) { + // We don't bother to order constants w.r.t. each other since they will be + // folded - they can all compare equal. + return !lhs.isa(); + } + if (lhs.isa()) return false; + + // Next up are named parameters. + if (auto rhsParam = rhs.dyn_cast()) { + // Parameters are sorted lexically w.r.t. each other. + if (auto lhsParam = lhs.dyn_cast()) + return lhsParam.getName().getValue() < rhsParam.getName().getValue(); + // They otherwise appear on the right of other things. + return true; + } + if (lhs.isa()) return false; + + // Next up are verbatim parameters. + if (auto rhsParam = rhs.dyn_cast()) { + // Verbatims are sorted lexically w.r.t. each other. + if (auto lhsParam = lhs.dyn_cast()) + return lhsParam.getValue().getValue() < rhsParam.getValue().getValue(); + // They otherwise appear on the right of other things. + return true; + } + if (lhs.isa()) return false; + + // The only thing left are nested expressions. + auto lhsExpr = lhs.cast(), rhsExpr = rhs.cast(); + // Sort by the string form of the opcode, e.g. add, .. mul,... then xor. + if (lhsExpr.getOpcode() != rhsExpr.getOpcode()) + return stringifyPEO(lhsExpr.getOpcode()) < + stringifyPEO(rhsExpr.getOpcode()); + + // If they are the same opcode, then sort by arity: more complex to the left. + ArrayRef lhsOperands = lhsExpr.getOperands(), + rhsOperands = rhsExpr.getOperands(); + if (lhsOperands.size() != rhsOperands.size()) + return lhsOperands.size() > rhsOperands.size(); + + // We know the two subexpressions are different (they'd otherwise be pointer + // equivalent) so just go compare all of the elements. + for (size_t i = 0, e = lhsOperands.size(); i != e; ++i) { + if (paramExprOperandSortPredicate(lhsOperands[i], rhsOperands[i])) + return true; + if (paramExprOperandSortPredicate(rhsOperands[i], lhsOperands[i])) + return false; + } + + llvm_unreachable("expressions should never be equivalent"); + return false; +} + +/// Given a fully associative variadic integer operation, constant fold any +/// constant operands and move them to the right. If the whole expression is +/// constant, then return that, otherwise update the operands list. +static TypedAttr simplifyAssocOp( + PEO opcode, SmallVector &operands, + llvm::function_ref calculateFn, + llvm::function_ref identityConstantFn, + llvm::function_ref destructiveConstantFn = {}) { + auto type = operands[0].getType(); + assert(isHWIntegerType(type)); + if (operands.size() == 1) return operands[0]; + + // Flatten any of the same operation into the operand list: + // `(add x, (add y, z))` => `(add x, y, z)`. + for (size_t i = 0, e = operands.size(); i != e; ++i) { + if (auto subexpr = dyn_castPE(opcode, operands[i])) { + std::swap(operands[i], operands.back()); + operands.pop_back(); + --e; + --i; + operands.append(subexpr.getOperands().begin(), + subexpr.getOperands().end()); + } + } + + // Impose an ordering on the operands, pushing subexpressions to the left and + // constants to the right, with verbatims and parameters in the middle - but + // predictably ordered w.r.t. each other. + llvm::stable_sort(operands, paramExprOperandSortPredicate); + + // Merge any constants, they will appear at the back of the operand list now. + if (operands.back().isa()) { + while (operands.size() >= 2 && + operands[operands.size() - 2].isa()) { + APInt c1 = operands.pop_back_val().cast().getValue(); + APInt c2 = operands.pop_back_val().cast().getValue(); + auto resultConstant = IntegerAttr::get(type, calculateFn(c1, c2)); + operands.push_back(resultConstant); + } + + auto resultCst = operands.back().cast(); + + // If the resulting constant is the destructive constant (e.g. `x*0`), then + // return it. + if (destructiveConstantFn && destructiveConstantFn(resultCst.getValue())) + return resultCst; + + // Remove the constant back to our operand list if it is the identity + // constant for this operator (e.g. `x*1`) and there are other operands. + if (identityConstantFn(resultCst.getValue()) && operands.size() != 1) + operands.pop_back(); + } + + return operands.size() == 1 ? operands[0] : TypedAttr(); +} + +/// Analyze an operand to an add. If it is a multiplication by a constant (e.g. +/// `(a*b*42)` then split it into the non-constant and the constant portions +/// (e.g. `a*b` and `42`). Otherwise return the operand as the first value and +/// null as the second (standin for "multiplication by 1"). +static std::pair decomposeAddend(TypedAttr operand) { + if (auto mul = dyn_castPE(PEO::Mul, operand)) + if (auto cst = mul.getOperands().back().dyn_cast()) { + auto nonCst = ParamExprAttr::get(PEO::Mul, mul.getOperands().drop_back()); + return {nonCst, cst}; + } + return {operand, TypedAttr()}; +} + +static TypedAttr getOneOfType(Type type) { + return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), 1)); +} + +static TypedAttr simplifyAdd(SmallVector &operands) { + if (auto result = simplifyAssocOp( + PEO::Add, operands, [](auto a, auto b) { return a + b; }, + /*identityCst*/ [](auto cst) { return cst.isZero(); })) + return result; + + // Canonicalize the add by splitting all addends into their variable and + // constant factors. + SmallVector> decomposedOperands; + llvm::SmallDenseSet nonConstantParts; + for (auto &op : operands) { + decomposedOperands.push_back(decomposeAddend(op)); + + // Keep track of non-constant parts we've already seen. If we see multiple + // uses of the same value, then we can fold them together with a multiply. + // This handles things like `(a+b+a)` => `(a*2 + b)` and `(a*2 + b + a)` => + // `(a*3 + b)`. + if (!nonConstantParts.insert(decomposedOperands.back().first).second) { + // The thing we multiply will be the common expression. + TypedAttr mulOperand = decomposedOperands.back().first; + + // Find the index of the first occurrence. + size_t i = 0; + while (decomposedOperands[i].first != mulOperand) ++i; + // Remove both occurrences from the operand list. + operands.erase(operands.begin() + (&op - &operands[0])); + operands.erase(operands.begin() + i); + + auto type = mulOperand.getType(); + auto c1 = decomposedOperands[i].second, + c2 = decomposedOperands.back().second; + // Fill in missing constant multiplicands with 1. + if (!c1) c1 = getOneOfType(type); + if (!c2) c2 = getOneOfType(type); + // Re-add the "a"*(c1+c2) expression to the operand list and + // re-canonicalize. + auto constant = ParamExprAttr::get(PEO::Add, c1, c2); + auto mulCst = ParamExprAttr::get(PEO::Mul, mulOperand, constant); + operands.push_back(mulCst); + return ParamExprAttr::get(PEO::Add, operands); + } + } + + return {}; +} + +static TypedAttr simplifyMul(SmallVector &operands) { + if (auto result = simplifyAssocOp( + PEO::Mul, operands, [](auto a, auto b) { return a * b; }, + /*identityCst*/ [](auto cst) { return cst.isOne(); }, + /*destructiveCst*/ [](auto cst) { return cst.isZero(); })) + return result; + + // We always build a sum-of-products representation, so if we see an addition + // as a subexpr, we need to pull it out: (a+b)*c*d ==> (a*c*d + b*c*d). + for (size_t i = 0, e = operands.size(); i != e; ++i) { + if (auto subexpr = dyn_castPE(PEO::Add, operands[i])) { + // Pull the `c*d` operands out - it is whatever operands remain after + // removing the `(a+b)` term. + SmallVector mulOperands(operands.begin(), operands.end()); + mulOperands.erase(mulOperands.begin() + i); + + // Build each add operand. + SmallVector addOperands; + for (auto addOperand : subexpr.getOperands()) { + mulOperands.push_back(addOperand); + addOperands.push_back(ParamExprAttr::get(PEO::Mul, mulOperands)); + mulOperands.pop_back(); + } + // Canonicalize and form the add expression. + return ParamExprAttr::get(PEO::Add, addOperands); + } + } + + return {}; +} +static TypedAttr simplifyAnd(SmallVector &operands) { + return simplifyAssocOp( + PEO::And, operands, [](auto a, auto b) { return a & b; }, + /*identityCst*/ [](auto cst) { return cst.isAllOnes(); }, + /*destructiveCst*/ [](auto cst) { return cst.isZero(); }); +} + +static TypedAttr simplifyOr(SmallVector &operands) { + return simplifyAssocOp( + PEO::Or, operands, [](auto a, auto b) { return a | b; }, + /*identityCst*/ [](auto cst) { return cst.isZero(); }, + /*destructiveCst*/ [](auto cst) { return cst.isAllOnes(); }); +} + +static TypedAttr simplifyXor(SmallVector &operands) { + return simplifyAssocOp( + PEO::Xor, operands, [](auto a, auto b) { return a ^ b; }, + /*identityCst*/ [](auto cst) { return cst.isZero(); }); +} + +static TypedAttr simplifyShl(SmallVector &operands) { + assert(isHWIntegerType(operands[0].getType())); + + if (auto rhs = operands[1].dyn_cast()) { + // Constant fold simple integers. + if (auto lhs = operands[0].dyn_cast()) + return IntegerAttr::get(lhs.getType(), + lhs.getValue().shl(rhs.getValue())); + + // Canonicalize `x << cst` => `x * (1< &operands) { + assert(isHWIntegerType(operands[0].getType())); + // Implement support for identities like `x >> 0`. + if (auto rhs = operands[1].dyn_cast()) + if (rhs.getValue().isZero()) return operands[0]; + + return foldBinaryOp(operands, [](auto a, auto b) { return a.lshr(b); }); +} + +static TypedAttr simplifyShrS(SmallVector &operands) { + assert(isHWIntegerType(operands[0].getType())); + // Implement support for identities like `x >> 0`. + if (auto rhs = operands[1].dyn_cast()) + if (rhs.getValue().isZero()) return operands[0]; + + return foldBinaryOp(operands, [](auto a, auto b) { return a.ashr(b); }); +} + +static TypedAttr simplifyDivU(SmallVector &operands) { + assert(isHWIntegerType(operands[0].getType())); + // Implement support for identities like `x/1`. + if (auto rhs = operands[1].dyn_cast()) + if (rhs.getValue().isOne()) return operands[0]; + + return foldBinaryOp(operands, [](auto a, auto b) { return a.udiv(b); }); +} + +static TypedAttr simplifyDivS(SmallVector &operands) { + assert(isHWIntegerType(operands[0].getType())); + // Implement support for identities like `x/1`. + if (auto rhs = operands[1].dyn_cast()) + if (rhs.getValue().isOne()) return operands[0]; + + return foldBinaryOp(operands, [](auto a, auto b) { return a.sdiv(b); }); +} + +static TypedAttr simplifyModU(SmallVector &operands) { + assert(isHWIntegerType(operands[0].getType())); + // Implement support for identities like `x%1`. + if (auto rhs = operands[1].dyn_cast()) + if (rhs.getValue().isOne()) return IntegerAttr::get(rhs.getType(), 0); + + return foldBinaryOp(operands, [](auto a, auto b) { return a.urem(b); }); +} + +static TypedAttr simplifyModS(SmallVector &operands) { + assert(isHWIntegerType(operands[0].getType())); + // Implement support for identities like `x%1`. + if (auto rhs = operands[1].dyn_cast()) + if (rhs.getValue().isOne()) return IntegerAttr::get(rhs.getType(), 0); + + return foldBinaryOp(operands, [](auto a, auto b) { return a.srem(b); }); +} + +static TypedAttr simplifyCLog2(SmallVector &operands) { + assert(isHWIntegerType(operands[0].getType())); + return foldUnaryOp(operands, [](auto a) { + // Following the Verilog spec, clog2(0) is 0 + return APInt(a.getBitWidth(), a == 0 ? 0 : a.ceilLogBase2()); + }); +} + +static TypedAttr simplifyStrConcat(SmallVector &operands) { + // Combine all adjacent strings. + SmallVector newOperands; + SmallVector stringsToCombine; + auto combineAndPush = [&]() { + if (stringsToCombine.empty()) return; + // Concatenate buffered strings, push to ops. + SmallString<32> newString; + for (auto part : stringsToCombine) newString.append(part.getValue()); + newOperands.push_back( + StringAttr::get(stringsToCombine[0].getContext(), newString)); + stringsToCombine.clear(); + }; + + for (TypedAttr op : operands) { + if (auto strOp = op.dyn_cast()) { + // Queue up adjacent strings. + stringsToCombine.push_back(strOp); + } else { + combineAndPush(); + newOperands.push_back(op); + } + } + combineAndPush(); + + assert(!newOperands.empty()); + if (newOperands.size() == 1) return newOperands[0]; + if (newOperands.size() < operands.size()) + return ParamExprAttr::get(PEO::StrConcat, newOperands); + return {}; +} + +/// Build a parameter expression. This automatically canonicalizes and +/// folds, so it may not necessarily return a ParamExprAttr. +TypedAttr ParamExprAttr::get(PEO opcode, ArrayRef operandsIn) { + assert(!operandsIn.empty() && "Cannot have expr with no operands"); + // All operands must have the same type, which is the type of the result. + auto type = operandsIn.front().getType(); + assert(llvm::all_of(operandsIn.drop_front(), + [&](auto op) { return op.getType() == type; })); + + SmallVector operands(operandsIn.begin(), operandsIn.end()); + + // Verify and canonicalize parameter expressions. + TypedAttr result; + switch (opcode) { + case PEO::Add: + result = simplifyAdd(operands); + break; + case PEO::Mul: + result = simplifyMul(operands); + break; + case PEO::And: + result = simplifyAnd(operands); + break; + case PEO::Or: + result = simplifyOr(operands); + break; + case PEO::Xor: + result = simplifyXor(operands); + break; + case PEO::Shl: + result = simplifyShl(operands); + break; + case PEO::ShrU: + result = simplifyShrU(operands); + break; + case PEO::ShrS: + result = simplifyShrS(operands); + break; + case PEO::DivU: + result = simplifyDivU(operands); + break; + case PEO::DivS: + result = simplifyDivS(operands); + break; + case PEO::ModU: + result = simplifyModU(operands); + break; + case PEO::ModS: + result = simplifyModS(operands); + break; + case PEO::CLog2: + result = simplifyCLog2(operands); + break; + case PEO::StrConcat: + result = simplifyStrConcat(operands); + break; + } + + // If we folded to an operand, return it. + if (result) return result; + + return Base::get(operands[0].getContext(), opcode, operands, type); +} + +Attribute ParamExprAttr::parse(AsmParser &p, Type type) { + // We require an opcode suffix like `#hw.param.expr.add`, we don't allow + // parsing a plain `#hw.param.expr` on its own. + p.emitError(p.getNameLoc(), "#hw.param.expr should have opcode suffix"); + return {}; +} + +/// Internal method used for .mlir file parsing when parsing the +/// "#hw.param.expr.mul" form of the attribute. +static Attribute parseParamExprWithOpcode(StringRef opcodeStr, + DialectAsmParser &p, Type type) { + SmallVector operands; + if (p.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { + operands.push_back({}); + return p.parseAttribute(operands.back(), type); + })) + return {}; + + std::optional opcode = symbolizePEO(opcodeStr); + if (!opcode.has_value()) { + p.emitError(p.getNameLoc(), "unknown parameter expr operator name"); + return {}; + } + + return ParamExprAttr::get(*opcode, operands); +} + +void ParamExprAttr::print(AsmPrinter &p) const { + p << "." << stringifyPEO(getOpcode()) << '<'; + llvm::interleaveComma(getOperands(), p.getStream(), + [&](Attribute op) { p.printAttributeWithoutType(op); }); + p << '>'; +} + +// Replaces any ParamDeclRefAttr within a parametric expression with its +// corresponding value from the map of provided parameters. +static FailureOr replaceDeclRefInExpr( + Location loc, const std::map ¶meters, + Attribute paramAttr) { + if (paramAttr.dyn_cast()) { + // Nothing to do, constant value. + return paramAttr; + } + if (auto paramRefAttr = paramAttr.dyn_cast()) { + // Get the value from the provided parameters. + auto it = parameters.find(paramRefAttr.getName().str()); + if (it == parameters.end()) + return emitError(loc) + << "Could not find parameter " << paramRefAttr.getName().str() + << " in the provided parameters for the expression!"; + return it->second; + } + if (auto paramExprAttr = paramAttr.dyn_cast()) { + // Recurse into all operands of the expression. + llvm::SmallVector replacedOperands; + for (auto operand : paramExprAttr.getOperands()) { + auto res = replaceDeclRefInExpr(loc, parameters, operand); + if (failed(res)) return {failure()}; + replacedOperands.push_back(res->cast()); + } + return { + hw::ParamExprAttr::get(paramExprAttr.getOpcode(), replacedOperands)}; + } + llvm_unreachable("Unhandled parametric attribute"); + return {}; +} + +FailureOr hw::evaluateParametricAttr(Location loc, + ArrayAttr parameters, + Attribute paramAttr) { + // Create a map of the provided parameters for faster lookup. + std::map parameterMap; + for (auto param : parameters) { + auto paramDecl = param.cast(); + parameterMap[paramDecl.getName().str()] = paramDecl.getValue(); + } + + // First, replace any ParamDeclRefAttr in the expression with its + // corresponding value in 'parameters'. + auto paramAttrRes = replaceDeclRefInExpr(loc, parameterMap, paramAttr); + if (failed(paramAttrRes)) return {failure()}; + paramAttr = *paramAttrRes; + + // Then, evaluate the parametric attribute. + if (paramAttr.isa()) + return paramAttr.cast(); + if (auto paramExprAttr = paramAttr.dyn_cast()) { + // Since any ParamDeclRefAttr was replaced within the expression, + // we re-evaluate the expression through the existing ParamExprAttr + // canonicalizer. + return ParamExprAttr::get(paramExprAttr.getOpcode(), + paramExprAttr.getOperands()); + } + + llvm_unreachable("Unhandled parametric attribute"); + return TypedAttr(); +} + +template +FailureOr evaluateParametricArrayType(Location loc, ArrayAttr parameters, + TArray arrayType) { + auto size = evaluateParametricAttr(loc, parameters, arrayType.getSizeAttr()); + if (failed(size)) return failure(); + auto elementType = + evaluateParametricType(loc, parameters, arrayType.getElementType()); + if (failed(elementType)) return failure(); + + // If the size was evaluated to a constant, use a 64-bit integer + // attribute version of it + if (auto intAttr = size->template dyn_cast()) + return TArray::get( + arrayType.getContext(), *elementType, + IntegerAttr::get(IntegerType::get(arrayType.getContext(), 64), + intAttr.getValue().getSExtValue())); + + // Otherwise parameter references are still involved + return TArray::get(arrayType.getContext(), *elementType, *size); +} + +FailureOr hw::evaluateParametricType(Location loc, ArrayAttr parameters, + Type type) { + return llvm::TypeSwitch>(type) + .Case([&](hw::IntType t) -> FailureOr { + auto evaluatedWidth = + evaluateParametricAttr(loc, parameters, t.getWidth()); + if (failed(evaluatedWidth)) return {failure()}; + + // If the width was evaluated to a constant, return an `IntegerType` + if (auto intAttr = evaluatedWidth->dyn_cast()) + return {IntegerType::get(type.getContext(), + intAttr.getValue().getSExtValue())}; + + // Otherwise parameter references are still involved + return hw::IntType::get(evaluatedWidth->cast()); + }) + .Case( + [&](auto arrayType) -> FailureOr { + return evaluateParametricArrayType(loc, parameters, arrayType); + }) + .Default([&](auto) { return type; }); +} + +// Returns true if any part of this parametric attribute contains a reference +// to a parameter declaration. +static bool isParamAttrWithParamRef(Attribute expr) { + return llvm::TypeSwitch(expr) + .Case([](ParamExprAttr attr) { + return llvm::any_of(attr.getOperands(), isParamAttrWithParamRef); + }) + .Case([](ParamDeclRefAttr) { return true; }) + .Default([](auto) { return false; }); +} + +bool hw::isParametricType(mlir::Type t) { + return llvm::TypeSwitch(t) + .Case( + [&](hw::IntType t) { return isParamAttrWithParamRef(t.getWidth()); }) + .Case([&](auto arrayType) { + return isParametricType(arrayType.getElementType()) || + isParamAttrWithParamRef(arrayType.getSizeAttr()); + }) + .Default([](auto) { return false; }); +} diff --git a/third_party/circt/lib/Dialect/HW/HWDialect.cpp b/third_party/circt/lib/Dialect/HW/HWDialect.cpp new file mode 100644 index 000000000..618891a9c --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWDialect.cpp @@ -0,0 +1,116 @@ +//===- HWDialect.cpp - Implement the HW dialect ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the HW dialect. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWDialect.h" + +#include "circt/Dialect/HW/HWAttributes.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace circt; +using namespace hw; + +//===----------------------------------------------------------------------===// +// Dialect specification. +//===----------------------------------------------------------------------===// + +// Pull in the dialect definition. +#include "circt/Dialect/HW/HWDialect.cpp.inc" + +namespace { + +// We implement the OpAsmDialectInterface so that HW dialect operations +// automatically interpret the name attribute on operations as their SSA name. +struct HWOpAsmDialectInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + + /// Get a special name to use when printing the given operation. See + /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. + void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const {} +}; +} // end anonymous namespace + +namespace { +/// This class defines the interface for handling inlining with HW operations. +struct HWInlinerInterface : public mlir::DialectInlinerInterface { + using mlir::DialectInlinerInterface::DialectInlinerInterface; + + bool isLegalToInline(Operation *op, Region *, bool, + mlir::IRMapping &) const final { + return isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op); + } + + bool isLegalToInline(Region *, Region *, bool, + mlir::IRMapping &) const final { + return false; + } +}; +} // end anonymous namespace + +void HWDialect::initialize() { + // Register types and attributes. + registerTypes(); + registerAttributes(); + + // Register operations. + addOperations< +#define GET_OP_LIST +#include "circt/Dialect/HW/HW.cpp.inc" + >(); + + // Register interface implementations. + addInterfaces(); +} + +// Registered hook to materialize a single constant operation from a given +/// attribute value with the desired resultant type. This method should use +/// the provided builder to create the operation without changing the +/// insertion position. The generated operation is expected to be constant +/// like, i.e. single result, zero operands, non side-effecting, etc. On +/// success, this hook should return the value generated to represent the +/// constant value. Otherwise, it should return null on failure. +Operation *HWDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + // Integer constants can materialize into hw.constant + if (auto intType = type.dyn_cast()) + if (auto attrValue = value.dyn_cast()) + return builder.create(loc, type, attrValue); + + // Aggregate constants. + if (auto arrayAttr = value.dyn_cast()) { + if (type.isa()) + return builder.create(loc, type, arrayAttr); + } + + // Parameter expressions materialize into hw.param.value. + auto parentOp = builder.getBlock()->getParentOp(); + auto curModule = dyn_cast(parentOp); + if (!curModule) curModule = parentOp->getParentOfType(); + if (curModule && isValidParameterExpression(value, curModule)) + return builder.create(loc, type, value); + + return nullptr; +} + +// Provide implementations for the enums we use. +#include "circt/Dialect/HW/HWEnums.cpp.inc" diff --git a/third_party/circt/lib/Dialect/HW/HWInstanceGraph.cpp b/third_party/circt/lib/Dialect/HW/HWInstanceGraph.cpp new file mode 100644 index 000000000..087bfa8b9 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWInstanceGraph.cpp @@ -0,0 +1,33 @@ +//===- HWInstanceGraph.cpp - Instance Graph ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWInstanceGraph.h" + +using namespace circt; +using namespace hw; + +InstanceGraph::InstanceGraph(Operation *operation) + : igraph::InstanceGraph(operation) { + for (auto &node : nodes) + if (cast(node.getModule().getOperation()).isPublic()) + entry.addInstance({}, &node); +} + +igraph::InstanceGraphNode *InstanceGraph::addHWModule(HWModuleLike module) { + auto *node = igraph::InstanceGraph::addModule( + cast(module.getOperation())); + if (module.isPublic()) entry.addInstance({}, node); + return node; +} + +void InstanceGraph::erase(igraph::InstanceGraphNode *node) { + for (auto *instance : llvm::make_early_inc_range(entry)) { + if (instance->getTarget() == node) instance->erase(); + } + igraph::InstanceGraph::erase(node); +} diff --git a/third_party/circt/lib/Dialect/HW/HWModuleOpInterface.cpp b/third_party/circt/lib/Dialect/HW/HWModuleOpInterface.cpp new file mode 100644 index 000000000..e0d8c7ed0 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWModuleOpInterface.cpp @@ -0,0 +1,88 @@ +//===- HWModuleOpInterface.cpp.h - Implement HWModuleLike ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements HWModuleLike related functionality. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWOpInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace circt; +using namespace hw; + +//===----------------------------------------------------------------------===// +// HWModuleLike Signature Conversion +//===----------------------------------------------------------------------===// + +static LogicalResult convertModuleOpTypes(HWModuleLike modOp, + const TypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + ModuleType type = modOp.getHWModuleType(); + if (!type) return failure(); + + // Convert the original port types. + // Update the module signature in-place. + SmallVector newPorts; + TypeConverter::SignatureConversion result(type.getNumInputs()); + unsigned atInput = 0; + unsigned curInputs = 0; + for (auto &p : type.getPorts()) { + if (p.dir == ModulePort::Direction::Output) { + SmallVector newResults; + if (failed(typeConverter.convertType(p.type, newResults))) + return failure(); + for (auto np : newResults) newPorts.push_back({p.name, np, p.dir}); + } else { + if (failed(typeConverter.convertSignatureArg( + atInput++, + /* inout ports need to be wrapped in the appropriate type */ + p.dir == ModulePort::Direction::Input ? p.type + : InOutType::get(p.type), + result))) + return failure(); + for (auto np : result.getConvertedTypes().drop_front(curInputs)) + newPorts.push_back({p.name, np, p.dir}); + curInputs = result.getConvertedTypes().size(); + } + } + + if (failed(rewriter.convertRegionTypes(&modOp->getRegion(0), typeConverter, + &result))) + return failure(); + + auto newType = ModuleType::get(rewriter.getContext(), newPorts); + rewriter.updateRootInPlace(modOp, [&] { modOp.setHWModuleType(newType); }); + + return success(); +} + +/// Create a default conversion pattern that rewrites the type signature of a +/// FunctionOpInterface op. This only supports ops which use FunctionType to +/// represent their type. +namespace { +struct HWModuleLikeSignatureConversion : public ConversionPattern { + HWModuleLikeSignatureConversion(StringRef moduleLikeOpName, MLIRContext *ctx, + const TypeConverter &converter) + : ConversionPattern(converter, moduleLikeOpName, /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef /*operands*/, + ConversionPatternRewriter &rewriter) const override { + HWModuleLike modOp = cast(op); + return convertModuleOpTypes(modOp, *typeConverter, rewriter); + } +}; +} // namespace + +void circt::hw::populateHWModuleLikeTypeConversionPattern( + StringRef moduleLikeOpName, RewritePatternSet &patterns, + TypeConverter &converter) { + patterns.add( + moduleLikeOpName, patterns.getContext(), converter); +} diff --git a/third_party/circt/lib/Dialect/HW/HWOpInterfaces.cpp b/third_party/circt/lib/Dialect/HW/HWOpInterfaces.cpp new file mode 100644 index 000000000..0245a2d81 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWOpInterfaces.cpp @@ -0,0 +1,99 @@ +//===- HWOpInterfaces.cpp - Implement the HW op interfaces ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implement the HW operation interfaces. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWOpInterfaces.h" + +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWTypeInterfaces.h" +#include "circt/Support/LLVM.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace circt; + +LogicalResult hw::verifyInnerSymAttr(InnerSymbolOpInterface op) { + auto innerSym = op.getInnerSymAttr(); + // If does not have any inner sym then ignore. + if (!innerSym) return success(); + + if (innerSym.empty()) + return op->emitOpError("has empty list of inner symbols"); + + if (!op.supportsPerFieldSymbols()) { + // The inner sym can only be specified on fieldID=0. + if (innerSym.size() > 1 || !innerSym.getSymName()) { + op->emitOpError("does not support per-field inner symbols"); + return failure(); + } + return success(); + } + + auto result = op.getTargetResult(); + // If op supports per-field symbols, but does not have a target result, + // its up to the operation to verify itself. + // (there are no uses for this presently, but be open to this anyway.) + if (!result) return success(); + auto resultType = result.getType(); + auto maxFields = FieldIdImpl::getMaxFieldID(resultType); + llvm::SmallBitVector indices(maxFields + 1); + llvm::SmallPtrSet symNames; + // Ensure fieldID and symbol names are unique. + auto uniqSyms = [&](InnerSymPropertiesAttr p) { + if (maxFields < p.getFieldID()) { + op->emitOpError("field id:'" + Twine(p.getFieldID()) + + "' is greater than the maximum field id:'" + + Twine(maxFields) + "'"); + return false; + } + if (indices.test(p.getFieldID())) { + op->emitOpError("cannot assign multiple symbol names to the field id:'" + + Twine(p.getFieldID()) + "'"); + return false; + } + indices.set(p.getFieldID()); + auto it = symNames.insert(p.getName()); + if (!it.second) { + op->emitOpError("cannot reuse symbol name:'" + p.getName().getValue() + + "'"); + return false; + } + return true; + }; + + if (!llvm::all_of(innerSym.getProps(), uniqSyms)) return failure(); + + return success(); +} + +raw_ostream &circt::hw::operator<<(raw_ostream &printer, PortInfo port) { + StringRef dirstr; + switch (port.dir) { + case ModulePort::Direction::Input: + dirstr = "input"; + break; + case ModulePort::Direction::Output: + dirstr = "output"; + break; + case ModulePort::Direction::InOut: + dirstr = "inout"; + break; + } + printer << dirstr << " " << port.name << " : " << port.type << " (argnum " + << port.argNum << ", sym " << port.sym << ", loc " << port.loc + << ", args " << port.attrs << ")"; + return printer; +} + +#include "circt/Dialect/HW/HWOpInterfaces.cpp.inc" diff --git a/third_party/circt/lib/Dialect/HW/HWOps.cpp b/third_party/circt/lib/Dialect/HW/HWOps.cpp new file mode 100644 index 000000000..9270a54e7 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWOps.cpp @@ -0,0 +1,3376 @@ +//===- HWOps.cpp - Implement the HW operations ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implement the HW ops. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWOps.h" + +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/CustomDirectiveImpl.h" +#include "circt/Dialect/HW/HWAttributes.h" +#include "circt/Dialect/HW/HWSymCache.h" +#include "circt/Dialect/HW/HWVisitors.h" +#include "circt/Dialect/HW/InstanceImplementation.h" +#include "circt/Dialect/HW/ModuleImplementation.h" +#include "circt/Support/CustomDirectiveImpl.h" +#include "circt/Support/Namespace.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/FunctionImplementation.h" + +using namespace circt; +using namespace hw; +using mlir::TypedAttr; + +/// Flip a port direction. +ModulePort::Direction hw::flip(ModulePort::Direction direction) { + switch (direction) { + case ModulePort::Direction::Input: + return ModulePort::Direction::Output; + case ModulePort::Direction::Output: + return ModulePort::Direction::Input; + case ModulePort::Direction::InOut: + return ModulePort::Direction::InOut; + } + llvm_unreachable("unknown PortDirection"); +} + +bool hw::isValidIndexBitWidth(Value index, Value array) { + hw::ArrayType arrayType = + hw::getCanonicalType(array.getType()).dyn_cast(); + assert(arrayType && "expected array type"); + unsigned indexWidth = index.getType().getIntOrFloatBitWidth(); + auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements()); + return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1) + : indexWidth == requiredWidth; +} + +/// Return true if the specified operation is a combinational logic op. +bool hw::isCombinational(Operation *op) { + struct IsCombClassifier : public TypeOpVisitor { + bool visitInvalidTypeOp(Operation *op) { return false; } + bool visitUnhandledTypeOp(Operation *op) { return true; } + }; + + return (op->getDialect() && op->getDialect()->getNamespace() == "comb") || + IsCombClassifier().dispatchTypeOpVisitor(op); +} + +static Value foldStructExtract(Operation *inputOp, StringRef field) { + // A struct extract of a struct create -> corresponding struct create operand. + if (auto structCreate = dyn_cast_or_null(inputOp)) { + auto ty = type_cast(structCreate.getResult().getType()); + if (auto idx = ty.getFieldIndex(field)) + return structCreate.getOperand(*idx); + return {}; + } + // Extracting injected field -> corresponding field + if (auto structInject = dyn_cast_or_null(inputOp)) { + if (structInject.getField() != field) return {}; + return structInject.getNewValue(); + } + return {}; +} + +/// Get a special name to use when printing the entry block arguments of the +/// region contained by an operation in this dialect. +static void getAsmBlockArgumentNamesImpl(mlir::Region ®ion, + OpAsmSetValueNameFn setNameFn) { + if (region.empty()) return; + // Assign port names to the bbargs. + // auto *module = region.getParentOp(); + auto module = cast(region.getParentOp()); + + auto *block = ®ion.front(); + for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) { + auto name = module.getInputName(i); + if (!name.empty()) setNameFn(block->getArgument(i), name); + } +} + +enum class Delimiter { + None, + Paren, // () enclosed list + OptionalLessGreater, // <> enclosed list or absent +}; + +/// Check parameter specified by `value` to see if it is valid according to the +/// module's parameters. If not, emit an error to the diagnostic provided as an +/// argument to the lambda 'instanceError' and return failure, otherwise return +/// success. +/// +/// If `disallowParamRefs` is true, then parameter references are not allowed. +LogicalResult hw::checkParameterInContext( + Attribute value, ArrayAttr moduleParameters, + const instance_like_impl::EmitErrorFn &instanceError, + bool disallowParamRefs) { + // Literals are always ok. Their types are already known to match + // expectations. + if (value.isa() || value.isa() || + value.isa() || value.isa()) + return success(); + + // Check both subexpressions of an expression. + if (auto expr = value.dyn_cast()) { + for (auto op : expr.getOperands()) + if (failed(checkParameterInContext(op, moduleParameters, instanceError, + disallowParamRefs))) + return failure(); + return success(); + } + + // Parameter references need more analysis to make sure they are valid within + // this module. + if (auto parameterRef = value.dyn_cast()) { + auto nameAttr = parameterRef.getName(); + + // Don't allow references to parameters from the default values of a + // parameter list. + if (disallowParamRefs) { + instanceError([&](auto &diag) { + diag << "parameter " << nameAttr + << " cannot be used as a default value for a parameter"; + return false; + }); + return failure(); + } + + // Find the corresponding attribute in the module. + for (auto param : moduleParameters) { + auto paramAttr = param.cast(); + if (paramAttr.getName() != nameAttr) continue; + + // If the types match then the reference is ok. + if (paramAttr.getType() == parameterRef.getType()) return success(); + + instanceError([&](auto &diag) { + diag << "parameter " << nameAttr << " used with type " + << parameterRef.getType() << "; should have type " + << paramAttr.getType(); + return true; + }); + return failure(); + } + + instanceError([&](auto &diag) { + diag << "use of unknown parameter " << nameAttr; + return true; + }); + return failure(); + } + + instanceError([&](auto &diag) { + diag << "invalid parameter value " << value; + return false; + }); + return failure(); +} + +/// Check parameter specified by `value` to see if it is valid within the scope +/// of the specified module `module`. If not, emit an error at the location of +/// `usingOp` and return failure, otherwise return success. If `usingOp` is +/// null, then no diagnostic is generated. +/// +/// If `disallowParamRefs` is true, then parameter references are not allowed. +LogicalResult hw::checkParameterInContext(Attribute value, Operation *module, + Operation *usingOp, + bool disallowParamRefs) { + instance_like_impl::EmitErrorFn emitError = + [&](const std::function &fn) { + if (usingOp) { + auto diag = usingOp->emitOpError(); + if (fn(diag)) + diag.attachNote(module->getLoc()) << "module declared here"; + } + }; + + return checkParameterInContext(value, + module->getAttrOfType("parameters"), + emitError, disallowParamRefs); +} + +/// Return true if the specified attribute tree is made up of nodes that are +/// valid in a parameter expression. +bool hw::isValidParameterExpression(Attribute attr, Operation *module) { + return succeeded(checkParameterInContext(attr, module, nullptr, false)); +} + +/// Return the name of the arg attributes list used for both modules and +/// instances. Normally we'd use the FunctionOpInterface for this, but both +/// modules and instances use the same attribute name, and instances don't +/// implement that interface. +StringAttr getArgAttrsName(MLIRContext *context) { + return HWModuleOp::getArgAttrsAttrName( + mlir::OperationName(HWModuleOp::getOperationName(), context)); +} + +/// Return the name of the result attributes list used for both modules and +/// instances. Normally we'd use the FunctionOpInterface for this, but both +/// modules and instances use the same attribute name, and instances don't +/// implement that interface. +StringAttr getResAttrsName(MLIRContext *context) { + return HWModuleOp::getResAttrsAttrName( + mlir::OperationName(HWModuleOp::getOperationName(), context)); +} + +HWModulePortAccessor::HWModulePortAccessor(Location loc, + const ModulePortInfo &info, + Region &bodyRegion) + : info(info) { + inputArgs.resize(info.sizeInputs()); + for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) { + inputIdx[info.at(i).name.str()] = i; + inputArgs[i] = barg; + } + + outputOperands.resize(info.sizeOutputs()); + for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) { + outputIdx[outputInfo.name.str()] = i; + } +} + +void HWModulePortAccessor::setOutput(unsigned i, Value v) { + assert(outputOperands.size() > i && "invalid output index"); + assert(outputOperands[i] == Value() && "output already set"); + outputOperands[i] = v; +} + +Value HWModulePortAccessor::getInput(unsigned i) { + assert(inputArgs.size() > i && "invalid input index"); + return inputArgs[i]; +} +Value HWModulePortAccessor::getInput(StringRef name) { + return getInput(inputIdx.find(name.str())->second); +} +void HWModulePortAccessor::setOutput(StringRef name, Value v) { + setOutput(outputIdx.find(name.str())->second, v); +} + +//===----------------------------------------------------------------------===// +// ConstantOp +//===----------------------------------------------------------------------===// + +void ConstantOp::print(OpAsmPrinter &p) { + p << " "; + p.printAttribute(getValueAttr()); + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); +} + +ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { + IntegerAttr valueAttr; + + if (parser.parseAttribute(valueAttr, "value", result.attributes) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + result.addTypes(valueAttr.getType()); + return success(); +} + +LogicalResult ConstantOp::verify() { + // If the result type has a bitwidth, then the attribute must match its width. + if (getValue().getBitWidth() != getType().cast().getWidth()) + return emitError( + "hw.constant attribute bitwidth doesn't match return type"); + + return success(); +} + +/// Build a ConstantOp from an APInt, infering the result type from the +/// width of the APInt. +void ConstantOp::build(OpBuilder &builder, OperationState &result, + const APInt &value) { + auto type = IntegerType::get(builder.getContext(), value.getBitWidth()); + auto attr = builder.getIntegerAttr(type, value); + return build(builder, result, type, attr); +} + +/// Build a ConstantOp from an APInt, infering the result type from the +/// width of the APInt. +void ConstantOp::build(OpBuilder &builder, OperationState &result, + IntegerAttr value) { + return build(builder, result, value.getType(), value); +} + +/// This builder allows construction of small signed integers like 0, 1, -1 +/// matching a specified MLIR IntegerType. This shouldn't be used for general +/// constant folding because it only works with values that can be expressed in +/// an int64_t. Use APInt's instead. +void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type, + int64_t value) { + auto numBits = type.cast().getWidth(); + build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true)); +} + +void ConstantOp::getAsmResultNames( + function_ref setNameFn) { + auto intTy = getType(); + auto intCst = getValue(); + + // Sugar i1 constants with 'true' and 'false'. + if (intTy.cast().getWidth() == 1) + return setNameFn(getResult(), intCst.isZero() ? "false" : "true"); + + // Otherwise, build a complex name with the value and type. + SmallVector specialNameBuffer; + llvm::raw_svector_ostream specialName(specialNameBuffer); + specialName << 'c' << intCst << '_' << intTy; + setNameFn(getResult(), specialName.str()); +} + +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getValueAttr(); +} + +//===----------------------------------------------------------------------===// +// WireOp +//===----------------------------------------------------------------------===// + +/// Check whether an operation has any additional attributes set beyond its +/// standard list of attributes returned by `getAttributeNames`. +template +static bool hasAdditionalAttributes(Op op, + ArrayRef ignoredAttrs = {}) { + auto names = op.getAttributeNames(); + llvm::SmallDenseSet nameSet; + nameSet.reserve(names.size() + ignoredAttrs.size()); + nameSet.insert(names.begin(), names.end()); + nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end()); + return llvm::any_of(op->getAttrs(), [&](auto namedAttr) { + return !nameSet.contains(namedAttr.getName()); + }); +} + +void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + // If the wire has an optional 'name' attribute, use it. + auto nameAttr = (*this)->getAttrOfType("name"); + if (!nameAttr.getValue().empty()) setNameFn(getResult(), nameAttr.getValue()); +} + +std::optional WireOp::getTargetResultIndex() { return 0; } + +OpFoldResult WireOp::fold(FoldAdaptor adaptor) { + // If the wire has no additional attributes, no name, and no symbol, just + // forward its input. + if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() && + !getInnerSymAttr()) + return getInput(); + return {}; +} + +LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) { + // Block if the wire has any attributes. + if (hasAdditionalAttributes(wire, {"sv.namehint"})) return failure(); + + // If the wire has a symbol, then we can't delete it. + if (wire.getInnerSymAttr()) return failure(); + + // If the wire has a name or an `sv.namehint` attribute, propagate it as an + // `sv.namehint` to the expression. + if (auto *inputOp = wire.getInput().getDefiningOp()) { + auto name = wire.getNameAttr(); + if (!name || name.getValue().empty()) + name = wire->getAttrOfType("sv.namehint"); + if (name) + rewriter.updateRootInPlace( + inputOp, [&] { inputOp->setAttr("sv.namehint", name); }); + } + + rewriter.replaceOp(wire, wire.getInput()); + return success(); +} + +//===----------------------------------------------------------------------===// +// AggregateConstantOp +//===----------------------------------------------------------------------===// + +static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) { + // If this is a type alias, get the underlying type. + if (auto typeAlias = type.dyn_cast()) + type = typeAlias.getCanonicalType(); + + if (auto structType = type.dyn_cast()) { + auto arrayAttr = attr.dyn_cast(); + if (!arrayAttr) + return op->emitOpError("expected array attribute for constant of type ") + << type; + for (auto [attr, fieldInfo] : + llvm::zip(arrayAttr.getValue(), structType.getElements())) { + if (failed(checkAttributes(op, attr, fieldInfo.type))) return failure(); + } + } else if (auto arrayType = type.dyn_cast()) { + auto arrayAttr = attr.dyn_cast(); + if (!arrayAttr) + return op->emitOpError("expected array attribute for constant of type ") + << type; + auto elementType = arrayType.getElementType(); + for (auto attr : arrayAttr.getValue()) { + if (failed(checkAttributes(op, attr, elementType))) return failure(); + } + } else if (auto arrayType = type.dyn_cast()) { + auto arrayAttr = attr.dyn_cast(); + if (!arrayAttr) + return op->emitOpError("expected array attribute for constant of type ") + << type; + auto elementType = arrayType.getElementType(); + for (auto attr : arrayAttr.getValue()) { + if (failed(checkAttributes(op, attr, elementType))) return failure(); + } + } else if (auto enumType = type.dyn_cast()) { + auto stringAttr = attr.dyn_cast(); + if (!stringAttr) + return op->emitOpError("expected string attribute for constant of type ") + << type; + } else if (auto intType = type.dyn_cast()) { + // Check the attribute kind is correct. + auto intAttr = attr.dyn_cast(); + if (!intAttr) + return op->emitOpError("expected integer attribute for constant of type ") + << type; + // Check the bitwidth is correct. + if (intAttr.getValue().getBitWidth() != intType.getWidth()) + return op->emitOpError( + "hw.constant attribute bitwidth " + "doesn't match return type"); + } else { + return op->emitOpError("unknown element type") << type; + } + return success(); +} + +LogicalResult AggregateConstantOp::verify() { + return checkAttributes(*this, getFieldsAttr(), getType()); +} + +OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); } + +//===----------------------------------------------------------------------===// +// ParamValueOp +//===----------------------------------------------------------------------===// + +static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, + Type &resultType) { + if (p.parseType(resultType) || p.parseEqual() || + p.parseAttribute(value, resultType)) + return failure(); + return success(); +} + +static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, + Type resultType) { + p << resultType << " = "; + p.printAttributeWithoutType(value); +} + +LogicalResult ParamValueOp::verify() { + // Check that the attribute expression is valid in this module. + return checkParameterInContext( + getValue(), (*this)->getParentOfType(), *this); +} + +OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "hw.param.value has no operands"); + return getValueAttr(); +} + +//===----------------------------------------------------------------------===// +// HWModuleOp +//===----------------------------------------------------------------------===/ + +/// Return true if isAnyModule or instance. +bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) { + return isa(moduleOrInstance); +} + +/// Return the signature for a module as a function type from the module itself +/// or from an hw::InstanceOp. +FunctionType hw::getModuleType(Operation *moduleOrInstance) { + if (auto instance = dyn_cast(moduleOrInstance)) { + SmallVector inputs(instance->getOperandTypes()); + SmallVector results(instance->getResultTypes()); + return FunctionType::get(instance->getContext(), inputs, results); + } + + if (auto mod = dyn_cast(moduleOrInstance)) + return mod.getModuleType().getFuncType(); + + if (auto mod = dyn_cast(moduleOrInstance)) + return mod.getHWModuleType().getFuncType(); + + return cast(moduleOrInstance) + .getFunctionType() + .cast(); +} + +/// Return the name to use for the Verilog module that we're referencing +/// here. This is typically the symbol, but can be overridden with the +/// verilogName attribute. +StringAttr hw::getVerilogModuleNameAttr(Operation *module) { + auto nameAttr = module->getAttrOfType("verilogName"); + if (nameAttr) return nameAttr; + + return module->getAttrOfType(SymbolTable::getSymbolAttrName()); +} + +// Flag for parsing different module types +enum ExternModKind { PlainMod, ExternMod, GenMod }; + +template +static void buildModule(OpBuilder &builder, OperationState &result, + StringAttr name, const ModulePortInfo &ports, + ArrayAttr parameters, + ArrayRef attributes, + StringAttr comment) { + using namespace mlir::function_interface_impl; + LocationAttr unknownLoc = builder.getUnknownLoc(); + + // Add an attribute for the name. + result.addAttribute(SymbolTable::getSymbolAttrName(), name); + + SmallVector argNames, resultNames; + SmallVector argTypes, resultTypes; + SmallVector argAttrs, resultAttrs; + SmallVector argLocs, resultLocs; + SmallVector portTypes; + auto exportPortIdent = StringAttr::get(builder.getContext(), "hw.exportPort"); + + for (auto elt : ports.getInputs()) { + portTypes.push_back(elt); + if (elt.dir == ModulePort::Direction::InOut && + !elt.type.isa()) + elt.type = hw::InOutType::get(elt.type); + argTypes.push_back(elt.type); + argNames.push_back(elt.name); + argLocs.push_back(elt.loc ? elt.loc : unknownLoc); + Attribute attr; + if (elt.sym && !elt.sym.empty()) + attr = builder.getDictionaryAttr({{exportPortIdent, elt.sym}}); + else + attr = builder.getDictionaryAttr({}); + argAttrs.push_back(attr); + } + + for (auto elt : ports.getOutputs()) { + portTypes.push_back(elt); + resultTypes.push_back(elt.type); + resultNames.push_back(elt.name); + resultLocs.push_back(elt.loc ? elt.loc : unknownLoc); + Attribute attr; + if (elt.sym && !elt.sym.empty()) + attr = builder.getDictionaryAttr({{exportPortIdent, elt.sym}}); + else + attr = builder.getDictionaryAttr({}); + resultAttrs.push_back(attr); + } + + // Allow clients to pass in null for the parameters list. + if (!parameters) parameters = builder.getArrayAttr({}); + + // Record the argument and result types as an attribute. + auto type = ModuleType::get(builder.getContext(), portTypes); + result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), + TypeAttr::get(type)); + result.addAttribute("argLocs", builder.getArrayAttr(argLocs)); + result.addAttribute("resultLocs", builder.getArrayAttr(resultLocs)); + result.addAttribute(ModuleTy::getArgAttrsAttrName(result.name), + builder.getArrayAttr(argAttrs)); + result.addAttribute(ModuleTy::getResAttrsAttrName(result.name), + builder.getArrayAttr(resultAttrs)); + result.addAttribute("parameters", parameters); + if (!comment) comment = builder.getStringAttr(""); + result.addAttribute("comment", comment); + result.addAttributes(attributes); + result.addRegion(); +} + +/// Internal implementation of argument/result insertion and removal on modules. +static void modifyModuleArgs( + MLIRContext *context, ArrayRef> insertArgs, + ArrayRef removeArgs, ArrayRef oldArgNames, + ArrayRef oldArgTypes, ArrayRef oldArgAttrs, + ArrayRef oldArgLocs, SmallVector &newArgNames, + SmallVector &newArgTypes, SmallVector &newArgAttrs, + SmallVector &newArgLocs, Block *body = nullptr) { +#ifndef NDEBUG + // Check that the `insertArgs` and `removeArgs` indices are in ascending + // order. + assert(llvm::is_sorted(insertArgs, + [](auto &a, auto &b) { return a.first < b.first; }) && + "insertArgs must be in ascending order"); + assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) && + "removeArgs must be in ascending order"); +#endif + + auto oldArgCount = oldArgTypes.size(); + auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size(); + assert((int)newArgCount >= 0); + + newArgNames.reserve(newArgCount); + newArgTypes.reserve(newArgCount); + newArgAttrs.reserve(newArgCount); + newArgLocs.reserve(newArgCount); + + auto exportPortAttrName = StringAttr::get(context, "hw.exportPort"); + auto emptyDictAttr = DictionaryAttr::get(context, {}); + auto unknownLoc = UnknownLoc::get(context); + + BitVector erasedIndices; + if (body) erasedIndices.resize(oldArgCount + insertArgs.size()); + + for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) { + // Insert new ports at this position. + while (!insertArgs.empty() && insertArgs[0].first == argIdx) { + auto port = insertArgs[0].second; + if (port.dir == ModulePort::Direction::InOut && + !port.type.isa()) + port.type = InOutType::get(port.type); + Attribute attr = + (port.sym && !port.sym.empty()) + ? DictionaryAttr::get(context, {{exportPortAttrName, port.sym}}) + : emptyDictAttr; + newArgNames.push_back(port.name); + newArgTypes.push_back(port.type); + newArgAttrs.push_back(attr); + insertArgs = insertArgs.drop_front(); + LocationAttr loc = port.loc ? port.loc : unknownLoc; + newArgLocs.push_back(loc); + if (body) body->insertArgument(idx++, port.type, loc); + } + if (argIdx == oldArgCount) break; + + // Migrate the old port at this position. + bool removed = false; + while (!removeArgs.empty() && removeArgs[0] == argIdx) { + removeArgs = removeArgs.drop_front(); + removed = true; + } + + if (removed) { + if (body) erasedIndices.set(idx); + } else { + newArgNames.push_back(oldArgNames[argIdx]); + newArgTypes.push_back(oldArgTypes[argIdx]); + newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr + : oldArgAttrs[argIdx]); + newArgLocs.push_back(oldArgLocs[argIdx]); + } + } + + if (body) body->eraseArguments(erasedIndices); + + assert(newArgNames.size() == newArgCount); + assert(newArgTypes.size() == newArgCount); + assert(newArgAttrs.size() == newArgCount); + assert(newArgLocs.size() == newArgCount); +} + +/// Insert and remove ports of a module. The insertion and removal indices must +/// be in ascending order. The indices refer to the port positions before any +/// insertion or removal occurs. Ports inserted at the same index will appear in +/// the module in the same order as they were listed in the `insert*` array. +/// +/// The operation must be any of the module-like operations. +void hw::modifyModulePorts( + Operation *op, ArrayRef> insertInputs, + ArrayRef> insertOutputs, + ArrayRef removeInputs, ArrayRef removeOutputs, + Block *body) { + auto moduleOp = cast(op); + auto *context = moduleOp.getContext(); + + // Dig up the old argument and result data. + auto oldArgNames = moduleOp.getInputNames(); + auto oldArgTypes = moduleOp.getInputTypes(); + auto oldArgAttrs = moduleOp.getAllInputAttrs(); + auto oldArgLocs = moduleOp.getInputLocs(); + + auto oldResultNames = moduleOp.getOutputNames(); + auto oldResultTypes = moduleOp.getOutputTypes(); + auto oldResultAttrs = moduleOp.getAllOutputAttrs(); + auto oldResultLocs = moduleOp.getOutputLocs(); + + // Modify the ports. + SmallVector newArgNames, newResultNames; + SmallVector newArgTypes, newResultTypes; + SmallVector newArgAttrs, newResultAttrs; + SmallVector newArgLocs, newResultLocs; + + modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames, + oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames, + newArgTypes, newArgAttrs, newArgLocs, body); + + modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames, + oldResultTypes, oldResultAttrs, oldResultLocs, + newResultNames, newResultTypes, newResultAttrs, + newResultLocs); + + // Update the module operation types and attributes. + auto fnty = FunctionType::get(context, newArgTypes, newResultTypes); + auto modty = detail::fnToMod(fnty, newArgNames, newResultNames); + moduleOp.setHWModuleType(modty); + moduleOp.setAllInputAttrs(newArgAttrs); + moduleOp.setInputLocs(newArgLocs); + moduleOp.setAllOutputAttrs(newResultAttrs); + moduleOp.setOutputLocs(newResultLocs); +} + +void HWModuleOp::build(OpBuilder &builder, OperationState &result, + StringAttr name, const ModulePortInfo &ports, + ArrayAttr parameters, + ArrayRef attributes, StringAttr comment, + bool shouldEnsureTerminator) { + buildModule(builder, result, name, ports, parameters, attributes, + comment); + + // Create a region and a block for the body. + auto *bodyRegion = result.regions[0].get(); + Block *body = new Block(); + bodyRegion->push_back(body); + + // Add arguments to the body block. + auto unknownLoc = builder.getUnknownLoc(); + for (auto port : ports.getInputs()) { + auto loc = port.loc ? Location(port.loc) : unknownLoc; + auto type = port.type; + if (port.isInOut() && !type.isa()) type = InOutType::get(type); + body->addArgument(type, loc); + } + + if (shouldEnsureTerminator) + HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location); +} + +void HWModuleOp::build(OpBuilder &builder, OperationState &result, + StringAttr name, ArrayRef ports, + ArrayAttr parameters, + ArrayRef attributes, + StringAttr comment) { + build(builder, result, name, ModulePortInfo(ports), parameters, attributes, + comment); +} + +void HWModuleOp::build(OpBuilder &builder, OperationState &odsState, + StringAttr name, const ModulePortInfo &ports, + HWModuleBuilder modBuilder, ArrayAttr parameters, + ArrayRef attributes, + StringAttr comment) { + build(builder, odsState, name, ports, parameters, attributes, comment, + /*shouldEnsureTerminator=*/false); + auto *bodyRegion = odsState.regions[0].get(); + OpBuilder::InsertionGuard guard(builder); + auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion); + builder.setInsertionPointToEnd(&bodyRegion->front()); + modBuilder(builder, accessor); + // Create output operands. + llvm::SmallVector outputOperands = accessor.getOutputOperands(); + builder.create(odsState.location, outputOperands); +} + +void HWModuleOp::modifyPorts( + ArrayRef> insertInputs, + ArrayRef> insertOutputs, + ArrayRef eraseInputs, ArrayRef eraseOutputs) { + hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs, + eraseOutputs); +} + +/// Return the name to use for the Verilog module that we're referencing +/// here. This is typically the symbol, but can be overridden with the +/// verilogName attribute. +StringAttr HWModuleExternOp::getVerilogModuleNameAttr() { + if (auto vName = getVerilogNameAttr()) return vName; + + return (*this)->getAttrOfType(SymbolTable::getSymbolAttrName()); +} + +StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() { + if (auto vName = getVerilogNameAttr()) { + return vName; + } + return (*this)->getAttrOfType( + ::mlir::SymbolTable::getSymbolAttrName()); +} + +void HWModuleExternOp::build(OpBuilder &builder, OperationState &result, + StringAttr name, const ModulePortInfo &ports, + StringRef verilogName, ArrayAttr parameters, + ArrayRef attributes) { + buildModule(builder, result, name, ports, parameters, + attributes, {}); + + if (!verilogName.empty()) + result.addAttribute("verilogName", builder.getStringAttr(verilogName)); +} + +void HWModuleExternOp::build(OpBuilder &builder, OperationState &result, + StringAttr name, ArrayRef ports, + StringRef verilogName, ArrayAttr parameters, + ArrayRef attributes) { + build(builder, result, name, ModulePortInfo(ports), verilogName, parameters, + attributes); +} + +void HWModuleExternOp::modifyPorts( + ArrayRef> insertInputs, + ArrayRef> insertOutputs, + ArrayRef eraseInputs, ArrayRef eraseOutputs) { + hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs, + eraseOutputs); +} + +void HWModuleExternOp::appendOutputs( + ArrayRef> outputs) {} + +void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result, + FlatSymbolRefAttr genKind, StringAttr name, + const ModulePortInfo &ports, + StringRef verilogName, ArrayAttr parameters, + ArrayRef attributes) { + buildModule(builder, result, name, ports, parameters, + attributes, {}); + result.addAttribute("generatorKind", genKind); + if (!verilogName.empty()) + result.addAttribute("verilogName", builder.getStringAttr(verilogName)); +} + +void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result, + FlatSymbolRefAttr genKind, StringAttr name, + ArrayRef ports, StringRef verilogName, + ArrayAttr parameters, + ArrayRef attributes) { + build(builder, result, genKind, name, ModulePortInfo(ports), verilogName, + parameters, attributes); +} + +void HWModuleGeneratedOp::modifyPorts( + ArrayRef> insertInputs, + ArrayRef> insertOutputs, + ArrayRef eraseInputs, ArrayRef eraseOutputs) { + hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs, + eraseOutputs); +} + +void HWModuleGeneratedOp::appendOutputs( + ArrayRef> outputs) {} + +static InnerSymAttr extractSym(DictionaryAttr attrs) { + if (attrs) + if (auto symRef = attrs.get("hw.exportPort")) + return symRef.cast(); + return {}; +} + +static bool hasAttribute(StringRef name, ArrayRef attrs) { + for (auto &argAttr : attrs) + if (argAttr.getName() == name) return true; + return false; +} + +static Attribute getAttribute(StringRef name, ArrayRef attrs) { + for (auto &argAttr : attrs) + if (argAttr.getName() == name) return argAttr.getValue(); + return {}; +} + +template +static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result, + ExternModKind modKind = PlainMod) { + using namespace mlir::function_interface_impl; + auto loc = parser.getCurrentLocation(); + + // Parse the visibility attribute. + (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the generator information. + FlatSymbolRefAttr kindAttr; + if (modKind == GenMod) { + if (parser.parseComma() || + parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) { + return failure(); + } + } + + // Parse the parameters. + ArrayAttr parameters; + if (parseOptionalParameterList(parser, parameters)) return failure(); + + // Parse the function signature. + bool isVariadic = false; + SmallVector entryArgs; + SmallVector argNames; + SmallVector argLocs; + SmallVector resultNames; + SmallVector resultAttrs; + SmallVector resultLocs; + TypeAttr functionType; + if (failed(module_like_impl::parseModuleFunctionSignature( + parser, isVariadic, entryArgs, argNames, argLocs, resultNames, + resultAttrs, resultLocs, functionType))) + return failure(); + + // Parse the attribute dict. + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + + if (hasAttribute("resultNames", result.attributes) || + hasAttribute("parameters", result.attributes)) { + parser.emitError( + loc, "explicit `resultNames` / `parameters` attributes not allowed"); + return failure(); + } + + auto *context = result.getContext(); + // prefer the attribute over the ssa values + auto attr = getAttribute("argNames", result.attributes); + auto modType = detail::fnToMod( + cast(functionType.getValue()), + attr ? cast(attr).getValue() : argNames, resultNames); + result.attributes.erase("argNames"); + result.addAttribute("argLocs", ArrayAttr::get(context, argLocs)); + result.addAttribute("resultLocs", ArrayAttr::get(context, resultLocs)); + result.addAttribute("parameters", parameters); + if (!hasAttribute("comment", result.attributes)) + result.addAttribute("comment", StringAttr::get(context, "")); + result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), + TypeAttr::get(modType)); + + // Add the attributes to the function arguments. + addArgAndResultAttrs(parser.getBuilder(), result, entryArgs, resultAttrs, + ModuleTy::getArgAttrsAttrName(result.name), + ModuleTy::getResAttrsAttrName(result.name)); + + // Parse the optional function body. + auto *body = result.addRegion(); + if (modKind == PlainMod) { + if (parser.parseRegion(*body, entryArgs)) return failure(); + + HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); + } + return success(); +} + +ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseHWModuleOp(parser, result); +} + +ParseResult HWModuleExternOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseHWModuleOp(parser, result, ExternMod); +} + +ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseHWModuleOp(parser, result, GenMod); +} + +FunctionType getHWModuleOpType(Operation *op) { + if (auto mod = dyn_cast(op)) + return mod.getHWModuleType().getFuncType(); + return cast(op) + .getFunctionType() + .cast(); +} + +template +static void printModuleOp(OpAsmPrinter &p, ModuleTy mod, + ExternModKind modKind) { + using namespace mlir::function_interface_impl; + + FunctionType fnType = mod.getHWModuleType().getFuncType(); + auto argTypes = fnType.getInputs(); + auto resultTypes = fnType.getResults(); + + p << ' '; + + // Print the visibility of the module. + StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); + if (auto visibility = mod.getOperation()->template getAttrOfType( + visibilityAttrName)) + p << visibility.getValue() << ' '; + + // Print the operation and the function name. + p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue()); + if (modKind == GenMod) { + p << ", "; + p.printSymbolName( + cast(mod.getOperation()).getGeneratorKind()); + } + + // Print the parameter list if present. + printOptionalParameterList( + p, mod.getOperation(), + mod.getOperation()->template getAttrOfType("parameters")); + + bool needArgNamesAttr = false; + module_like_impl::printModuleSignature(p, mod.getOperation(), argTypes, + /*isVariadic=*/false, resultTypes, + needArgNamesAttr); + + SmallVector omittedAttrs; + if (modKind == GenMod) omittedAttrs.push_back("generatorKind"); + if (!needArgNamesAttr) omittedAttrs.push_back("argNames"); + omittedAttrs.push_back("argLocs"); + omittedAttrs.push_back( + ModuleTy::getModuleTypeAttrName(mod.getOperation()->getName())); + omittedAttrs.push_back( + ModuleTy::getArgAttrsAttrName(mod.getOperation()->getName())); + omittedAttrs.push_back( + ModuleTy::getResAttrsAttrName(mod.getOperation()->getName())); + omittedAttrs.push_back("resultNames"); + omittedAttrs.push_back("resultLocs"); + omittedAttrs.push_back("parameters"); + omittedAttrs.push_back(visibilityAttrName); + omittedAttrs.push_back(SymbolTable::getSymbolAttrName()); + if (mod.getOperation() + ->template getAttrOfType("comment") + .getValue() + .empty()) + omittedAttrs.push_back("comment"); + // inject argNames + auto attrs = mod->getAttrs(); + SmallVector realAttrs(attrs.begin(), attrs.end()); + realAttrs.push_back( + NamedAttribute(StringAttr::get(mod.getContext(), "argNames"), + ArrayAttr::get(mod.getContext(), mod.getInputNames()))); + p.printOptionalAttrDictWithKeyword(realAttrs, omittedAttrs); +} + +void HWModuleExternOp::print(OpAsmPrinter &p) { + printModuleOp(p, *this, ExternMod); +} +void HWModuleGeneratedOp::print(OpAsmPrinter &p) { + printModuleOp(p, *this, GenMod); +} + +void HWModuleOp::print(OpAsmPrinter &p) { + printModuleOp(p, *this, PlainMod); + + // Print the body if this is not an external function. + Region &body = getBody(); + if (!body.empty()) { + p << " "; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +static LogicalResult verifyModuleCommon(HWModuleLike module) { + assert(isa(module) && + "verifier hook should only be called on modules"); + + auto moduleType = module.getHWModuleType(); + + auto argLocs = module.getInputLocs(); + if (argLocs.size() != moduleType.getNumInputs()) + return module->emitOpError("incorrect number of argument locations"); + + auto resultLocs = module.getOutputLocs(); + if (resultLocs.size() != moduleType.getNumOutputs()) + return module->emitOpError("incorrect number of result locations"); + + SmallPtrSet paramNames; + + // Check parameter default values are sensible. + for (auto param : module->getAttrOfType("parameters")) { + auto paramAttr = param.cast(); + + // Check that we don't have any redundant parameter names. These are + // resolved by string name: reuse of the same name would cause ambiguities. + if (!paramNames.insert(paramAttr.getName()).second) + return module->emitOpError("parameter ") + << paramAttr << " has the same name as a previous parameter"; + + // Default values are allowed to be missing, check them if present. + auto value = paramAttr.getValue(); + if (!value) continue; + + auto typedValue = value.dyn_cast(); + if (!typedValue) + return module->emitOpError("parameter ") + << paramAttr << " should have a typed value; has value " << value; + + if (typedValue.getType() != paramAttr.getType()) + return module->emitOpError("parameter ") + << paramAttr << " should have type " << paramAttr.getType() + << "; has type " << typedValue.getType(); + + // Verify that this is a valid parameter value, disallowing parameter + // references. We could allow parameters to refer to each other in the + // future with lexical ordering if there is a need. + if (failed(checkParameterInContext(value, module, module, + /*disallowParamRefs=*/true))) + return failure(); + } + return success(); +} + +LogicalResult HWModuleOp::verify() { + if (failed(verifyModuleCommon(*this))) return failure(); + + auto type = getModuleType(); + auto *body = getBodyBlock(); + + // Verify the number of block arguments. + auto numInputs = type.getNumInputs(); + if (body->getNumArguments() != numInputs) + return emitOpError("entry block must have") + << numInputs << " arguments to match module signature"; + + // Verify that the block arguments match the op's attributes. + for (auto [arg, type, loc] : llvm::zip(getBodyBlock()->getArguments(), + getInputTypes(), getInputLocs())) { + if (arg.getType() != type) + return emitOpError("block argument types should match signature types"); + if (arg.getLoc() != loc.cast()) + return emitOpError( + "block argument locations should match signature locations"); + } + + return success(); +} + +LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); } + +std::pair HWModuleOp::insertInput(unsigned index, + StringAttr name, + Type ty) { + // Find a unique name for the wire. + Namespace ns; + auto ports = getPortList(); + for (auto port : ports) ns.newName(port.name.getValue()); + auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue())); + + Block *body = getBodyBlock(); + + // Create a new port for the host clock. + PortInfo port; + port.name = nameAttr; + port.dir = ModulePort::Direction::Input; + port.type = ty; + hw::modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, + {}, body); + + // Add a new argument. + return {nameAttr, body->getArgument(index)}; +} + +void HWModuleOp::insertOutputs(unsigned index, + ArrayRef> outputs) { + auto output = cast(getBodyBlock()->getTerminator()); + assert(index <= output->getNumOperands() && "invalid output index"); + + // Rewrite the port list of the module. + SmallVector> indexedNewPorts; + for (auto &[name, value] : outputs) { + PortInfo port; + port.name = name; + port.dir = ModulePort::Direction::Output; + port.type = value.getType(); + indexedNewPorts.emplace_back(index, port); + } + hw::modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {}, + getBodyBlock()); + + // Rewrite the output op. + for (auto &[name, value] : outputs) output->insertOperands(index++, value); +} + +void HWModuleOp::appendOutputs(ArrayRef> outputs) { + return insertOutputs(getNumOutputPorts(), outputs); +} + +void HWModuleOp::getAsmBlockArgumentNames(mlir::Region ®ion, + mlir::OpAsmSetValueNameFn setNameFn) { + getAsmBlockArgumentNamesImpl(region, setNameFn); +} + +void HWModuleExternOp::getAsmBlockArgumentNames( + mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) { + getAsmBlockArgumentNamesImpl(region, setNameFn); +} + +template +static SmallVector getAllPortLocs(ModTy module) { + SmallVector retval; + auto empty = UnknownLoc::get(module.getContext()); + auto locs = module.getArgLocs(); + if (locs) + for (auto l : locs) retval.push_back(cast(l)); + retval.resize(module.getNumInputPorts(), empty); + locs = module.getResultLocs(); + if (locs) + for (auto l : locs) retval.push_back(cast(l)); + retval.resize(module.getNumInputPorts() + module.getNumOutputPorts(), empty); + return retval; +} + +SmallVector HWModuleOp::getAllPortLocs() { + return ::getAllPortLocs(*this); +} + +SmallVector HWModuleExternOp::getAllPortLocs() { + return ::getAllPortLocs(*this); +} + +SmallVector HWModuleGeneratedOp::getAllPortLocs() { + return ::getAllPortLocs(*this); +} + +template +static void setAllPortLocs(ArrayRef locs, ModTy module) { + auto numInputs = module.getNumInputPorts(); + SmallVector argLocs(locs.begin(), locs.begin() + numInputs); + SmallVector resLocs(locs.begin() + numInputs, locs.end()); + module.setArgLocsAttr(ArrayAttr::get(module.getContext(), argLocs)); + module.setResultLocsAttr(ArrayAttr::get(module.getContext(), resLocs)); +} + +void HWModuleOp::setAllPortLocs(ArrayRef locs) { + ::setAllPortLocs(locs, *this); +} + +void HWModuleExternOp::setAllPortLocs(ArrayRef locs) { + ::setAllPortLocs(locs, *this); +} + +void HWModuleGeneratedOp::setAllPortLocs(ArrayRef locs) { + ::setAllPortLocs(locs, *this); +} + +template +static void setAllPortNames(ArrayRef names, ModTy module) { + auto numInputs = module.getNumInputPorts(); + SmallVector argNames(names.begin(), names.begin() + numInputs); + SmallVector resNames(names.begin() + numInputs, names.end()); + auto oldType = module.getModuleType(); + SmallVector newPorts(oldType.getPorts().begin(), + oldType.getPorts().end()); + for (size_t i = 0UL, e = newPorts.size(); i != e; ++i) + newPorts[i].name = cast(names[i]); + auto newType = ModuleType::get(module.getContext(), newPorts); + module.setModuleType(newType); +} + +void HWModuleOp::setAllPortNames(ArrayRef names) { + ::setAllPortNames(names, *this); +} + +void HWModuleExternOp::setAllPortNames(ArrayRef names) { + ::setAllPortNames(names, *this); +} + +void HWModuleGeneratedOp::setAllPortNames(ArrayRef names) { + ::setAllPortNames(names, *this); +} + +template +static SmallVector getAllPortAttrs(ModTy &mod) { + SmallVector retval; + auto empty = DictionaryAttr::get(mod.getContext()); + auto attrs = mod.getArgAttrs(); + if (attrs) + for (auto a : *attrs) retval.push_back(a); + retval.resize(mod.getNumInputPorts(), empty); + attrs = mod.getResAttrs(); + if (attrs) + for (auto a : *attrs) retval.push_back(a); + retval.resize(mod.getNumInputPorts() + mod.getNumOutputPorts(), empty); + return retval; +} + +SmallVector HWModuleOp::getAllPortAttrs() { + return ::getAllPortAttrs(*this); +} + +SmallVector HWModuleExternOp::getAllPortAttrs() { + return ::getAllPortAttrs(*this); +} + +SmallVector HWModuleGeneratedOp::getAllPortAttrs() { + return ::getAllPortAttrs(*this); +} + +template +static void setAllPortAttrs(ModTy &mod, ArrayRef attrs) { + auto numInputs = mod.getNumInputPorts(); + SmallVector argAttrs(attrs.begin(), attrs.begin() + numInputs); + SmallVector resAttrs(attrs.begin() + numInputs, attrs.end()); + + mod.setArgAttrsAttr(ArrayAttr::get(mod.getContext(), argAttrs)); + mod.setResAttrsAttr(ArrayAttr::get(mod.getContext(), resAttrs)); +} + +void HWModuleOp::setAllPortAttrs(ArrayRef attrs) { + return ::setAllPortAttrs(*this, attrs); +} + +void HWModuleExternOp::setAllPortAttrs(ArrayRef attrs) { + return ::setAllPortAttrs(*this, attrs); +} + +void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef attrs) { + return ::setAllPortAttrs(*this, attrs); +} + +template +static void removeAllPortAttrs(ModTy &mod) { + mod.setArgAttrsAttr(ArrayAttr::get(mod.getContext(), {})); + mod.setResAttrsAttr(ArrayAttr::get(mod.getContext(), {})); +} + +void HWModuleOp::removeAllPortAttrs() { return ::removeAllPortAttrs(*this); } + +void HWModuleExternOp::removeAllPortAttrs() { + return ::removeAllPortAttrs(*this); +} + +void HWModuleGeneratedOp::removeAllPortAttrs() { + return ::removeAllPortAttrs(*this); +} + +// This probably does really unexpected stuff when you change the number of + +template +static void setHWModuleType(ModTy &mod, ModuleType type) { + auto argAttrs = mod.getAllInputAttrs(); + auto resAttrs = mod.getAllOutputAttrs(); + mod.setModuleTypeAttr(TypeAttr::get(type)); + unsigned newNumArgs = type.getNumInputs(); + unsigned newNumResults = type.getNumOutputs(); + + auto emptyDict = DictionaryAttr::get(mod.getContext()); + argAttrs.resize(newNumArgs, emptyDict); + resAttrs.resize(newNumResults, emptyDict); + + SmallVector attrs; + attrs.append(argAttrs.begin(), argAttrs.end()); + attrs.append(resAttrs.begin(), resAttrs.end()); + + if (attrs.empty()) return mod.removeAllPortAttrs(); + mod.setAllPortAttrs(attrs); +} + +void HWModuleOp::setHWModuleType(ModuleType type) { + return ::setHWModuleType(*this, type); +} + +void HWModuleExternOp::setHWModuleType(ModuleType type) { + return ::setHWModuleType(*this, type); +} + +void HWModuleGeneratedOp::setHWModuleType(ModuleType type) { + return ::setHWModuleType(*this, type); +} + +/// Lookup the generator for the symbol. This returns null on +/// invalid IR. +Operation *HWModuleGeneratedOp::getGeneratorKindOp() { + auto topLevelModuleOp = (*this)->getParentOfType(); + return topLevelModuleOp.lookupSymbol(getGeneratorKind()); +} + +LogicalResult HWModuleGeneratedOp::verifySymbolUses( + SymbolTableCollection &symbolTable) { + auto *referencedKind = + symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr()); + + if (referencedKind == nullptr) + return emitError("Cannot find generator definition '") + << getGeneratorKind() << "'"; + + if (!isa(referencedKind)) + return emitError("Symbol resolved to '") + << referencedKind->getName() + << "' which is not a HWGeneratorSchemaOp"; + + auto referencedKindOp = dyn_cast(referencedKind); + auto paramRef = referencedKindOp.getRequiredAttrs(); + auto dict = (*this)->getAttrDictionary(); + for (auto str : paramRef) { + auto strAttr = str.dyn_cast(); + if (!strAttr) return emitError("Unknown attribute type, expected a string"); + if (!dict.get(strAttr.getValue())) + return emitError("Missing attribute '") << strAttr.getValue() << "'"; + } + + return success(); +} + +LogicalResult HWModuleGeneratedOp::verify() { + return verifyModuleCommon(*this); +} + +void HWModuleGeneratedOp::getAsmBlockArgumentNames( + mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) { + getAsmBlockArgumentNamesImpl(region, setNameFn); +} + +LogicalResult HWModuleOp::verifyBody() { return success(); } + +//===----------------------------------------------------------------------===// +// InstanceOp +//===----------------------------------------------------------------------===// + +/// Create a instance that refers to a known module. +void InstanceOp::build(OpBuilder &builder, OperationState &result, + Operation *module, StringAttr name, + ArrayRef inputs, ArrayAttr parameters, + InnerSymAttr innerSym) { + if (!parameters) parameters = builder.getArrayAttr({}); + + auto mod = cast(module); + auto argNames = builder.getArrayAttr(mod.getInputNames()); + auto resultNames = builder.getArrayAttr(mod.getOutputNames()); + FunctionType modType = mod.getHWModuleType().getFuncType(); + build(builder, result, modType.getResults(), name, + FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs, + argNames, resultNames, parameters, innerSym); +} + +std::optional InstanceOp::getTargetResultIndex() { + // Inner symbols on instance operations target the op not any result. + return std::nullopt; +} + +/// Lookup the module or extmodule for the symbol. This returns null on +/// invalid IR. +Operation *InstanceOp::getReferencedModule(const HWSymbolCache *cache) { + return instance_like_impl::getReferencedModule(cache, *this, + getModuleNameAttr()); +} + +Operation *InstanceOp::getReferencedModule(SymbolTable &symtbl) { + return symtbl.lookup(getModuleNameAttr().getValue()); +} + +Operation *InstanceOp::getReferencedModuleSlow() { + return getReferencedModule(/*cache=*/nullptr); +} + +LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + return instance_like_impl::verifyInstanceOfHWModule( + *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(), + getResultNames(), getParameters(), symbolTable); +} + +LogicalResult InstanceOp::verify() { + auto module = (*this)->getParentOfType(); + if (!module) return success(); + + auto moduleParameters = module->getAttrOfType("parameters"); + instance_like_impl::EmitErrorFn emitError = + [&](const std::function &fn) { + auto diag = emitOpError(); + if (fn(diag)) + diag.attachNote(module->getLoc()) << "module declared here"; + }; + return instance_like_impl::verifyParameterStructure( + getParameters(), moduleParameters, emitError); +} + +ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) { + StringAttr instanceNameAttr; + InnerSymAttr innerSym; + FlatSymbolRefAttr moduleNameAttr; + SmallVector inputsOperands; + SmallVector inputsTypes, allResultTypes; + ArrayAttr argNames, resultNames, parameters; + auto noneType = parser.getBuilder().getType(); + + if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName", + result.attributes)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("sym"))) { + // Parsing an optional symbol name doesn't fail, so no need to check the + // result. + if (parser.parseCustomAttributeWithFallback(innerSym)) return failure(); + result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym); + } + + llvm::SMLoc parametersLoc, inputsOperandsLoc; + if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName", + result.attributes) || + parser.getCurrentLocation(¶metersLoc) || + parseOptionalParameterList(parser, parameters) || + parseInputPortList(parser, inputsOperands, inputsTypes, argNames) || + parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, + result.operands) || + parser.parseArrow() || + parseOutputPortList(parser, allResultTypes, resultNames) || + parser.parseOptionalAttrDict(result.attributes)) { + return failure(); + } + + result.addAttribute("argNames", argNames); + result.addAttribute("resultNames", resultNames); + result.addAttribute("parameters", parameters); + result.addTypes(allResultTypes); + return success(); +} + +void InstanceOp::print(OpAsmPrinter &p) { + p << ' '; + p.printAttributeWithoutType(getInstanceNameAttr()); + if (auto attr = getInnerSymAttr()) { + p << " sym "; + attr.print(p); + } + p << ' '; + p.printAttributeWithoutType(getModuleNameAttr()); + printOptionalParameterList(p, *this, getParameters()); + printInputPortList(p, *this, getInputs(), getInputs().getTypes(), + getArgNames()); + p << " -> "; + printOutputPortList(p, *this, getResultTypes(), getResultNames()); + + p.printOptionalAttrDict( + (*this)->getAttrs(), + /*elidedAttrs=*/{"instanceName", + InnerSymbolTable::getInnerSymbolAttrName(), "moduleName", + "argNames", "resultNames", "parameters"}); +} + +/// Return the name of the specified input port or null if it cannot be +/// determined. +StringAttr InstanceOp::getArgumentName(size_t idx) { + return instance_like_impl::getName(getArgNames(), idx); +} + +/// Return the name of the specified result or null if it cannot be +/// determined. +StringAttr InstanceOp::getResultName(size_t idx) { + return instance_like_impl::getName(getResultNames(), idx); +} + +/// Change the name of the specified input port. +void InstanceOp::setArgumentName(size_t i, StringAttr name) { + setInputNames(instance_like_impl::updateName(getArgNames(), i, name)); +} + +/// Change the name of the specified output port. +void InstanceOp::setResultName(size_t i, StringAttr name) { + setOutputNames(instance_like_impl::updateName(getResultNames(), i, name)); +} + +/// Suggest a name for each result value based on the saved result names +/// attribute. +void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { + instance_like_impl::getAsmResultNames(setNameFn, getInstanceName(), + getResultNames(), getResults()); +} + +ModulePortInfo InstanceOp::getPortList() { + SmallVector inputs, outputs; + auto emptyDict = DictionaryAttr::get(getContext()); + auto argNames = (*this)->getAttrOfType("argNames"); + auto argTypes = getModuleType(*this).getInputs(); + auto argLocs = (*this)->getAttrOfType("argLocs"); + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + auto type = argTypes[i]; + auto direction = ModulePort::Direction::Input; + + if (auto inout = type.dyn_cast()) { + type = inout.getElementType(); + direction = ModulePort::Direction::InOut; + } + + LocationAttr loc; + if (argLocs) loc = argLocs[i].cast(); + inputs.push_back({{argNames[i].cast(), type, direction}, + i, + {}, + emptyDict, + loc}); + } + + auto resultNames = (*this)->getAttrOfType("resultNames"); + auto resultTypes = getModuleType(*this).getResults(); + auto resultLocs = (*this)->getAttrOfType("resultLocs"); + for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) { + LocationAttr loc; + if (resultLocs) loc = resultLocs[i].cast(); + outputs.push_back({{resultNames[i].cast(), resultTypes[i], + ModulePort::Direction::Output}, + i, + {}, + emptyDict, + loc}); + } + return ModulePortInfo(inputs, outputs); +} + +size_t InstanceOp::getNumPorts() { + return getNumInputPorts() + getNumOutputPorts(); +} + +size_t InstanceOp::getNumInputPorts() { return getNumOperands(); } + +size_t InstanceOp::getNumOutputPorts() { return getNumResults(); } + +size_t InstanceOp::getPortIdForInputId(size_t idx) { return idx; } + +size_t InstanceOp::getPortIdForOutputId(size_t idx) { + return idx + getNumInputPorts(); +} + +void InstanceOp::getValues(SmallVectorImpl &values, + const ModulePortInfo &mpi) { + size_t inputPort = 0, resultPort = 0; + values.resize(mpi.size()); + auto results = getResults(); + auto inputs = getInputs(); + for (auto [idx, port] : llvm::enumerate(mpi)) + if (mpi.at(idx).isOutput()) + values[idx] = results[resultPort++]; + else + values[idx] = inputs[inputPort++]; +} + +//===----------------------------------------------------------------------===// +// HWOutputOp +//===----------------------------------------------------------------------===// + +/// Verify that the num of operands and types fit the declared results. +LogicalResult OutputOp::verify() { + // Check that the we (hw.output) have the same number of operands as our + // region has results. + ModuleType modType; + if (auto mod = dyn_cast((*this)->getParentOp())) + modType = mod.getHWModuleType(); + else if (auto mod = dyn_cast((*this)->getParentOp())) + modType = mod.getModuleType(); + else { + emitOpError("must have a module parent"); + return failure(); + } + auto modResults = modType.getOutputTypes(); + OperandRange outputValues = getOperands(); + if (modResults.size() != outputValues.size()) { + emitOpError("must have same number of operands as region results."); + return failure(); + } + + // Check that the types of our operands and the region's results match. + for (size_t i = 0, e = modResults.size(); i < e; ++i) { + if (modResults[i] != outputValues[i].getType()) { + emitOpError( + "output types must match module. In " + "operand ") + << i << ", expected " << modResults[i] << ", but got " + << outputValues[i].getType() << "."; + return failure(); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Other Operations +//===----------------------------------------------------------------------===// + +LogicalResult GlobalRefOp::verifySymbolUses( + mlir::SymbolTableCollection &symTables) { + Operation *parent = (*this)->getParentOp(); + SymbolTable &symTable = symTables.getSymbolTable(parent); + StringAttr symNameAttr = (*this).getSymNameAttr(); + auto hasGlobalRef = [&](Attribute attr) -> bool { + if (!attr) return false; + for (auto ref : attr.cast().getAsRange()) + if (ref.getGlblSym().getAttr() == symNameAttr) return true; + return false; + }; + // For all inner refs in the namepath, ensure they have a corresponding + // GlobalRefAttr to this GlobalRefOp. + for (auto innerRef : getNamepath().getAsRange()) { + StringAttr modName = innerRef.getModule(); + StringAttr symName = innerRef.getName(); + Operation *mod = symTable.lookup(modName); + if (!mod) { + (*this)->emitOpError("module:'" + modName.str() + "' not found"); + return failure(); + } + bool glblSymNotFound = true; + bool innerSymOpNotFound = true; + mod->walk([&](InnerSymbolOpInterface op) -> WalkResult { + // If this is one of the ops in the instance path for the GlobalRefOp. + if (op.getInnerNameAttr() == symName) { + innerSymOpNotFound = false; + // Each op can have an array of GlobalRefAttr, check if this op is one + // of them. + if (hasGlobalRef(op->getAttr(GlobalRefAttr::DialectAttrName))) { + glblSymNotFound = false; + return WalkResult::interrupt(); + } + // If cannot find the ref, then its an error. + return failure(); + } + return WalkResult::advance(); + }); + if (glblSymNotFound) { + // TODO: Doesn't yet work for symbls on FIRRTL module ports. Need to + // implement an interface. + if (isa(mod)) { + auto hwmod = cast(mod); + auto inAttrs = hwmod.getAllInputAttrs(); + for (auto attr : inAttrs) + if (auto symRef = cast(attr).getAs( + "hw.exportPort")) + if (symRef.getSymName() == symName) + if (hasGlobalRef(cast(attr).get( + GlobalRefAttr::DialectAttrName))) + return success(); + + auto outAttrs = hwmod.getAllOutputAttrs(); + for (auto attr : outAttrs) + if (auto symRef = cast(attr).getAs( + "hw.exportPort")) + if (symRef.getSymName() == symName) + if (hasGlobalRef(cast(attr).get( + GlobalRefAttr::DialectAttrName))) + return success(); + } + } + if (innerSymOpNotFound) + return (*this)->emitOpError("operation:'" + symName.str() + + "' in module:'" + modName.str() + + "' could not be found"); + if (glblSymNotFound) + return (*this)->emitOpError( + "operation:'" + symName.str() + "' in module:'" + modName.str() + + "' does not contain a reference to '" + symNameAttr.str() + "'"); + } + return success(); +} + +static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, + Type &idxType) { + Type type; + if (p.parseType(type)) + return p.emitError(p.getCurrentLocation(), "Expected type"); + auto arrType = type_dyn_cast(type); + if (!arrType) + return p.emitError(p.getCurrentLocation(), "Expected !hw.array type"); + srcType = type; + unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements()); + idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth); + return success(); +} + +static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, + Type idxType) { + p.printType(srcType); +} + +ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); + llvm::SmallVector operands; + Type elemType; + + if (parser.parseOperandList(operands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseType(elemType)) + return failure(); + + if (operands.size() == 0) + return parser.emitError(inputOperandsLoc, + "Cannot construct an array of length 0"); + result.addTypes({ArrayType::get(elemType, operands.size())}); + + for (auto operand : operands) + if (parser.resolveOperand(operand, elemType, result.operands)) + return failure(); + return success(); +} + +void ArrayCreateOp::print(OpAsmPrinter &p) { + p << " "; + p.printOperands(getInputs()); + p.printOptionalAttrDict((*this)->getAttrs()); + p << " : " << getInputs()[0].getType(); +} + +void ArrayCreateOp::build(OpBuilder &b, OperationState &state, + ValueRange values) { + assert(values.size() > 0 && "Cannot build array of zero elements"); + Type elemType = values[0].getType(); + assert(llvm::all_of( + values, + [elemType](Value v) -> bool { return v.getType() == elemType; }) && + "All values must have same type."); + build(b, state, ArrayType::get(elemType, values.size()), values); +} + +LogicalResult ArrayCreateOp::verify() { + unsigned returnSize = getType().cast().getNumElements(); + if (getInputs().size() != returnSize) return failure(); + return success(); +} + +OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) { + if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; })) + return {}; + return ArrayAttr::get(getContext(), adaptor.getInputs()); +} + +// Check whether an integer value is an offset from a base. +bool hw::isOffset(Value base, Value index, uint64_t offset) { + if (auto constBase = base.getDefiningOp()) { + if (auto constIndex = index.getDefiningOp()) { + // If both values are a constant, check if index == base + offset. + // To account for overflow, the addition is performed with an extra bit + // and the offset is asserted to fit in the bit width of the base. + auto baseValue = constBase.getValue(); + auto indexValue = constIndex.getValue(); + + unsigned bits = baseValue.getBitWidth(); + assert(bits == indexValue.getBitWidth() && "mismatched widths"); + + if (bits < 64 && offset >= (1ull << bits)) return false; + + APInt baseExt = baseValue.zextOrTrunc(bits + 1); + APInt indexExt = indexValue.zextOrTrunc(bits + 1); + return baseExt + offset == indexExt; + } + } + return false; +} + +// Canonicalize a create of consecutive elements to a slice. +static LogicalResult foldCreateToSlice(ArrayCreateOp op, + PatternRewriter &rewriter) { + // Do not canonicalize create of get into a slice. + auto arrayTy = hw::type_cast(op.getType()); + if (arrayTy.getNumElements() <= 1) return failure(); + auto elemTy = arrayTy.getElementType(); + + // Check if create arguments are consecutive elements of the same array. + // Attempt to break a create of gets into a sequence of consecutive intervals. + struct Chunk { + Value input; + Value index; + size_t size; + }; + SmallVector chunks; + for (Value value : llvm::reverse(op.getInputs())) { + auto get = value.getDefiningOp(); + if (!get) return failure(); + + Value input = get.getInput(); + Value index = get.getIndex(); + if (!chunks.empty()) { + auto &c = *chunks.rbegin(); + if (c.input == get.getInput() && isOffset(c.index, index, c.size)) { + c.size++; + continue; + } + } + + chunks.push_back(Chunk{input, index, 1}); + } + + // If there is a single slice, eliminate the create. + if (chunks.size() == 1) { + auto &chunk = chunks[0]; + rewriter.replaceOp(op, rewriter.createOrFold( + op.getLoc(), arrayTy, chunk.input, chunk.index)); + return success(); + } + + // If the number of chunks is significantly less than the number of + // elements, replace the create with a concat of the identified slices. + if (chunks.size() * 2 < arrayTy.getNumElements()) { + SmallVector slices; + for (auto &chunk : llvm::reverse(chunks)) { + auto sliceTy = ArrayType::get(elemTy, chunk.size); + slices.push_back(rewriter.createOrFold( + op.getLoc(), sliceTy, chunk.input, chunk.index)); + } + rewriter.replaceOpWithNewOp(op, arrayTy, slices); + return success(); + } + + return failure(); +} + +LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op, + PatternRewriter &rewriter) { + if (succeeded(foldCreateToSlice(op, rewriter))) return success(); + return failure(); +} + +Value ArrayCreateOp::getUniformElement() { + if (!getInputs().empty() && llvm::all_equal(getInputs())) + return getInputs()[0]; + return {}; +} + +static std::optional getUIntFromValue(Value value) { + auto idxOp = dyn_cast_or_null(value.getDefiningOp()); + if (!idxOp) return std::nullopt; + APInt idxAttr = idxOp.getValue(); + if (idxAttr.getBitWidth() > 64) return std::nullopt; + return idxAttr.getLimitedValue(); +} + +LogicalResult ArraySliceOp::verify() { + unsigned inputSize = + type_cast(getInput().getType()).getNumElements(); + if (llvm::Log2_64_Ceil(inputSize) != + getLowIndex().getType().getIntOrFloatBitWidth()) + return emitOpError( + "ArraySlice: index width must match clog2 of array size"); + return success(); +} + +OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) { + // If we are slicing the entire input, then return it. + if (getType() == getInput().getType()) return getInput(); + return {}; +} + +LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op, + PatternRewriter &rewriter) { + auto sliceTy = hw::type_cast(op.getType()); + auto elemTy = sliceTy.getElementType(); + uint64_t sliceSize = sliceTy.getNumElements(); + if (sliceSize == 0) return failure(); + + if (sliceSize == 1) { + // slice(a, n) -> create(a[n]) + auto get = rewriter.create(op.getLoc(), op.getInput(), + op.getLowIndex()); + rewriter.replaceOpWithNewOp(op, op.getType(), + get.getResult()); + return success(); + } + + auto offsetOpt = getUIntFromValue(op.getLowIndex()); + if (!offsetOpt) return failure(); + + auto inputOp = op.getInput().getDefiningOp(); + if (auto inputSlice = dyn_cast_or_null(inputOp)) { + // slice(slice(a, n), m) -> slice(a, n + m) + if (inputSlice == op) return failure(); + + auto inputIndex = inputSlice.getLowIndex(); + auto inputOffsetOpt = getUIntFromValue(inputIndex); + if (!inputOffsetOpt) return failure(); + + uint64_t offset = *offsetOpt + *inputOffsetOpt; + auto lowIndex = + rewriter.create(op.getLoc(), inputIndex.getType(), offset); + rewriter.replaceOpWithNewOp(op, op.getType(), + inputSlice.getInput(), lowIndex); + return success(); + } + + if (auto inputCreate = dyn_cast_or_null(inputOp)) { + // slice(create(a0, a1, ..., an), m) -> create(am, ...) + auto inputs = inputCreate.getInputs(); + + uint64_t begin = inputs.size() - *offsetOpt - sliceSize; + rewriter.replaceOpWithNewOp(op, op.getType(), + inputs.slice(begin, sliceSize)); + return success(); + } + + if (auto inputConcat = dyn_cast_or_null(inputOp)) { + // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...) + SmallVector chunks; + uint64_t sliceStart = *offsetOpt; + for (auto input : llvm::reverse(inputConcat.getInputs())) { + // Check whether the input intersects with the slice. + uint64_t inputSize = + hw::type_cast(input.getType()).getNumElements(); + if (inputSize == 0 || inputSize <= sliceStart) { + sliceStart -= inputSize; + continue; + } + + // Find the indices to slice from this input by intersection. + uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize); + uint64_t cutSize = cutEnd - sliceStart; + assert(cutSize != 0 && "slice cannot be empty"); + + if (cutSize == inputSize) { + // The whole input fits in the slice, add it. + assert(sliceStart == 0 && "invalid cut size"); + chunks.push_back(input); + } else { + // Slice the required bits from the input. + unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize); + auto lowIndex = rewriter.create( + op.getLoc(), rewriter.getIntegerType(width), sliceStart); + chunks.push_back(rewriter.create( + op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex)); + } + + sliceStart = 0; + sliceSize -= cutSize; + if (sliceSize == 0) break; + } + + assert(chunks.size() > 0 && "missing sliced items"); + if (chunks.size() == 1) + rewriter.replaceOp(op, chunks[0]); + else + rewriter.replaceOpWithNewOp( + op, llvm::to_vector(llvm::reverse(chunks))); + return success(); + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// ArrayConcatOp +//===----------------------------------------------------------------------===// + +static ParseResult parseArrayConcatTypes(OpAsmParser &p, + SmallVectorImpl &inputTypes, + Type &resultType) { + Type elemType; + uint64_t resultSize = 0; + + auto parseElement = [&]() -> ParseResult { + Type ty; + if (p.parseType(ty)) return failure(); + auto arrTy = type_dyn_cast(ty); + if (!arrTy) + return p.emitError(p.getCurrentLocation(), "Expected !hw.array type"); + if (elemType && elemType != arrTy.getElementType()) + return p.emitError(p.getCurrentLocation(), "Expected array element type ") + << elemType; + + elemType = arrTy.getElementType(); + inputTypes.push_back(ty); + resultSize += arrTy.getNumElements(); + return success(); + }; + + if (p.parseCommaSeparatedList(parseElement)) return failure(); + + resultType = ArrayType::get(elemType, resultSize); + return success(); +} + +static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, + TypeRange inputTypes, Type resultType) { + llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; }); +} + +void ArrayConcatOp::build(OpBuilder &b, OperationState &state, + ValueRange values) { + assert(!values.empty() && "Cannot build array of zero elements"); + ArrayType arrayTy = values[0].getType().cast(); + Type elemTy = arrayTy.getElementType(); + assert(llvm::all_of(values, + [elemTy](Value v) -> bool { + return v.getType().isa() && + v.getType().cast().getElementType() == + elemTy; + }) && + "All values must be of ArrayType with the same element type."); + + uint64_t resultSize = 0; + for (Value val : values) + resultSize += val.getType().cast().getNumElements(); + build(b, state, ArrayType::get(elemTy, resultSize), values); +} + +OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) { + auto inputs = adaptor.getInputs(); + SmallVector array; + for (size_t i = 0, e = getNumOperands(); i < e; ++i) { + if (!inputs[i]) return {}; + llvm::copy(inputs[i].cast(), std::back_inserter(array)); + } + return ArrayAttr::get(getContext(), array); +} + +// Flatten a concatenation of array creates into a single create. +static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) { + for (auto input : op.getInputs()) + if (!input.getDefiningOp()) return false; + + SmallVector items; + for (auto input : op.getInputs()) { + auto create = cast(input.getDefiningOp()); + for (auto item : create.getInputs()) items.push_back(item); + } + + rewriter.replaceOpWithNewOp(op, items); + return true; +} + +// Merge consecutive slice expressions in a concatenation. +static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) { + struct Slice { + Value input; + Value index; + size_t size; + Value op; + SmallVector locs; + }; + + SmallVector items; + std::optional last; + bool changed = false; + + auto concatenate = [&] { + // If there is only one op in the slice, place it to the items list. + if (!last) return; + if (last->op) { + items.push_back(last->op); + last.reset(); + return; + } + + // Otherwise, create a new slice of with the given size and place it. + // In this case, the concat op is replaced, using the new argument. + changed = true; + auto loc = FusedLoc::get(op.getContext(), last->locs); + auto origTy = hw::type_cast(last->input.getType()); + auto arrayTy = ArrayType::get(origTy.getElementType(), last->size); + items.push_back(rewriter.createOrFold( + loc, arrayTy, last->input, last->index)); + + last.reset(); + }; + + auto append = [&](Value op, Value input, Value index, size_t size) { + // If this slice is an extension of the previous one, extend the size + // saved. In this case, a new slice of is created and the concatenation + // operator is rewritten. Otherwise, flush the last slice. + if (last) { + if (last->input == input && isOffset(last->index, index, last->size)) { + last->size += size; + last->op = {}; + last->locs.push_back(op.getLoc()); + return; + } + concatenate(); + } + last.emplace(Slice{input, index, size, op, {op.getLoc()}}); + }; + + for (auto item : llvm::reverse(op.getInputs())) { + if (auto slice = item.getDefiningOp()) { + auto size = hw::type_cast(slice.getType()).getNumElements(); + append(item, slice.getInput(), slice.getLowIndex(), size); + continue; + } + + if (auto create = item.getDefiningOp()) { + if (create.getInputs().size() == 1) { + if (auto get = create.getInputs()[0].getDefiningOp()) { + append(item, get.getInput(), get.getIndex(), 1); + continue; + } + } + } + + concatenate(); + items.push_back(item); + } + concatenate(); + + if (!changed) return false; + + if (items.size() == 1) { + rewriter.replaceOp(op, items[0]); + } else { + std::reverse(items.begin(), items.end()); + rewriter.replaceOpWithNewOp(op, items); + } + return true; +} + +LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op, + PatternRewriter &rewriter) { + // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...) + if (flattenConcatOp(op, rewriter)) return success(); + + // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p)) + if (mergeConcatSlices(op, rewriter)) return success(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// EnumConstantOp +//===----------------------------------------------------------------------===// + +ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) { + // Parse a Type instead of an EnumType since the type might be a type alias. + // The validity of the canonical type is checked during construction of the + // EnumFieldAttr. + Type type; + StringRef field; + + auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + if (parser.parseKeyword(&field) || parser.parseColonType(type)) + return failure(); + + auto fieldAttr = EnumFieldAttr::get( + loc, StringAttr::get(parser.getContext(), field), type); + + if (!fieldAttr) return failure(); + + result.addAttribute("field", fieldAttr); + result.addTypes(type); + + return success(); +} + +void EnumConstantOp::print(OpAsmPrinter &p) { + p << " " << getField().getField().getValue() << " : " + << getField().getType().getValue(); +} + +void EnumConstantOp::getAsmResultNames( + function_ref setNameFn) { + setNameFn(getResult(), getField().getField().str()); +} + +void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState, + EnumFieldAttr field) { + return build(builder, odsState, field.getType().getValue(), field); +} + +OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) { + assert(adaptor.getOperands().empty() && "constant has no operands"); + return getFieldAttr(); +} + +LogicalResult EnumConstantOp::verify() { + auto fieldAttr = getFieldAttr(); + auto fieldType = fieldAttr.getType().getValue(); + // This check ensures that we are using the exact same type, without looking + // through type aliases. + if (fieldType != getType()) + emitOpError("return type ") + << getType() << " does not match attribute type " << fieldAttr; + return success(); +} + +//===----------------------------------------------------------------------===// +// EnumCmpOp +//===----------------------------------------------------------------------===// + +LogicalResult EnumCmpOp::verify() { + // Compare the canonical types. + auto lhsType = type_cast(getLhs().getType()); + auto rhsType = type_cast(getRhs().getType()); + if (rhsType != lhsType) emitOpError("types do not match"); + return success(); +} + +//===----------------------------------------------------------------------===// +// StructCreateOp +//===----------------------------------------------------------------------===// + +ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); + llvm::SmallVector operands; + Type declOrAliasType; + + if (parser.parseLParen() || parser.parseOperandList(operands) || + parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(declOrAliasType)) + return failure(); + + auto declType = type_dyn_cast(declOrAliasType); + if (!declType) + return parser.emitError(parser.getNameLoc(), + "expected !hw.struct type or alias"); + + llvm::SmallVector structInnerTypes; + declType.getInnerTypes(structInnerTypes); + result.addTypes(declOrAliasType); + + if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc, + result.operands)) + return failure(); + return success(); +} + +void StructCreateOp::print(OpAsmPrinter &printer) { + printer << " ("; + printer.printOperands(getInput()); + printer << ")"; + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getType(); +} + +LogicalResult StructCreateOp::verify() { + auto elements = hw::type_cast(getType()).getElements(); + + if (elements.size() != getInput().size()) + return emitOpError("structure field count mismatch"); + + for (const auto &[field, value] : llvm::zip(elements, getInput())) + if (field.type != value.getType()) + return emitOpError("structure field `") + << field.name << "` type does not match"; + + return success(); +} + +OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) { + // struct_create(struct_explode(x)) => x + if (!getInput().empty()) + if (auto explodeOp = getInput()[0].getDefiningOp(); + explodeOp && getInput() == explodeOp.getResults() && + getResult().getType() == explodeOp.getInput().getType()) + return explodeOp.getInput(); + + auto inputs = adaptor.getInput(); + if (llvm::any_of(inputs, [](Attribute attr) { return !attr; })) return {}; + return ArrayAttr::get(getContext(), inputs); +} + +//===----------------------------------------------------------------------===// +// StructExplodeOp +//===----------------------------------------------------------------------===// + +ParseResult StructExplodeOp::parse(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + Type declType; + + if (parser.parseOperand(operand) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(declType)) + return failure(); + auto structType = type_dyn_cast(declType); + if (!structType) + return parser.emitError(parser.getNameLoc(), + "invalid kind of type specified"); + + llvm::SmallVector structInnerTypes; + structType.getInnerTypes(structInnerTypes); + result.addTypes(structInnerTypes); + + if (parser.resolveOperand(operand, declType, result.operands)) + return failure(); + return success(); +} + +void StructExplodeOp::print(OpAsmPrinter &printer) { + printer << " "; + printer.printOperand(getInput()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : " << getInput().getType(); +} + +LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + auto input = adaptor.getInput(); + if (!input) return failure(); + llvm::copy(input.cast(), std::back_inserter(results)); + return success(); +} + +LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op, + PatternRewriter &rewriter) { + auto *inputOp = op.getInput().getDefiningOp(); + auto elements = type_cast(op.getInput().getType()).getElements(); + auto result = failure(); + for (auto [element, res] : llvm::zip(elements, op.getResults())) { + if (auto foldResult = foldStructExtract(inputOp, element.name.str())) { + rewriter.replaceAllUsesWith(res, foldResult); + result = success(); + } + } + return result; +} + +void StructExplodeOp::getAsmResultNames( + function_ref setNameFn) { + auto structType = type_cast(getInput().getType()); + for (auto [res, field] : llvm::zip(getResults(), structType.getElements())) + setNameFn(res, field.name.str()); +} + +void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + Value input) { + StructType inputType = input.getType().dyn_cast(); + assert(inputType); + SmallVector fieldTypes; + for (auto field : inputType.getElements()) fieldTypes.push_back(field.type); + build(odsBuilder, odsState, fieldTypes, input); +} + +//===----------------------------------------------------------------------===// +// StructExtractOp +//===----------------------------------------------------------------------===// + +/// Use the same parser for both struct_extract and union_extract since the +/// syntax is identical. +template +static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::UnresolvedOperand operand; + StringAttr fieldName; + Type declType; + + if (parser.parseOperand(operand) || parser.parseLSquare() || + parser.parseAttribute(fieldName, "field", result.attributes) || + parser.parseRSquare() || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(declType)) + return failure(); + auto aggType = type_dyn_cast(declType); + if (!aggType) + return parser.emitError(parser.getNameLoc(), + "invalid kind of type specified"); + + Type resultType = aggType.getFieldType(fieldName.getValue()); + if (!resultType) { + parser.emitError(parser.getNameLoc(), "invalid field name specified"); + return failure(); + } + result.addTypes(resultType); + + if (parser.resolveOperand(operand, declType, result.operands)) + return failure(); + return success(); +} + +/// Use the same printer for both struct_extract and union_extract since the +/// syntax is identical. +template +static void printExtractOp(OpAsmPrinter &printer, AggType op) { + printer << " "; + printer.printOperand(op.getInput()); + printer << "[\"" << op.getField() << "\"]"; + printer.printOptionalAttrDict(op->getAttrs(), {"field"}); + printer << " : " << op.getInput().getType(); +} + +ParseResult StructExtractOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseExtractOp(parser, result); +} + +void StructExtractOp::print(OpAsmPrinter &printer) { + printExtractOp(printer, *this); +} + +void StructExtractOp::build(OpBuilder &builder, OperationState &odsState, + Value input, StructType::FieldInfo field) { + build(builder, odsState, field.type, input, field.name); +} + +void StructExtractOp::build(OpBuilder &builder, OperationState &odsState, + Value input, StringAttr fieldAttr) { + auto structType = type_cast(input.getType()); + auto resultType = structType.getFieldType(fieldAttr); + build(builder, odsState, resultType, input, fieldAttr); +} + +OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) { + if (auto constOperand = adaptor.getInput()) { + // Fold extract from aggregate constant + auto operandType = type_cast(getOperand().getType()); + auto fieldIdx = operandType.getFieldIndex(getField()); + auto operandAttr = llvm::cast(constOperand); + return operandAttr.getValue()[*fieldIdx]; + } + + if (auto foldResult = + foldStructExtract(getInput().getDefiningOp(), getField())) + return foldResult; + return {}; +} + +LogicalResult StructExtractOp::canonicalize(StructExtractOp op, + PatternRewriter &rewriter) { + auto inputOp = op.getInput().getDefiningOp(); + + // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b") + if (auto structInject = dyn_cast_or_null(inputOp)) { + if (structInject.getField() != op.getField()) { + rewriter.replaceOpWithNewOp( + op, op.getType(), structInject.getInput(), op.getField()); + return success(); + } + } + + return failure(); +} + +void StructExtractOp::getAsmResultNames( + function_ref setNameFn) { + auto structType = type_cast(getInput().getType()); + for (auto field : structType.getElements()) { + if (field.name == getField()) { + setNameFn(getResult(), field.name.str()); + return; + } + } +} + +//===----------------------------------------------------------------------===// +// StructInjectOp +//===----------------------------------------------------------------------===// + +ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) { + llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); + OpAsmParser::UnresolvedOperand operand, val; + StringAttr fieldName; + Type declType; + + if (parser.parseOperand(operand) || parser.parseLSquare() || + parser.parseAttribute(fieldName, "field", result.attributes) || + parser.parseRSquare() || parser.parseComma() || + parser.parseOperand(val) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(declType)) + return failure(); + auto structType = type_dyn_cast(declType); + if (!structType) + return parser.emitError(inputOperandsLoc, "invalid kind of type specified"); + + Type resultType = structType.getFieldType(fieldName.getValue()); + if (!resultType) { + parser.emitError(inputOperandsLoc, "invalid field name specified"); + return failure(); + } + result.addTypes(declType); + + if (parser.resolveOperands({operand, val}, {declType, resultType}, + inputOperandsLoc, result.operands)) + return failure(); + return success(); +} + +void StructInjectOp::print(OpAsmPrinter &printer) { + printer << " "; + printer.printOperand(getInput()); + printer << "[\"" << getField() << "\"], "; + printer.printOperand(getNewValue()); + printer.printOptionalAttrDict((*this)->getAttrs(), {"field"}); + printer << " : " << getInput().getType(); +} + +OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) { + auto input = adaptor.getInput(); + auto newValue = adaptor.getNewValue(); + if (!input || !newValue) return {}; + SmallVector array; + llvm::copy(input.cast(), std::back_inserter(array)); + StructType structType = getInput().getType(); + auto index = *structType.getFieldIndex(getField()); + array[index] = newValue; + return ArrayAttr::get(getContext(), array); +} + +LogicalResult StructInjectOp::canonicalize(StructInjectOp op, + PatternRewriter &rewriter) { + // Canonicalize multiple injects into a create op and eliminate overwrites. + SmallPtrSet injects; + DenseMap fields; + + // Chase a chain of injects. Bail out if cycles are present. + StructInjectOp inject = op; + Value input; + do { + if (!injects.insert(inject).second) return failure(); + + fields.try_emplace(inject.getFieldAttr(), inject.getNewValue()); + input = inject.getInput(); + inject = dyn_cast_or_null(input.getDefiningOp()); + } while (inject); + assert(input && "missing input to inject chain"); + + auto ty = hw::type_cast(op.getType()); + auto elements = ty.getElements(); + + // If the inject chain sets all fields, canonicalize to create. + if (fields.size() == elements.size()) { + SmallVector createFields; + for (const auto &field : elements) { + auto it = fields.find(field.name); + assert(it != fields.end() && "missing field"); + createFields.push_back(it->second); + } + rewriter.replaceOpWithNewOp(op, ty, createFields); + return success(); + } + + // Nothing to canonicalize, only the original inject in the chain. + if (injects.size() == fields.size()) return failure(); + + // Eliminate overwrites. The hash map contains the last write to each field. + for (const auto &field : elements) { + auto it = fields.find(field.name); + if (it == fields.end()) continue; + input = rewriter.create(op.getLoc(), ty, input, field.name, + it->second); + } + + rewriter.replaceOp(op, input); + return success(); +} + +//===----------------------------------------------------------------------===// +// UnionCreateOp +//===----------------------------------------------------------------------===// + +ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) { + Type declOrAliasType; + StringAttr field; + OpAsmParser::UnresolvedOperand input; + llvm::SMLoc fieldLoc = parser.getCurrentLocation(); + + if (parser.parseAttribute(field, "field", result.attributes) || + parser.parseComma() || parser.parseOperand(input) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(declOrAliasType)) + return failure(); + + auto declType = type_dyn_cast(declOrAliasType); + if (!declType) + return parser.emitError(parser.getNameLoc(), + "expected !hw.union type or alias"); + + Type inputType = declType.getFieldType(field.getValue()); + if (!inputType) { + parser.emitError(fieldLoc, "cannot find union field '") + << field.getValue() << '\''; + return failure(); + } + + if (parser.resolveOperand(input, inputType, result.operands)) + return failure(); + result.addTypes({declOrAliasType}); + return success(); +} + +void UnionCreateOp::print(OpAsmPrinter &printer) { + printer << " \"" << getField() << "\", "; + printer.printOperand(getInput()); + printer.printOptionalAttrDict((*this)->getAttrs(), {"field"}); + printer << " : " << getType(); +} + +//===----------------------------------------------------------------------===// +// UnionExtractOp +//===----------------------------------------------------------------------===// + +ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) { + return parseExtractOp(parser, result); +} + +void UnionExtractOp::print(OpAsmPrinter &printer) { + printExtractOp(printer, *this); +} + +LogicalResult UnionExtractOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attrs, mlir::OpaqueProperties properties, + mlir::RegionRange regions, SmallVectorImpl &results) { + results.push_back(cast(getCanonicalType(operands[0].getType())) + .getFieldType(attrs.getAs("field"))); + return success(); +} + +//===----------------------------------------------------------------------===// +// ArrayGetOp +//===----------------------------------------------------------------------===// + +// An array_get of an array_create with a constant index can just be the +// array_create operand at the constant index. If the array_create has a +// single uniform value for each element, just return that value regardless of +// the index. If the array is constructed from a constant by a bitcast +// operation, we can fold into a constant. +OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) { + auto inputCst = adaptor.getInput().dyn_cast_or_null(); + auto indexCst = adaptor.getIndex().dyn_cast_or_null(); + + if (inputCst) { + // Constant array index. + if (indexCst) { + auto indexVal = indexCst.getValue(); + if (indexVal.getBitWidth() < 64) { + auto index = indexVal.getZExtValue(); + return inputCst[inputCst.size() - 1 - index]; + } + } + // If all elements of the array are the same, we can return any element of + // array. + if (!inputCst.empty() && llvm::all_equal(inputCst)) return inputCst[0]; + } + + // array_get(bitcast(c), i) -> c[i*w+w-1:i*w] + if (auto bitcast = getInput().getDefiningOp()) { + auto intTy = getType().dyn_cast(); + if (!intTy) return {}; + auto bitcastInputOp = bitcast.getInput().getDefiningOp(); + if (!bitcastInputOp) return {}; + if (!indexCst) return {}; + auto bitcastInputCst = bitcastInputOp.getValue(); + // Calculate the index. Make sure to zero-extend the index value before + // multiplying the element width. + auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) * + getType().getIntOrFloatBitWidth(); + // Extract [startIdx + width - 1: startIdx]. + return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc( + intTy.getIntOrFloatBitWidth())); + } + + auto inputCreate = getInput().getDefiningOp(); + if (!inputCreate) return {}; + + if (auto uniformValue = inputCreate.getUniformElement()) return uniformValue; + + if (!indexCst || indexCst.getValue().getBitWidth() > 64) return {}; + + uint64_t index = indexCst.getValue().getLimitedValue(); + auto createInputs = inputCreate.getInputs(); + if (index >= createInputs.size()) return {}; + return createInputs[createInputs.size() - index - 1]; +} + +LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op, + PatternRewriter &rewriter) { + auto idxOpt = getUIntFromValue(op.getIndex()); + if (!idxOpt) return failure(); + + auto *inputOp = op.getInput().getDefiningOp(); + if (auto inputSlice = dyn_cast_or_null(inputOp)) { + // get(slice(a, n), m) -> get(a, n + m) + auto offsetOp = inputSlice.getLowIndex(); + auto offsetOpt = getUIntFromValue(offsetOp); + if (!offsetOpt) return failure(); + + uint64_t offset = *offsetOpt + *idxOpt; + auto newOffset = + rewriter.create(op.getLoc(), offsetOp.getType(), offset); + rewriter.replaceOpWithNewOp(op, inputSlice.getInput(), + newOffset); + return success(); + } + + if (auto inputConcat = dyn_cast_or_null(inputOp)) { + // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...) + uint64_t elemIndex = *idxOpt; + for (auto input : llvm::reverse(inputConcat.getInputs())) { + size_t size = hw::type_cast(input.getType()).getNumElements(); + if (elemIndex >= size) { + elemIndex -= size; + continue; + } + + unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size); + auto newIdxOp = rewriter.create( + op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex); + + rewriter.replaceOpWithNewOp(op, input, newIdxOp); + return success(); + } + return failure(); + } + + // array_get const, (array_get sel, (array_create a, b, c, d)) --> + // array_get sel, (array_create (array_get const a), (array_get const b), + // (array_get const, c), (array_get const, d)) + if (auto innerGet = dyn_cast_or_null(inputOp)) { + if (!innerGet.getIndex().getDefiningOp()) { + if (auto create = + innerGet.getInput().getDefiningOp()) { + SmallVector newValues; + for (auto operand : create.getOperands()) + newValues.push_back(rewriter.createOrFold( + op.getLoc(), operand, op.getIndex())); + + rewriter.replaceOpWithNewOp( + op, + rewriter.createOrFold(op.getLoc(), newValues), + innerGet.getIndex()); + return success(); + } + } + } + + return failure(); +} + +//===----------------------------------------------------------------------===// +// TypedeclOp +//===----------------------------------------------------------------------===// + +StringRef TypedeclOp::getPreferredName() { + return getVerilogName().value_or(getName()); +} + +Type TypedeclOp::getAliasType() { + auto parentScope = cast(getOperation()->getParentOp()); + return hw::TypeAliasType::get( + SymbolRefAttr::get(parentScope.getSymNameAttr(), + {FlatSymbolRefAttr::get(*this)}), + getType()); +} + +//===----------------------------------------------------------------------===// +// BitcastOp +//===----------------------------------------------------------------------===// + +OpFoldResult BitcastOp::fold(FoldAdaptor) { + // Identity. + // bitcast(%a) : A -> A ==> %a + if (getOperand().getType() == getType()) return getOperand(); + + return {}; +} + +LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) { + // Composition. + // %b = bitcast(%a) : A -> B + // bitcast(%b) : B -> C + // ===> bitcast(%a) : A -> C + auto inputBitcast = + dyn_cast_or_null(op.getInput().getDefiningOp()); + if (!inputBitcast) return failure(); + auto bitcast = rewriter.createOrFold(op.getLoc(), op.getType(), + inputBitcast.getInput()); + rewriter.replaceOp(op, bitcast); + return success(); +} + +LogicalResult BitcastOp::verify() { + if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType())) + return this->emitOpError("Bitwidth of input must match result"); + return success(); +} + +//===----------------------------------------------------------------------===// +// HierPathOp helpers. +//===----------------------------------------------------------------------===// + +bool HierPathOp::dropModule(StringAttr moduleToDrop) { + SmallVector newPath; + bool updateMade = false; + for (auto nameRef : getNamepath()) { + // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. + if (auto ref = nameRef.dyn_cast()) { + if (ref.getModule() == moduleToDrop) + updateMade = true; + else + newPath.push_back(ref); + } else { + if (nameRef.cast().getAttr() == moduleToDrop) + updateMade = true; + else + newPath.push_back(nameRef); + } + } + if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); + return updateMade; +} + +bool HierPathOp::inlineModule(StringAttr moduleToDrop) { + SmallVector newPath; + bool updateMade = false; + StringRef inlinedInstanceName = ""; + for (auto nameRef : getNamepath()) { + // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. + if (auto ref = nameRef.dyn_cast()) { + if (ref.getModule() == moduleToDrop) { + inlinedInstanceName = ref.getName().getValue(); + updateMade = true; + } else if (!inlinedInstanceName.empty()) { + newPath.push_back(hw::InnerRefAttr::get( + ref.getModule(), + StringAttr::get(getContext(), inlinedInstanceName + "_" + + ref.getName().getValue()))); + inlinedInstanceName = ""; + } else + newPath.push_back(ref); + } else { + if (nameRef.cast().getAttr() == moduleToDrop) + updateMade = true; + else + newPath.push_back(nameRef); + } + } + if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); + return updateMade; +} + +bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) { + SmallVector newPath; + bool updateMade = false; + for (auto nameRef : getNamepath()) { + // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. + if (auto ref = nameRef.dyn_cast()) { + if (ref.getModule() == oldMod) { + newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName())); + updateMade = true; + } else + newPath.push_back(ref); + } else { + if (nameRef.cast().getAttr() == oldMod) { + newPath.push_back(FlatSymbolRefAttr::get(newMod)); + updateMade = true; + } else + newPath.push_back(nameRef); + } + } + if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); + return updateMade; +} + +bool HierPathOp::updateModuleAndInnerRef( + StringAttr oldMod, StringAttr newMod, + const llvm::DenseMap &innerSymRenameMap) { + auto fromRef = FlatSymbolRefAttr::get(oldMod); + if (oldMod == newMod) return false; + + auto namepathNew = getNamepath().getValue().vec(); + bool updateMade = false; + // Break from the loop if the module is found, since it can occur only once. + for (auto &element : namepathNew) { + if (auto innerRef = element.dyn_cast()) { + if (innerRef.getModule() != oldMod) continue; + auto symName = innerRef.getName(); + // Since the module got updated, the old innerRef symbol inside oldMod + // should also be updated to the new symbol inside the newMod. + auto to = innerSymRenameMap.find(symName); + if (to != innerSymRenameMap.end()) symName = to->second; + updateMade = true; + element = hw::InnerRefAttr::get(newMod, symName); + break; + } + if (element != fromRef) continue; + + updateMade = true; + element = FlatSymbolRefAttr::get(newMod); + break; + } + if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), namepathNew)); + return updateMade; +} + +bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) { + SmallVector newPath; + bool updateMade = false; + for (auto nameRef : getNamepath()) { + // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. + if (auto ref = nameRef.dyn_cast()) { + if (ref.getModule() == atMod) { + updateMade = true; + if (includeMod) newPath.push_back(ref); + } else + newPath.push_back(ref); + } else { + if (nameRef.cast().getAttr() == atMod && !includeMod) + updateMade = true; + else + newPath.push_back(nameRef); + } + if (updateMade) break; + } + if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); + return updateMade; +} + +/// Return just the module part of the namepath at a specific index. +StringAttr HierPathOp::modPart(unsigned i) { + return TypeSwitch(getNamepath()[i]) + .Case([](auto a) { return a.getAttr(); }) + .Case([](auto a) { return a.getModule(); }); +} + +/// Return the root module. +StringAttr HierPathOp::root() { + assert(!getNamepath().empty()); + return modPart(0); +} + +/// Return true if the NLA has the module in its path. +bool HierPathOp::hasModule(StringAttr modName) { + for (auto nameRef : getNamepath()) { + // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. + if (auto ref = nameRef.dyn_cast()) { + if (ref.getModule() == modName) return true; + } else { + if (nameRef.cast().getAttr() == modName) return true; + } + } + return false; +} + +/// Return true if the NLA has the InnerSym . +bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const { + for (auto nameRef : const_cast(this)->getNamepath()) + if (auto ref = nameRef.dyn_cast()) + if (ref.getName() == symName && ref.getModule() == modName) return true; + + return false; +} + +/// Return just the reference part of the namepath at a specific index. This +/// will return an empty attribute if this is the leaf and the leaf is a module. +StringAttr HierPathOp::refPart(unsigned i) { + return TypeSwitch(getNamepath()[i]) + .Case([](auto a) { return StringAttr({}); }) + .Case([](auto a) { return a.getName(); }); +} + +/// Return the leaf reference. This returns an empty attribute if the leaf +/// reference is a module. +StringAttr HierPathOp::ref() { + assert(!getNamepath().empty()); + return refPart(getNamepath().size() - 1); +} + +/// Return the leaf module. +StringAttr HierPathOp::leafMod() { + assert(!getNamepath().empty()); + return modPart(getNamepath().size() - 1); +} + +/// Returns true if this NLA targets an instance of a module (as opposed to +/// an instance's port or something inside an instance). +bool HierPathOp::isModule() { return !ref(); } + +/// Returns true if this NLA targets something inside a module (as opposed +/// to a module or an instance of a module); +bool HierPathOp::isComponent() { return (bool)ref(); } + +// Verify the HierPathOp. +// 1. Iterate over the namepath. +// 2. The namepath should be a valid instance path, specified either on a +// module or a declaration inside a module. +// 3. Each element in the namepath is an InnerRefAttr except possibly the +// last element. +// 4. Make sure that the InnerRefAttr is legal, by verifying the module name +// and the corresponding inner_sym on the instance. +// 5. Make sure that the instance path is legal, by verifying the sequence of +// instance and the expected module occurs as the next element in the path. +// 6. The last element of the namepath, can be an InnerRefAttr on either a +// module port or a declaration inside the module. +// 7. The last element of the namepath can also be a module symbol. +LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) { + StringAttr expectedModuleName = {}; + if (!getNamepath() || getNamepath().empty()) + return emitOpError() << "the instance path cannot be empty"; + for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) { + hw::InnerRefAttr innerRef = getNamepath()[i].dyn_cast(); + if (!innerRef) + return emitOpError() + << "the instance path can only contain inner sym reference" + << ", only the leaf can refer to a module symbol"; + + if (expectedModuleName && expectedModuleName != innerRef.getModule()) + return emitOpError() << "instance path is incorrect. Expected module: " + << expectedModuleName + << " instead found: " << innerRef.getModule(); + auto instOp = ns.lookupOp(innerRef); + if (!instOp) + return emitOpError() << " module: " << innerRef.getModule() + << " does not contain any instance with symbol: " + << innerRef.getName(); + expectedModuleName = instOp.getReferencedModuleNameAttr(); + } + // The instance path has been verified. Now verify the last element. + auto leafRef = getNamepath()[getNamepath().size() - 1]; + if (auto innerRef = leafRef.dyn_cast()) { + if (!ns.lookup(innerRef)) { + return emitOpError() << " operation with symbol: " << innerRef + << " was not found "; + } + if (expectedModuleName && expectedModuleName != innerRef.getModule()) + return emitOpError() << "instance path is incorrect. Expected module: " + << expectedModuleName + << " instead found: " << innerRef.getModule(); + } else if (expectedModuleName && + expectedModuleName != + leafRef.cast().getAttr()) { + // This is the case when the nla is applied to a module. + return emitOpError() << "instance path is incorrect. Expected module: " + << expectedModuleName << " instead found: " + << leafRef.cast().getAttr(); + } + return success(); +} + +void HierPathOp::print(OpAsmPrinter &p) { + p << " "; + + // Print visibility if present. + StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); + if (auto visibility = + getOperation()->getAttrOfType(visibilityAttrName)) + p << visibility.getValue() << ' '; + + p.printSymbolName(getSymName()); + p << " ["; + llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) { + if (auto ref = attr.dyn_cast()) { + p.printSymbolName(ref.getModule().getValue()); + p << "::"; + p.printSymbolName(ref.getName().getValue()); + } else { + p.printSymbolName(attr.cast().getValue()); + } + }); + p << "]"; + p.printOptionalAttrDict( + (*this)->getAttrs(), + {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName}); +} + +ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) { + // Parse the visibility attribute. + (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes); + + // Parse the symbol name. + StringAttr symName; + if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the namepath. + SmallVector namepath; + if (parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Square, [&]() -> ParseResult { + auto loc = parser.getCurrentLocation(); + SymbolRefAttr ref; + if (parser.parseAttribute(ref)) return failure(); + + // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error. + auto pathLength = ref.getNestedReferences().size(); + if (pathLength == 0) + namepath.push_back( + FlatSymbolRefAttr::get(ref.getRootReference())); + else if (pathLength == 1) + namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(), + ref.getLeafReference())); + else + return parser.emitError(loc, + "only one nested reference is allowed"); + return success(); + })) + return failure(); + result.addAttribute("namepath", + ArrayAttr::get(parser.getContext(), namepath)); + + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TriggeredOp +//===----------------------------------------------------------------------===// + +void TriggeredOp::build(OpBuilder &builder, OperationState &odsState, + EventControlAttr event, Value trigger, + ValueRange inputs) { + odsState.addOperands(trigger); + odsState.addOperands(inputs); + odsState.addAttribute(getEventAttrName(odsState.name), event); + auto *r = odsState.addRegion(); + Block *b = new Block(); + r->push_back(b); + + llvm::SmallVector argLocs; + llvm::transform(inputs, std::back_inserter(argLocs), + [&](Value v) { return v.getLoc(); }); + b->addArguments(inputs.getTypes(), argLocs); +} + +//===----------------------------------------------------------------------===// +// Temporary test module +//===----------------------------------------------------------------------===// + +static void addPortAttrsAndLocs( + Builder &builder, OperationState &result, + SmallVectorImpl &ports, + StringAttr portAttrsName, StringAttr portLocsName) { + auto unknownLoc = builder.getUnknownLoc(); + auto nonEmptyAttrsFn = [](Attribute attr) { + return attr && !cast(attr).empty(); + }; + auto nonEmptyLocsFn = [unknownLoc](Attribute attr) { + return attr && cast(attr) != unknownLoc; + }; + + // Convert the specified array of dictionary attrs (which may have null + // entries) to an ArrayAttr of dictionaries. + SmallVector attrs; + SmallVector locs; + for (auto &port : ports) { + attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({})); + locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc); + } + + // Add the attributes to the ports. + if (llvm::any_of(attrs, nonEmptyAttrsFn)) + result.addAttribute(portAttrsName, builder.getArrayAttr(attrs)); + + if (llvm::any_of(locs, nonEmptyLocsFn)) + result.addAttribute(portLocsName, builder.getArrayAttr(locs)); +} + +void HWTestModuleOp::print(OpAsmPrinter &p) { + p << ' '; + // Print the visibility of the module. + StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); + if (auto visibility = (*this)->getAttrOfType(visibilityAttrName)) + p << visibility.getValue() << ' '; + + // Print the operation and the function name. + p.printSymbolName(SymbolTable::getSymbolName(*this).getValue()); + + // Print the parameter list if present. + printOptionalParameterList(p, *this, getParameters()); + + module_like_impl::printModuleSignatureNew(p, *this); + SmallVector omittedAttrs; + omittedAttrs.push_back(getPortLocsAttrName()); + omittedAttrs.push_back(getModuleTypeAttrName()); + omittedAttrs.push_back(getPortAttrsAttrName()); + omittedAttrs.push_back(getParametersAttrName()); + omittedAttrs.push_back(visibilityAttrName); + if (auto cmt = (*this)->getAttrOfType("comment")) + if (cmt.getValue().empty()) omittedAttrs.push_back("comment"); + + mlir::function_interface_impl::printFunctionAttributes(p, *this, + omittedAttrs); + + // Print the body if this is not an external function. + Region &body = getBody(); + if (!body.empty()) { + p << " "; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +ParseResult HWTestModuleOp::parse(OpAsmParser &parser, OperationState &result) { + auto loc = parser.getCurrentLocation(); + + // Parse the visibility attribute. + (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the parameters. + ArrayAttr parameters; + if (parseOptionalParameterList(parser, parameters)) return failure(); + + SmallVector ports; + TypeAttr modType; + if (failed(module_like_impl::parseModuleSignature(parser, ports, modType))) + return failure(); + + // Parse the attribute dict. + if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) + return failure(); + + if (hasAttribute("parameters", result.attributes)) { + parser.emitError(loc, "explicit `parameters` attributes not allowed"); + return failure(); + } + + result.addAttribute("parameters", parameters); + result.addAttribute(getModuleTypeAttrName(result.name), modType); + addPortAttrsAndLocs(parser.getBuilder(), result, ports, + getPortAttrsAttrName(result.name), + getPortLocsAttrName(result.name)); + + SmallVector entryArgs; + for (auto &port : ports) + if (port.direction != ModulePort::Direction::Output) + entryArgs.push_back(port); + + // Parse the optional function body. + auto *body = result.addRegion(); + if (parser.parseRegion(*body, entryArgs)) return failure(); + + HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); + + return success(); +} + +void HWTestModuleOp::getAsmBlockArgumentNames( + mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) { + if (region.empty()) return; + // Assign port names to the bbargs. + auto *block = ®ion.front(); + auto mt = getModuleType(); + for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) { + auto name = mt.getInputName(i); + if (!name.empty()) setNameFn(block->getArgument(i), name); + } +} + +ModulePortInfo HWTestModuleOp::getPortList() { + SmallVector ports; + auto refPorts = getModuleType().getPorts(); + for (auto [i, port] : enumerate(refPorts)) { + auto loc = getPortLocs() ? cast((*getPortLocs())[i]) + : LocationAttr(); + auto attr = getPortAttrs() ? cast((*getPortAttrs())[i]) + : DictionaryAttr(); + InnerSymAttr sym = {}; + ports.push_back({{port}, i, sym, attr, loc}); + } + return ModulePortInfo(ports); +} + +size_t HWTestModuleOp::getNumPorts() { return getModuleType().getNumPorts(); } +size_t HWTestModuleOp::getNumInputPorts() { + return getModuleType().getNumInputs(); +} +size_t HWTestModuleOp::getNumOutputPorts() { + return getModuleType().getNumOutputs(); +} + +size_t HWTestModuleOp::getPortIdForInputId(size_t idx) { + return getModuleType().getPortIdForInputId(idx); +} + +size_t HWTestModuleOp::getPortIdForOutputId(size_t idx) { + return getModuleType().getPortIdForOutputId(idx); +} + +hw::InnerSymAttr HWTestModuleOp::getPortSymbolAttr(size_t portIndex) { + auto pa = getPortAttrs(); + if (pa) return cast((*pa)[portIndex]); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// TableGen generated logic. +//===----------------------------------------------------------------------===// + +// Provide the autogenerated implementation guts for the Op classes. +#define GET_OP_CLASSES +#include "circt/Dialect/HW/HW.cpp.inc" diff --git a/third_party/circt/lib/Dialect/HW/HWReductions.cpp b/third_party/circt/lib/Dialect/HW/HWReductions.cpp new file mode 100644 index 000000000..3c5123073 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWReductions.cpp @@ -0,0 +1,157 @@ +//===- HWReductions.cpp - Reduction patterns for the HW dialect -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWReductions.h" + +#include "circt/Dialect/HW/HWInstanceGraph.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Reduce/ReductionUtils.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "hw-reductions" + +using namespace mlir; +using namespace circt; +using namespace hw; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +/// Utility to track the transitive size of modules. +struct ModuleSizeCache { + void clear() { moduleSizes.clear(); } + + uint64_t getModuleSize(HWModuleLike module, + hw::InstanceGraph &instanceGraph) { + if (auto it = moduleSizes.find(module); it != moduleSizes.end()) + return it->second; + uint64_t size = 1; + module->walk([&](Operation *op) { + size += 1; + if (auto instOp = dyn_cast(op)) + if (auto instModule = + instanceGraph.getReferencedModule(instOp)) + size += getModuleSize(instModule, instanceGraph); + }); + moduleSizes.insert({module, size}); + return size; + } + + private: + llvm::DenseMap moduleSizes; +}; + +//===----------------------------------------------------------------------===// +// Reduction patterns +//===----------------------------------------------------------------------===// + +/// A sample reduction pattern that maps `hw.module` to `hw.module.extern`. +struct ModuleExternalizer : public OpReduction { + void beforeReduction(mlir::ModuleOp op) override { + instanceGraph = std::make_unique(op); + moduleSizes.clear(); + } + + uint64_t match(HWModuleOp op) override { + return moduleSizes.getModuleSize(op, *instanceGraph); + } + + LogicalResult rewrite(HWModuleOp op) override { + OpBuilder builder(op); + builder.create(op->getLoc(), op.getModuleNameAttr(), + op.getPortList(), StringRef(), + op.getParameters()); + op->erase(); + return success(); + } + + std::string getName() const override { return "hw-module-externalizer"; } + + std::unique_ptr instanceGraph; + ModuleSizeCache moduleSizes; +}; + +/// A sample reduction pattern that replaces all uses of an operation with one +/// of its operands. This can help pruning large parts of the expression tree +/// rapidly. +template +struct HWOperandForwarder : public Reduction { + uint64_t match(Operation *op) override { + if (op->getNumResults() != 1 || op->getNumOperands() < 2 || + OpNum >= op->getNumOperands()) + return 0; + auto resultTy = op->getResult(0).getType().dyn_cast(); + auto opTy = op->getOperand(OpNum).getType().dyn_cast(); + return resultTy && opTy && resultTy == opTy && + op->getResult(0) != op->getOperand(OpNum); + } + LogicalResult rewrite(Operation *op) override { + assert(match(op)); + ImplicitLocOpBuilder builder(op->getLoc(), op); + auto result = op->getResult(0); + auto operand = op->getOperand(OpNum); + LLVM_DEBUG(llvm::dbgs() + << "Forwarding " << operand << " in " << *op << "\n"); + result.replaceAllUsesWith(operand); + reduce::pruneUnusedOps(op, *this); + return success(); + } + std::string getName() const override { + return ("hw-operand" + Twine(OpNum) + "-forwarder").str(); + } +}; + +/// A sample reduction pattern that replaces integer operations with a constant +/// zero of their type. +struct HWConstantifier : public Reduction { + uint64_t match(Operation *op) override { + if (op->getNumResults() == 0 || op->getNumOperands() == 0) return 0; + return llvm::all_of(op->getResults(), [](Value result) { + return result.getType().isa(); + }); + } + LogicalResult rewrite(Operation *op) override { + assert(match(op)); + OpBuilder builder(op); + for (auto result : op->getResults()) { + auto type = result.getType().cast(); + auto newOp = builder.create(op->getLoc(), type, 0); + result.replaceAllUsesWith(newOp); + } + reduce::pruneUnusedOps(op, *this); + return success(); + } + std::string getName() const override { return "hw-constantifier"; } +}; + +//===----------------------------------------------------------------------===// +// Reduction Registration +//===----------------------------------------------------------------------===// + +void HWReducePatternDialectInterface::populateReducePatterns( + circt::ReducePatternSet &patterns) const { + // Gather a list of reduction patterns that we should try. Ideally these are + // assigned reasonable benefit indicators (higher benefit patterns are + // prioritized). For example, things that can knock out entire modules while + // being cheap should be tried first (and thus have higher benefit), before + // trying to tweak operands of individual arithmetic ops. + patterns.add(); + patterns.add(); + patterns.add, 4>(); + patterns.add, 3>(); + patterns.add, 2>(); +} + +void hw::registerReducePatternDialectInterface( + mlir::DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) { + dialect->addInterfaces(); + }); +} diff --git a/third_party/circt/lib/Dialect/HW/HWTypeInterfaces.cpp b/third_party/circt/lib/Dialect/HW/HWTypeInterfaces.cpp new file mode 100644 index 000000000..70dd40f6f --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWTypeInterfaces.cpp @@ -0,0 +1,74 @@ +//===- HWTypeInterfaces.cpp - Implement HW type interfaces ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements type interfaces of the HW Dialect. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWTypeInterfaces.h" + +using namespace mlir; +using namespace circt; +using namespace hw; +using namespace FieldIdImpl; + +Type circt::hw::FieldIdImpl::getFinalTypeByFieldID(Type type, + uint64_t fieldID) { + std::pair pair(type, fieldID); + while (pair.second) { + if (auto ftype = dyn_cast(pair.first)) + pair = ftype.getSubTypeByFieldID(pair.second); + else + llvm::report_fatal_error("fieldID indexing into a non-aggregate type"); + } + return pair.first; +} + +std::pair circt::hw::FieldIdImpl::getSubTypeByFieldID( + Type type, uint64_t fieldID) { + if (!fieldID) return {type, 0}; + if (auto ftype = dyn_cast(type)) + return ftype.getSubTypeByFieldID(fieldID); + + llvm::report_fatal_error("fieldID indexing into a non-aggregate type"); +} + +uint64_t circt::hw::FieldIdImpl::getMaxFieldID(Type type) { + if (auto ftype = dyn_cast(type)) + return ftype.getMaxFieldID(); + return 0; +} + +std::pair circt::hw::FieldIdImpl::projectToChildFieldID( + Type type, uint64_t fieldID, uint64_t index) { + if (auto ftype = dyn_cast(type)) + return ftype.projectToChildFieldID(fieldID, index); + return {0, fieldID == 0}; +} + +uint64_t circt::hw::FieldIdImpl::getIndexForFieldID(Type type, + uint64_t fieldID) { + if (auto ftype = dyn_cast(type)) + return ftype.getIndexForFieldID(fieldID); + return 0; +} + +uint64_t circt::hw::FieldIdImpl::getFieldID(Type type, uint64_t fieldID) { + if (auto ftype = dyn_cast(type)) + return ftype.getFieldID(fieldID); + return 0; +} + +std::pair circt::hw::FieldIdImpl::getIndexAndSubfieldID( + Type type, uint64_t fieldID) { + if (auto ftype = dyn_cast(type)) + return ftype.getIndexAndSubfieldID(fieldID); + return {0, fieldID == 0}; +} + +#include "circt/Dialect/HW/HWTypeInterfaces.cpp.inc" diff --git a/third_party/circt/lib/Dialect/HW/HWTypes.cpp b/third_party/circt/lib/Dialect/HW/HWTypes.cpp new file mode 100644 index 000000000..407f39eee --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/HWTypes.cpp @@ -0,0 +1,974 @@ +//===- HWTypes.cpp - HW types code defs -----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation logic for HW data types. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/HWTypes.h" + +#include "circt/Dialect/HW/HWAttributes.h" +#include "circt/Dialect/HW/HWDialect.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWSymCache.h" +#include "circt/Support/LLVM.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StorageUniquerSupport.h" +#include "mlir/IR/Types.h" + +using namespace circt; +using namespace circt::hw; +using namespace circt::hw::detail; + +#define GET_TYPEDEF_CLASSES +#include "circt/Dialect/HW/HWTypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// Type Helpers +//===----------------------------------------------------------------------===/ + +mlir::Type circt::hw::getCanonicalType(mlir::Type type) { + Type canonicalType; + if (auto typeAlias = type.dyn_cast()) + canonicalType = typeAlias.getCanonicalType(); + else + canonicalType = type; + return canonicalType; +} + +/// Return true if the specified type is a value HW Integer type. This checks +/// that it is a signless standard dialect type or a hw::IntType. +bool circt::hw::isHWIntegerType(mlir::Type type) { + Type canonicalType = getCanonicalType(type); + + if (canonicalType.isa()) return true; + + auto intType = canonicalType.dyn_cast(); + if (!intType || !intType.isSignless()) return false; + + return true; +} + +bool circt::hw::isHWEnumType(mlir::Type type) { + return getCanonicalType(type).isa(); +} + +/// Return true if the specified type can be used as an HW value type, that is +/// the set of types that can be composed together to represent synthesized, +/// hardware but not marker types like InOutType. +bool circt::hw::isHWValueType(Type type) { + // Signless and signed integer types are both valid. + if (type.isa()) return true; + + if (auto array = type.dyn_cast()) + return isHWValueType(array.getElementType()); + + if (auto array = type.dyn_cast()) + return isHWValueType(array.getElementType()); + + if (auto t = type.dyn_cast()) + return llvm::all_of(t.getElements(), + [](auto f) { return isHWValueType(f.type); }); + + if (auto t = type.dyn_cast()) + return llvm::all_of(t.getElements(), + [](auto f) { return isHWValueType(f.type); }); + + if (auto t = type.dyn_cast()) + return isHWValueType(t.getCanonicalType()); + + return false; +} + +/// Return the hardware bit width of a type. Does not reflect any encoding, +/// padding, or storage scheme, just the bit (and wire width) of a +/// statically-size type. Reflects the number of wires needed to transmit a +/// value of this type. Returns -1 if the type is not known or cannot be +/// statically computed. +int64_t circt::hw::getBitWidth(mlir::Type type) { + return llvm::TypeSwitch<::mlir::Type, size_t>(type) + .Case( + [](IntegerType t) { return t.getIntOrFloatBitWidth(); }) + .Case([](auto a) { + int64_t elementBitWidth = getBitWidth(a.getElementType()); + if (elementBitWidth < 0) return elementBitWidth; + int64_t dimBitWidth = a.getNumElements(); + if (dimBitWidth < 0) return static_cast(-1L); + return (int64_t)a.getNumElements() * elementBitWidth; + }) + .Case([](StructType s) { + int64_t total = 0; + for (auto field : s.getElements()) { + int64_t fieldSize = getBitWidth(field.type); + if (fieldSize < 0) return fieldSize; + total += fieldSize; + } + return total; + }) + .Case([](UnionType u) { + int64_t maxSize = 0; + for (auto field : u.getElements()) { + int64_t fieldSize = getBitWidth(field.type) + field.offset; + if (fieldSize > maxSize) maxSize = fieldSize; + } + return maxSize; + }) + .Case([](EnumType e) { return e.getBitWidth(); }) + .Case( + [](TypeAliasType t) { return getBitWidth(t.getCanonicalType()); }) + .Default([](Type) { return -1; }); +} + +/// Return true if the specified type contains known marker types like +/// InOutType. Unlike isHWValueType, this is not conservative, it only returns +/// false on known InOut types, rather than any unknown types. +bool circt::hw::hasHWInOutType(Type type) { + if (auto array = type.dyn_cast()) + return hasHWInOutType(array.getElementType()); + + if (auto array = type.dyn_cast()) + return hasHWInOutType(array.getElementType()); + + if (auto t = type.dyn_cast()) { + return std::any_of(t.getElements().begin(), t.getElements().end(), + [](const auto &f) { return hasHWInOutType(f.type); }); + } + + if (auto t = type.dyn_cast()) + return hasHWInOutType(t.getCanonicalType()); + + return type.isa(); +} + +/// Parse and print nested HW types nicely. These helper methods allow eliding +/// the "hw." prefix on array, inout, and other types when in a context that +/// expects HW subelement types. +static ParseResult parseHWElementType(Type &result, AsmParser &p) { + // If this is an HW dialect type, then we don't need/want the !hw. prefix + // redundantly specified. + auto fullString = static_cast(p).getFullSymbolSpec(); + auto *curPtr = p.getCurrentLocation().getPointer(); + auto typeString = + StringRef(curPtr, fullString.size() - (curPtr - fullString.data())); + + if (typeString.startswith("array<") || typeString.startswith("inout<") || + typeString.startswith("uarray<") || typeString.startswith("struct<") || + typeString.startswith("typealias<") || typeString.startswith("int<") || + typeString.startswith("enum<")) { + llvm::StringRef mnemonic; + auto parseResult = generatedTypeParser(p, &mnemonic, result); + return parseResult.has_value() ? success() : failure(); + } + + return p.parseType(result); +} + +static void printHWElementType(Type element, AsmPrinter &p) { + if (succeeded(generatedTypePrinter(element, p))) return; + p.printType(element); +} + +//===----------------------------------------------------------------------===// +// Int Type +//===----------------------------------------------------------------------===// + +Type IntType::get(mlir::TypedAttr width) { + // The width expression must always be a 32-bit wide integer type itself. + auto widthWidth = width.getType().dyn_cast(); + assert(widthWidth && widthWidth.getWidth() == 32 && + "!hw.int width must be 32-bits"); + (void)widthWidth; + + if (auto cstWidth = width.dyn_cast()) + return IntegerType::get(width.getContext(), + cstWidth.getValue().getZExtValue()); + + return Base::get(width.getContext(), width); +} + +Type IntType::parse(AsmParser &p) { + // The bitwidth of the parameter size is always 32 bits. + auto int32Type = p.getBuilder().getIntegerType(32); + + mlir::TypedAttr width; + if (p.parseLess() || p.parseAttribute(width, int32Type) || p.parseGreater()) + return Type(); + return get(width); +} + +void IntType::print(AsmPrinter &p) const { + p << "<"; + p.printAttributeWithoutType(getWidth()); + p << '>'; +} + +//===----------------------------------------------------------------------===// +// Struct Type +//===----------------------------------------------------------------------===// + +namespace circt { +namespace hw { +namespace detail { +bool operator==(const FieldInfo &a, const FieldInfo &b) { + return a.name == b.name && a.type == b.type; +} +llvm::hash_code hash_value(const FieldInfo &fi) { + return llvm::hash_combine(fi.name, fi.type); +} +} // namespace detail +} // namespace hw +} // namespace circt + +/// Parse a list of field names and types within <>. E.g.: +/// +static ParseResult parseFields(AsmParser &p, + SmallVectorImpl ¶meters) { + return p.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { + StringRef name; + Type type; + if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type)) + return failure(); + parameters.push_back( + FieldInfo{StringAttr::get(p.getContext(), name), type}); + return success(); + }); +} + +/// Print out a list of named fields surrounded by <>. +static void printFields(AsmPrinter &p, ArrayRef fields) { + p << '<'; + llvm::interleaveComma(fields, p, [&](const FieldInfo &field) { + p << field.name.getValue() << ": " << field.type; + }); + p << ">"; +} + +Type StructType::parse(AsmParser &p) { + llvm::SmallVector parameters; + if (parseFields(p, parameters)) return Type(); + return get(p.getContext(), parameters); +} + +void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); } + +Type StructType::getFieldType(mlir::StringRef fieldName) { + for (const auto &field : getElements()) + if (field.name == fieldName) return field.type; + return Type(); +} + +std::optional StructType::getFieldIndex(mlir::StringRef fieldName) { + ArrayRef elems = getElements(); + for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx) + if (elems[idx].name == fieldName) return idx; + return {}; +} + +std::optional StructType::getFieldIndex(mlir::StringAttr fieldName) { + ArrayRef elems = getElements(); + for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx) + if (elems[idx].name == fieldName) return idx; + return {}; +} + +static std::pair> getFieldIDsStruct( + const StructType &st) { + uint64_t fieldID = 0; + auto elements = st.getElements(); + SmallVector fieldIDs; + fieldIDs.reserve(elements.size()); + for (auto &element : elements) { + auto type = element.type; + fieldID += 1; + fieldIDs.push_back(fieldID); + // Increment the field ID for the next field by the number of subfields. + fieldID += hw::FieldIdImpl::getMaxFieldID(type); + } + return {fieldID, fieldIDs}; +} + +void StructType::getInnerTypes(SmallVectorImpl &types) { + for (const auto &field : getElements()) types.push_back(field.type); +} + +uint64_t StructType::getMaxFieldID() const { + uint64_t fieldID = 0; + for (const auto &field : getElements()) + fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type); + return fieldID; +} + +std::pair StructType::getSubTypeByFieldID( + uint64_t fieldID) const { + if (fieldID == 0) return {*this, 0}; + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID)); + auto subfieldIndex = std::distance(fieldIDs.begin(), it); + auto subfieldType = getElements()[subfieldIndex].type; + auto subfieldID = fieldID - fieldIDs[subfieldIndex]; + return {subfieldType, subfieldID}; +} + +std::pair StructType::projectToChildFieldID( + uint64_t fieldID, uint64_t index) const { + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + auto childRoot = fieldIDs[index]; + auto rangeEnd = + index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1); + return std::make_pair(fieldID - childRoot, + fieldID >= childRoot && fieldID <= rangeEnd); +} + +uint64_t StructType::getFieldID(uint64_t index) const { + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + return fieldIDs[index]; +} + +uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const { + assert(!getElements().empty() && "Bundle must have >0 fields"); + auto [maxId, fieldIDs] = getFieldIDsStruct(*this); + auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID)); + return std::distance(fieldIDs.begin(), it); +} + +std::pair StructType::getIndexAndSubfieldID( + uint64_t fieldID) const { + auto index = getIndexForFieldID(fieldID); + auto elementFieldID = getFieldID(index); + return {index, fieldID - elementFieldID}; +} + +//===----------------------------------------------------------------------===// +// Union Type +//===----------------------------------------------------------------------===// + +namespace circt { +namespace hw { +namespace detail { +bool operator==(const OffsetFieldInfo &a, const OffsetFieldInfo &b) { + return a.name == b.name && a.type == b.type && a.offset == b.offset; +} +// NOLINTNEXTLINE +llvm::hash_code hash_value(const OffsetFieldInfo &fi) { + return llvm::hash_combine(fi.name, fi.type, fi.offset); +} +} // namespace detail +} // namespace hw +} // namespace circt + +Type UnionType::parse(AsmParser &p) { + llvm::SmallVector parameters; + if (p.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { + StringRef name; + Type type; + if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type)) + return failure(); + size_t offset = 0; + if (succeeded(p.parseOptionalKeyword("offset"))) + if (p.parseInteger(offset)) return failure(); + parameters.push_back(UnionType::FieldInfo{ + StringAttr::get(p.getContext(), name), type, offset}); + return success(); + })) + return Type(); + return get(p.getContext(), parameters); +} + +void UnionType::print(AsmPrinter &odsPrinter) const { + odsPrinter << '<'; + llvm::interleaveComma( + getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) { + odsPrinter << field.name.getValue() << ": " << field.type; + if (field.offset) odsPrinter << " offset " << field.offset; + }); + odsPrinter << ">"; +} + +UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) { + for (const auto &field : getElements()) + if (field.name == fieldName) return field; + return FieldInfo(); +} + +Type UnionType::getFieldType(mlir::StringRef fieldName) { + return getFieldInfo(fieldName).type; +} + +//===----------------------------------------------------------------------===// +// Enum Type +//===----------------------------------------------------------------------===// + +Type EnumType::parse(AsmParser &p) { + llvm::SmallVector fields; + + if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() { + StringRef name; + if (p.parseKeyword(&name)) return failure(); + fields.push_back(StringAttr::get(p.getContext(), name)); + return success(); + })) + return Type(); + + return get(p.getContext(), ArrayAttr::get(p.getContext(), fields)); +} + +void EnumType::print(AsmPrinter &p) const { + p << '<'; + llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) { + p << enumerator.cast().getValue(); + }); + p << ">"; +} + +bool EnumType::contains(mlir::StringRef field) { + return indexOf(field).has_value(); +} + +std::optional EnumType::indexOf(mlir::StringRef field) { + for (auto it : llvm::enumerate(getFields())) + if (it.value().cast().getValue() == field) return it.index(); + return {}; +} + +size_t EnumType::getBitWidth() { + auto w = getFields().size(); + if (w > 1) return llvm::Log2_64_Ceil(getFields().size()); + return 1; +} + +//===----------------------------------------------------------------------===// +// ArrayType +//===----------------------------------------------------------------------===// + +static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner) { + if (p.parseLess()) return failure(); + + uint64_t dimLiteral; + auto int64Type = p.getBuilder().getIntegerType(64); + + if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value()) + dim = p.getBuilder().getI64IntegerAttr(dimLiteral); + else if (!p.parseOptionalAttribute(dim, int64Type).has_value()) + return failure(); + + if (!dim.isa()) { + p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array"); + return failure(); + } + + if (p.parseXInDimensionList() || parseHWElementType(inner, p) || + p.parseGreater()) + return failure(); + + return success(); +} + +Type ArrayType::parse(AsmParser &p) { + Attribute dim; + Type inner; + + if (failed(parseArray(p, dim, inner))) return Type(); + + auto loc = p.getEncodedSourceLoc(p.getCurrentLocation()); + if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim))) + return Type(); + + return get(inner.getContext(), inner, dim); +} + +void ArrayType::print(AsmPrinter &p) const { + p << "<"; + p.printAttributeWithoutType(getSizeAttr()); + p << "x"; + printHWElementType(getElementType(), p); + p << '>'; +} + +size_t ArrayType::getNumElements() const { + if (auto intAttr = getSizeAttr().dyn_cast()) + return intAttr.getInt(); + return -1; +} + +LogicalResult ArrayType::verify(function_ref emitError, + Type innerType, Attribute size) { + if (hasHWInOutType(innerType)) + return emitError() << "hw.array cannot contain InOut types"; + return success(); +} + +uint64_t ArrayType::getMaxFieldID() const { + return getNumElements() * + (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); +} + +std::pair ArrayType::getSubTypeByFieldID( + uint64_t fieldID) const { + if (fieldID == 0) return {*this, 0}; + return {getElementType(), getIndexAndSubfieldID(fieldID).second}; +} + +std::pair ArrayType::projectToChildFieldID( + uint64_t fieldID, uint64_t index) const { + auto childRoot = getFieldID(index); + auto rangeEnd = + index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1); + return std::make_pair(fieldID - childRoot, + fieldID >= childRoot && fieldID <= rangeEnd); +} + +uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const { + assert(fieldID && "fieldID must be at least 1"); + // Divide the field ID by the number of fieldID's per element. + return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); +} + +std::pair ArrayType::getIndexAndSubfieldID( + uint64_t fieldID) const { + auto index = getIndexForFieldID(fieldID); + auto elementFieldID = getFieldID(index); + return {index, fieldID - elementFieldID}; +} + +uint64_t ArrayType::getFieldID(uint64_t index) const { + return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); +} + +//===----------------------------------------------------------------------===// +// UnpackedArrayType +//===----------------------------------------------------------------------===// + +Type UnpackedArrayType::parse(AsmParser &p) { + Attribute dim; + Type inner; + + if (failed(parseArray(p, dim, inner))) return Type(); + + auto loc = p.getEncodedSourceLoc(p.getCurrentLocation()); + if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim))) + return Type(); + + return get(inner.getContext(), inner, dim); +} + +void UnpackedArrayType::print(AsmPrinter &p) const { + p << "<"; + p.printAttributeWithoutType(getSizeAttr()); + p << "x"; + printHWElementType(getElementType(), p); + p << '>'; +} + +LogicalResult UnpackedArrayType::verify( + function_ref emitError, Type innerType, + Attribute size) { + if (!isHWValueType(innerType)) + return emitError() << "invalid element for uarray type"; + return success(); +} + +size_t UnpackedArrayType::getNumElements() const { + return getSizeAttr().cast().getInt(); +} + +uint64_t UnpackedArrayType::getMaxFieldID() const { + return getNumElements() * + (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); +} + +std::pair UnpackedArrayType::getSubTypeByFieldID( + uint64_t fieldID) const { + if (fieldID == 0) return {*this, 0}; + return {getElementType(), getIndexAndSubfieldID(fieldID).second}; +} + +std::pair UnpackedArrayType::projectToChildFieldID( + uint64_t fieldID, uint64_t index) const { + auto childRoot = getFieldID(index); + auto rangeEnd = + index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1); + return std::make_pair(fieldID - childRoot, + fieldID >= childRoot && fieldID <= rangeEnd); +} + +uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const { + assert(fieldID && "fieldID must be at least 1"); + // Divide the field ID by the number of fieldID's per element. + return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); +} + +std::pair UnpackedArrayType::getIndexAndSubfieldID( + uint64_t fieldID) const { + auto index = getIndexForFieldID(fieldID); + auto elementFieldID = getFieldID(index); + return {index, fieldID - elementFieldID}; +} + +uint64_t UnpackedArrayType::getFieldID(uint64_t index) const { + return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); +} + +//===----------------------------------------------------------------------===// +// InOutType +//===----------------------------------------------------------------------===// + +Type InOutType::parse(AsmParser &p) { + Type inner; + if (p.parseLess() || parseHWElementType(inner, p) || p.parseGreater()) + return Type(); + + auto loc = p.getEncodedSourceLoc(p.getCurrentLocation()); + if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner))) + return Type(); + + return get(p.getContext(), inner); +} + +void InOutType::print(AsmPrinter &p) const { + p << "<"; + printHWElementType(getElementType(), p); + p << '>'; +} + +LogicalResult InOutType::verify(function_ref emitError, + Type innerType) { + if (!isHWValueType(innerType)) + return emitError() << "invalid element for hw.inout type " << innerType; + return success(); +} + +//===----------------------------------------------------------------------===// +// TypeAliasType +//===----------------------------------------------------------------------===// + +static Type computeCanonicalType(Type type) { + return llvm::TypeSwitch(type) + .Case([](TypeAliasType t) { + return computeCanonicalType(t.getCanonicalType()); + }) + .Case([](ArrayType t) { + return ArrayType::get(computeCanonicalType(t.getElementType()), + t.getNumElements()); + }) + .Case([](UnpackedArrayType t) { + return UnpackedArrayType::get(computeCanonicalType(t.getElementType()), + t.getNumElements()); + }) + .Case([](StructType t) { + SmallVector fieldInfo; + for (auto field : t.getElements()) + fieldInfo.push_back(StructType::FieldInfo{ + field.name, computeCanonicalType(field.type)}); + return StructType::get(t.getContext(), fieldInfo); + }) + .Default([](Type t) { return t; }); +} + +TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) { + return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType)); +} + +Type TypeAliasType::parse(AsmParser &p) { + SymbolRefAttr ref; + Type type; + if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() || + p.parseType(type) || p.parseGreater()) + return Type(); + + return get(ref, type); +} + +void TypeAliasType::print(AsmPrinter &p) const { + p << "<" << getRef() << ", " << getInnerType() << ">"; +} + +/// Return the Typedecl referenced by this TypeAlias, given the module to look +/// in. This returns null when the IR is malformed. +TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) { + SymbolRefAttr ref = getRef(); + auto typeScope = ::dyn_cast_or_null( + cache.getDefinition(ref.getRootReference())); + if (!typeScope) return {}; + + return typeScope.lookupSymbol(ref.getLeafReference()); +} + +//////////////////////////////////////////////////////////////////////////////// +// ModuleType +//////////////////////////////////////////////////////////////////////////////// + +LogicalResult ModuleType::verify(function_ref emitError, + ArrayRef ports) { + if (llvm::any_of(ports, [](const ModulePort &port) { + return hasHWInOutType(port.type); + })) + return emitError() << "Ports cannot be inout types"; + return success(); +} + +size_t ModuleType::getPortIdForInputId(size_t idx) { + for (auto [i, p] : llvm::enumerate(getPorts())) { + if (p.dir != ModulePort::Direction::Output) { + if (!idx) return i; + --idx; + } + } + assert(0 && "Out of bounds input port id"); + return ~0UL; +} + +size_t ModuleType::getPortIdForOutputId(size_t idx) { + for (auto [i, p] : llvm::enumerate(getPorts())) { + if (p.dir == ModulePort::Direction::Output) { + if (!idx) return i; + --idx; + } + } + assert(0 && "Out of bounds output port id"); + return ~0UL; +} + +size_t ModuleType::getNumInputs() { + return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) { + return p.dir != ModulePort::Direction::Output; + }); +} + +size_t ModuleType::getNumOutputs() { + return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) { + return p.dir == ModulePort::Direction::Output; + }); +} + +size_t ModuleType::getNumPorts() { return getPorts().size(); } + +SmallVector ModuleType::getInputTypes() { + SmallVector retval; + for (auto &p : getPorts()) { + if (p.dir == ModulePort::Direction::Input) + retval.push_back(p.type); + else if (p.dir == ModulePort::Direction::InOut) { + retval.push_back(hw::InOutType::get(p.type)); + } + } + return retval; +} + +SmallVector ModuleType::getOutputTypes() { + SmallVector retval; + for (auto &p : getPorts()) + if (p.dir == ModulePort::Direction::Output) retval.push_back(p.type); + return retval; +} + +SmallVector ModuleType::getPortTypes() { + SmallVector retval; + for (auto &p : getPorts()) retval.push_back(p.type); + return retval; +} + +Type ModuleType::getInputType(size_t idx) { + return getPorts()[getPortIdForInputId(idx)].type; +} + +Type ModuleType::getOutputType(size_t idx) { + return getPorts()[getPortIdForOutputId(idx)].type; +} + +SmallVector ModuleType::getInputNamesStr() { + SmallVector retval; + for (auto &p : getPorts()) + if (p.dir != ModulePort::Direction::Output) retval.push_back(p.name); + return retval; +} + +SmallVector ModuleType::getOutputNamesStr() { + SmallVector retval; + for (auto &p : getPorts()) + if (p.dir == ModulePort::Direction::Output) retval.push_back(p.name); + return retval; +} + +SmallVector ModuleType::getInputNames() { + SmallVector retval; + for (auto &p : getPorts()) + if (p.dir != ModulePort::Direction::Output) retval.push_back(p.name); + return retval; +} + +SmallVector ModuleType::getOutputNames() { + SmallVector retval; + for (auto &p : getPorts()) + if (p.dir == ModulePort::Direction::Output) retval.push_back(p.name); + return retval; +} + +StringAttr ModuleType::getPortNameAttr(size_t idx) { + return getPorts()[idx].name; +} + +StringRef ModuleType::getPortName(size_t idx) { + auto sa = getPortNameAttr(idx); + if (sa) return sa.getValue(); + return {}; +} + +StringAttr ModuleType::getInputNameAttr(size_t idx) { + return getPorts()[getPortIdForInputId(idx)].name; +} + +StringRef ModuleType::getInputName(size_t idx) { + auto sa = getInputNameAttr(idx); + if (sa) return sa.getValue(); + return {}; +} + +StringAttr ModuleType::getOutputNameAttr(size_t idx) { + return getPorts()[getPortIdForOutputId(idx)].name; +} + +StringRef ModuleType::getOutputName(size_t idx) { + auto sa = getOutputNameAttr(idx); + if (sa) return sa.getValue(); + return {}; +} + +FunctionType ModuleType::getFuncType() { + SmallVector inputs, outputs; + for (auto p : getPorts()) + if (p.dir == ModulePort::Input) + inputs.push_back(p.type); + else if (p.dir == ModulePort::InOut) + inputs.push_back(InOutType::get(p.type)); + else + outputs.push_back(p.type); + return FunctionType::get(getContext(), inputs, outputs); +} + +static StringRef dirToStr(ModulePort::Direction dir) { + switch (dir) { + case ModulePort::Direction::Input: + return "input"; + case ModulePort::Direction::Output: + return "output"; + case ModulePort::Direction::InOut: + return "inout"; + } +} + +static ModulePort::Direction strToDir(StringRef str) { + if (str == "input") return ModulePort::Direction::Input; + if (str == "output") return ModulePort::Direction::Output; + if (str == "inout") return ModulePort::Direction::InOut; + llvm::report_fatal_error("invalid direction"); +} + +/// Parse a list of field names and types within <>. E.g.: +/// +static ParseResult parsePorts(AsmParser &p, + SmallVectorImpl &ports) { + return p.parseCommaSeparatedList( + mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { + StringRef dir; + StringRef name; + Type type; + if (p.parseKeyword(&dir) || p.parseKeyword(&name) || p.parseColon() || + p.parseType(type)) + return failure(); + ports.push_back( + {StringAttr::get(p.getContext(), name), type, strToDir(dir)}); + return success(); + }); +} + +/// Print out a list of named fields surrounded by <>. +static void printPorts(AsmPrinter &p, ArrayRef ports) { + p << '<'; + llvm::interleaveComma(ports, p, [&](const ModulePort &port) { + p << dirToStr(port.dir) << " " << port.name.getValue() << " : " + << port.type; + }); + p << ">"; +} + +Type ModuleType::parse(AsmParser &odsParser) { + llvm::SmallVector ports; + if (parsePorts(odsParser, ports)) return Type(); + return get(odsParser.getContext(), ports); +} + +void ModuleType::print(AsmPrinter &odsPrinter) const { + printPorts(odsPrinter, getPorts()); +} + +namespace circt { +namespace hw { + +static bool operator==(const ModulePort &a, const ModulePort &b) { + return a.dir == b.dir && a.name == b.name && a.type == b.type; +} +static llvm::hash_code hash_value(const ModulePort &port) { + return llvm::hash_combine(port.dir, port.name, port.type); +} +} // namespace hw +} // namespace circt + +ModuleType circt::hw::detail::fnToMod(Operation *op, + ArrayRef inputNames, + ArrayRef outputNames) { + return fnToMod( + cast(cast(op).getFunctionType()), + inputNames, outputNames); +} + +ModuleType circt::hw::detail::fnToMod(FunctionType fnty, + ArrayRef inputNames, + ArrayRef outputNames) { + SmallVector ports; + if (!inputNames.empty()) { + for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames)) + if (auto iot = dyn_cast(t)) + ports.push_back({cast(n), iot.getElementType(), + ModulePort::Direction::InOut}); + else + ports.push_back({cast(n), t, ModulePort::Direction::Input}); + } else { + for (auto t : fnty.getInputs()) + if (auto iot = dyn_cast(t)) + ports.push_back( + {{}, iot.getElementType(), ModulePort::Direction::InOut}); + else + ports.push_back({{}, t, ModulePort::Direction::Input}); + } + if (!outputNames.empty()) { + for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames)) + ports.push_back({cast(n), t, ModulePort::Direction::Output}); + } else { + for (auto t : fnty.getResults()) + ports.push_back({{}, t, ModulePort::Direction::Output}); + } + return ModuleType::get(fnty.getContext(), ports); +} + +//////////////////////////////////////////////////////////////////////////////// +// BoilerPlate +//////////////////////////////////////////////////////////////////////////////// + +void HWDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "circt/Dialect/HW/HWTypes.cpp.inc" + >(); +} diff --git a/third_party/circt/lib/Dialect/HW/InnerSymbolTable.cpp b/third_party/circt/lib/Dialect/HW/InnerSymbolTable.cpp new file mode 100644 index 000000000..6cdf7f5f6 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/InnerSymbolTable.cpp @@ -0,0 +1,251 @@ +//===- InnerSymbolTable.cpp - InnerSymbolTable and InnerRef verification --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements InnerSymbolTable and verification for InnerRef's. +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/InnerSymbolTable.h" + +#include "circt/Dialect/HW/HWOpInterfaces.h" +#include "llvm/Support/Debug.h" +#include "mlir/IR/Threading.h" + +using namespace circt; +using namespace hw; + +namespace circt { +namespace hw { + +//===----------------------------------------------------------------------===// +// InnerSymbolTable +//===----------------------------------------------------------------------===// +InnerSymbolTable::InnerSymbolTable(Operation *op) { + assert(op->hasTrait()); + // Save the operation this table is for. + this->innerSymTblOp = op; + + walkSymbols(op, [&](StringAttr name, const InnerSymTarget &target) { + auto it = symbolTable.try_emplace(name, target); + (void)it; + assert(it.second && "repeated symbol found"); + }); +} + +FailureOr InnerSymbolTable::get(Operation *op) { + assert(op); + if (!op->hasTrait()) + return op->emitError("expected operation to have InnerSymbolTable trait"); + + TableTy table; + auto result = walkSymbols( + op, [&](StringAttr name, const InnerSymTarget &target) -> LogicalResult { + auto it = table.try_emplace(name, target); + if (it.second) return success(); + auto existing = it.first->second; + return target.getOp() + ->emitError() + .append("redefinition of inner symbol named '", name.strref(), "'") + .attachNote(existing.getOp()->getLoc()) + .append("see existing inner symbol definition here"); + }); + if (failed(result)) return failure(); + return InnerSymbolTable(op, std::move(table)); +} + +LogicalResult InnerSymbolTable::walkSymbols(Operation *op, + InnerSymCallbackFn callback) { + auto walkSym = [&](StringAttr name, const InnerSymTarget &target) { + assert(name && !name.getValue().empty()); + return callback(name, target); + }; + + auto walkSyms = [&](hw::InnerSymAttr symAttr, + const InnerSymTarget &baseTarget) -> LogicalResult { + assert(baseTarget.getField() == 0); + for (auto symProp : symAttr) { + if (failed(walkSym(symProp.getName(), + InnerSymTarget::getTargetForSubfield( + baseTarget, symProp.getFieldID())))) + return failure(); + } + return success(); + }; + + // Walk the operation and add InnerSymbolTarget's to the table. + return success( + !op->walk([&](Operation *curOp) -> WalkResult { + if (auto symOp = dyn_cast(curOp)) + if (auto symAttr = symOp.getInnerSymAttr()) + if (failed(walkSyms(symAttr, InnerSymTarget(symOp)))) + return WalkResult::interrupt(); + + // Check for ports + if (auto mod = dyn_cast(curOp)) { + for (size_t i = 0, e = mod.getNumPorts(); i < e; ++i) { + if (auto symAttr = mod.getPortSymbolAttr(i)) + if (failed(walkSyms(symAttr, InnerSymTarget(i, curOp)))) + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }).wasInterrupted()); +} + +/// Look up a symbol with the specified name, returning empty InnerSymTarget if +/// no such name exists. Names never include the @ on them. +InnerSymTarget InnerSymbolTable::lookup(StringRef name) const { + return lookup(StringAttr::get(innerSymTblOp->getContext(), name)); +} +InnerSymTarget InnerSymbolTable::lookup(StringAttr name) const { + return symbolTable.lookup(name); +} + +/// Look up a symbol with the specified name, returning null if no such +/// name exists or doesn't target just an operation. +Operation *InnerSymbolTable::lookupOp(StringRef name) const { + return lookupOp(StringAttr::get(innerSymTblOp->getContext(), name)); +} +Operation *InnerSymbolTable::lookupOp(StringAttr name) const { + auto result = lookup(name); + if (result.isOpOnly()) return result.getOp(); + return nullptr; +} + +/// Get InnerSymbol for an operation. +StringAttr InnerSymbolTable::getInnerSymbol(Operation *op) { + if (auto innerSymOp = dyn_cast(op)) + return innerSymOp.getInnerNameAttr(); + return {}; +} + +/// Get InnerSymbol for a target. Be robust to queries on unexpected +/// operations to avoid users needing to know the details. +StringAttr InnerSymbolTable::getInnerSymbol(const InnerSymTarget &target) { + // Assert on misuse, but try to handle queries otherwise. + assert(target); + + // Obtain the base InnerSymAttr for the specified target. + auto getBase = [](auto &target) -> hw::InnerSymAttr { + if (target.isPort()) { + if (auto mod = dyn_cast(target.getOp())) { + assert(target.getPort() < mod.getNumPorts()); + return mod.getPortSymbolAttr(target.getPort()); + } + } else { + // InnerSymbols only supported if op implements the interface. + if (auto symOp = dyn_cast(target.getOp())) + return symOp.getInnerSymAttr(); + } + return {}; + }; + + if (auto base = getBase(target)) + return base.getSymIfExists(target.getField()); + return {}; +} + +//===----------------------------------------------------------------------===// +// InnerSymbolTableCollection +//===----------------------------------------------------------------------===// + +InnerSymbolTable &InnerSymbolTableCollection::getInnerSymbolTable( + Operation *op) { + auto it = symbolTables.try_emplace(op, nullptr); + if (it.second) it.first->second = ::std::make_unique(op); + return *it.first->second; +} + +LogicalResult InnerSymbolTableCollection::populateAndVerifyTables( + Operation *innerRefNSOp) { + // Gather top-level operations that have the InnerSymbolTable trait. + SmallVector innerSymTableOps(llvm::make_filter_range( + llvm::make_pointer_range(innerRefNSOp->getRegion(0).front()), + [&](Operation *op) { + return op->hasTrait(); + })); + + // Ensure entries exist for each operation. + llvm::for_each(innerSymTableOps, + [&](auto *op) { symbolTables.try_emplace(op, nullptr); }); + + // Construct the tables in parallel (if context allows it). + return mlir::failableParallelForEach( + innerRefNSOp->getContext(), innerSymTableOps, [&](auto *op) { + auto it = symbolTables.find(op); + assert(it != symbolTables.end()); + if (!it->second) { + auto result = InnerSymbolTable::get(op); + if (failed(result)) return failure(); + it->second = std::make_unique(std::move(*result)); + return success(); + } + return failure(); + }); +} + +//===----------------------------------------------------------------------===// +// InnerRefNamespace +//===----------------------------------------------------------------------===// + +InnerSymTarget InnerRefNamespace::lookup(hw::InnerRefAttr inner) { + auto *mod = symTable.lookup(inner.getModule()); + if (!mod) return {}; + assert(mod->hasTrait()); + return innerSymTables.getInnerSymbolTable(mod).lookup(inner.getName()); +} + +Operation *InnerRefNamespace::lookupOp(hw::InnerRefAttr inner) { + auto *mod = symTable.lookup(inner.getModule()); + if (!mod) return nullptr; + assert(mod->hasTrait()); + return innerSymTables.getInnerSymbolTable(mod).lookupOp(inner.getName()); +} + +//===----------------------------------------------------------------------===// +// InnerRefNamespace verification +//===----------------------------------------------------------------------===// + +namespace detail { + +LogicalResult verifyInnerRefNamespace(Operation *op) { + // Construct the symbol tables. + InnerSymbolTableCollection innerSymTables; + if (failed(innerSymTables.populateAndVerifyTables(op))) return failure(); + + SymbolTable symbolTable(op); + InnerRefNamespace ns{symbolTable, innerSymTables}; + + // Conduct parallel walks of the top-level children of this + // InnerRefNamespace, verifying all InnerRefUserOp's discovered within. + auto verifySymbolUserFn = [&](Operation *op) -> WalkResult { + if (auto user = dyn_cast(op)) + return WalkResult(user.verifyInnerRefs(ns)); + return WalkResult::advance(); + }; + return mlir::failableParallelForEach( + op->getContext(), op->getRegion(0).front(), [&](auto &op) { + return success(!op.walk(verifySymbolUserFn).wasInterrupted()); + }); +} + +} // namespace detail + +bool InnerRefNamespaceLike::classof(mlir::Operation *op) { + return op->hasTrait() || + op->hasTrait(); +} + +bool InnerRefNamespaceLike::classof( + const mlir::RegisteredOperationName *opInfo) { + return opInfo->hasTrait() || + opInfo->hasTrait(); +} + +} // namespace hw +} // namespace circt diff --git a/third_party/circt/lib/Dialect/HW/InstanceImplementation.cpp b/third_party/circt/lib/Dialect/HW/InstanceImplementation.cpp new file mode 100644 index 000000000..31ac85bdc --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/InstanceImplementation.cpp @@ -0,0 +1,347 @@ +//===- InstanceImplementation.cpp - Utilities for instance-like ops -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/InstanceImplementation.h" + +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWSymCache.h" + +using namespace circt; +using namespace circt::hw; + +Operation *instance_like_impl::getReferencedModule( + const HWSymbolCache *cache, Operation *instanceOp, + mlir::FlatSymbolRefAttr moduleName) { + if (cache) + if (auto *result = cache->getDefinition(moduleName)) return result; + + auto topLevelModuleOp = instanceOp->getParentOfType(); + return topLevelModuleOp.lookupSymbol(moduleName.getValue()); +} + +LogicalResult instance_like_impl::verifyReferencedModule( + Operation *instanceOp, SymbolTableCollection &symbolTable, + mlir::FlatSymbolRefAttr moduleName, Operation *&module) { + module = symbolTable.lookupNearestSymbolFrom(instanceOp, moduleName); + if (module == nullptr) + return instanceOp->emitError("Cannot find module definition '") + << moduleName.getValue() << "'"; + + // It must be some sort of module. + if (!isa(module)) + return instanceOp->emitError("symbol reference '") + << moduleName.getValue() << "' isn't a module"; + + return success(); +} + +LogicalResult instance_like_impl::resolveParametricTypes( + Location loc, ArrayAttr parameters, ArrayRef types, + SmallVectorImpl &resolvedTypes, const EmitErrorFn &emitError) { + for (auto type : types) { + auto expectedType = evaluateParametricType(loc, parameters, type); + if (failed(expectedType)) { + emitError([&](auto &diag) { + diag << "failed to resolve parametric input of instantiated module"; + return true; + }); + return failure(); + } + + resolvedTypes.push_back(*expectedType); + } + + return success(); +} + +LogicalResult instance_like_impl::verifyInputs(ArrayAttr argNames, + ArrayAttr moduleArgNames, + TypeRange inputTypes, + ArrayRef moduleInputTypes, + const EmitErrorFn &emitError) { + // Check operand types first. + if (moduleInputTypes.size() != inputTypes.size()) { + emitError([&](auto &diag) { + diag << "has a wrong number of operands; expected " + << moduleInputTypes.size() << " but got " << inputTypes.size(); + return true; + }); + return failure(); + } + + if (argNames.size() != inputTypes.size()) { + emitError([&](auto &diag) { + diag << "has a wrong number of input port names; expected " + << inputTypes.size() << " but got " << argNames.size(); + return true; + }); + return failure(); + } + + for (size_t i = 0; i != inputTypes.size(); ++i) { + auto expectedType = moduleInputTypes[i]; + auto operandType = inputTypes[i]; + + if (operandType != expectedType) { + emitError([&](auto &diag) { + diag << "operand type #" << i << " must be " << expectedType + << ", but got " << operandType; + return true; + }); + return failure(); + } + + if (argNames[i] != moduleArgNames[i]) { + emitError([&](auto &diag) { + diag << "input label #" << i << " must be " << moduleArgNames[i] + << ", but got " << argNames[i]; + return true; + }); + return failure(); + } + } + + return success(); +} + +LogicalResult instance_like_impl::verifyOutputs( + ArrayAttr resultNames, ArrayAttr moduleResultNames, TypeRange resultTypes, + ArrayRef moduleResultTypes, const EmitErrorFn &emitError) { + // Check result types and labels. + if (moduleResultTypes.size() != resultTypes.size()) { + emitError([&](auto &diag) { + diag << "has a wrong number of results; expected " + << moduleResultTypes.size() << " but got " << resultTypes.size(); + return true; + }); + return failure(); + } + + if (resultNames.size() != resultTypes.size()) { + emitError([&](auto &diag) { + diag << "has a wrong number of results port labels; expected " + << resultTypes.size() << " but got " << resultNames.size(); + return true; + }); + return failure(); + } + + for (size_t i = 0; i != resultTypes.size(); ++i) { + auto expectedType = moduleResultTypes[i]; + auto resultType = resultTypes[i]; + + if (resultType != expectedType) { + emitError([&](auto &diag) { + diag << "result type #" << i << " must be " << expectedType + << ", but got " << resultType; + return true; + }); + return failure(); + } + + if (resultNames[i] != moduleResultNames[i]) { + emitError([&](auto &diag) { + diag << "result label #" << i << " must be " << moduleResultNames[i] + << ", but got " << resultNames[i]; + return true; + }); + return failure(); + } + } + + return success(); +} + +LogicalResult instance_like_impl::verifyParameters( + ArrayAttr parameters, ArrayAttr moduleParameters, + const EmitErrorFn &emitError) { + // Check parameters match up. + auto numParameters = parameters.size(); + if (numParameters != moduleParameters.size()) { + emitError([&](auto &diag) { + diag << "expected " << moduleParameters.size() << " parameters but had " + << numParameters; + return true; + }); + return failure(); + } + + for (size_t i = 0; i != numParameters; ++i) { + auto param = parameters[i].cast(); + auto modParam = moduleParameters[i].cast(); + + auto paramName = param.getName(); + if (paramName != modParam.getName()) { + emitError([&](auto &diag) { + diag << "parameter #" << i << " should have name " << modParam.getName() + << " but has name " << paramName; + return true; + }); + return failure(); + } + + if (param.getType() != modParam.getType()) { + emitError([&](auto &diag) { + diag << "parameter " << paramName << " should have type " + << modParam.getType() << " but has type " << param.getType(); + return true; + }); + return failure(); + } + + // All instance parameters must have a value. Specify the same value as + // a module's default value if you want the default. + if (!param.getValue()) { + emitError([&](auto &diag) { + diag << "parameter " << paramName << " must have a value"; + return false; + }); + return failure(); + } + } + + return success(); +} + +LogicalResult instance_like_impl::verifyInstanceOfHWModule( + Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, + TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, + ArrayAttr parameters, SymbolTableCollection &symbolTable) { + // Verify that we reference some kind of HW module and get the module on + // success. + Operation *module; + if (failed(instance_like_impl::verifyReferencedModule(instance, symbolTable, + moduleRef, module))) + return failure(); + + // Emit an error message on the instance, with a note indicating which module + // is being referenced. The error message on the instance is added by the + // verification function this lambda is passed to. + EmitErrorFn emitError = + [&](const std::function &fn) { + auto diag = instance->emitOpError(); + if (fn(diag)) + diag.attachNote(module->getLoc()) << "module declared here"; + }; + + // Check that input types are consistent with the referenced module. + auto mod = cast(module); + auto modArgNames = + ArrayAttr::get(instance->getContext(), mod.getInputNames()); + auto modResultNames = + ArrayAttr::get(instance->getContext(), mod.getOutputNames()); + + ArrayRef resolvedModInputTypesRef = getModuleType(module).getInputs(); + SmallVector resolvedModInputTypes; + if (parameters) { + if (failed(instance_like_impl::resolveParametricTypes( + instance->getLoc(), parameters, getModuleType(module).getInputs(), + resolvedModInputTypes, emitError))) + return failure(); + resolvedModInputTypesRef = resolvedModInputTypes; + } + if (failed(instance_like_impl::verifyInputs( + argNames, modArgNames, inputs.getTypes(), resolvedModInputTypesRef, + emitError))) + return failure(); + + // Check that result types are consistent with the referenced module. + ArrayRef resolvedModResultTypesRef = getModuleType(module).getResults(); + SmallVector resolvedModResultTypes; + if (parameters) { + if (failed(instance_like_impl::resolveParametricTypes( + instance->getLoc(), parameters, getModuleType(module).getResults(), + resolvedModResultTypes, emitError))) + return failure(); + resolvedModResultTypesRef = resolvedModResultTypes; + } + if (failed(instance_like_impl::verifyOutputs( + resultNames, modResultNames, results, resolvedModResultTypesRef, + emitError))) + return failure(); + + if (parameters) { + // Check that the parameters are consistent with the referenced module. + ArrayAttr modParameters = module->getAttrOfType("parameters"); + if (failed(instance_like_impl::verifyParameters(parameters, modParameters, + emitError))) + return failure(); + } + + return success(); +} + +LogicalResult instance_like_impl::verifyParameterStructure( + ArrayAttr parameters, ArrayAttr moduleParameters, + const EmitErrorFn &emitError) { + // Check that all the parameter values specified to the instance are + // structurally valid. + for (auto param : parameters) { + auto paramAttr = param.cast(); + auto value = paramAttr.getValue(); + // The SymbolUses verifier which checks that this exists may not have been + // run yet. Let it issue the error. + if (!value) continue; + + auto typedValue = value.dyn_cast(); + if (!typedValue) { + emitError([&](auto &diag) { + diag << "parameter " << paramAttr + << " should have a typed value; has value " << value; + return false; + }); + return failure(); + } + + if (typedValue.getType() != paramAttr.getType()) { + emitError([&](auto &diag) { + diag << "parameter " << paramAttr << " should have type " + << paramAttr.getType() << "; has type " << typedValue.getType(); + return false; + }); + return failure(); + } + + if (failed(checkParameterInContext(value, moduleParameters, emitError))) + return failure(); + } + return success(); +} + +StringAttr instance_like_impl::getName(ArrayAttr names, size_t idx) { + // Tolerate malformed IR here to enable debug printing etc. + if (names && idx < names.size()) return names[idx].cast(); + return StringAttr(); +} + +ArrayAttr instance_like_impl::updateName(ArrayAttr oldNames, size_t i, + StringAttr name) { + SmallVector newNames(oldNames.begin(), oldNames.end()); + if (newNames[i] == name) return oldNames; + newNames[i] = name; + return ArrayAttr::get(oldNames.getContext(), oldNames); +} + +void instance_like_impl::getAsmResultNames(OpAsmSetValueNameFn setNameFn, + StringRef instanceName, + ArrayAttr resultNames, + ValueRange results) { + // Provide default names for instance results. + std::string name = instanceName.str() + "."; + size_t baseNameLen = name.size(); + + for (size_t i = 0, e = resultNames.size(); i != e; ++i) { + auto resName = getName(resultNames, i); + name.resize(baseNameLen); + if (resName && !resName.getValue().empty()) + name += resName.getValue().str(); + else + name += std::to_string(i); + setNameFn(results[i], name); + } +} diff --git a/third_party/circt/lib/Dialect/HW/ModuleImplementation.cpp b/third_party/circt/lib/Dialect/HW/ModuleImplementation.cpp new file mode 100644 index 000000000..8a07e2494 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/ModuleImplementation.cpp @@ -0,0 +1,328 @@ +//===- ModuleImplementation.cpp - Utilities for module-like ops -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/ModuleImplementation.h" + +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Support/LLVM.h" +#include "circt/Support/ParsingUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/Interfaces/FunctionImplementation.h" + +using namespace circt; +using namespace circt::hw; + +/// Parse a function result list. +/// +/// function-result-list ::= function-result-list-parens +/// function-result-list-parens ::= `(` `)` +/// | `(` function-result-list-no-parens `)` +/// function-result-list-no-parens ::= function-result (`,` function-result)* +/// function-result ::= (percent-identifier `:`) type attribute-dict? +/// +static ParseResult parseFunctionResultList( + OpAsmParser &parser, SmallVectorImpl &resultNames, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs, + SmallVectorImpl &resultLocs) { + auto parseElt = [&]() -> ParseResult { + // Stash the current location parser location. + auto irLoc = parser.getCurrentLocation(); + + // Parse the result name. + std::string portName; + if (parser.parseKeywordOrString(&portName)) return failure(); + resultNames.push_back(StringAttr::get(parser.getContext(), portName)); + + // Parse the results type. + resultTypes.emplace_back(); + if (parser.parseColonType(resultTypes.back())) return failure(); + + // Parse the result attributes. + NamedAttrList attrs; + if (failed(parser.parseOptionalAttrDict(attrs))) return failure(); + resultAttrs.push_back(attrs.getDictionary(parser.getContext())); + + // Parse the result location. + std::optional maybeLoc; + if (failed(parser.parseOptionalLocationSpecifier(maybeLoc))) + return failure(); + Location loc = maybeLoc ? *maybeLoc : parser.getEncodedSourceLoc(irLoc); + resultLocs.push_back(loc); + + return success(); + }; + + return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, + parseElt); +} + +/// Return the port name for the specified argument or result. +static StringRef getModuleArgumentName(Operation *module, size_t argNo) { + if (auto mod = dyn_cast(module)) { + if (argNo < mod.getNumInputPorts()) return mod.getInputName(argNo); + return StringRef(); + } + auto argNames = module->getAttrOfType("argNames"); + // Tolerate malformed IR here to enable debug printing etc. + if (argNames && argNo < argNames.size()) + return argNames[argNo].cast().getValue(); + return StringRef(); +} + +static StringRef getModuleResultName(Operation *module, size_t resultNo) { + if (auto mod = dyn_cast(module)) { + if (resultNo < mod.getNumOutputPorts()) return mod.getOutputName(resultNo); + return StringRef(); + } + auto resultNames = module->getAttrOfType("resultNames"); + // Tolerate malformed IR here to enable debug printing etc. + if (resultNames && resultNo < resultNames.size()) + return resultNames[resultNo].cast().getValue(); + return StringRef(); +} + +void module_like_impl::printModuleSignature(OpAsmPrinter &p, Operation *op, + ArrayRef argTypes, + bool isVariadic, + ArrayRef resultTypes, + bool &needArgNamesAttr) { + using namespace mlir::function_interface_impl; + + Region &body = op->getRegion(0); + bool isExternal = body.empty(); + SmallString<32> resultNameStr; + mlir::OpPrintingFlags flags; + + // Handle either old FunctionOpInterface modules or new-style hwmodulelike + // This whole thing should be split up into two functions, but the delta is + // so small, we are leaving this for now. + auto modOp = dyn_cast(op); + auto funcOp = dyn_cast(op); + SmallVector inputAttrs, outputAttrs; + if (funcOp) { + if (auto args = funcOp.getAllArgAttrs()) + for (auto a : args.getValue()) inputAttrs.push_back(a); + inputAttrs.resize(funcOp.getNumArguments()); + if (auto results = funcOp.getAllResultAttrs()) + for (auto a : results.getValue()) outputAttrs.push_back(a); + outputAttrs.resize(funcOp.getNumResults()); + } else { + inputAttrs = modOp.getAllInputAttrs(); + outputAttrs = modOp.getAllOutputAttrs(); + } + + p << '('; + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + if (i > 0) p << ", "; + + auto argName = modOp ? modOp.getInputName(i) : getModuleArgumentName(op, i); + if (!isExternal) { + // Get the printed format for the argument name. + resultNameStr.clear(); + llvm::raw_svector_ostream tmpStream(resultNameStr); + p.printOperand(body.front().getArgument(i), tmpStream); + // If the name wasn't printable in a way that agreed with argName, make + // sure to print out an explicit argNames attribute. + if (tmpStream.str().drop_front() != argName) needArgNamesAttr = true; + + p << tmpStream.str() << ": "; + } else if (!argName.empty()) { + p << '%' << argName << ": "; + } + + p.printType(argTypes[i]); + auto inputAttr = inputAttrs[i]; + p.printOptionalAttrDict(inputAttr + ? cast(inputAttr).getValue() + : ArrayRef()); + + // TODO: `printOptionalLocationSpecifier` will emit aliases for locations, + // even if they are not printed. This will have to be fixed upstream. For + // now, use what was specified on the command line. + if (flags.shouldPrintDebugInfo()) { + auto loc = modOp.getInputLoc(i); + if (!isa(loc)) p.printOptionalLocationSpecifier(loc); + } + } + + if (isVariadic) { + if (!argTypes.empty()) p << ", "; + p << "..."; + } + + p << ')'; + + // We print result types specially since we support named arguments. + if (!resultTypes.empty()) { + p << " -> ("; + for (size_t i = 0, e = resultTypes.size(); i < e; ++i) { + if (i != 0) p << ", "; + p.printKeywordOrString(getModuleResultName(op, i)); + p << ": "; + p.printType(resultTypes[i]); + auto outputAttr = outputAttrs[i]; + p.printOptionalAttrDict(outputAttr + ? cast(outputAttr).getValue() + : ArrayRef()); + + // TODO: `printOptionalLocationSpecifier` will emit aliases for locations, + // even if they are not printed. This will have to be fixed upstream. For + // now, use what was specified on the command line. + if (flags.shouldPrintDebugInfo()) { + auto loc = modOp.getOutputLoc(i); + if (!isa(loc)) p.printOptionalLocationSpecifier(loc); + } + } + p << ')'; + } +} + +ParseResult module_like_impl::parseModuleFunctionSignature( + OpAsmParser &parser, bool &isVariadic, + SmallVectorImpl &args, + SmallVectorImpl &argNames, SmallVectorImpl &argLocs, + SmallVectorImpl &resultNames, + SmallVectorImpl &resultAttrs, + SmallVectorImpl &resultLocs, TypeAttr &type) { + using namespace mlir::function_interface_impl; + auto *context = parser.getContext(); + + // Parse the argument list. + if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, + /*allowType=*/true, /*allowAttrs=*/true)) + return failure(); + + // Parse the result list. + SmallVector resultTypes; + if (succeeded(parser.parseOptionalArrow())) + if (failed(parseFunctionResultList(parser, resultNames, resultTypes, + resultAttrs, resultLocs))) + return failure(); + + // Process the ssa args for the information we're looking for. + SmallVector argTypes; + for (auto &arg : args) { + argNames.push_back(parsing_util::getNameFromSSA(context, arg.ssaName.name)); + argTypes.push_back(arg.type); + if (!arg.sourceLoc) + arg.sourceLoc = parser.getEncodedSourceLoc(arg.ssaName.location); + argLocs.push_back(*arg.sourceLoc); + } + + type = TypeAttr::get(FunctionType::get(context, argTypes, resultTypes)); + + return success(); +} + +//////////////////////////////////////////////////////////////////////////////// +// New Style +//////////////////////////////////////////////////////////////////////////////// + +static ParseResult parseDirection(OpAsmParser &p, ModulePort::Direction &dir) { + StringRef key; + if (failed(p.parseKeyword(&key))) + return p.emitError(p.getCurrentLocation(), "expected port direction"); + if (key == "input") + dir = ModulePort::Direction::Input; + else if (key == "output") + dir = ModulePort::Direction::Output; + else if (key == "inout") + dir = ModulePort::Direction::InOut; + else + return p.emitError(p.getCurrentLocation(), "unknown port direction '") + << key << "'"; + return success(); +} + +/// Parse a single argument with the following syntax: +/// +/// direction `%ssaname : !type { optionalAttrDict} loc(optionalSourceLoc)` +/// +/// If `allowType` is false or `allowAttrs` are false then the respective +/// parts of the grammar are not parsed. +static ParseResult parsePort(OpAsmParser &p, + module_like_impl::PortParse &result) { + NamedAttrList attrs; + if (parseDirection(p, result.direction) || + p.parseOperand(result.ssaName, /*allowResultNumber=*/false) || + p.parseColonType(result.type) || p.parseOptionalAttrDict(attrs) || + p.parseOptionalLocationSpecifier(result.sourceLoc)) + return failure(); + result.attrs = attrs.getDictionary(p.getContext()); + return success(); +} + +static ParseResult parsePortList( + OpAsmParser &p, SmallVectorImpl &result) { + auto parseOnePort = [&]() -> ParseResult { + return parsePort(p, result.emplace_back()); + }; + return p.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOnePort, + " in port list"); +} + +ParseResult module_like_impl::parseModuleSignature( + OpAsmParser &parser, SmallVectorImpl &args, TypeAttr &modType) { + auto *context = parser.getContext(); + + // Parse the port list. + if (parsePortList(parser, args)) return failure(); + + // Process the ssa args for the information we're looking for. + SmallVector ports; + for (auto &arg : args) { + ports.push_back({parsing_util::getNameFromSSA(context, arg.ssaName.name), + arg.type, arg.direction}); + if (!arg.sourceLoc) + arg.sourceLoc = parser.getEncodedSourceLoc(arg.ssaName.location); + } + modType = TypeAttr::get(ModuleType::get(context, ports)); + + return success(); +} + +static const char *directionAsString(ModulePort::Direction dir) { + if (dir == ModulePort::Direction::Input) return "input"; + if (dir == ModulePort::Direction::Output) return "output"; + if (dir == ModulePort::Direction::InOut) return "inout"; + assert(0 && "Unknown port direction"); + abort(); + return "unknown"; +} + +void module_like_impl::printModuleSignatureNew(OpAsmPrinter &p, Operation *op) { + mlir::OpPrintingFlags flags; + + auto typeAttr = op->getAttrOfType("module_type"); + auto modType = cast(typeAttr.getValue()); + auto portAttrs = op->getAttrOfType("port_attrs"); + auto locAttrs = op->getAttrOfType("port_locs"); + + p << '('; + for (auto [i, port] : llvm::enumerate(modType.getPorts())) { + if (i > 0) p << ", "; + p.printKeywordOrString(directionAsString(port.dir)); + p << " %"; + p.printKeywordOrString(port.name); + p << " : "; + p.printType(port.type); + if (auto attr = dyn_cast(portAttrs[i])) + p.printOptionalAttrDict(attr.getValue()); + + // TODO: `printOptionalLocationSpecifier` will emit aliases for locations, + // even if they are not printed. This will have to be fixed upstream. For + // now, use what was specified on the command line. + if (flags.shouldPrintDebugInfo()) + if (auto loc = locAttrs[i]) + p.printOptionalLocationSpecifier(cast(loc)); + } + + p << ')'; +} diff --git a/third_party/circt/lib/Dialect/HW/PortConverter.cpp b/third_party/circt/lib/Dialect/HW/PortConverter.cpp new file mode 100644 index 000000000..b9b84c98f --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/PortConverter.cpp @@ -0,0 +1,233 @@ +//===- PortConverter.cpp - Module I/O rewriting utility ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/HW/PortConverter.h" + +#include + +using namespace circt; +using namespace hw; + +/// Return a attribute with the specified suffix appended. +static StringAttr append(StringAttr base, const Twine &suffix) { + if (suffix.isTriviallyEmpty()) return base; + auto *context = base.getContext(); + return StringAttr::get(context, base.getValue() + suffix); +} + +namespace { + +/// We consider non-caught ports to be ad-hoc signaling or 'untouched'. (Which +/// counts as a signaling protocol if one squints pretty hard). We mostly do +/// this since it allows us a more consistent internal API. +class UntouchedPortConversion : public PortConversion { + public: + UntouchedPortConversion(PortConverterImpl &converter, hw::PortInfo origPort) + : PortConversion(converter, origPort) { + // Set the "RTTI flag" to true (see comment in header for this variable). + isUntouchedFlag = true; + } + + void mapInputSignals(OpBuilder &b, Operation *inst, Value instValue, + SmallVectorImpl &newOperands, + ArrayRef newResults) override { + newOperands[portInfo.argNum] = instValue; + } + void mapOutputSignals(OpBuilder &b, Operation *inst, Value instValue, + SmallVectorImpl &newOperands, + ArrayRef newResults) override { + instValue.replaceAllUsesWith(newResults[portInfo.argNum]); + } + + private: + void buildInputSignals() override { + Value newValue = + converter.createNewInput(origPort, "", origPort.type, portInfo); + if (body) body->getArgument(origPort.argNum).replaceAllUsesWith(newValue); + } + + void buildOutputSignals() override { + Value output; + if (body) output = body->getTerminator()->getOperand(origPort.argNum); + converter.createNewOutput(origPort, "", origPort.type, output, portInfo); + } + + hw::PortInfo portInfo; +}; + +} // namespace + +FailureOr> PortConversionBuilder::build( + hw::PortInfo port) { + // Default builder is the 'untouched' port conversion which will simply + // pass ports through unmodified. + return {std::make_unique(converter, port)}; +} + +PortConverterImpl::PortConverterImpl(igraph::InstanceGraphNode *moduleNode) + : moduleNode(moduleNode), b(moduleNode->getModule()->getContext()) { + mod = dyn_cast(*moduleNode->getModule()); + assert(mod && "PortConverter only works on HWMutableModuleLike"); + + if (mod->getNumRegions() == 1 && mod->getRegion(0).hasOneBlock()) { + body = &mod->getRegion(0).front(); + terminator = body->getTerminator(); + } +} + +Value PortConverterImpl::createNewInput(PortInfo origPort, const Twine &suffix, + Type type, PortInfo &newPort) { + newPort = PortInfo{ + {append(origPort.name, suffix), type, ModulePort::Direction::Input}, + newInputs.size(), + {}, + {}, + origPort.loc}; + newInputs.emplace_back(0, newPort); + + if (!body) return {}; + return body->addArgument(type, origPort.loc); +} + +void PortConverterImpl::createNewOutput(PortInfo origPort, const Twine &suffix, + Type type, Value output, + PortInfo &newPort) { + newPort = PortInfo{ + {append(origPort.name, suffix), type, ModulePort::Direction::Output}, + newOutputs.size(), + {}, + {}, + origPort.loc}; + newOutputs.emplace_back(0, newPort); + + if (!body) return; + + OpBuilder::InsertionGuard g(b); + b.setInsertionPointToStart(body); + terminator->insertOperands(terminator->getNumOperands(), output); +} + +LogicalResult PortConverterImpl::run() { + ModulePortInfo ports = mod.getPortList(); + + bool foundLoweredPorts = false; + + auto createPortLowering = [&](PortInfo port) { + auto &loweredPorts = port.dir == ModulePort::Direction::Output + ? loweredOutputs + : loweredInputs; + + auto loweredPort = ssb->build(port); + if (failed(loweredPort)) return failure(); + + foundLoweredPorts |= !(*loweredPort)->isUntouched(); + loweredPorts.emplace_back(std::move(*loweredPort)); + + if (failed(loweredPorts.back()->init())) return failure(); + + return success(); + }; + + // Dispatch the port conversion builder on the I/O of the module. + for (PortInfo port : ports) + if (failed(createPortLowering(port))) return failure(); + + // Bail early if we didn't find anything to convert. + if (!foundLoweredPorts) { + // Memory optimization. + loweredInputs.clear(); + loweredOutputs.clear(); + return success(); + } + + // Lower the ports -- this mutates the body directly and builds the port + // lists. + for (auto &lowering : loweredInputs) lowering->lowerPort(); + for (auto &lowering : loweredOutputs) lowering->lowerPort(); + + // Set up vectors to erase _all_ the ports. It's easier to rebuild everything + // than reason about interleaving the newly lowered ports with the non lowered + // ports. Also, the 'modifyPorts' method ends up rebuilding the port lists + // anyway, so this isn't nearly as expensive as it may seem. + SmallVector inputsToErase(mod.getNumInputPorts()); + std::iota(inputsToErase.begin(), inputsToErase.end(), 0); + SmallVector outputsToErase(mod.getNumOutputPorts()); + std::iota(outputsToErase.begin(), outputsToErase.end(), 0); + + mod.modifyPorts(newInputs, newOutputs, inputsToErase, outputsToErase); + + if (body) { + // We should only erase the original arguments. New ones were appended + // with the `createInput` method call. + body->eraseArguments([&ports](BlockArgument arg) { + return arg.getArgNumber() < ports.sizeInputs(); + }); + + // And erase the first ports.sizeOutputs operands from the terminator. + terminator->eraseOperands(0, ports.sizeOutputs()); + } + + // Rewrite instances pointing to this module. + for (auto *instance : moduleNode->uses()) { + auto instanceLike = instance->getInstance(); + if (!instanceLike) continue; + hw::InstanceOp hwInstance = dyn_cast_or_null(*instanceLike); + if (!hwInstance) { + return instanceLike->emitOpError( + "This code only converts hw.instance instances - ask your friendly " + "neighborhood compiler engineers to implement support for something " + "like an hw::HWMutableInstanceLike interface"); + } + updateInstance(hwInstance); + } + + // Memory optimization -- we don't need these anymore. + newInputs.clear(); + newOutputs.clear(); + return success(); +} + +void PortConverterImpl::updateInstance(hw::InstanceOp inst) { + ImplicitLocOpBuilder b(inst.getLoc(), inst); + BackedgeBuilder beb(b, inst.getLoc()); + ModulePortInfo ports = mod.getPortList(); + + // Create backedges for the future instance results so the signal mappers can + // use the future results as values. + SmallVector newResults; + for (PortInfo outputPort : ports.getOutputs()) + newResults.push_back(beb.get(outputPort.type)); + + // Map the operands. + SmallVector newOperands(ports.sizeInputs(), {}); + for (size_t oldOpIdx = 0, e = inst.getNumOperands(); oldOpIdx < e; ++oldOpIdx) + loweredInputs[oldOpIdx]->mapInputSignals( + b, inst, inst->getOperand(oldOpIdx), newOperands, newResults); + + // Map the results. + for (size_t oldResIdx = 0, e = inst.getNumResults(); oldResIdx < e; + ++oldResIdx) + loweredOutputs[oldResIdx]->mapOutputSignals( + b, inst, inst->getResult(oldResIdx), newOperands, newResults); + + // Clone the instance. We cannot just modifiy the existing one since the + // result types might have changed types and number of them. + assert(llvm::none_of(newOperands, [](Value v) { return !v; })); + b.setInsertionPointAfter(inst); + auto newInst = + b.create(mod, inst.getInstanceNameAttr(), newOperands, + inst.getParameters(), inst.getInnerSymAttr()); + newInst->setDialectAttrs(inst->getDialectAttrs()); + + // Assign the backedges to the new results. + for (auto [idx, be] : llvm::enumerate(newResults)) + be.setValue(newInst.getResult(idx)); + + // Erase the old instance. + inst.erase(); +} diff --git a/third_party/circt/lib/Dialect/HW/Transforms/CMakeLists.txt b/third_party/circt/lib/Dialect/HW/Transforms/CMakeLists.txt new file mode 100644 index 000000000..86fbb1b64 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/Transforms/CMakeLists.txt @@ -0,0 +1,20 @@ +add_circt_dialect_library(CIRCTHWTransforms + HWPrintInstanceGraph.cpp + HWSpecialize.cpp + PrintHWModuleGraph.cpp + FlattenIO.cpp + VerifyInnerRefNamespace.cpp + + DEPENDS + CIRCTHWTransformsIncGen + + LINK_LIBS PUBLIC + CIRCTHW + CIRCTSV + CIRCTSeq + CIRCTComb + CIRCTSupport + MLIRIR + MLIRPass + MLIRTransformUtils +) diff --git a/third_party/circt/lib/Dialect/HW/Transforms/FlattenIO.cpp b/third_party/circt/lib/Dialect/HW/Transforms/FlattenIO.cpp new file mode 100644 index 000000000..51f889fd2 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/Transforms/FlattenIO.cpp @@ -0,0 +1,429 @@ +//===- FlattenIO.cpp - HW I/O flattening pass -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWPasses.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace circt; + +static bool isStructType(Type type) { + return hw::getCanonicalType(type).isa(); +} + +static hw::StructType getStructType(Type type) { + return hw::getCanonicalType(type).dyn_cast(); +} + +// Legal if no in- or output type is a struct. +static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp) { + return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(), + isStructType); +} + +static llvm::SmallVector getInnerTypes(hw::StructType t) { + llvm::SmallVector inner; + t.getInnerTypes(inner); + for (auto [index, innerType] : llvm::enumerate(inner)) + inner[index] = hw::getCanonicalType(innerType); + return inner; +} + +namespace { + +// Replaces an output op with a new output with flattened (exploded) structs. +struct OutputOpConversion : public OpConversionPattern { + OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context, + DenseSet *opVisited) + : OpConversionPattern(typeConverter, context), opVisited(opVisited) {} + + LogicalResult matchAndRewrite( + hw::OutputOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector convOperands; + + // Flatten the operands. + for (auto operand : adaptor.getOperands()) { + if (auto structType = getStructType(operand.getType())) { + auto explodedStruct = rewriter.create( + op.getLoc(), getInnerTypes(structType), operand); + llvm::copy(explodedStruct.getResults(), + std::back_inserter(convOperands)); + } else { + convOperands.push_back(operand); + } + } + + // And replace. + rewriter.replaceOpWithNewOp(op, convOperands); + opVisited->insert(op->getParentOp()); + return success(); + } + DenseSet *opVisited; +}; + +struct InstanceOpConversion : public OpConversionPattern { + InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context, + DenseSet *convertedOps) + : OpConversionPattern(typeConverter, context), + convertedOps(convertedOps) {} + + LogicalResult matchAndRewrite( + hw::InstanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + // Flatten the operands. + llvm::SmallVector convOperands; + for (auto operand : adaptor.getOperands()) { + if (auto structType = getStructType(operand.getType())) { + auto explodedStruct = rewriter.create( + loc, getInnerTypes(structType), operand); + llvm::copy(explodedStruct.getResults(), + std::back_inserter(convOperands)); + } else { + convOperands.push_back(operand); + } + } + + // Create the new instance... + auto newInstance = rewriter.create( + loc, op.getReferencedModuleSlow(), op.getInstanceName(), convOperands); + + // re-create any structs in the result. + llvm::SmallVector convResults; + size_t oldResultCntr = 0; + for (size_t resIndex = 0; resIndex < newInstance.getNumResults(); + ++resIndex) { + Type oldResultType = op.getResultTypes()[oldResultCntr]; + if (auto structType = getStructType(oldResultType)) { + size_t nElements = structType.getElements().size(); + auto implodedStruct = rewriter.create( + loc, structType, + newInstance.getResults().slice(resIndex, nElements)); + convResults.push_back(implodedStruct.getResult()); + resIndex += nElements - 1; + } else + convResults.push_back(newInstance.getResult(resIndex)); + + ++oldResultCntr; + } + rewriter.replaceOp(op, convResults); + convertedOps->insert(newInstance); + return success(); + } + + DenseSet *convertedOps; +}; + +using IOTypes = std::pair; + +struct IOInfo { + // A mapping between an arg/res index and the struct type of the given field. + DenseMap argStructs, resStructs; + + // Records of the original arg/res types. + SmallVector argTypes, resTypes; +}; + +class FlattenIOTypeConverter : public TypeConverter { + public: + FlattenIOTypeConverter() { + addConversion([](Type type, SmallVectorImpl &results) { + auto structType = getStructType(type); + if (!structType) + results.push_back(type); + else { + for (auto field : structType.getElements()) + results.push_back(field.type); + } + return success(); + }); + + addTargetMaterialization([](OpBuilder &builder, hw::StructType type, + ValueRange inputs, Location loc) { + auto result = builder.create(loc, type, inputs); + return result.getResult(); + }); + + addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type, + ValueRange inputs, Location loc) { + auto structType = getStructType(type); + assert(structType && "expected struct type"); + auto result = builder.create(loc, structType, inputs); + return result.getResult(); + }); + } +}; + +} // namespace + +template +static void addSignatureConversion(DenseMap &ioMap, + ConversionTarget &target, + RewritePatternSet &patterns, + FlattenIOTypeConverter &typeConverter) { + (hw::populateHWModuleLikeTypeConversionPattern(TOp::getOperationName(), + patterns, typeConverter), + ...); + + // Legality is defined by a module having been processed once. This is due to + // that a pattern cannot be applied multiple times (a 'pattern was already + // applied' error - a case that would occur for nested structs). Additionally, + // if a pattern could be applied multiple times, this would complicate + // updating arg/res names. + + // Instead, we define legality as when a module has had a modification to its + // top-level i/o. This ensures that only a single level of structs are + // processed during signature conversion, which then allows us to use the + // signature conversion in a recursive manner. + target.addDynamicallyLegalOp([&](hw::HWModuleLike moduleLikeOp) { + if (isLegalModLikeOp(moduleLikeOp)) return true; + + // This op is involved in conversion. Check if the signature has changed. + auto ioInfoIt = ioMap.find(moduleLikeOp); + if (ioInfoIt == ioMap.end()) { + // Op wasn't primed in the map. Do the safe thing, assume + // that it's not considered in this pass, and mark it as legal + return true; + } + auto ioInfo = ioInfoIt->second; + + auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) { + return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](auto typePair) { + auto oldType = std::get<0>(typePair); + auto newType = std::get<1>(typePair); + return oldType != newType; + }); + }; + auto mtype = moduleLikeOp.getHWModuleType(); + if (compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) || + compareTypes(mtype.getInputTypes(), ioInfo.argTypes)) + return true; + + // We're pre-conversion for an op that was primed in the map - it will + // always be illegal since it has to-be-converted struct types at its I/O. + return false; + }); +} + +template +static bool hasUnconvertedOps(mlir::ModuleOp module) { + return llvm::any_of(module.getBody()->getOps(), + [](T op) { return !isLegalModLikeOp(op); }); +} + +template +static DenseMap populateIOMap(mlir::ModuleOp module) { + DenseMap ioMap; + for (auto op : module.getOps()) + ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()}; + return ioMap; +} + +template +static llvm::SmallVector updateNameAttribute( + ModTy op, StringRef attrName, DenseMap &structMap, + T oldNames) { + llvm::SmallVector newNames; + for (auto [i, oldName] : llvm::enumerate(oldNames)) { + // Was this arg/res index a struct? + auto it = structMap.find(i); + if (it == structMap.end()) { + // No, keep old name. + newNames.push_back(StringAttr::get(op->getContext(), oldName)); + continue; + } + + // Yes - create new names from the struct fields and the old name at the + // index. + auto structType = it->second; + for (auto field : structType.getElements()) + newNames.push_back( + StringAttr::get(op->getContext(), oldName + "." + field.name.str())); + } + return newNames; +} + +static llvm::SmallVector updateLocAttribute( + DenseMap &structMap, ArrayAttr oldLocs) { + llvm::SmallVector newLocs; + if (!oldLocs) return newLocs; + for (auto [i, oldLoc] : llvm::enumerate(oldLocs.getAsRange())) { + // Was this arg/res index a struct? + auto it = structMap.find(i); + if (it == structMap.end()) { + // No, keep old name. + newLocs.push_back(oldLoc); + continue; + } + + auto structType = it->second; + for (size_t i = 0, e = structType.getElements().size(); i < e; ++i) + newLocs.push_back(oldLoc); + } + return newLocs; +} + +/// The conversion framework seems to throw away block argument locations. We +/// use this function to copy the location from the original argument to the +/// set of flattened arguments. +static void updateBlockLocations( + hw::HWModuleLike op, StringRef attrName, + DenseMap &structMap) { + auto locs = op.getOperation()->getAttrOfType(attrName); + if (!locs || op.getModuleBody().empty()) return; + for (auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), + locs.getAsRange())) + arg.setLoc(loc); +} + +template +static DenseMap populateIOInfoMap(mlir::ModuleOp module) { + DenseMap ioInfoMap; + for (auto op : module.getOps()) { + IOInfo ioInfo; + ioInfo.argTypes = op.getInputTypes(); + ioInfo.resTypes = op.getOutputTypes(); + for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) { + if (auto structType = getStructType(arg)) + ioInfo.argStructs[i] = structType; + } + for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) { + if (auto structType = getStructType(res)) + ioInfo.resStructs[i] = structType; + } + ioInfoMap[op] = ioInfo; + } + return ioInfoMap; +} + +template +static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) { + auto *ctx = module.getContext(); + FlattenIOTypeConverter typeConverter; + + // Recursively (in case of nested structs) lower the module. We do this one + // conversion at a time to allow for updating the arg/res names of the + // module in between flattening each level of structs. + while (hasUnconvertedOps(module)) { + ConversionTarget target(*ctx); + RewritePatternSet patterns(ctx); + target.addLegalDialect(); + + // Record any struct types at the module signature. This will be used + // post-conversion to update the argument and result names. + auto ioInfoMap = populateIOInfoMap(module); + + // Record the instances that were converted. We keep these around since we + // need to update their arg/res attribute names after the modules themselves + // have been updated. + llvm::DenseSet convertedInstances; + + // Argument conversion for output ops. Similarly to the signature + // conversion, legality is based on the op having been visited once, due to + // the possibility of nested structs. + DenseSet opVisited; + patterns.add(typeConverter, ctx, &opVisited); + + patterns.add(typeConverter, ctx, &convertedInstances); + target.addDynamicallyLegalOp( + [&](auto op) { return opVisited.contains(op->getParentOp()); }); + target.addDynamicallyLegalOp([&](auto op) { + return llvm::none_of(op->getOperands(), [](auto operand) { + return isStructType(operand.getType()); + }); + }); + + DenseMap oldArgNames, oldResNames, oldArgLocs, + oldResLocs; + for (auto op : module.getOps()) { + oldArgNames[op] = ArrayAttr::get(module.getContext(), op.getInputNames()); + oldResNames[op] = + ArrayAttr::get(module.getContext(), op.getOutputNames()); + oldArgLocs[op] = op.getInputLocsAttr(); + oldResLocs[op] = op.getOutputLocsAttr(); + } + + // Signature conversion and legalization patterns. + addSignatureConversion(ioInfoMap, target, patterns, typeConverter); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + return failure(); + + // Update the arg/res names of the module. + for (auto op : module.getOps()) { + auto ioInfo = ioInfoMap[op]; + auto newArgNames = updateNameAttribute( + op, "argNames", ioInfo.argStructs, + oldArgNames[op].template getAsValueRange()); + auto newResNames = updateNameAttribute( + op, "resultNames", ioInfo.resStructs, + oldResNames[op].template getAsValueRange()); + newArgNames.append(newResNames.begin(), newResNames.end()); + op.setAllPortNames(newArgNames); + op.setInputLocs(updateLocAttribute(ioInfo.argStructs, oldArgLocs[op])); + op.setOutputLocs(updateLocAttribute(ioInfo.resStructs, oldResLocs[op])); + updateBlockLocations(op, "argLocs", ioInfo.argStructs); + } + + // And likewise with the converted instance ops. + for (auto instanceOp : convertedInstances) { + Operation *targetModule = instanceOp.getReferencedModuleSlow(); + auto ioInfo = ioInfoMap[targetModule]; + instanceOp.setInputNames(ArrayAttr::get( + instanceOp.getContext(), + updateNameAttribute(instanceOp, "argNames", ioInfo.argStructs, + oldArgNames[targetModule] + .template getAsValueRange()))); + instanceOp.setOutputNames(ArrayAttr::get( + instanceOp.getContext(), + updateNameAttribute(instanceOp, "resultNames", ioInfo.resStructs, + oldResNames[targetModule] + .template getAsValueRange()))); + instanceOp.dump(); + } + + // Break if we've only lowering a single level of structs. + if (!recursive) break; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass driver +//===----------------------------------------------------------------------===// + +template +static bool flattenIO(ModuleOp module, bool recursive) { + return (failed(flattenOpsOfType(module, recursive)) || ...); +} + +namespace { + +class FlattenIOPass : public circt::hw::FlattenIOBase { + public: + void runOnOperation() override { + ModuleOp module = getOperation(); + if (flattenIO(module, recursive)) + signalPassFailure(); + }; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Pass initialization +//===----------------------------------------------------------------------===// + +std::unique_ptr circt::hw::createFlattenIOPass() { + return std::make_unique(); +} diff --git a/third_party/circt/lib/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp b/third_party/circt/lib/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp new file mode 100644 index 000000000..99dededf7 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp @@ -0,0 +1,36 @@ +//===- HWPrintInstanceGraph.cpp - Print the instance graph ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// +// +// Print the module hierarchy. +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "circt/Dialect/HW/HWInstanceGraph.h" +#include "circt/Dialect/HW/HWPasses.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" + +using namespace circt; +using namespace hw; + +namespace { +struct PrintInstanceGraphPass + : public PrintInstanceGraphBase { + PrintInstanceGraphPass(raw_ostream &os) : os(os) {} + void runOnOperation() override { + InstanceGraph &instanceGraph = getAnalysis(); + llvm::WriteGraph(os, &instanceGraph, /*ShortNames=*/false); + markAllAnalysesPreserved(); + } + raw_ostream &os; +}; +} // end anonymous namespace + +std::unique_ptr circt::hw::createPrintInstanceGraphPass() { + return std::make_unique(llvm::errs()); +} diff --git a/third_party/circt/lib/Dialect/HW/Transforms/HWSpecialize.cpp b/third_party/circt/lib/Dialect/HW/Transforms/HWSpecialize.cpp new file mode 100644 index 000000000..a691024c3 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/Transforms/HWSpecialize.cpp @@ -0,0 +1,422 @@ +//===- HWSpecialize.cpp - hw module specialization ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transform performs specialization of parametric hw.module's. +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/HW/HWAttributes.h" +#include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/HW/HWPasses.h" +#include "circt/Dialect/HW/HWSymCache.h" +#include "circt/Support/Namespace.h" +#include "circt/Support/ValueMapper.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace llvm; +using namespace mlir; +using namespace circt; +using namespace hw; + +namespace { + +// Generates a module name by composing the name of 'moduleOp' and the set of +// provided 'parameters'. +static std::string generateModuleName(Namespace &ns, hw::HWModuleOp moduleOp, + ArrayAttr parameters) { + assert(parameters.size() != 0); + std::string name = moduleOp.getName().str(); + for (auto param : parameters) { + auto paramAttr = param.cast(); + int64_t paramValue = paramAttr.getValue().cast().getInt(); + name += "_" + paramAttr.getName().str() + "_" + std::to_string(paramValue); + } + + // Query the namespace to generate a unique name. + return ns.newName(name).str(); +} + +// Returns true if any operand or result of 'op' is parametric. +static bool isParametricOp(Operation *op) { + return llvm::any_of(op->getOperandTypes(), isParametricType) || + llvm::any_of(op->getResultTypes(), isParametricType); +} + +// Narrows 'value' using a comb.extract operation to the width of the +// hw.array-typed 'array'. +static FailureOr narrowValueToArrayWidth(OpBuilder &builder, Value array, + Value value) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointAfterValue(value); + auto arrayType = array.getType().cast(); + unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements()); + + return hiBit == 0 + ? builder + .create(value.getLoc(), + APInt(arrayType.getNumElements(), 0)) + .getResult() + : builder + .create(value.getLoc(), value, + /*lowBit=*/0, hiBit) + .getResult(); +} + +static hw::HWModuleOp targetModuleOp(hw::InstanceOp instanceOp, + const SymbolCache &sc) { + auto *targetOp = sc.getDefinition(instanceOp.getModuleNameAttr()); + auto targetHWModule = dyn_cast(targetOp); + if (!targetHWModule) return {}; // Won't specialize external modules. + + if (targetHWModule.getParameters().size() == 0) + return {}; // nothing to record or specialize + + return targetHWModule; +} + +// Stores unique module parameters and references to them +struct ParameterSpecializationRegistry { + llvm::MapVector> + uniqueModuleParameters; + + bool isRegistered(hw::HWModuleOp moduleOp, ArrayAttr parameters) const { + auto it = uniqueModuleParameters.find(moduleOp); + return it != uniqueModuleParameters.end() && + it->second.contains(parameters); + } + + void registerModuleOp(hw::HWModuleOp moduleOp, ArrayAttr parameters) { + uniqueModuleParameters[moduleOp].insert(parameters); + } +}; + +struct EliminateParamValueOpPattern : public OpRewritePattern { + EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters) + : OpRewritePattern(context), parameters(parameters) {} + + LogicalResult matchAndRewrite(ParamValueOp op, + PatternRewriter &rewriter) const override { + // Substitute the param value op with an evaluated constant operation. + FailureOr evaluated = + evaluateParametricAttr(op.getLoc(), parameters, op.getValue()); + if (failed(evaluated)) return failure(); + rewriter.replaceOpWithNewOp( + op, op.getType(), + evaluated->cast().getValue().getSExtValue()); + return success(); + } + + ArrayAttr parameters; +}; + +// hw.array_get operations require indexes to be of equal width of the +// array itself. Since indexes may originate from constants or parameters, +// emit comb.extract operations to fulfill this invariant. +struct NarrowArrayGetIndexPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ArrayGetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto inputType = type_cast(op.getInput().getType()); + Type targetIndexType = IntegerType::get( + getContext(), inputType.getNumElements() == 1 + ? 1 + : llvm::Log2_64_Ceil(inputType.getNumElements())); + + if (op.getIndex().getType().getIntOrFloatBitWidth() == + targetIndexType.getIntOrFloatBitWidth()) + return failure(); // nothing to do + + // Narrow the index value. + FailureOr narrowedIndex = + narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex()); + if (failed(narrowedIndex)) return failure(); + rewriter.replaceOpWithNewOp(op, op.getInput(), *narrowedIndex); + return success(); + } +}; + +// Generic pattern to convert parametric result types. +struct ParametricTypeConversionPattern : public ConversionPattern { + ParametricTypeConversionPattern(MLIRContext *ctx, + TypeConverter &typeConverter, + ArrayAttr parameters) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, + ctx), + parameters(parameters) {} + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + llvm::SmallVector convertedOperands; + // Update the result types of the operation + bool ok = true; + rewriter.updateRootInPlace(op, [&]() { + // Mutate result types + for (auto it : llvm::enumerate(op->getResultTypes())) { + FailureOr res = + evaluateParametricType(op->getLoc(), parameters, it.value()); + ok &= succeeded(res); + if (!ok) return; + op->getResult(it.index()).setType(*res); + } + + // Note: 'operands' have already been converted with the supplied type + // converter to this pattern. Make sure that we materialize this + // conversion by updating the operands to op. + op->setOperands(operands); + }); + + return success(ok); + }; + ArrayAttr parameters; +}; + +struct HWSpecializePass : public hw::HWSpecializeBase { + void runOnOperation() override; +}; + +static void populateTypeConversion(Location loc, TypeConverter &typeConverter, + ArrayAttr parameters) { + // Possibly parametric types + typeConverter.addConversion([=](hw::IntType type) { + return evaluateParametricType(loc, parameters, type); + }); + typeConverter.addConversion([=](hw::ArrayType type) { + return evaluateParametricType(loc, parameters, type); + }); + + // Valid target types. + typeConverter.addConversion([](mlir::IntegerType type) { return type; }); +} + +// Registers any nested parametric instance ops of `target` for the next +// specialization loop +static LogicalResult registerNestedParametricInstanceOps( + HWModuleOp target, ArrayAttr parameters, SymbolCache &sc, + const ParameterSpecializationRegistry ¤tRegistry, + ParameterSpecializationRegistry &nextRegistry, + llvm::DenseMap>> + ¶metersUsers) { + // Register any nested parametric instance ops for the next loop + auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult { + auto instanceParameters = instanceOp.getParameters(); + // We can ignore non-parametric instances + if (instanceParameters.empty()) return WalkResult::advance(); + + // Replace instance parameters with evaluated versions + llvm::SmallVector evaluatedInstanceParameters; + evaluatedInstanceParameters.reserve(instanceParameters.size()); + for (auto instanceParameter : instanceParameters) { + auto instanceParameterDecl = instanceParameter.cast(); + auto instanceParameterValue = instanceParameterDecl.getValue(); + auto evaluated = evaluateParametricAttr(target.getLoc(), parameters, + instanceParameterValue); + if (failed(evaluated)) return WalkResult::interrupt(); + evaluatedInstanceParameters.push_back( + hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated)); + } + + auto evaluatedInstanceParametersAttr = + ArrayAttr::get(target.getContext(), evaluatedInstanceParameters); + + if (auto targetHWModule = targetModuleOp(instanceOp, sc)) { + if (!currentRegistry.isRegistered(targetHWModule, + evaluatedInstanceParametersAttr)) + nextRegistry.registerModuleOp(targetHWModule, + evaluatedInstanceParametersAttr); + parametersUsers[targetHWModule][evaluatedInstanceParametersAttr] + .push_back(instanceOp); + } + + return WalkResult::advance(); + }); + + return failure(walkResult.wasInterrupted()); +} + +// Specializes the provided 'base' module into the 'target' module. By doing +// so, we create a new module which +// 1. has no parameters +// 2. has a name composing the name of 'base' as well as the 'parameters' +// parameters. +// 3. Has a top-level interface with any parametric types resolved. +// 4. Any references to module parameters have been replaced with the +// parameter value. +static LogicalResult specializeModule( + OpBuilder builder, ArrayAttr parameters, SymbolCache &sc, Namespace &ns, + HWModuleOp source, HWModuleOp &target, + const ParameterSpecializationRegistry ¤tRegistry, + ParameterSpecializationRegistry &nextRegistry, + llvm::DenseMap>> + ¶metersUsers) { + auto *ctx = builder.getContext(); + // Update the types of the source module ports based on evaluating any + // parametric in/output ports. + auto ports = source.getPortList(); + for (auto in : llvm::enumerate(source.getInputTypes())) { + FailureOr resType = + evaluateParametricType(source.getLoc(), parameters, in.value()); + if (failed(resType)) return failure(); + ports.atInput(in.index()).type = *resType; + } + for (auto out : llvm::enumerate(source.getOutputTypes())) { + FailureOr resolvedType = + evaluateParametricType(source.getLoc(), parameters, out.value()); + if (failed(resolvedType)) return failure(); + ports.atOutput(out.index()).type = *resolvedType; + } + + // Create the specialized module using the evaluated port info. + target = builder.create( + source.getLoc(), + StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports); + + // Erase the default created hw.output op - we'll copy the correct operation + // during body elaboration. + (*target.getOps().begin()).erase(); + + // Clone body of the source into the target. Use ValueMapper to ensure safe + // cloning in the presence of backedges. + BackedgeBuilder bb(builder, source.getLoc()); + ValueMapper mapper(&bb); + for (auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(), + target.getBodyBlock()->getArguments())) + mapper.set(src, dst); + builder.setInsertionPointToStart(target.getBodyBlock()); + + for (auto &op : source.getOps()) { + IRMapping bvMapper; + for (auto operand : op.getOperands()) + bvMapper.map(operand, mapper.get(operand)); + auto *newOp = builder.clone(op, bvMapper); + for (auto &&[oldRes, newRes] : + llvm::zip(op.getResults(), newOp->getResults())) + mapper.set(oldRes, newRes); + } + + // Register any nested parametric instance ops for the next loop + auto nestedRegistrationResult = registerNestedParametricInstanceOps( + target, parameters, sc, currentRegistry, nextRegistry, parametersUsers); + if (failed(nestedRegistrationResult)) return failure(); + + // We've now created a separate copy of the source module with a rewritten + // top-level interface. Next, we enter the module to convert parametric + // types within operations. + RewritePatternSet patterns(ctx); + TypeConverter t; + populateTypeConversion(target.getLoc(), t, parameters); + patterns.add(ctx, parameters); + patterns.add(ctx); + patterns.add(ctx, t, parameters); + ConversionTarget convTarget(*ctx); + convTarget.addLegalOp(); + convTarget.addIllegalOp(); + + // Generic legalization of converted operations. + convTarget.markUnknownOpDynamicallyLegal( + [](Operation *op) { return !isParametricOp(op); }); + + return applyPartialConversion(target, convTarget, std::move(patterns)); +} + +void HWSpecializePass::runOnOperation() { + ModuleOp module = getOperation(); + + // Record unique module parameters and references to these. + llvm::DenseMap>> + parametersUsers; + ParameterSpecializationRegistry registry; + + // Maintain a symbol cache for fast lookup during module specialization. + SymbolCache sc; + sc.addDefinitions(module); + Namespace ns; + ns.add(sc); + + for (auto hwModule : module.getOps()) { + // If this module is parametric, defer registering its parametric + // instantiations until this module is specialized + if (!hwModule.getParameters().empty()) continue; + for (auto instanceOp : hwModule.getOps()) { + if (auto targetHWModule = targetModuleOp(instanceOp, sc)) { + auto parameters = instanceOp.getParameters(); + registry.registerModuleOp(targetHWModule, parameters); + + parametersUsers[targetHWModule][parameters].push_back(instanceOp); + } + } + } + + // Create specialized modules. + OpBuilder builder = OpBuilder(&getContext()); + builder.setInsertionPointToStart(module.getBody()); + llvm::DenseMap> + specializations; + + // For every module specialization, any nested parametric modules will be + // registered for the next loop. We loop until no new nested modules have been + // registered. + while (!registry.uniqueModuleParameters.empty()) { + // The registry for the next specialization loop + ParameterSpecializationRegistry nextRegistry; + for (auto it : registry.uniqueModuleParameters) { + for (auto parameters : it.second) { + HWModuleOp specializedModule; + if (failed(specializeModule(builder, parameters, sc, ns, it.first, + specializedModule, registry, nextRegistry, + parametersUsers))) { + signalPassFailure(); + return; + } + + // Extend the symbol cache with the newly created module. + sc.addDefinition(specializedModule.getNameAttr(), specializedModule); + + // Add the specialization + specializations[it.first][parameters] = specializedModule; + } + } + + // Transfer newly registered specializations to iterate over + registry.uniqueModuleParameters = + std::move(nextRegistry.uniqueModuleParameters); + } + + // Rewrite instances of specialized modules to the specialized module. + for (auto it : specializations) { + auto unspecialized = it.getFirst(); + auto &users = parametersUsers[unspecialized]; + for (auto specialization : it.getSecond()) { + auto parameters = specialization.getFirst(); + auto specializedModule = specialization.getSecond(); + for (auto instanceOp : users[parameters]) { + instanceOp->setAttr("moduleName", + FlatSymbolRefAttr::get(specializedModule)); + instanceOp->setAttr("parameters", ArrayAttr::get(&getContext(), {})); + } + } + } +} + +} // namespace + +std::unique_ptr circt::hw::createHWSpecializePass() { + return std::make_unique(); +} diff --git a/third_party/circt/lib/Dialect/HW/Transforms/PassDetails.h b/third_party/circt/lib/Dialect/HW/Transforms/PassDetails.h new file mode 100644 index 000000000..3b8b69433 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/Transforms/PassDetails.h @@ -0,0 +1,31 @@ +//===- PassDetails.h - HW pass class details ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Stuff shared between the different HW passes. +// +//===----------------------------------------------------------------------===// + +// clang-tidy seems to expect the absolute path in the header guard on some +// systems, so just disable it. +// NOLINTNEXTLINE(llvm-header-guard) +#ifndef DIALECT_HW_TRANSFORMS_PASSDETAILS_H +#define DIALECT_HW_TRANSFORMS_PASSDETAILS_H + +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Pass/Pass.h" + +namespace circt { +namespace hw { + +#define GEN_PASS_CLASSES +#include "circt/Dialect/HW/Passes.h.inc" + +} // namespace hw +} // namespace circt + +#endif // DIALECT_HW_TRANSFORMS_PASSDETAILS_H diff --git a/third_party/circt/lib/Dialect/HW/Transforms/PrintHWModuleGraph.cpp b/third_party/circt/lib/Dialect/HW/Transforms/PrintHWModuleGraph.cpp new file mode 100644 index 000000000..196aeaf79 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/Transforms/PrintHWModuleGraph.cpp @@ -0,0 +1,42 @@ +//===- PrintHWModuleGraph.cpp - Print the instance graph --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// +// +// Prints an HW module as a .dot graph. +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "circt/Dialect/HW/HWModuleGraph.h" +#include "circt/Dialect/HW/HWPasses.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" + +using namespace circt; +using namespace hw; + +namespace { +struct PrintHWModuleGraphPass + : public PrintHWModuleGraphBase { + PrintHWModuleGraphPass(raw_ostream &os) : os(os) {} + void runOnOperation() override { + getOperation().walk([&](hw::HWModuleOp module) { + // We don't really have any other way of forwarding draw arguments to the + // DOTGraphTraits for HWModule except through the module itself - as an + // attribute. + module->setAttr("dot_verboseEdges", + BoolAttr::get(module.getContext(), verboseEdges)); + + llvm::WriteGraph(os, module, /*ShortNames=*/false); + }); + } + raw_ostream &os; +}; +} // end anonymous namespace + +std::unique_ptr circt::hw::createPrintHWModuleGraphPass() { + return std::make_unique(llvm::errs()); +} diff --git a/third_party/circt/lib/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp b/third_party/circt/lib/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp new file mode 100644 index 000000000..1a4e75282 --- /dev/null +++ b/third_party/circt/lib/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp @@ -0,0 +1,44 @@ +//===- VerifyInnerRefNamespace.cpp - InnerRefNamespace verification Pass --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a simple pass to drive verification of operations +// that want to be InnerRefNamespace's but don't have the trait to verify +// themselves. +// +//===----------------------------------------------------------------------===// + +#include "PassDetails.h" +#include "circt/Dialect/HW/HWOpInterfaces.h" +#include "circt/Dialect/HW/HWPasses.h" + +/// VerifyInnerRefNamespace pass until have container operation. + +namespace { + +class VerifyInnerRefNamespacePass + : public circt::hw::VerifyInnerRefNamespaceBase< + VerifyInnerRefNamespacePass> { + public: + void runOnOperation() override { + auto *irnLike = getOperation(); + if (!irnLike->hasTrait()) + if (failed(circt::hw::detail::verifyInnerRefNamespace(irnLike))) + return signalPassFailure(); + + return markAllAnalysesPreserved(); + }; + bool canScheduleOn(mlir::RegisteredOperationName opInfo) const override { + return llvm::isa(opInfo); + } +}; + +} // namespace + +std::unique_ptr circt::hw::createVerifyInnerRefNamespacePass() { + return std::make_unique(); +} diff --git a/third_party/circt/lib/Support/APInt.cpp b/third_party/circt/lib/Support/APInt.cpp new file mode 100644 index 000000000..f19f9fec8 --- /dev/null +++ b/third_party/circt/lib/Support/APInt.cpp @@ -0,0 +1,27 @@ +//===- APInt.h - CIRCT Lowering Options -------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for working around limitations of upstream LLVM APInts. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/APInt.h" + +#include "llvm/ADT/APSInt.h" + +using namespace circt; + +APInt circt::sextZeroWidth(APInt value, unsigned width) { + return value.getBitWidth() ? value.sext(width) : value.zext(width); +} + +APSInt circt::extOrTruncZeroWidth(APSInt value, unsigned width) { + return value.getBitWidth() + ? value.extOrTrunc(width) + : APSInt(value.zextOrTrunc(width), value.isUnsigned()); +} diff --git a/third_party/circt/lib/Support/BUILD b/third_party/circt/lib/Support/BUILD new file mode 100644 index 000000000..c02339399 --- /dev/null +++ b/third_party/circt/lib/Support/BUILD @@ -0,0 +1,37 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Support", + srcs = [ + "CustomDirectiveImpl.cpp", + "InstanceGraph.cpp", + "ValueMapper.cpp", + ], + hdrs = [ + "@heir//third_party/circt/include/circt/Support:BackedgeBuilder.h", + "@heir//third_party/circt/include/circt/Support:BuilderUtils.h", + "@heir//third_party/circt/include/circt/Support:CustomDirectiveImpl.h", + "@heir//third_party/circt/include/circt/Support:InstanceGraph.h", + "@heir//third_party/circt/include/circt/Support:InstanceGraphInterface.h", + "@heir//third_party/circt/include/circt/Support:LLVM.h", + "@heir//third_party/circt/include/circt/Support:Namespace.h", + "@heir//third_party/circt/include/circt/Support:ParsingUtils.h", + "@heir//third_party/circt/include/circt/Support:SymCache.h", + "@heir//third_party/circt/include/circt/Support:ValueMapper.h", + ], + copts = [ + "-Ithird_party/circt/include", + ], + strip_include_prefix = "/third_party/circt/include", + deps = [ + "@heir//third_party/circt/include/circt/Support:interfaces_inc_gen", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/circt/lib/Support/BackedgeBuilder.cpp b/third_party/circt/lib/Support/BackedgeBuilder.cpp new file mode 100644 index 000000000..737ea9f71 --- /dev/null +++ b/third_party/circt/lib/Support/BackedgeBuilder.cpp @@ -0,0 +1,71 @@ +//===- BackedgeBuilder.cpp - Support for building backedges ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provide support for building backedges. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/BackedgeBuilder.h" + +#include "circt/Support/LLVM.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace circt; + +Backedge::Backedge(mlir::Operation *op) : value(op->getResult(0)) {} + +void Backedge::setValue(mlir::Value newValue) { + assert(value.getType() == newValue.getType()); + assert(!set && "backedge already set to a value!"); + value.replaceAllUsesWith(newValue); + value = newValue; // In case the backedge is still referred to after setting. + set = true; + + // If the backedge is referenced again, it should now point to the updated + // value. + value = newValue; +} + +BackedgeBuilder::~BackedgeBuilder() { (void)clearOrEmitError(); } + +LogicalResult BackedgeBuilder::clearOrEmitError() { + unsigned numInUse = 0; + for (Operation *op : edges) { + if (!op->use_empty()) { + auto diag = op->emitError("backedge of type `") + << op->getResult(0).getType() << "`still in use"; + for (auto user : op->getUsers()) + diag.attachNote(user->getLoc()) << "used by " << *user; + ++numInUse; + continue; + } + if (rewriter) + rewriter->eraseOp(op); + else + op->erase(); + } + edges.clear(); + if (numInUse > 0) + mlir::emitRemark(loc, "abandoned ") << numInUse << " backedges"; + return success(numInUse == 0); +} + +void BackedgeBuilder::abandon() { edges.clear(); } + +BackedgeBuilder::BackedgeBuilder(OpBuilder &builder, Location loc) + : builder(builder), rewriter(nullptr), loc(loc) {} +BackedgeBuilder::BackedgeBuilder(PatternRewriter &rewriter, Location loc) + : builder(rewriter), rewriter(&rewriter), loc(loc) {} +Backedge BackedgeBuilder::get(Type t, mlir::LocationAttr optionalLoc) { + if (!optionalLoc) optionalLoc = loc; + Operation *op = builder.create( + optionalLoc, t, ValueRange{}); + edges.push_back(op); + return Backedge(op); +} diff --git a/third_party/circt/lib/Support/CMakeLists.txt b/third_party/circt/lib/Support/CMakeLists.txt new file mode 100644 index 000000000..27d6fd639 --- /dev/null +++ b/third_party/circt/lib/Support/CMakeLists.txt @@ -0,0 +1,52 @@ +##===- CMakeLists.txt - Define a support library --------------*- cmake -*-===// +## +##===----------------------------------------------------------------------===// + +set(VERSION_CPP "${CMAKE_CURRENT_BINARY_DIR}/Version.cpp") +set_source_files_properties("${VERSION_CPP}" PROPERTIES GENERATED TRUE) + +add_circt_library(CIRCTSupport + APInt.cpp + BackedgeBuilder.cpp + CustomDirectiveImpl.cpp + FieldRef.cpp + JSON.cpp + LoweringOptions.cpp + Passes.cpp + Path.cpp + PrettyPrinter.cpp + PrettyPrinterHelpers.cpp + ParsingUtils.cpp + SymCache.cpp + ValueMapper.cpp + InstanceGraph.cpp + "${VERSION_CPP}" + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransforms + MLIRTransformUtils + ) + +#------------------------------------------------------------------------------- +# Generate Version.cpp +#------------------------------------------------------------------------------- +find_first_existing_vc_file("${CIRCT_SOURCE_DIR}" CIRCT_GIT_LOGS_HEAD) +set(GEN_VERSION_SCRIPT "${CIRCT_SOURCE_DIR}/cmake/modules/GenVersionFile.cmake") + +if (CIRCT_RELEASE_TAG_ENABLED) + add_custom_command(OUTPUT "${VERSION_CPP}" + DEPENDS "${CIRCT_GIT_LOGS_HEAD}" "${GEN_VERSION_SCRIPT}" + COMMAND ${CMAKE_COMMAND} -DIN_FILE="${CMAKE_CURRENT_SOURCE_DIR}/Version.cpp.in" + -DOUT_FILE="${VERSION_CPP}" -DRELEASE_PATTERN=${CIRCT_RELEASE_TAG}* + -DDRY_RUN=OFF -DSOURCE_ROOT="${CIRCT_SOURCE_DIR}" + -P "${GEN_VERSION_SCRIPT}") +else () + # If the release tag generation is disabled, run the script only at the first + # cmake configuration. + add_custom_command(OUTPUT "${VERSION_CPP}" + DEPENDS "${GEN_VERSION_SCRIPT}" + COMMAND ${CMAKE_COMMAND} -DIN_FILE="${CMAKE_CURRENT_SOURCE_DIR}/Version.cpp.in" + -DOUT_FILE="${VERSION_CPP}" -DDRY_RUN=ON -DSOURCE_ROOT="${CIRCT_SOURCE_DIR}" + -P "${GEN_VERSION_SCRIPT}") +endif() diff --git a/third_party/circt/lib/Support/ConversionPatterns.cpp b/third_party/circt/lib/Support/ConversionPatterns.cpp new file mode 100644 index 000000000..5b7ebf84f --- /dev/null +++ b/third_party/circt/lib/Support/ConversionPatterns.cpp @@ -0,0 +1,90 @@ +//===- ConversionPatterns.cpp - Common Conversion patterns ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/ConversionPatterns.h" + +using namespace circt; + +// Converts a function type wrt. the given type converter. +static FunctionType convertFunctionType(const TypeConverter &typeConverter, + FunctionType type) { + // Convert the original function types. + llvm::SmallVector res, arg; + llvm::transform(type.getResults(), std::back_inserter(res), + [&](Type t) { return typeConverter.convertType(t); }); + llvm::transform(type.getInputs(), std::back_inserter(arg), + [&](Type t) { return typeConverter.convertType(t); }); + + return FunctionType::get(type.getContext(), arg, res); +} + +LogicalResult TypeConversionPattern::matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // Convert the TypeAttrs. + llvm::SmallVector newAttrs; + newAttrs.reserve(op->getAttrs().size()); + for (auto attr : op->getAttrs()) { + if (auto typeAttr = attr.getValue().dyn_cast()) { + auto innerType = typeAttr.getValue(); + // TypeConvert::convertType doesn't handle function types, so we need to + // handle them manually. + if (auto funcType = innerType.dyn_cast(); innerType) + innerType = convertFunctionType(*getTypeConverter(), funcType); + else + innerType = getTypeConverter()->convertType(innerType); + newAttrs.emplace_back(attr.getName(), TypeAttr::get(innerType)); + } else { + newAttrs.push_back(attr); + } + } + + // Convert the result types. + llvm::SmallVector newResults; + if (failed( + getTypeConverter()->convertTypes(op->getResultTypes(), newResults))) + return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed"); + + // Build the state for the edited clone. + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + newResults, newAttrs, op->getSuccessors()); + for (size_t i = 0, e = op->getNumRegions(); i < e; ++i) state.addRegion(); + + // Must create the op before running any modifications on the regions so that + // we don't crash with '-debug' and so we have something to 'root update'. + Operation *newOp = rewriter.create(state); + + // Move the regions over, converting the signatures as we go. + rewriter.startRootUpdate(newOp); + for (size_t i = 0, e = op->getNumRegions(); i < e; ++i) { + Region ®ion = op->getRegion(i); + Region *newRegion = &newOp->getRegion(i); + + // TypeConverter::SignatureConversion drops argument locations, so we need + // to manually copy them over (a verifier in e.g. HWModule checks this). + llvm::SmallVector argLocs; + for (auto arg : region.getArguments()) argLocs.push_back(arg.getLoc()); + + // Move the region and convert the region args. + rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); + TypeConverter::SignatureConversion result(newRegion->getNumArguments()); + if (failed(getTypeConverter()->convertSignatureArgs( + newRegion->getArgumentTypes(), result))) + return rewriter.notifyMatchFailure(op->getLoc(), + "type conversion failed"); + rewriter.applySignatureConversion(newRegion, result, getTypeConverter()); + + // Apply the argument locations. + for (auto [arg, loc] : llvm::zip(newRegion->getArguments(), argLocs)) + arg.setLoc(loc); + } + rewriter.finalizeRootUpdate(newOp); + + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} diff --git a/third_party/circt/lib/Support/CustomDirectiveImpl.cpp b/third_party/circt/lib/Support/CustomDirectiveImpl.cpp new file mode 100644 index 000000000..f8ef9eac2 --- /dev/null +++ b/third_party/circt/lib/Support/CustomDirectiveImpl.cpp @@ -0,0 +1,130 @@ +//===- CustomDirectiveImpl.cpp - Custom TableGen directives ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/CustomDirectiveImpl.h" + +#include "llvm/ADT/SmallString.h" + +using namespace circt; + +ParseResult circt::parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr) { + // Use the explicit name if one is provided as `name "xyz"`. + if (!parser.parseOptionalKeyword("name")) { + std::string str; + if (parser.parseString(&str)) return failure(); + attr = parser.getBuilder().getStringAttr(str); + return success(); + } + + // Infer the name from the SSA name of the operation's first result. + auto resultName = parser.getResultName(0).first; + if (!resultName.empty() && isdigit(resultName[0])) resultName = ""; + attr = parser.getBuilder().getStringAttr(resultName); + return success(); +} + +ParseResult circt::parseImplicitSSAName(OpAsmParser &parser, + NamedAttrList &attrs) { + if (parser.parseOptionalAttrDict(attrs)) return failure(); + inferImplicitSSAName(parser, attrs); + return success(); +} + +bool circt::inferImplicitSSAName(OpAsmParser &parser, NamedAttrList &attrs) { + // Don't do anything if a `name` attribute is explicitly provided. + if (attrs.get("name")) return false; + + // Infer the name from the SSA name of the operation's first result. + auto resultName = parser.getResultName(0).first; + if (!resultName.empty() && isdigit(resultName[0])) resultName = ""; + auto nameAttr = parser.getBuilder().getStringAttr(resultName); + auto *context = parser.getBuilder().getContext(); + attrs.push_back({StringAttr::get(context, "name"), nameAttr}); + return true; +} + +void circt::printImplicitSSAName(OpAsmPrinter &printer, Operation *op, + StringAttr attr) { + SmallString<32> resultNameStr; + llvm::raw_svector_ostream tmpStream(resultNameStr); + printer.printOperand(op->getResult(0), tmpStream); + auto actualName = tmpStream.str().drop_front(); + auto expectedName = attr.getValue(); + // Anonymous names are printed as digits, which is fine. + if (actualName == expectedName || + (expectedName.empty() && isdigit(actualName[0]))) + return; + + printer << " name " << attr; +} + +void circt::printImplicitSSAName(OpAsmPrinter &printer, Operation *op, + DictionaryAttr attrs, + ArrayRef extraElides) { + SmallVector elides(extraElides.begin(), extraElides.end()); + elideImplicitSSAName(printer, op, attrs, elides); + printer.printOptionalAttrDict(attrs.getValue(), elides); +} + +void circt::elideImplicitSSAName(OpAsmPrinter &printer, Operation *op, + DictionaryAttr attrs, + SmallVectorImpl &elides) { + SmallString<32> resultNameStr; + llvm::raw_svector_ostream tmpStream(resultNameStr); + printer.printOperand(op->getResult(0), tmpStream); + auto actualName = tmpStream.str().drop_front(); + auto expectedName = attrs.getAs("name").getValue(); + // Anonymous names are printed as digits, which is fine. + if (actualName == expectedName || + (expectedName.empty() && isdigit(actualName[0]))) + elides.push_back("name"); +} + +ParseResult circt::parseOptionalBinaryOpTypes(OpAsmParser &parser, Type &lhs, + Type &rhs) { + if (parser.parseType(lhs)) return failure(); + + // Parse an optional rhs type. + if (parser.parseOptionalComma()) { + rhs = lhs; + } else { + if (parser.parseType(rhs)) return failure(); + } + return success(); +} + +void circt::printOptionalBinaryOpTypes(OpAsmPrinter &p, Operation *op, Type lhs, + Type rhs) { + p << lhs; + // If operand types are not same, print a rhs type. + if (lhs != rhs) p << ", " << rhs; +} + +ParseResult circt::parseKeywordBool(OpAsmParser &parser, BoolAttr &attr, + StringRef trueKeyword, + StringRef falseKeyword) { + if (succeeded(parser.parseOptionalKeyword(trueKeyword))) { + attr = BoolAttr::get(parser.getContext(), true); + } else if (succeeded(parser.parseOptionalKeyword(falseKeyword))) { + attr = BoolAttr::get(parser.getContext(), false); + } else { + return parser.emitError(parser.getCurrentLocation()) + << "expected keyword \"" << trueKeyword << "\" or \"" << falseKeyword + << "\""; + } + return success(); +} + +void circt::printKeywordBool(OpAsmPrinter &printer, Operation *op, + BoolAttr attr, StringRef trueKeyword, + StringRef falseKeyword) { + if (attr.getValue()) + printer << trueKeyword; + else + printer << falseKeyword; +} diff --git a/third_party/circt/lib/Support/FieldRef.cpp b/third_party/circt/lib/Support/FieldRef.cpp new file mode 100644 index 000000000..9037ebfa8 --- /dev/null +++ b/third_party/circt/lib/Support/FieldRef.cpp @@ -0,0 +1,23 @@ +//===- FieldRef.cpp - Field Refs -------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines FieldRef and helpers for them. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/FieldRef.h" + +#include "mlir/IR/Block.h" +#include "mlir/IR/Value.h" + +using namespace circt; + +Operation *FieldRef::getDefiningOp() const { + if (auto *op = value.getDefiningOp()) return op; + return value.cast().getOwner()->getParentOp(); +} diff --git a/third_party/circt/lib/Support/InstanceGraph.cpp b/third_party/circt/lib/Support/InstanceGraph.cpp new file mode 100644 index 000000000..729bb0a6c --- /dev/null +++ b/third_party/circt/lib/Support/InstanceGraph.cpp @@ -0,0 +1,314 @@ +//===- InstanceGraph.cpp - Instance Graph -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/InstanceGraph.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Threading.h" + +using namespace circt; +using namespace igraph; + +void InstanceRecord::erase() { + // Update the prev node to point to the next node. + if (prevUse) + prevUse->nextUse = nextUse; + else + target->firstUse = nextUse; + // Update the next node to point to the prev node. + if (nextUse) nextUse->prevUse = prevUse; + getParent()->instances.erase(this); +} + +InstanceRecord *InstanceGraphNode::addInstance(InstanceOpInterface instance, + InstanceGraphNode *target) { + auto *instanceRecord = new InstanceRecord(this, instance, target); + target->recordUse(instanceRecord); + instances.push_back(instanceRecord); + return instanceRecord; +} + +void InstanceGraphNode::recordUse(InstanceRecord *record) { + record->nextUse = firstUse; + if (firstUse) firstUse->prevUse = record; + firstUse = record; +} + +InstanceGraphNode *InstanceGraph::getOrAddNode(StringAttr name) { + // Try to insert an InstanceGraphNode. If its not inserted, it returns + // an iterator pointing to the node. + auto *&node = nodeMap[name]; + if (!node) { + node = new InstanceGraphNode(); + nodes.push_back(node); + } + return node; +} + +InstanceGraph::InstanceGraph(Operation *parent) : parent(parent) { + assert(parent->hasTrait() && + "top-level operation must have a single block"); + SmallVector>> + moduleToInstances; + // First accumulate modules inside the parent op. + for (auto module : + parent->getRegion(0).front().getOps()) + moduleToInstances.push_back({module, {}}); + + // Populate instances in the module parallelly. + mlir::parallelFor(parent->getContext(), 0, moduleToInstances.size(), + [&](size_t idx) { + auto module = moduleToInstances[idx].first; + auto &instances = moduleToInstances[idx].second; + // Find all instance operations in the module body. + module.walk([&](InstanceOpInterface instanceOp) { + instances.push_back(instanceOp); + }); + }); + + // Construct an instance graph sequentially. + for (auto &[module, instances] : moduleToInstances) { + auto name = module.getModuleNameAttr(); + auto *currentNode = getOrAddNode(name); + currentNode->module = module; + for (auto instanceOp : instances) { + // Add an edge to indicate that this module instantiates the target. + auto *targetNode = getOrAddNode(instanceOp.getReferencedModuleNameAttr()); + currentNode->addInstance(instanceOp, targetNode); + } + } +} + +InstanceGraphNode *InstanceGraph::addModule(ModuleOpInterface module) { + assert(!nodeMap.count(module.getModuleNameAttr()) && "module already added"); + auto *node = new InstanceGraphNode(); + node->module = module; + nodeMap[module.getModuleNameAttr()] = node; + nodes.push_back(node); + return node; +} + +void InstanceGraph::erase(InstanceGraphNode *node) { + assert(node->noUses() && + "all instances of this module must have been erased."); + // Erase all instances inside this module. + for (auto *instance : llvm::make_early_inc_range(*node)) instance->erase(); + nodeMap.erase(node->getModule().getModuleNameAttr()); + nodes.erase(node); +} + +InstanceGraphNode *InstanceGraph::lookup(StringAttr name) { + auto it = nodeMap.find(name); + assert(it != nodeMap.end() && "Module not in InstanceGraph!"); + return it->second; +} + +InstanceGraphNode *InstanceGraph::lookup(ModuleOpInterface op) { + return lookup(cast(op).getModuleNameAttr()); +} + +ModuleOpInterface InstanceGraph::getReferencedModuleImpl( + InstanceOpInterface op) { + return lookup(op.getReferencedModuleNameAttr())->getModule(); +} + +void InstanceGraph::replaceInstance(InstanceOpInterface inst, + InstanceOpInterface newInst) { + assert(inst.getReferencedModuleName() == newInst.getReferencedModuleName() && + "Both instances must be targeting the same module"); + + // Find the instance record of this instance. + auto *node = lookup(inst.getReferencedModuleNameAttr()); + auto it = llvm::find_if(node->uses(), [&](InstanceRecord *record) { + return record->getInstance() == inst; + }); + assert(it != node->usesEnd() && "Instance of module not recorded in graph"); + + // We can just replace the instance op in the InstanceRecord without updating + // any instance lists. + (*it)->instance = newInst; +} + +bool InstanceGraph::isAncestor(ModuleOpInterface child, + ModuleOpInterface parent) { + DenseSet seen; + SmallVector worklist; + auto *cn = lookup(child); + worklist.push_back(cn); + seen.insert(cn); + while (!worklist.empty()) { + auto *node = worklist.back(); + worklist.pop_back(); + if (node->getModule() == parent) return true; + for (auto *use : node->uses()) { + auto *mod = use->getParent(); + if (!seen.count(mod)) { + seen.insert(mod); + worklist.push_back(mod); + } + } + } + return false; +} + +FailureOr> +InstanceGraph::getInferredTopLevelNodes() { + if (!inferredTopLevelNodes.empty()) return {inferredTopLevelNodes}; + + /// Topologically sort the instance graph. + llvm::SetVector visited, marked; + llvm::SetVector candidateTopLevels(this->begin(), + this->end()); + SmallVector cycleTrace; + + // Recursion function; returns true if a cycle was detected. + std::function)> + cycleUtil = + [&](InstanceGraphNode *node, SmallVector trace) { + if (visited.contains(node)) return false; + trace.push_back(node); + if (marked.contains(node)) { + // Cycle detected. + cycleTrace = trace; + return true; + } + marked.insert(node); + for (auto use : *node) { + InstanceGraphNode *targetModule = use->getTarget(); + candidateTopLevels.remove(targetModule); + if (cycleUtil(targetModule, trace)) + return true; // Cycle detected. + } + marked.remove(node); + visited.insert(node); + return false; + }; + + bool cyclic = false; + for (auto moduleIt : *this) { + if (visited.contains(moduleIt)) continue; + + cyclic |= cycleUtil(moduleIt, {}); + if (cyclic) break; + } + + if (cyclic) { + auto err = getParent()->emitOpError(); + err << "cannot deduce top level module - cycle " + "detected in instance graph ("; + llvm::interleave( + cycleTrace, err, + [&](auto node) { err << node->getModule().getModuleName(); }, "->"); + err << ")."; + return err; + } + assert(!candidateTopLevels.empty() && + "if non-cyclic, there should be at least 1 candidate top level"); + + inferredTopLevelNodes = llvm::SmallVector( + candidateTopLevels.begin(), candidateTopLevels.end()); + return {inferredTopLevelNodes}; +} + +static InstancePath empty{}; + +// NOLINTBEGIN(misc-no-recursion) +ArrayRef InstancePathCache::getAbsolutePaths( + ModuleOpInterface op) { + InstanceGraphNode *node = instanceGraph[op]; + + if (node == instanceGraph.getTopLevelNode()) { + return empty; + } + + // Fast path: hit the cache. + auto cached = absolutePathsCache.find(op); + if (cached != absolutePathsCache.end()) return cached->second; + + // For each instance, collect the instance paths to its parent and append the + // instance itself to each. + SmallVector extendedPaths; + for (auto *inst : node->uses()) { + if (auto module = inst->getParent()->getModule()) { + auto instPaths = getAbsolutePaths(module); + extendedPaths.reserve(instPaths.size()); + for (auto path : instPaths) { + extendedPaths.push_back(appendInstance( + path, cast(*inst->getInstance()))); + } + } else { + extendedPaths.emplace_back(empty); + } + } + + // Move the list of paths into the bump allocator for later quick retrieval. + ArrayRef pathList; + if (!extendedPaths.empty()) { + auto *paths = allocator.Allocate(extendedPaths.size()); + std::copy(extendedPaths.begin(), extendedPaths.end(), paths); + pathList = ArrayRef(paths, extendedPaths.size()); + } + absolutePathsCache.insert({op, pathList}); + return pathList; +} +// NOLINTEND(misc-no-recursion) + +InstancePath InstancePathCache::appendInstance(InstancePath path, + InstanceOpInterface inst) { + size_t n = path.size() + 1; + auto *newPath = allocator.Allocate(n); + std::copy(path.begin(), path.end(), newPath); + newPath[path.size()] = inst; + return InstancePath(ArrayRef(newPath, n)); +} + +InstancePath InstancePathCache::prependInstance(InstanceOpInterface inst, + InstancePath path) { + size_t n = path.size() + 1; + auto *newPath = allocator.Allocate(n); + std::copy(path.begin(), path.end(), newPath + 1); + newPath[0] = inst; + return InstancePath(ArrayRef(newPath, n)); +} + +void InstancePathCache::replaceInstance(InstanceOpInterface oldOp, + InstanceOpInterface newOp) { + instanceGraph.replaceInstance(oldOp, newOp); + + // Iterate over all the paths, and search for the old InstanceOpInterface. If + // found, then replace it with the new InstanceOpInterface, and create a new + // copy of the paths and update the cache. + auto instanceExists = [&](const ArrayRef &paths) -> bool { + return llvm::any_of( + paths, [&](InstancePath p) { return llvm::is_contained(p, oldOp); }); + }; + + for (auto &iter : absolutePathsCache) { + if (!instanceExists(iter.getSecond())) continue; + SmallVector updatedPaths; + for (auto path : iter.getSecond()) { + const auto *iter = llvm::find(path, oldOp); + if (iter == path.end()) { + // path does not contain the oldOp, just copy it as is. + updatedPaths.push_back(path); + continue; + } + auto *newPath = allocator.Allocate(path.size()); + llvm::copy(path, newPath); + newPath[iter - path.begin()] = newOp; + updatedPaths.push_back(InstancePath(ArrayRef(newPath, path.size()))); + } + // Move the list of paths into the bump allocator for later quick + // retrieval. + auto *paths = allocator.Allocate(updatedPaths.size()); + llvm::copy(updatedPaths, paths); + iter.getSecond() = ArrayRef(paths, updatedPaths.size()); + } +} + +#include "circt/Support/InstanceGraphInterface.cpp.inc" diff --git a/third_party/circt/lib/Support/JSON.cpp b/third_party/circt/lib/Support/JSON.cpp new file mode 100644 index 000000000..776545b90 --- /dev/null +++ b/third_party/circt/lib/Support/JSON.cpp @@ -0,0 +1,123 @@ +//===- Json.cpp - Json Utilities --------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/JSON.h" + +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/OperationSupport.h" + +namespace json = llvm::json; + +using namespace circt; +using mlir::UnitAttr; + +// NOLINTBEGIN(misc-no-recursion) +LogicalResult circt::convertAttributeToJSON(llvm::json::OStream &json, + Attribute attr) { + return TypeSwitch(attr) + .Case([&](auto attr) { + json.objectBegin(); + for (auto subAttr : attr) { + json.attributeBegin(subAttr.getName()); + if (failed(convertAttributeToJSON(json, subAttr.getValue()))) + return failure(); + json.attributeEnd(); + } + json.objectEnd(); + return success(); + }) + .Case([&](auto attr) { + json.arrayBegin(); + for (auto subAttr : attr) + if (failed(convertAttributeToJSON(json, subAttr))) return failure(); + json.arrayEnd(); + return success(); + }) + .Case([&](auto attr) { + json.value(attr.getValue()); + return success(); + }) + .Case([&](auto attr) -> LogicalResult { + // If the integer can be accurately represented by a double, print + // it as an integer. Otherwise, convert it to an exact decimal string. + const auto &apint = attr.getValue(); + if (!apint.isSignedIntN(64)) return failure(); + json.value(apint.getSExtValue()); + return success(); + }) + .Case([&](auto attr) -> LogicalResult { + const auto &apfloat = attr.getValue(); + json.value(apfloat.convertToDouble()); + return success(); + }) + .Default([&](auto) -> LogicalResult { return failure(); }); +} +// NOLINTEND(misc-no-recursion) + +// NOLINTBEGIN(misc-no-recursion) +Attribute circt::convertJSONToAttribute(MLIRContext *context, + json::Value &value, json::Path p) { + // String or quoted JSON + if (auto a = value.getAsString()) { + // Test to see if this might be quoted JSON (a string that is actually + // JSON). Sometimes FIRRTL developers will do this to serialize objects + // that the Scala FIRRTL Compiler doesn't know about. + auto unquotedValue = json::parse(*a); + auto err = unquotedValue.takeError(); + // If this parsed without an error and we didn't just unquote a number, then + // it's more JSON and recurse on that. + // + // We intentionally do not want to unquote a number as, in JSON, the string + // "0" is different from the number 0. If we conflate these, then later + // expectations about annotation structure may be broken. I.e., an + // annotation expecting a string may see a number. + if (!err && !unquotedValue.get().getAsNumber()) + return convertJSONToAttribute(context, unquotedValue.get(), p); + // If there was an error, then swallow it and handle this as a string. + handleAllErrors(std::move(err), [&](const json::ParseError &a) {}); + return StringAttr::get(context, *a); + } + + // Integer + if (auto a = value.getAsInteger()) + return IntegerAttr::get(IntegerType::get(context, 64), *a); + + // Float + if (auto a = value.getAsNumber()) + return FloatAttr::get(mlir::FloatType::getF64(context), *a); + + // Boolean + if (auto a = value.getAsBoolean()) return BoolAttr::get(context, *a); + + // Null + if (auto a = value.getAsNull()) return mlir::UnitAttr::get(context); + + // Object + if (auto *a = value.getAsObject()) { + NamedAttrList metadata; + for (auto b : *a) + metadata.append( + b.first, convertJSONToAttribute(context, b.second, p.field(b.first))); + return DictionaryAttr::get(context, metadata); + } + + // Array + if (auto *a = value.getAsArray()) { + SmallVector metadata; + for (size_t i = 0, e = (*a).size(); i != e; ++i) + metadata.push_back(convertJSONToAttribute(context, (*a)[i], p.index(i))); + return ArrayAttr::get(context, metadata); + } + + llvm_unreachable("Impossible unhandled JSON type"); +} +// NOLINTEND(misc-no-recursion) diff --git a/third_party/circt/lib/Support/LoweringOptions.cpp b/third_party/circt/lib/Support/LoweringOptions.cpp new file mode 100644 index 000000000..ce49431de --- /dev/null +++ b/third_party/circt/lib/Support/LoweringOptions.cpp @@ -0,0 +1,187 @@ +//===- LoweringOptions.cpp - CIRCT Lowering Options -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Options for controlling the lowering process. Contains command line +// option definitions and support. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/LoweringOptions.h" + +#include "mlir/IR/BuiltinOps.h" + +using namespace circt; +using namespace mlir; + +//===----------------------------------------------------------------------===// +// LoweringOptions +//===----------------------------------------------------------------------===// + +LoweringOptions::LoweringOptions(StringRef options, ErrorHandlerT errorHandler) + : LoweringOptions() { + parse(options, errorHandler); +} + +LoweringOptions::LoweringOptions(mlir::ModuleOp module) : LoweringOptions() { + parseFromAttribute(module); +} + +static std::optional parseLocationInfoStyle( + StringRef option) { + return llvm::StringSwitch>( + option) + .Case("plain", LoweringOptions::Plain) + .Case("wrapInAtSquareBracket", LoweringOptions::WrapInAtSquareBracket) + .Case("none", LoweringOptions::None) + .Default(std::nullopt); +} + +static std::optional +parseWireSpillingHeuristic(StringRef option) { + return llvm::StringSwitch< + std::optional>(option) + .Case("spillLargeTermsWithNamehints", + LoweringOptions::SpillLargeTermsWithNamehints) + .Default(std::nullopt); +} + +void LoweringOptions::parse(StringRef text, ErrorHandlerT errorHandler) { + while (!text.empty()) { + // Remove the first option from the text. + auto split = text.split(","); + auto option = split.first.trim(); + text = split.second; + if (option == "") { + // Empty options are fine. + } else if (option == "noAlwaysComb") { + noAlwaysComb = true; + } else if (option == "exprInEventControl") { + allowExprInEventControl = true; + } else if (option == "disallowPackedArrays") { + disallowPackedArrays = true; + } else if (option == "disallowPackedStructAssignments") { + disallowPackedStructAssignments = true; + } else if (option == "disallowLocalVariables") { + disallowLocalVariables = true; + } else if (option == "verifLabels") { + enforceVerifLabels = true; + } else if (option.consume_front("emittedLineLength=")) { + if (option.getAsInteger(10, emittedLineLength)) { + errorHandler("expected integer source width"); + emittedLineLength = DEFAULT_LINE_LENGTH; + } + } else if (option == "explicitBitcast") { + explicitBitcast = true; + } else if (option == "emitReplicatedOpsToHeader") { + emitReplicatedOpsToHeader = true; + } else if (option.consume_front("maximumNumberOfTermsPerExpression=")) { + if (option.getAsInteger(10, maximumNumberOfTermsPerExpression)) { + errorHandler("expected integer source width"); + maximumNumberOfTermsPerExpression = DEFAULT_TERM_LIMIT; + } + } else if (option.consume_front("locationInfoStyle=")) { + if (auto style = parseLocationInfoStyle(option)) { + locationInfoStyle = *style; + } else { + errorHandler("expected 'plain', 'wrapInAtSquareBracket', or 'none'"); + } + } else if (option == "disallowPortDeclSharing") { + disallowPortDeclSharing = true; + } else if (option == "printDebugInfo") { + printDebugInfo = true; + } else if (option == "disallowExpressionInliningInPorts") { + disallowExpressionInliningInPorts = true; + } else if (option == "disallowMuxInlining") { + disallowMuxInlining = true; + } else if (option == "mitigateVivadoArrayIndexConstPropBug") { + mitigateVivadoArrayIndexConstPropBug = true; + } else if (option.consume_front("wireSpillingHeuristic=")) { + if (auto heuristic = parseWireSpillingHeuristic(option)) { + wireSpillingHeuristicSet |= *heuristic; + } else { + errorHandler("expected ''spillLargeTermsWithNamehints'"); + } + } else if (option.consume_front("wireSpillingNamehintTermLimit=")) { + if (option.getAsInteger(10, wireSpillingNamehintTermLimit)) { + errorHandler( + "expected integer for number of namehint heurstic term limit"); + wireSpillingNamehintTermLimit = DEFAULT_NAMEHINT_TERM_LIMIT; + } + } else if (option == "emitWireInPorts") { + emitWireInPorts = true; + } else if (option == "emitBindComments") { + emitBindComments = true; + } else if (option == "omitVersionComment") { + omitVersionComment = true; + } else if (option == "caseInsensitiveKeywords") { + caseInsensitiveKeywords = true; + } else { + errorHandler(llvm::Twine("unknown style option \'") + option + "\'"); + // We continue parsing options after a failure. + } + } +} + +std::string LoweringOptions::toString() const { + std::string options = ""; + // All options should add a trailing comma to simplify the code. + if (noAlwaysComb) options += "noAlwaysComb,"; + if (allowExprInEventControl) options += "exprInEventControl,"; + if (disallowPackedArrays) options += "disallowPackedArrays,"; + if (disallowPackedStructAssignments) + options += "disallowPackedStructAssignments,"; + if (disallowLocalVariables) options += "disallowLocalVariables,"; + if (enforceVerifLabels) options += "verifLabels,"; + if (explicitBitcast) options += "explicitBitcast,"; + if (emitReplicatedOpsToHeader) options += "emitReplicatedOpsToHeader,"; + if (locationInfoStyle == LocationInfoStyle::WrapInAtSquareBracket) + options += "locationInfoStyle=wrapInAtSquareBracket,"; + if (locationInfoStyle == LocationInfoStyle::None) + options += "locationInfoStyle=none,"; + if (disallowPortDeclSharing) options += "disallowPortDeclSharing,"; + if (printDebugInfo) options += "printDebugInfo,"; + if (isWireSpillingHeuristicEnabled( + WireSpillingHeuristic::SpillLargeTermsWithNamehints)) + options += "wireSpillingHeuristic=spillLargeTermsWithNamehints,"; + if (disallowExpressionInliningInPorts) + options += "disallowExpressionInliningInPorts,"; + if (disallowMuxInlining) options += "disallowMuxInlining,"; + if (mitigateVivadoArrayIndexConstPropBug) + options += "mitigateVivadoArrayIndexConstPropBug,"; + + if (emittedLineLength != DEFAULT_LINE_LENGTH) + options += "emittedLineLength=" + std::to_string(emittedLineLength) + ','; + if (maximumNumberOfTermsPerExpression != DEFAULT_TERM_LIMIT) + options += "maximumNumberOfTermsPerExpression=" + + std::to_string(maximumNumberOfTermsPerExpression) + ','; + if (emitWireInPorts) options += "emitWireInPorts,"; + if (emitBindComments) options += "emitBindComments,"; + if (omitVersionComment) options += "omitVersionComment,"; + if (caseInsensitiveKeywords) options += "caseInsensitiveKeywords,"; + + // Remove a trailing comma if present. + if (!options.empty()) { + assert(options.back() == ',' && "all options should add a trailing comma"); + options.pop_back(); + } + return options; +} + +StringAttr LoweringOptions::getAttributeFrom(ModuleOp module) { + return module->getAttrOfType("circt.loweringOptions"); +} + +void LoweringOptions::setAsAttribute(ModuleOp module) { + module->setAttr("circt.loweringOptions", + StringAttr::get(module.getContext(), toString())); +} + +void LoweringOptions::parseFromAttribute(ModuleOp module) { + if (auto styleAttr = getAttributeFrom(module)) + parse(styleAttr.getValue(), [&](Twine error) { module.emitError(error); }); +} diff --git a/third_party/circt/lib/Support/ParsingUtils.cpp b/third_party/circt/lib/Support/ParsingUtils.cpp new file mode 100644 index 000000000..2d0edcf6f --- /dev/null +++ b/third_party/circt/lib/Support/ParsingUtils.cpp @@ -0,0 +1,50 @@ +//===- ParsingUtils.cpp - CIRCT parsing common functions ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/ParsingUtils.h" + +using namespace circt; + +ParseResult circt::parsing_util::parseInitializerList( + OpAsmParser &parser, + llvm::SmallVector &inputArguments, + llvm::SmallVector &inputOperands, + llvm::SmallVector &inputTypes, ArrayAttr &inputNames) { + llvm::SmallVector names; + if (failed(parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + OpAsmParser::UnresolvedOperand inputOperand; + Type type; + auto &arg = inputArguments.emplace_back(); + if (parser.parseArgument(arg) || parser.parseColonType(type) || + parser.parseEqual() || parser.parseOperand(inputOperand)) + return failure(); + + inputOperands.push_back(inputOperand); + inputTypes.push_back(type); + arg.type = type; + names.push_back(StringAttr::get( + parser.getContext(), + /*drop leading %*/ arg.ssaName.name.drop_front())); + return success(); + }))) + return failure(); + + inputNames = ArrayAttr::get(parser.getContext(), names); + return success(); +} + +void circt::parsing_util::printInitializerList(OpAsmPrinter &p, ValueRange ins, + ArrayRef args) { + p << "("; + llvm::interleaveComma(llvm::zip(ins, args), p, [&](auto it) { + auto [in, arg] = it; + p << arg << " : " << in.getType() << " = " << in; + }); + p << ")"; +} diff --git a/third_party/circt/lib/Support/Passes.cpp b/third_party/circt/lib/Support/Passes.cpp new file mode 100644 index 000000000..da5a04003 --- /dev/null +++ b/third_party/circt/lib/Support/Passes.cpp @@ -0,0 +1,21 @@ +//===- Passes.cpp - Pass Utilities ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/Passes.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +using namespace circt; + +std::unique_ptr circt::createSimpleCanonicalizerPass() { + mlir::GreedyRewriteConfig config; + config.useTopDownTraversal = true; + config.enableRegionSimplification = false; + return mlir::createCanonicalizerPass(config); +} diff --git a/third_party/circt/lib/Support/Path.cpp b/third_party/circt/lib/Support/Path.cpp new file mode 100644 index 000000000..ff5d0525a --- /dev/null +++ b/third_party/circt/lib/Support/Path.cpp @@ -0,0 +1,32 @@ +//===- Path.cpp - Path Utilities --------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Utilities for file system path handling, supplementing the ones from +// llvm::sys::path. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/Path.h" + +#include "llvm/Support/Path.h" + +using namespace circt; + +/// Append a path to an existing path, replacing it if the other path is +/// absolute. This mimicks the behaviour of `foo/bar` and `/foo/bar` being used +/// in a working directory `/home`, resulting in `/home/foo/bar` and `/foo/bar`, +/// respectively. +void circt::appendPossiblyAbsolutePath(llvm::SmallVectorImpl &base, + const llvm::Twine &suffix) { + if (llvm::sys::path::is_absolute(suffix)) { + base.clear(); + suffix.toVector(base); + } else { + llvm::sys::path::append(base, suffix); + } +} diff --git a/third_party/circt/lib/Support/PrettyPrinter.cpp b/third_party/circt/lib/Support/PrettyPrinter.cpp new file mode 100644 index 000000000..8ea7c9ebf --- /dev/null +++ b/third_party/circt/lib/Support/PrettyPrinter.cpp @@ -0,0 +1,305 @@ +//===- PrettyPrinter.cpp - Pretty printing --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This implements a pretty-printer. +// "PrettyPrinting", Derek C. Oppen, 1980. +// https://dx.doi.org/10.1145/357114.357115 +// +// This was selected as it is linear in number of tokens O(n) and requires +// memory O(linewidth). +// +// This has been adjusted from the paper: +// * Deque for tokens instead of ringbuffer + left/right cursors. +// This is simpler to reason about and allows us to easily grow the buffer +// to accommodate longer widths when needed (and not reserve 3*linewidth). +// Since scanStack references buffered tokens by index, we track an offset +// that we increase when dropping off the front. +// When the scan stack is cleared the buffer is reset, including this offset. +// * Indentation tracked from left not relative to margin (linewidth). +// * Indentation emitted lazily, avoid trailing whitespace. +// * Group indentation styles: Visual and Block, set on 'begin' tokens. +// "Visual" is the style in the paper, offset relative to current column. +// "Block" is relative to current base indentation. +// * Break: Add "Neverbreak": acts like a break re:sizing previous range, +// but will never be broken. Useful for adding content to end of line +// that may go over margin but should not influence layout. +// * Begin: Add "Never" breaking style, for forcing no breaks including +// within nested groups. Use sparingly. It is an error to insert +// a newline (Break with spaces==kInfinity) within such a group. +// * If leftTotal grows too large, "rebase" our datastructures by +// walking the tokens with pending sizes (scanStack) and adjusting +// them by `leftTotal - 1`. Also reset tokenOffset while visiting. +// This is mostly needed due to use of tokens/groups that 'never' break +// which can greatly increase times between `clear()`. +// * Optionally, minimum amount of space is granted regardless of indentation. +// To avoid forcing expressions against the line limit, never try to print +// an expression in, say, 2 columns, as this is unlikely to produce good +// output. +// (TODO) +// +// There are many pretty-printing implementations based on this paper, +// and research literature is rich with functional formulations based originally +// on this algorithm. +// +// Implementations of note that have interesting modifications for their +// languages and modernization of the paper's algorithm: +// * prettyplease / rustc_ast_pretty +// Pretty-printers for rust, the first being useful for rustfmt-like output. +// These have largely the same code and were based on one another. +// https://github.com/dtolnay/prettyplease +// https://github.com/rust-lang/rust/tree/master/compiler/rustc_ast_pretty +// This is closest to the paper's algorithm with modernizations, +// and most of the initial tweaks have also been implemented here (thanks!). +// * swift-format: https://github.com/apple/swift-format/ +// +// If we want fancier output or need to handle more complicated constructs, +// both are good references for lessons and ideas. +// +// FWIW, at time of writing these have compatible licensing (Apache 2.0). +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/PrettyPrinter.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" + +namespace circt { +namespace pretty { + +/// Destructor, anchor. +PrettyPrinter::Listener::~Listener() = default; + +/// Add token for printing. In Oppen, this is "scan". +void PrettyPrinter::add(Token t) { + // Add token to tokens, and add its index to scanStack. + auto addScanToken = [&](auto offset) { + auto right = tokenOffset + tokens.size(); + scanStack.push_back(right); + tokens.push_back({t, offset}); + }; + llvm::TypeSwitch(&t) + .Case([&](StringToken *s) { + // If nothing on stack, directly print + FormattedToken f{t, (int32_t)s->text().size()}; + // Empty string token isn't /wrong/ but can have unintended effect. + assert(!s->text().empty() && "empty string token"); + if (scanStack.empty()) return print(f); + tokens.push_back(f); + rightTotal += f.size; + assert(rightTotal > 0); + checkStream(); + }) + .Case([&](BreakToken *b) { + if (scanStack.empty()) + clear(); + else + checkStack(); + addScanToken(-rightTotal); + rightTotal += b->spaces(); + assert(rightTotal > 0); + }) + .Case([&](BeginToken *b) { + if (scanStack.empty()) clear(); + addScanToken(-rightTotal); + }) + .Case([&](EndToken *end) { + if (scanStack.empty()) return print({t, 0}); + addScanToken(-1); + }) + .Case([&](CallbackToken *c) { + // Callbacktoken must be associated with a listener, it doesn't have any + // meaning without it. + assert(listener); + if (scanStack.empty()) return print({t, 0}); + tokens.push_back({t, 0}); + }); + rebaseIfNeeded(); +} + +void PrettyPrinter::rebaseIfNeeded() { + // Check for too-large totals, reset. + // This can happen if we have an open group and emit + // many tokens, especially newlines which have artificial size. + if (tokens.empty()) return; + assert(leftTotal >= 0); + assert(rightTotal >= 0); + if (uint32_t(leftTotal) > rebaseThreshold) { + // Plan: reset leftTotal to '1', adjust all accordingly. + auto adjust = leftTotal - 1; + for (auto &scanIndex : scanStack) { + assert(scanIndex >= tokenOffset); + auto &t = tokens[scanIndex - tokenOffset]; + if (isa(&t.token)) { + if (t.size < 0) { + assert(t.size + adjust < 0); + t.size += adjust; + } + } + // While walking, reset tokenOffset too. + scanIndex -= tokenOffset; + } + leftTotal -= adjust; + rightTotal -= adjust; + tokenOffset = 0; + } +} + +void PrettyPrinter::eof() { + if (!scanStack.empty()) { + checkStack(); + advanceLeft(); + } + assert(scanStack.empty() && "unclosed groups at EOF"); + if (scanStack.empty()) clear(); +} + +void PrettyPrinter::clear() { + assert(scanStack.empty() && "clearing tokens while still on scan stack"); + assert(tokens.empty()); + leftTotal = rightTotal = 1; + tokens.clear(); + tokenOffset = 0; + if (listener && !donotClear) listener->clear(); +} + +/// Break encountered, set sizes of begin/breaks in scanStack that we now know. +void PrettyPrinter::checkStack() { + unsigned depth = 0; + while (!scanStack.empty()) { + auto x = scanStack.back(); + assert(x >= tokenOffset && tokens.size() + tokenOffset > x); + auto &t = tokens[x - tokenOffset]; + if (llvm::isa(&t.token)) { + if (depth == 0) break; + scanStack.pop_back(); + t.size += rightTotal; + --depth; + } else if (llvm::isa(&t.token)) { + scanStack.pop_back(); + t.size = 1; + ++depth; + } else { + scanStack.pop_back(); + t.size += rightTotal; + if (depth == 0) break; + } + } +} + +/// Check if there are enough tokens to hit width, if so print. +/// If scan size is wider than line, it's infinity. +void PrettyPrinter::checkStream() { + // While buffer needs more than 1 line to print, print and consume. + assert(!tokens.empty()); + assert(leftTotal >= 0); + assert(rightTotal >= 0); + while (rightTotal - leftTotal > space && !tokens.empty()) { + // Ran out of space, set size to infinity and take off scan stack. + // No need to keep track as we know enough to know this won't fit. + if (!scanStack.empty() && tokenOffset == scanStack.front()) { + tokens.front().size = kInfinity; + scanStack.pop_front(); + } + advanceLeft(); + } +} + +/// Print out tokens we know sizes for, and drop from token buffer. +void PrettyPrinter::advanceLeft() { + assert(!tokens.empty()); + + while (!tokens.empty() && tokens.front().size >= 0) { + const auto &f = tokens.front(); + print(f); + leftTotal += + llvm::TypeSwitch(&f.token) + .Case([&](const BreakToken *b) { return b->spaces(); }) + .Case([&](const StringToken *s) { return s->text().size(); }) + .Default([](const auto *) { return 0; }); + tokens.pop_front(); + ++tokenOffset; + } +} + +/// Compute indentation w/o overflow, clamp to [0,maxStartingIndent]. +/// Output looks better if we don't stop indenting entirely at target width, +/// but don't do this indefinitely. +static uint32_t computeNewIndent(ssize_t newIndent, int32_t offset, + uint32_t maxStartingIndent) { + return std::clamp(newIndent + offset, 0, maxStartingIndent); +} + +/// Print a token, maintaining printStack for context. +void PrettyPrinter::print(const FormattedToken &f) { + llvm::TypeSwitch(&f.token) + .Case([&](const StringToken *s) { + space -= f.size; + os.indent(pendingIndentation); + pendingIndentation = 0; + os << s->text(); + }) + .Case([&](const BreakToken *b) { + auto &frame = getPrintFrame(); + assert((b->spaces() != kInfinity || alwaysFits == 0) && + "newline inside never group"); + bool fits = + (alwaysFits > 0) || b->neverbreak() || + frame.breaks == PrintBreaks::Fits || + (frame.breaks == PrintBreaks::Inconsistent && f.size <= space); + if (fits) { + space -= b->spaces(); + pendingIndentation += b->spaces(); + } else { + os << "\n"; + pendingIndentation = + computeNewIndent(indent, b->offset(), maxStartingIndent); + space = margin - pendingIndentation; + } + }) + .Case([&](const BeginToken *b) { + if (b->breaks() == Breaks::Never) { + printStack.push_back({0, PrintBreaks::AlwaysFits}); + ++alwaysFits; + } else if (f.size > space && alwaysFits == 0) { + auto breaks = b->breaks() == Breaks::Consistent + ? PrintBreaks::Consistent + : PrintBreaks::Inconsistent; + ssize_t newIndent = indent; + if (b->style() == IndentStyle::Visual) + newIndent = ssize_t{margin} - space; + indent = computeNewIndent(newIndent, b->offset(), maxStartingIndent); + printStack.push_back({indent, breaks}); + } else { + printStack.push_back({0, PrintBreaks::Fits}); + } + }) + .Case([&](const EndToken *) { + assert(!printStack.empty() && "more ends than begins?"); + // Try to tolerate this when assertions are disabled. + if (printStack.empty()) return; + if (getPrintFrame().breaks == PrintBreaks::AlwaysFits) --alwaysFits; + printStack.pop_back(); + auto &frame = getPrintFrame(); + if (frame.breaks != PrintBreaks::Fits && + frame.breaks != PrintBreaks::AlwaysFits) + indent = frame.offset; + }) + .Case([&](const CallbackToken *c) { + if (pendingIndentation) { + // This is necessary to get the correct location on the stream for the + // callback invocation. + os.indent(pendingIndentation); + pendingIndentation = 0; + } + listener->print(); + }); +} +} // end namespace pretty +} // end namespace circt diff --git a/third_party/circt/lib/Support/PrettyPrinterHelpers.cpp b/third_party/circt/lib/Support/PrettyPrinterHelpers.cpp new file mode 100644 index 000000000..9ea3b1250 --- /dev/null +++ b/third_party/circt/lib/Support/PrettyPrinterHelpers.cpp @@ -0,0 +1,50 @@ +//===- PrettyPrinterHelpers.cpp - Pretty printing helpers -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Helper classes for using PrettyPrinter. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/PrettyPrinterHelpers.h" + +#include + +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/raw_ostream.h" + +namespace circt { +namespace pretty { + +//===----------------------------------------------------------------------===// +// Convenience builders. +//===----------------------------------------------------------------------===// + +void TokenStringSaver::clear() { alloc.Reset(); } + +/// Add multiple non-breaking spaces as a single token. +void detail::emitNBSP(unsigned n, llvm::function_ref add) { + static const std::array spaces = ([]() constexpr { + std::array s = {}; + for (auto &c : s) c = ' '; + return s; + })(); + + const auto size = spaces.size(); + if (n <= size) { + if (n != 0) add(StringToken({spaces.data(), n})); + return; + } + while (n) { + auto chunk = std::min(n, size); + add(StringToken({spaces.data(), chunk})); + n -= chunk; + } +} + +} // end namespace pretty +} // end namespace circt diff --git a/third_party/circt/lib/Support/SymCache.cpp b/third_party/circt/lib/Support/SymCache.cpp new file mode 100644 index 000000000..c45c57398 --- /dev/null +++ b/third_party/circt/lib/Support/SymCache.cpp @@ -0,0 +1,29 @@ +//===- SymCache.cpp - Declare Symbol Cache ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a Symbol Cache. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/SymCache.h" + +using namespace mlir; +using namespace circt; + +namespace circt { + +/// Virtual method anchor. +SymbolCacheBase::~SymbolCacheBase() {} + +void SymbolCacheBase::addDefinitions(mlir::Operation *top) { + for (auto ®ion : top->getRegions()) + for (auto &block : region.getBlocks()) + for (auto symOp : block.getOps()) + addSymbol(symOp); +} +} // namespace circt diff --git a/third_party/circt/lib/Support/ValueMapper.cpp b/third_party/circt/lib/Support/ValueMapper.cpp new file mode 100644 index 000000000..e6000e969 --- /dev/null +++ b/third_party/circt/lib/Support/ValueMapper.cpp @@ -0,0 +1,62 @@ +//===- ValueMapper.cpp - Support for mapping SSA values ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides support for mapping SSA values between two domains. +// Provided a BackedgeBuilder, the ValueMapper supports mappings between +// GraphRegions, creating Backedges in cases of 'get'ing mapped values which are +// yet to be 'set'. +// +//===----------------------------------------------------------------------===// + +#include "circt/Support/ValueMapper.h" + +using namespace mlir; +using namespace circt; +mlir::Value ValueMapper::get(Value from, TypeTransformer typeTransformer) { + if (mapping.count(from) == 0) { + assert(bb && + "Trying to 'get' a mapped value without any value set. No " + "BackedgeBuilder was provided, so cannot provide any mapped " + "SSA value!"); + // Create a backedge which will be resolved at a later time once all + // operands are created. + mapping[from] = bb->get(typeTransformer(from.getType())); + } + auto operandMapping = mapping[from]; + Value mappedOperand; + if (auto *v = std::get_if(&operandMapping)) + mappedOperand = *v; + else + mappedOperand = std::get(operandMapping); + return mappedOperand; +} + +llvm::SmallVector ValueMapper::get(ValueRange from, + TypeTransformer typeTransformer) { + llvm::SmallVector to; + for (auto f : from) to.push_back(get(f, typeTransformer)); + return to; +} + +void ValueMapper::set(Value from, Value to, bool replace) { + auto it = mapping.find(from); + if (it != mapping.end()) { + if (auto *backedge = std::get_if(&it->second)) + backedge->setValue(to); + else if (!replace) + assert(false && "'from' was already mapped to a final value!"); + } + // Register the new mapping + mapping[from] = to; +} + +void ValueMapper::set(ValueRange from, ValueRange to, bool replace) { + assert(from.size() == to.size() && + "Expected # of 'from' values and # of 'to' values to be identical."); + for (auto [f, t] : llvm::zip(from, to)) set(f, t, replace); +} diff --git a/third_party/circt/lib/Support/Version.cpp.in b/third_party/circt/lib/Support/Version.cpp.in new file mode 100644 index 000000000..545cfc88f --- /dev/null +++ b/third_party/circt/lib/Support/Version.cpp.in @@ -0,0 +1,6 @@ +#include "circt/Support/Version.h" + +const char *circt::getCirctVersion() { return "CIRCT @GIT_DESCRIBE_OUTPUT@"; } +const char *circt::getCirctVersionComment() { + return "// Generated by CIRCT @GIT_DESCRIBE_OUTPUT@\n"; +} diff --git a/tools/BUILD b/tools/BUILD index 91114d78e..4bc4077a2 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -9,7 +9,10 @@ package( cc_binary( name = "heir-opt", srcs = ["heir-opt.cpp"], - includes = ["include"], + # FIXME: Propagate these from deps, see why includes isn't working + includes = [ + "/third_party/circt/include", + ], deps = [ "@heir//lib/Conversion/BGVToPoly", "@heir//lib/Conversion/MemrefToArith:ExpandCopy", @@ -20,6 +23,8 @@ cc_binary( "@heir//lib/Dialect/Poly/IR:Dialect", "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/Secret/Transforms", + "@heir//third_party/circt/lib/Dialect/Comb:Dialect", + "@heir//third_party/circt/lib/Dialect/HW:Dialect", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineTransforms", "@llvm-project//mlir:AllPassesAndDialects", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index b23ffe7ac..a34389c99 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -24,6 +24,8 @@ #include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project +#include "third_party/circt/include/circt/Dialect/Comb/CombDialect.h" +#include "third_party/circt/include/circt/Dialect/HW/HWDialect.h" void tosaPipelineBuilder(mlir::OpPassManager &manager) { // TOSA to linalg @@ -78,6 +80,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); // Add expected MLIR dialects to the registry. registry.insert();