From eab2b70d30f8e9c3dc6cc8e6aea6a8979aff0ad7 Mon Sep 17 00:00:00 2001 From: Asra Ali Date: Wed, 27 Sep 2023 19:14:39 +0000 Subject: [PATCH] remove hw and support Signed-off-by: Asra Ali fix include Signed-off-by: Asra Ali --- include/circt/Dialect/Comb/BUILD | 1 - include/circt/Dialect/Comb/CMakeLists.txt | 14 - include/circt/Dialect/Comb/Comb.td | 5 +- include/circt/Dialect/Comb/CombOps.h | 45 +- include/circt/Dialect/Comb/CombPasses.h | 33 - include/circt/Dialect/Comb/CombVisitors.h | 113 - include/circt/Dialect/Comb/Combinational.td | 52 +- include/circt/Dialect/Comb/Passes.td | 25 - include/circt/Dialect/HW/BUILD | 204 - include/circt/Dialect/HW/CMakeLists.txt | 44 - include/circt/Dialect/HW/ConversionPatterns.h | 36 - .../circt/Dialect/HW/CustomDirectiveImpl.h | 70 - include/circt/Dialect/HW/HW.td | 30 - include/circt/Dialect/HW/HWAggregates.td | 306 -- include/circt/Dialect/HW/HWAttributes.h | 45 - include/circt/Dialect/HW/HWAttributes.td | 309 -- .../circt/Dialect/HW/HWAttributesNaming.td | 70 - include/circt/Dialect/HW/HWDialect.h | 26 - include/circt/Dialect/HW/HWDialect.td | 49 - include/circt/Dialect/HW/HWInstanceGraph.h | 56 - include/circt/Dialect/HW/HWMiscOps.td | 192 - include/circt/Dialect/HW/HWModuleGraph.h | 186 - include/circt/Dialect/HW/HWOpInterfaces.h | 297 -- include/circt/Dialect/HW/HWOpInterfaces.td | 528 --- include/circt/Dialect/HW/HWOps.h | 140 - include/circt/Dialect/HW/HWPasses.h | 38 - include/circt/Dialect/HW/HWReductions.h | 29 - include/circt/Dialect/HW/HWStructure.td | 720 ---- include/circt/Dialect/HW/HWSymCache.h | 118 - include/circt/Dialect/HW/HWTypeDecls.td | 63 - include/circt/Dialect/HW/HWTypeInterfaces.h | 44 - include/circt/Dialect/HW/HWTypeInterfaces.td | 81 - include/circt/Dialect/HW/HWTypes.h | 165 - include/circt/Dialect/HW/HWTypes.td | 179 - include/circt/Dialect/HW/HWTypesImpl.td | 280 -- include/circt/Dialect/HW/HWVisitors.h | 140 - .../circt/Dialect/HW/InnerSymbolNamespace.h | 51 - include/circt/Dialect/HW/InnerSymbolTable.h | 264 -- .../circt/Dialect/HW/InstanceImplementation.h | 104 - .../circt/Dialect/HW/ModuleImplementation.h | 56 - include/circt/Dialect/HW/Passes.td | 61 - include/circt/Dialect/HW/PortConverter.h | 182 - include/circt/Support/APInt.h | 30 - include/circt/Support/BUILD | 42 - include/circt/Support/BackedgeBuilder.h | 100 - include/circt/Support/BuilderUtils.h | 49 - include/circt/Support/CMakeLists.txt | 1 - include/circt/Support/CustomDirectiveImpl.h | 91 - include/circt/Support/FieldRef.h | 116 - include/circt/Support/FoldUtils.h | 38 - include/circt/Support/InstanceGraph.h | 453 --- .../circt/Support/InstanceGraphInterface.h | 23 - .../circt/Support/InstanceGraphInterface.td | 74 - include/circt/Support/JSON.h | 30 - include/circt/Support/LoweringOptions.h | 169 - include/circt/Support/LoweringOptionsParser.h | 57 - include/circt/Support/Namespace.h | 134 - include/circt/Support/ParsingUtils.h | 57 - include/circt/Support/Passes.h | 64 - include/circt/Support/Path.h | 30 - include/circt/Support/PrettyPrinter.h | 325 -- include/circt/Support/PrettyPrinterHelpers.h | 374 -- include/circt/Support/SymCache.h | 132 - include/circt/Support/ValueMapper.h | 63 - include/circt/Support/Version.h | 16 - lib/circt/Dialect/Comb/BUILD | 9 +- lib/circt/Dialect/Comb/CMakeLists.txt | 26 - lib/circt/Dialect/Comb/CombAnalysis.cpp | 87 - lib/circt/Dialect/Comb/CombDialect.cpp | 34 +- lib/circt/Dialect/Comb/CombFolds.cpp | 3026 --------------- lib/circt/Dialect/Comb/CombOps.cpp | 114 +- .../Dialect/Comb/Transforms/CMakeLists.txt | 15 - .../Dialect/Comb/Transforms/LowerComb.cpp | 88 - .../Dialect/Comb/Transforms/PassDetails.h | 27 - lib/circt/Dialect/HW/BUILD | 55 - lib/circt/Dialect/HW/CMakeLists.txt | 52 - lib/circt/Dialect/HW/ConversionPatterns.cpp | 103 - lib/circt/Dialect/HW/CustomDirectiveImpl.cpp | 136 - lib/circt/Dialect/HW/HWAttributes.cpp | 1032 ----- lib/circt/Dialect/HW/HWDialect.cpp | 116 - lib/circt/Dialect/HW/HWInstanceGraph.cpp | 33 - lib/circt/Dialect/HW/HWModuleOpInterface.cpp | 88 - lib/circt/Dialect/HW/HWOpInterfaces.cpp | 99 - lib/circt/Dialect/HW/HWOps.cpp | 3376 ----------------- lib/circt/Dialect/HW/HWReductions.cpp | 157 - lib/circt/Dialect/HW/HWTypeInterfaces.cpp | 74 - lib/circt/Dialect/HW/HWTypes.cpp | 974 ----- lib/circt/Dialect/HW/InnerSymbolTable.cpp | 251 -- .../Dialect/HW/InstanceImplementation.cpp | 347 -- lib/circt/Dialect/HW/ModuleImplementation.cpp | 328 -- lib/circt/Dialect/HW/PortConverter.cpp | 233 -- .../Dialect/HW/Transforms/CMakeLists.txt | 20 - lib/circt/Dialect/HW/Transforms/FlattenIO.cpp | 429 --- .../HW/Transforms/HWPrintInstanceGraph.cpp | 36 - .../Dialect/HW/Transforms/HWSpecialize.cpp | 422 --- lib/circt/Dialect/HW/Transforms/PassDetails.h | 31 - .../HW/Transforms/PrintHWModuleGraph.cpp | 42 - .../HW/Transforms/VerifyInnerRefNamespace.cpp | 44 - lib/circt/Support/APInt.cpp | 27 - lib/circt/Support/BUILD | 19 - lib/circt/Support/BackedgeBuilder.cpp | 71 - lib/circt/Support/CMakeLists.txt | 52 - lib/circt/Support/CustomDirectiveImpl.cpp | 130 - lib/circt/Support/FieldRef.cpp | 23 - lib/circt/Support/InstanceGraph.cpp | 314 -- lib/circt/Support/JSON.cpp | 123 - lib/circt/Support/LoweringOptions.cpp | 187 - lib/circt/Support/ParsingUtils.cpp | 50 - lib/circt/Support/Passes.cpp | 21 - lib/circt/Support/Path.cpp | 32 - lib/circt/Support/PrettyPrinter.cpp | 305 -- lib/circt/Support/PrettyPrinterHelpers.cpp | 50 - lib/circt/Support/SymCache.cpp | 29 - lib/circt/Support/ValueMapper.cpp | 62 - lib/circt/Support/Version.cpp.in | 6 - 115 files changed, 132 insertions(+), 21192 deletions(-) delete mode 100644 include/circt/Dialect/Comb/CMakeLists.txt delete mode 100644 include/circt/Dialect/Comb/CombPasses.h delete mode 100644 include/circt/Dialect/Comb/CombVisitors.h delete mode 100644 include/circt/Dialect/Comb/Passes.td delete mode 100644 include/circt/Dialect/HW/BUILD delete mode 100644 include/circt/Dialect/HW/CMakeLists.txt delete mode 100644 include/circt/Dialect/HW/ConversionPatterns.h delete mode 100644 include/circt/Dialect/HW/CustomDirectiveImpl.h delete mode 100644 include/circt/Dialect/HW/HW.td delete mode 100644 include/circt/Dialect/HW/HWAggregates.td delete mode 100644 include/circt/Dialect/HW/HWAttributes.h delete mode 100644 include/circt/Dialect/HW/HWAttributes.td delete mode 100644 include/circt/Dialect/HW/HWAttributesNaming.td delete mode 100644 include/circt/Dialect/HW/HWDialect.h delete mode 100644 include/circt/Dialect/HW/HWDialect.td delete mode 100644 include/circt/Dialect/HW/HWInstanceGraph.h delete mode 100644 include/circt/Dialect/HW/HWMiscOps.td delete mode 100644 include/circt/Dialect/HW/HWModuleGraph.h delete mode 100644 include/circt/Dialect/HW/HWOpInterfaces.h delete mode 100644 include/circt/Dialect/HW/HWOpInterfaces.td delete mode 100644 include/circt/Dialect/HW/HWOps.h delete mode 100644 include/circt/Dialect/HW/HWPasses.h delete mode 100644 include/circt/Dialect/HW/HWReductions.h delete mode 100644 include/circt/Dialect/HW/HWStructure.td delete mode 100644 include/circt/Dialect/HW/HWSymCache.h delete mode 100644 include/circt/Dialect/HW/HWTypeDecls.td delete mode 100644 include/circt/Dialect/HW/HWTypeInterfaces.h delete mode 100644 include/circt/Dialect/HW/HWTypeInterfaces.td delete mode 100644 include/circt/Dialect/HW/HWTypes.h delete mode 100644 include/circt/Dialect/HW/HWTypes.td delete mode 100644 include/circt/Dialect/HW/HWTypesImpl.td delete mode 100644 include/circt/Dialect/HW/HWVisitors.h delete mode 100644 include/circt/Dialect/HW/InnerSymbolNamespace.h delete mode 100644 include/circt/Dialect/HW/InnerSymbolTable.h delete mode 100644 include/circt/Dialect/HW/InstanceImplementation.h delete mode 100644 include/circt/Dialect/HW/ModuleImplementation.h delete mode 100644 include/circt/Dialect/HW/Passes.td delete mode 100644 include/circt/Dialect/HW/PortConverter.h delete mode 100644 include/circt/Support/APInt.h delete mode 100644 include/circt/Support/BackedgeBuilder.h delete mode 100644 include/circt/Support/BuilderUtils.h delete mode 100644 include/circt/Support/CMakeLists.txt delete mode 100644 include/circt/Support/CustomDirectiveImpl.h delete mode 100644 include/circt/Support/FieldRef.h delete mode 100644 include/circt/Support/FoldUtils.h delete mode 100644 include/circt/Support/InstanceGraph.h delete mode 100644 include/circt/Support/InstanceGraphInterface.h delete mode 100644 include/circt/Support/InstanceGraphInterface.td delete mode 100644 include/circt/Support/JSON.h delete mode 100644 include/circt/Support/LoweringOptions.h delete mode 100644 include/circt/Support/LoweringOptionsParser.h delete mode 100644 include/circt/Support/Namespace.h delete mode 100644 include/circt/Support/ParsingUtils.h delete mode 100644 include/circt/Support/Passes.h delete mode 100644 include/circt/Support/Path.h delete mode 100644 include/circt/Support/PrettyPrinter.h delete mode 100644 include/circt/Support/PrettyPrinterHelpers.h delete mode 100644 include/circt/Support/SymCache.h delete mode 100644 include/circt/Support/ValueMapper.h delete mode 100644 include/circt/Support/Version.h delete mode 100644 lib/circt/Dialect/Comb/CMakeLists.txt delete mode 100644 lib/circt/Dialect/Comb/CombAnalysis.cpp delete mode 100644 lib/circt/Dialect/Comb/CombFolds.cpp delete mode 100644 lib/circt/Dialect/Comb/Transforms/CMakeLists.txt delete mode 100644 lib/circt/Dialect/Comb/Transforms/LowerComb.cpp delete mode 100644 lib/circt/Dialect/Comb/Transforms/PassDetails.h delete mode 100644 lib/circt/Dialect/HW/BUILD delete mode 100644 lib/circt/Dialect/HW/CMakeLists.txt delete mode 100644 lib/circt/Dialect/HW/ConversionPatterns.cpp delete mode 100644 lib/circt/Dialect/HW/CustomDirectiveImpl.cpp delete mode 100644 lib/circt/Dialect/HW/HWAttributes.cpp delete mode 100644 lib/circt/Dialect/HW/HWDialect.cpp delete mode 100644 lib/circt/Dialect/HW/HWInstanceGraph.cpp delete mode 100644 lib/circt/Dialect/HW/HWModuleOpInterface.cpp delete mode 100644 lib/circt/Dialect/HW/HWOpInterfaces.cpp delete mode 100644 lib/circt/Dialect/HW/HWOps.cpp delete mode 100644 lib/circt/Dialect/HW/HWReductions.cpp delete mode 100644 lib/circt/Dialect/HW/HWTypeInterfaces.cpp delete mode 100644 lib/circt/Dialect/HW/HWTypes.cpp delete mode 100644 lib/circt/Dialect/HW/InnerSymbolTable.cpp delete mode 100644 lib/circt/Dialect/HW/InstanceImplementation.cpp delete mode 100644 lib/circt/Dialect/HW/ModuleImplementation.cpp delete mode 100644 lib/circt/Dialect/HW/PortConverter.cpp delete mode 100644 lib/circt/Dialect/HW/Transforms/CMakeLists.txt delete mode 100644 lib/circt/Dialect/HW/Transforms/FlattenIO.cpp delete mode 100644 lib/circt/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp delete mode 100644 lib/circt/Dialect/HW/Transforms/HWSpecialize.cpp delete mode 100644 lib/circt/Dialect/HW/Transforms/PassDetails.h delete mode 100644 lib/circt/Dialect/HW/Transforms/PrintHWModuleGraph.cpp delete mode 100644 lib/circt/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp delete mode 100644 lib/circt/Support/APInt.cpp delete mode 100644 lib/circt/Support/BackedgeBuilder.cpp delete mode 100644 lib/circt/Support/CMakeLists.txt delete mode 100644 lib/circt/Support/CustomDirectiveImpl.cpp delete mode 100644 lib/circt/Support/FieldRef.cpp delete mode 100644 lib/circt/Support/InstanceGraph.cpp delete mode 100644 lib/circt/Support/JSON.cpp delete mode 100644 lib/circt/Support/LoweringOptions.cpp delete mode 100644 lib/circt/Support/ParsingUtils.cpp delete mode 100644 lib/circt/Support/Passes.cpp delete mode 100644 lib/circt/Support/Path.cpp delete mode 100644 lib/circt/Support/PrettyPrinter.cpp delete mode 100644 lib/circt/Support/PrettyPrinterHelpers.cpp delete mode 100644 lib/circt/Support/SymCache.cpp delete mode 100644 lib/circt/Support/ValueMapper.cpp delete mode 100644 lib/circt/Support/Version.cpp.in diff --git a/include/circt/Dialect/Comb/BUILD b/include/circt/Dialect/Comb/BUILD index a7f284e9a9..c1b17ba0af 100644 --- a/include/circt/Dialect/Comb/BUILD +++ b/include/circt/Dialect/Comb/BUILD @@ -17,7 +17,6 @@ td_library( ], includes = ["include"], deps = [ - "@heir//include/circt/Dialect/HW:td_files", "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:FunctionInterfacesTdFiles", diff --git a/include/circt/Dialect/Comb/CMakeLists.txt b/include/circt/Dialect/Comb/CMakeLists.txt deleted file mode 100644 index a7bb1a6251..0000000000 --- a/include/circt/Dialect/Comb/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -add_circt_dialect(Comb comb) -add_circt_dialect_doc(Comb comb) - -set(LLVM_TARGET_DEFINITIONS Comb.td) -mlir_tablegen(CombEnums.h.inc -gen-enum-decls) -mlir_tablegen(CombEnums.cpp.inc -gen-enum-defs) -add_public_tablegen_target(MLIRCombEnumsIncGen) - -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) -add_public_tablegen_target(CIRCTCombTransformsIncGen) - -# Generate Pass documentation. -add_circt_doc(Passes CombPasses -gen-pass-doc) diff --git a/include/circt/Dialect/Comb/Comb.td b/include/circt/Dialect/Comb/Comb.td index 13dfa99015..3e4bc28a37 100644 --- a/include/circt/Dialect/Comb/Comb.td +++ b/include/circt/Dialect/Comb/Comb.td @@ -26,7 +26,7 @@ def CombDialect : Dialect { This dialect defines the `comb` dialect, which is intended to be a generic representation of combinational logic outside of a particular use-case. }]; - let hasConstantMaterializer = 1; + // let hasConstantMaterializer = 1; let cppNamespace = "::circt::comb"; // This will be the default after next LLVM bump. @@ -38,7 +38,6 @@ def CombDialect : Dialect { class CombOp traits = []> : Op; -include "include/circt/Dialect/HW/HWTypes.td" -include "include/circt/Dialect/Comb/Combinational.td" +include "Combinational.td" #endif // COMB_TD diff --git a/include/circt/Dialect/Comb/CombOps.h b/include/circt/Dialect/Comb/CombOps.h index bcb88c5560..effe7c88c1 100644 --- a/include/circt/Dialect/Comb/CombOps.h +++ b/include/circt/Dialect/Comb/CombOps.h @@ -14,7 +14,6 @@ #define CIRCT_DIALECT_COMB_COMBOPS_H #include "include/circt/Dialect/Comb/CombDialect.h" -#include "include/circt/Dialect/HW/HWTypes.h" #include "include/circt/Support/LLVM.h" #include "mlir/include/mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project @@ -34,30 +33,34 @@ class PatternRewriter; #define GET_OP_CLASSES #include "include/circt/Dialect/Comb/Comb.h.inc" -namespace circt { -namespace comb { +// namespace circt { +// namespace comb { -using llvm::KnownBits; +// using llvm::KnownBits; -/// Compute "known bits" information about the specified value - the set of bits -/// that are guaranteed to always be zero, and the set of bits that are -/// guaranteed to always be one (these must be exclusive!). A bit that exists -/// in neither set is unknown. -KnownBits computeKnownBits(Value value); +// /// Compute "known bits" information about the specified value - the set of +// bits +// /// that are guaranteed to always be zero, and the set of bits that are +// /// guaranteed to always be one (these must be exclusive!). A bit that +// exists +// /// in neither set is unknown. +// KnownBits computeKnownBits(Value value); -/// Create a sign extension operation from a value of integer type to an equal -/// or larger integer type. -Value createOrFoldSExt(Location loc, Value value, Type destTy, - OpBuilder &builder); -Value createOrFoldSExt(Value value, Type destTy, ImplicitLocOpBuilder &builder); +// /// Create a sign extension operation from a value of integer type to an +// equal +// /// or larger integer type. +// Value createOrFoldSExt(Location loc, Value value, Type destTy, +// OpBuilder &builder); +// Value createOrFoldSExt(Value value, Type destTy, ImplicitLocOpBuilder +// &builder); -/// Create a ``Not'' gate on a value. -Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, - bool twoState = false); -Value createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, - bool twoState = false); +// /// Create a ``Not'' gate on a value. +// Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, +// bool twoState = false); +// Value createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, +// bool twoState = false); -} // namespace comb -} // namespace circt +// } // namespace comb +// } // namespace circt #endif // CIRCT_DIALECT_COMB_COMBOPS_H diff --git a/include/circt/Dialect/Comb/CombPasses.h b/include/circt/Dialect/Comb/CombPasses.h deleted file mode 100644 index 84b4db8286..0000000000 --- a/include/circt/Dialect/Comb/CombPasses.h +++ /dev/null @@ -1,33 +0,0 @@ -//===- Passes.h - Comb pass entry points ------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header file defines prototypes that expose pass constructors. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_COMB_COMBPASSES_H -#define CIRCT_DIALECT_COMB_COMBPASSES_H - -#include -#include - -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project - -namespace circt { -namespace comb { - -/// Generate the code for registering passes. -#define GEN_PASS_DECL -#define GEN_PASS_REGISTRATION -#include "include/circt/Dialect/Comb/Passes.h.inc" - -} // namespace comb -} // namespace circt - -#endif // CIRCT_DIALECT_COMB_COMBPASSES_H diff --git a/include/circt/Dialect/Comb/CombVisitors.h b/include/circt/Dialect/Comb/CombVisitors.h deleted file mode 100644 index c92ed4e966..0000000000 --- a/include/circt/Dialect/Comb/CombVisitors.h +++ /dev/null @@ -1,113 +0,0 @@ -//===- CombVisitors.h - Comb Dialect Visitors -------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines visitors that make it easier to work with the Comb IR. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_COMB_COMBVISITORS_H -#define CIRCT_DIALECT_COMB_COMBVISITORS_H - -#include "include/circt/Dialect/Comb/CombOps.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project - -namespace circt { -namespace comb { - -/// This helps visit Combinational nodes. -template -class CombinationalVisitor { - public: - ResultType dispatchCombinationalVisitor(Operation *op, ExtraArgs... args) { - auto *thisCast = static_cast(this); - return TypeSwitch(op) - .template Case< - // Arithmetic and Logical Binary Operations. - AddOp, SubOp, MulOp, DivUOp, DivSOp, ModUOp, ModSOp, ShlOp, ShrUOp, - ShrSOp, - // Bitwise operations - AndOp, OrOp, XorOp, - // Comparison operations - ICmpOp, - // Reduction Operators - ParityOp, - // Other operations. - ConcatOp, ReplicateOp, ExtractOp, MuxOp>( - [&](auto expr) -> ResultType { - return thisCast->visitComb(expr, args...); - }) - .Default([&](auto expr) -> ResultType { - return thisCast->visitInvalidComb(op, args...); - }); - } - - /// This callback is invoked on any non-expression operations. - ResultType visitInvalidComb(Operation *op, ExtraArgs... args) { - op->emitOpError("unknown combinational node"); - abort(); - } - - /// This callback is invoked on any combinational operations that are not - /// handled by the concrete visitor. - ResultType visitUnhandledComb(Operation *op, ExtraArgs... args) { - return ResultType(); - } - - /// This fallback is invoked on any binary node that isn't explicitly handled. - /// The default implementation delegates to the 'unhandled' fallback. - ResultType visitBinaryComb(Operation *op, ExtraArgs... args) { - return static_cast(this)->visitUnhandledComb(op, args...); - } - - ResultType visitUnaryComb(Operation *op, ExtraArgs... args) { - return static_cast(this)->visitUnhandledComb(op, args...); - } - - ResultType visitVariadicComb(Operation *op, ExtraArgs... args) { - return static_cast(this)->visitUnhandledComb(op, args...); - } - -#define HANDLE(OPTYPE, OPKIND) \ - ResultType visitComb(OPTYPE op, ExtraArgs... args) { \ - return static_cast(this)->visit##OPKIND##Comb(op, \ - args...); \ - } - - // Arithmetic and Logical Binary Operations. - HANDLE(AddOp, Binary); - HANDLE(SubOp, Binary); - HANDLE(MulOp, Binary); - HANDLE(DivUOp, Binary); - HANDLE(DivSOp, Binary); - HANDLE(ModUOp, Binary); - HANDLE(ModSOp, Binary); - HANDLE(ShlOp, Binary); - HANDLE(ShrUOp, Binary); - HANDLE(ShrSOp, Binary); - - HANDLE(AndOp, Variadic); - HANDLE(OrOp, Variadic); - HANDLE(XorOp, Variadic); - - HANDLE(ParityOp, Unary); - - HANDLE(ICmpOp, Binary); - - // Other operations. - HANDLE(ConcatOp, Unhandled); - HANDLE(ReplicateOp, Unhandled); - HANDLE(ExtractOp, Unhandled); - HANDLE(MuxOp, Unhandled); -#undef HANDLE -}; - -} // namespace comb -} // namespace circt - -#endif // CIRCT_DIALECT_COMB_COMBVISITORS_H diff --git a/include/circt/Dialect/Comb/Combinational.td b/include/circt/Dialect/Comb/Combinational.td index 8e53c843a1..4090e4bf81 100644 --- a/include/circt/Dialect/Comb/Combinational.td +++ b/include/circt/Dialect/Comb/Combinational.td @@ -17,6 +17,10 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/IR/EnumAttr.td" +def HWIntegerType : Type< + CPred<"$_self.isSignlessInteger()">, "signless integer", + "::mlir::IntegerType">; + // Base class for binary operators. class BinOp traits = []> : CombOp { @@ -45,8 +49,8 @@ class UTVariadicOp traits = []> : VariadicOp { - let hasCanonicalizeMethod = true; - let hasFolder = true; + // let hasCanonicalizeMethod = true; + // let hasFolder = true; let hasVerifier = 1; let assemblyFormat = "(`bin` $twoState^)? $inputs attr-dict `:` qualified(type($result))"; @@ -62,18 +66,18 @@ class UTVariadicOp traits = []> : // Arithmetic and Logical Operations. def AddOp : UTVariadicOp<"add", [Commutative]>; def MulOp : UTVariadicOp<"mul", [Commutative]>; -let hasFolder = true in { - def DivUOp : UTBinOp<"divu">; - def DivSOp : UTBinOp<"divs">; - def ModUOp : UTBinOp<"modu">; - def ModSOp : UTBinOp<"mods">; - let hasCanonicalizeMethod = true in { - def ShlOp : UTBinOp<"shl">; - def ShrUOp : UTBinOp<"shru">; - def ShrSOp : UTBinOp<"shrs">; - def SubOp : UTBinOp<"sub">; - } -} +// let hasFolder = true in { +// def DivUOp : UTBinOp<"divu">; +// def DivSOp : UTBinOp<"divs">; +// def ModUOp : UTBinOp<"modu">; +// def ModSOp : UTBinOp<"mods">; +// let hasCanonicalizeMethod = true in { +// def ShlOp : UTBinOp<"shl">; +// def ShrUOp : UTBinOp<"shru">; +// def ShrSOp : UTBinOp<"shrs">; +// def SubOp : UTBinOp<"sub">; +// } +// } def AndOp : UTVariadicOp<"and", [Commutative]>; def OrOp : UTVariadicOp<"or", [Commutative]>; @@ -132,8 +136,8 @@ def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { let assemblyFormat = "(`bin` $twoState^)? $predicate $lhs `,` $rhs attr-dict `:` qualified(type($lhs))"; - let hasFolder = true; - let hasCanonicalizeMethod = true; + // let hasFolder = true; + // let hasCanonicalizeMethod = true; let extraClassDeclaration = [{ /// Returns the flipped predicate, reversing the LHS and RHS operands. The @@ -166,7 +170,7 @@ class UnaryI1ReductionOp traits = []> : CombOp { let arguments = (ins HWIntegerType:$input, UnitAttr:$twoState); let results = (outs I1:$result); - let hasFolder = 1; + // let hasFolder = 1; let assemblyFormat = "(`bin` $twoState^)? $input attr-dict `:` qualified(type($input))"; } @@ -188,9 +192,9 @@ def ExtractOp : CombOp<"extract", [Pure]> { let assemblyFormat = "$input `from` $lowBit attr-dict `:` functional-type($input, $result)"; - let hasFolder = true; + // let hasFolder = true; let hasVerifier = 1; - let hasCanonicalizeMethod = true; + // let hasCanonicalizeMethod = true; let builders = [ OpBuilder<(ins "Value":$lhs, "int32_t":$lowBit, "int32_t":$bitWidth), [{ @@ -212,8 +216,8 @@ def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { let arguments = (ins Variadic:$inputs); let results = (outs HWIntegerType:$result); - let hasFolder = true; - let hasCanonicalizeMethod = true; + // let hasFolder = true; + // let hasCanonicalizeMethod = true; let hasVerifier = 1; let assemblyFormat = "$inputs attr-dict `:` qualified(type($inputs))"; @@ -246,7 +250,7 @@ def ReplicateOp : CombOp<"replicate", [Pure]> { let assemblyFormat = "$input attr-dict `:` functional-type($input, $result)"; - let hasFolder = true; + // let hasFolder = true; let hasVerifier = 1; let builders = [ @@ -283,8 +287,8 @@ def MuxOp : CombOp<"mux", let assemblyFormat = "(`bin` $twoState^)? $cond `,` $trueValue `,` $falseValue attr-dict `:` qualified(type($result))"; - let hasFolder = true; - let hasCanonicalizer = true; + // let hasFolder = true; + // let hasCanonicalizer = true; } def TruthTableOp : CombOp<"truth_table", [Pure]> { diff --git a/include/circt/Dialect/Comb/Passes.td b/include/circt/Dialect/Comb/Passes.td deleted file mode 100644 index 5ecaa997df..0000000000 --- a/include/circt/Dialect/Comb/Passes.td +++ /dev/null @@ -1,25 +0,0 @@ -//===-- Passes.td - Comb pass definition file --------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_COMB_PASSES_TD -#define CIRCT_DIALECT_COMB_PASSES_TD - -include "mlir/Pass/PassBase.td" - -def LowerComb : Pass<"lower-comb", "mlir::ModuleOp"> { - let summary = "Lowers the some of the comb ops"; - let description = [{ - Some operations in the comb dialect (e.g. `comb.truth_table`) are not - directly supported by ExportVerilog. They need to be lowered into ops which - are supported. There are many ways to lower these ops so we do this in a - separate pass. This also allows the lowered form to participate in - optimizations like the comb canonicalizers. - }]; -} - -#endif // CIRCT_DIALECT_COMB_PASSES_TD diff --git a/include/circt/Dialect/HW/BUILD b/include/circt/Dialect/HW/BUILD deleted file mode 100644 index e6487a41c6..0000000000 --- a/include/circt/Dialect/HW/BUILD +++ /dev/null @@ -1,204 +0,0 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") - -package( - default_applicable_licenses = ["@heir//:license"], - default_visibility = ["//visibility:public"], -) - -exports_files( - glob(["*.h"]), -) - -td_library( - name = "td_files", - srcs = glob([ - "*.td", - ]), - includes = ["include"], - deps = [ - "@heir//include/circt/Support:td_files", - "@llvm-project//mlir:BuiltinDialectTdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:FunctionInterfacesTdFiles", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "dialect_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-dialect-decls", - ], - "HWDialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - ], - "HWDialect.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HWDialect.td", - deps = [ - ":td_files", - "@heir//include/circt/Support:interfaces_inc_gen", - ], -) - -gentbl_cc_library( - name = "types_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-typedef-decls", - ], - "HWTypes.h.inc", - ), - ( - [ - "-gen-typedef-defs", - ], - "HWTypes.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HW.td", - deps = [ - ":dialect_inc_gen", - ":td_files", - ], -) - -gentbl_cc_library( - name = "ops_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-op-decls", - ], - "HW.h.inc", - ), - ( - [ - "-gen-op-defs", - ], - "HW.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HW.td", - deps = [ - ":dialect_inc_gen", - ":td_files", - ":types_inc_gen", - ], -) - -gentbl_cc_library( - name = "attributes_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-attrdef-decls", - ], - "HWAttributes.h.inc", - ), - ( - [ - "-gen-attrdef-defs", - ], - "HWAttributes.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HW.td", - deps = [ - ":dialect_inc_gen", - ":td_files", - "@llvm-project//mlir:FunctionInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "op_interfaces_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-op-interface-decls", - ], - "HWOpInterfaces.h.inc", - ), - ( - [ - "-gen-op-interface-defs", - ], - "HWOpInterfaces.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HWOpInterfaces.td", - deps = [ - ":td_files", - ], -) - -gentbl_cc_library( - name = "type_interfaces_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-type-interface-decls", - ], - "HWTypeInterfaces.h.inc", - ), - ( - [ - "-gen-type-interface-defs", - ], - "HWTypeInterfaces.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HWTypeInterfaces.td", - deps = [ - ":attributes_inc_gen", - ":dialect_inc_gen", - ":td_files", - ], -) - -gentbl_cc_library( - name = "enum_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-enum-decls", - ], - "HWEnums.h.inc", - ), - ( - [ - "-gen-enum-defs", - ], - "HWEnums.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "HW.td", - deps = [ - ":dialect_inc_gen", - ":td_files", - ], -) diff --git a/include/circt/Dialect/HW/CMakeLists.txt b/include/circt/Dialect/HW/CMakeLists.txt deleted file mode 100644 index 8572e8a76c..0000000000 --- a/include/circt/Dialect/HW/CMakeLists.txt +++ /dev/null @@ -1,44 +0,0 @@ -add_circt_dialect(HW hw) - -set(LLVM_TARGET_DEFINITIONS HW.td) - -mlir_tablegen(HWAttributes.h.inc -gen-attrdef-decls) -mlir_tablegen(HWAttributes.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(MLIRHWAttrIncGen) -add_dependencies(circt-headers MLIRHWAttrIncGen) - -mlir_tablegen(HWEnums.h.inc -gen-enum-decls) -mlir_tablegen(HWEnums.cpp.inc -gen-enum-defs) -add_public_tablegen_target(MLIRHWEnumsIncGen) -add_dependencies(circt-headers MLIRHWEnumsIncGen) - -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls) -add_public_tablegen_target(CIRCTHWTransformsIncGen) - -set(LLVM_TARGET_DEFINITIONS HWOpInterfaces.td) -mlir_tablegen(HWOpInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(HWOpInterfaces.cpp.inc -gen-op-interface-defs) -add_public_tablegen_target(CIRCTHWOpInterfacesIncGen) -add_dependencies(circt-headers CIRCTHWOpInterfacesIncGen) - -set(LLVM_TARGET_DEFINITIONS HWTypeInterfaces.td) -mlir_tablegen(HWTypeInterfaces.h.inc -gen-type-interface-decls) -mlir_tablegen(HWTypeInterfaces.cpp.inc -gen-type-interface-defs) -add_public_tablegen_target(CIRCTHWTypeInterfacesIncGen) -add_dependencies(circt-headers CIRCTHWTypeInterfacesIncGen) - -# Generate Dialect documentation. -add_circt_doc(HWAggregates Dialects/HWAggregateOps -gen-op-doc) -add_circt_doc(HWAttributes Dialects/HWAttributes -gen-attrdef-doc) -add_circt_doc(HWAttributesNaming Dialects/HWAttributesNaming -gen-attrdef-doc) -add_circt_doc(HWMiscOps Dialects/HWMiscOps -gen-op-doc) -add_circt_doc(HWOpInterfaces Dialects/HWOpInterfaces -gen-op-interface-docs) -add_circt_doc(HWStructure Dialects/HWStructureOps -gen-op-doc) -add_circt_doc(HWTypeDecls Dialects/HWTypeDeclsOps -gen-op-doc) -add_circt_doc(HWTypeInterfaces Dialects/HWTypeInterfaces -gen-type-interface-docs) -add_circt_doc(HWTypes Dialects/HWTypes -gen-typedef-doc) -add_circt_doc(HWTypesImpl Dialects/HWTypesImpl -gen-typedef-doc) - -# Generate Pass documentation. -add_circt_doc(Passes HWPasses -gen-pass-doc) diff --git a/include/circt/Dialect/HW/ConversionPatterns.h b/include/circt/Dialect/HW/ConversionPatterns.h deleted file mode 100644 index fdb0bf715e..0000000000 --- a/include/circt/Dialect/HW/ConversionPatterns.h +++ /dev/null @@ -1,36 +0,0 @@ -//===- ConversionPatterns.h - Common Conversion patterns --------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_CONVERSIONPATTERNS_H -#define CIRCT_SUPPORT_CONVERSIONPATTERNS_H - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - -namespace circt { - -/// Generic pattern which replaces an operation by one of the same operation -/// name, but with converted attributes, operands, and result types to eliminate -/// illegal types. Uses generic builders based on OperationState to make sure -/// that this pattern can apply to _any_ operation. -/// -/// Useful when a conversion can be entirely defined by a TypeConverter. -struct TypeConversionPattern : public mlir::ConversionPattern { - public: - TypeConversionPattern(TypeConverter &converter, MLIRContext *context) - : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, context) {} - using ConversionPattern::ConversionPattern; - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -} // namespace circt - -#endif // CIRCT_SUPPORT_CONVERSIONPATTERNS_H diff --git a/include/circt/Dialect/HW/CustomDirectiveImpl.h b/include/circt/Dialect/HW/CustomDirectiveImpl.h deleted file mode 100644 index e8f739b378..0000000000 --- a/include/circt/Dialect/HW/CustomDirectiveImpl.h +++ /dev/null @@ -1,70 +0,0 @@ -//===- CustomDirectiveImpl.h - Table-gen custom directive impl --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides common custom directives for table-gen assembly formats. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_CUSTOMDIRECTIVEIMPL_H -#define CIRCT_DIALECT_HW_CUSTOMDIRECTIVEIMPL_H - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project - -namespace circt { - -//===----------------------------------------------------------------------===// -// InputPortList Custom Directive -//===----------------------------------------------------------------------===// - -/// Parse a list of instance input ports. -/// input-list ::= `(` ( input-element (`,` input-element )* )? `)` -/// input-element ::= identifier `:` value `:` type -ParseResult parseInputPortList( - OpAsmParser &parser, - SmallVectorImpl &inputs, - SmallVectorImpl &inputTypes, ArrayAttr &inputNames); - -/// Print a list of instance input ports. -void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, - TypeRange inputTypes, ArrayAttr inputNames); - -//===----------------------------------------------------------------------===// -// OutputPortList Custom Directive -//===----------------------------------------------------------------------===// - -/// Parse a list of instance output ports. -/// output-list ::= `(` ( output-element (`,` output-element )* )? `)` -/// output-element ::= identifier `:` type -ParseResult parseOutputPortList(OpAsmParser &parser, - SmallVectorImpl &resultTypes, - ArrayAttr &resultNames); - -/// Print a list of instance output ports. -void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, - ArrayAttr resultNames); - -//===----------------------------------------------------------------------===// -// OptionalParameterList Custom Directive -//===----------------------------------------------------------------------===// - -/// Parse an parameter list if present. -/// module-parameter-list ::= `<` parameter-decl (`,` parameter-decl)* `>` -/// parameter-decl ::= identifier `:` type -/// parameter-decl ::= identifier `:` type `=` attribute -ParseResult parseOptionalParameterList(OpAsmParser &parser, - ArrayAttr ¶meters); - -/// Print a parameter list for a module or instance. -void printOptionalParameterList(OpAsmPrinter &p, Operation *op, - ArrayAttr parameters); - -} // namespace circt - -#endif // CIRCT_DIALECT_HW_CUSTOMDIRECTIVEIMPL_H diff --git a/include/circt/Dialect/HW/HW.td b/include/circt/Dialect/HW/HW.td deleted file mode 100644 index 3935a7f46f..0000000000 --- a/include/circt/Dialect/HW/HW.td +++ /dev/null @@ -1,30 +0,0 @@ -//===- HW.td - HW dialect definition -----------------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the top level file for the HW dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HW_TD -#define CIRCT_DIALECT_HW_HW_TD - -include "include/circt/Dialect/HW/HWDialect.td" -include "include/circt/Dialect/HW/HWAttributes.td" -include "include/circt/Dialect/HW/HWAttributesNaming.td" - -include "include/circt/Dialect/HW/HWTypesImpl.td" -include "include/circt/Dialect/HW/HWTypes.td" - -include "include/circt/Dialect/HW/HWOpInterfaces.td" -include "include/circt/Dialect/HW/HWTypeInterfaces.td" -include "include/circt/Dialect/HW/HWMiscOps.td" -include "include/circt/Dialect/HW/HWAggregates.td" -include "include/circt/Dialect/HW/HWStructure.td" -include "include/circt/Dialect/HW/HWTypeDecls.td" - -#endif // CIRCT_DIALECT_HW_HW_TD diff --git a/include/circt/Dialect/HW/HWAggregates.td b/include/circt/Dialect/HW/HWAggregates.td deleted file mode 100644 index c348330b3e..0000000000 --- a/include/circt/Dialect/HW/HWAggregates.td +++ /dev/null @@ -1,306 +0,0 @@ -//===- HWAggregates.td - HW ops for structs/arrays/etc -----*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This describes the MLIR ops for working with aggregate values like structs -// and arrays. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWAGGREGATES_TD -#define CIRCT_DIALECT_HW_HWAGGREGATES_TD - -include "include/circt/Dialect/HW/HWDialect.td" -include "include/circt/Dialect/HW/HWTypes.td" -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/IR/OpAsmInterface.td" - -//===----------------------------------------------------------------------===// -// Constants -//===----------------------------------------------------------------------===// - -def AggregateConstantOp : HWOp<"aggregate_constant", [Pure, ConstantLike]> { - let summary = "Produce a constant aggregate value"; - let description = [{ - This operation produces a constant value of an aggregate type. Clock and - reset values are supported. For nested aggregates, embedded arrays are - used. - - Examples: - ```mlir - %result = hw.aggregate.constant [1 : i1, 2 : i2, 3 : i2] : !hw.struct - %result = hw.aggregate.constant [1 : i1, [2 : i2, 3 : i2]] : !hw.struct> - ``` - }]; - - let arguments = (ins ArrayAttr:$fields); - let results = (outs HWAggregateType:$result); - let assemblyFormat = "$fields attr-dict `:` type($result)"; - let hasVerifier = 1; - let hasFolder = 1; -} - -//===----------------------------------------------------------------------===// -// Packed Array Processing Operations -//===----------------------------------------------------------------------===// - -def ArrayCreateOp : HWOp<"array_create", [Pure, SameTypeOperands]> { - let summary = "Create an array from values"; - let description = [{ - Creates an array from a variable set of values. One or more values must be - listed. - - ``` - // %a, %b, %c are all i4 - %array = hw.array_create %a, %b, %c : i4 - ``` - - See the HW-SV rationale document for details on operand ordering. - }]; - - let arguments = (ins Variadic:$inputs); - let results = (outs ArrayType:$result); - - let hasVerifier = 1; - let hasFolder = 1; - let hasCanonicalizeMethod = 1; - - let hasCustomAssemblyFormat = 1; - let builders = [ - // ValueRange needs to contain at least one element. - OpBuilder<(ins "ValueRange":$input)> - ]; - - let extraClassDeclaration = [{ - /// If the all elements of the array are identical, returns that element - /// value. Otherwise returns a null value. - Value getUniformElement(); - /// Returns true if all array elements are identical. - bool isUniform() { return !!getUniformElement(); } - }]; -} - -def ArrayConcatOp : HWOp<"array_concat", [Pure]> { - let summary = "Concatenate some arrays"; - let description = [{ - Creates an array by concatenating a variable set of arrays. One or more - values must be listed. - - ``` - // %a, %b, %c are hw arrays of i4 with sizes 2, 5, and 4 respectively. - %array = hw.array_concat %a, %b, %c : (2, 5, 4 x i4) - // %array is !hw.array<11 x i4> - ``` - - See the HW-SV rationale document for details on operand ordering. - }]; - - let arguments = (ins Variadic:$inputs); - let results = (outs ArrayType:$result); - - let assemblyFormat = [{ - $inputs attr-dict `:` custom(type($inputs), qualified(type($result))) - }]; - - let builders = [ - // ValueRange needs to contain at least one element. - OpBuilder<(ins "ValueRange":$inputs)> - ]; - let hasFolder = 1; - let hasCanonicalizeMethod = 1; -} - -def ArraySliceOp : HWOp<"array_slice", [Pure]> { - let summary = "Get a range of values from an array"; - let description = [{ - Extracts a sub-range from an array. The range is from `lowIndex` to - `lowIndex` + the number of elements in the return type, non-inclusive on - the high end. For instance, - - ``` - // Slices 16 elements starting at '%offset'. - %subArray = hw.slice %largerArray at %offset : - (!hw.array<1024xi8>) -> !hw.array<16xi8> - ``` - - Would translate to the following SystemVerilog: - - ``` - logic [7:0][15:0] subArray = largerArray[offset +: 16]; - ``` - - Width of 'idx' is defined to be the precise number of bits required to - index the 'input' array. More precisely: for an input array of size M, - the width of 'idx' is ceil(log2(M)). Lower and upper bound indexes which - are larger than the size of the 'input' array results in undefined - behavior. - }]; - - let arguments = (ins ArrayType:$input, HWIntegerType:$lowIndex); - let results = (outs ArrayType:$dst); - - let hasVerifier = 1; - let hasFolder = 1; - let hasCanonicalizeMethod = 1; - - let assemblyFormat = [{ - $input`[`$lowIndex`]` attr-dict `:` - `(` custom(type($input), qualified(type($lowIndex))) `)` `->` qualified(type($dst)) - }]; -} - -class IndexBitWidthConstraint - : PredOpTrait<"Index width should be exactly clog2 (size of array), or either 0 or 1 if the array is a singleton.", - CPred<"isValidIndexBitWidth($" # index # ", $" # input # ")">>; - -class ArrayElementTypeConstraint - : TypesMatchWith<"Result must be arrays element type", - input, result, - "type_cast($_self).getElementType()">; - -// hw.array_get does not work with unpacked arrays. -def ArrayGetOp : HWOp<"array_get", - [Pure, IndexBitWidthConstraint<"index", "input">, - ArrayElementTypeConstraint<"result", "input">]> { - let summary = "Get the value in an array at the specified index"; - let arguments = (ins ArrayType:$input, HWIntegerType:$index); - let results = (outs HWNonInOutType:$result); - - let assemblyFormat = [{ - $input`[`$index`]` attr-dict `:` qualified(type($input)) `,` qualified(type($index)) - }]; - - let hasFolder = 1; - let hasCanonicalizeMethod = 1; -} - -//===----------------------------------------------------------------------===// -// Structure Processing Operations -//===----------------------------------------------------------------------===// - -def StructCreateOp : HWOp<"struct_create", [Pure]> { - let summary = "Create a struct from constituent parts."; - let arguments = (ins Variadic:$input); - let results = (outs StructType:$result); - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - let hasVerifier = 1; -} - -// Extract the value of a field of a structure. -def StructExtractOp : HWOp<"struct_extract", - [Pure, - DeclareOpInterfaceMethods - ]> { - let summary = "Extract a named field from a struct."; - let description = [{ - ``` - %result = hw.struct_extract %input["field"] : !hw.struct - ``` - }]; - - let arguments = (ins StructType:$input, StrAttr:$field); - let results = (outs HWNonInOutType:$result); - let hasCustomAssemblyFormat = 1; - - let builders = [ - OpBuilder<(ins "Value":$input, "StructType::FieldInfo":$field)>, - OpBuilder<(ins "Value":$input, "StringAttr":$field)>, - OpBuilder<(ins "Value":$input, "StringRef":$field), [{ - build(odsBuilder, odsState, input, odsBuilder.getStringAttr(field)); - }]> - ]; - - let hasFolder = 1; - let hasCanonicalizeMethod = 1; -} - -// Create a structure by replacing a field with a value in an existing one. -def StructInjectOp : HWOp<"struct_inject", [Pure, - AllTypesMatch<["input", "result"]>]> { - let summary = "Inject a value into a named field of a struct."; - let description = [{ - ``` - %result = hw.struct_inject %input["field"], %newValue - : !hw.struct - ``` - }]; - - let arguments = (ins StructType:$input, StrAttr:$field, - HWNonInOutType:$newValue); - let results = (outs StructType:$result); - let hasCustomAssemblyFormat = 1; - let hasFolder = 1; - let hasCanonicalizeMethod = 1; -} - -def StructExplodeOp : HWOp<"struct_explode", [Pure, - DeclareOpInterfaceMethods - ]> { - let summary = "Expand a struct into its constituent parts."; - let description = [{ - ``` - %result:2 = hw.struct_explode %input : !hw.struct - ``` - }]; - let arguments = (ins StructType:$input); - let results = (outs Variadic:$result); - let hasCustomAssemblyFormat = 1; - - let builders = [ - OpBuilder<(ins "Value":$input)> - ]; - - let hasFolder = 1; - let hasCanonicalizeMethod = 1; - } - -//===----------------------------------------------------------------------===// -// Union operations -//===----------------------------------------------------------------------===// - -def UnionCreateOp : HWOp<"union_create", [Pure]> { - let summary = "Create a union with the specified value."; - let description = [{ - Create a union with the value 'input', which can then be accessed via the - specified field. - - ``` - %x = hw.constant 0 : i3 - %z = hw.union_create "bar", %x : !hw.union - ``` - }]; - - let arguments = (ins StrAttr:$field, HWNonInOutType:$input); - let results = (outs UnionType:$result); - let hasCustomAssemblyFormat = 1; -} - -def UnionExtractOp : HWOp<"union_extract", [Pure, - DeclareOpInterfaceMethods, - ]> { - let summary = "Get a union member."; - let description = [{ - Get the value of a union, interpreting it as the type of the specified - member field. Extracting a value belonging to a different field than the - union was initially created will result in undefined behavior. - - ``` - %u = ... - %v = hw.union_extract %u["foo"] : !hw.union - // %v is of type 'i3' - ``` - }]; - - let arguments = (ins UnionType:$input, StrAttr:$field); - let results = (outs HWNonInOutType:$result); - let hasCustomAssemblyFormat = 1; -} - -#endif // CIRCT_DIALECT_HW_HWAGGREGATES_TD diff --git a/include/circt/Dialect/HW/HWAttributes.h b/include/circt/Dialect/HW/HWAttributes.h deleted file mode 100644 index 85586fdf36..0000000000 --- a/include/circt/Dialect/HW/HWAttributes.h +++ /dev/null @@ -1,45 +0,0 @@ -//===- HWAttributes.h - Declare HW dialect attributes ------------*- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_ATTRIBUTES_H -#define CIRCT_DIALECT_HW_ATTRIBUTES_H - -#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project - -namespace circt { -namespace hw { -class PEOAttr; -class EnumType; -enum class PEO : uint32_t; - -// Forward declaration. -class GlobalRefOp; - -/// Returns a resolved version of 'type' wherein any parameter reference -/// has been evaluated based on the set of provided 'parameters'. -mlir::FailureOr evaluateParametricType(mlir::Location loc, - mlir::ArrayAttr parameters, - mlir::Type type); - -/// Evaluates a parametric attribute (param.decl.ref/param.expr) based on a set -/// of provided parameter values. -mlir::FailureOr evaluateParametricAttr( - mlir::Location loc, mlir::ArrayAttr parameters, mlir::Attribute paramAttr); - -/// Returns true if any part of t is parametric. -bool isParametricType(mlir::Type t); - -} // namespace hw -} // namespace circt - -#define GET_ATTRDEF_CLASSES -#include "include/circt/Dialect/HW/HWAttributes.h.inc" - -#endif // CIRCT_DIALECT_HW_ATTRIBUTES_H diff --git a/include/circt/Dialect/HW/HWAttributes.td b/include/circt/Dialect/HW/HWAttributes.td deleted file mode 100644 index 03fead38b0..0000000000 --- a/include/circt/Dialect/HW/HWAttributes.td +++ /dev/null @@ -1,309 +0,0 @@ -//===- HWAttributes.td - Attributes for HW dialect ---------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines HW dialect specific attributes. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWATTRIBUTES_TD -#define CIRCT_DIALECT_HW_HWATTRIBUTES_TD - -include "include/circt/Dialect/HW/HWDialect.td" -include "mlir/IR/EnumAttr.td" -include "mlir/IR/BuiltinAttributeInterfaces.td" - -/// An attribute to indicate the output file an operation should be emitted to. -def OutputFileAttr : AttrDef { - let summary = "Output file attribute"; - let description = [{ - This attribute represents an output file for something which will be - printed. The `filename` string is the file to be output to. If `filename` - ends in a `/` it is considered an output directory. - - When ExportVerilog runs, one of the files produced is a list of all other - files which are produced. The flag `excludeFromFileList` controls if this - file should be included in this list. If any `OutputFileAttr` referring to - the same file sets this to `true`, it will be included in the file list. - This option defaults to `false`. - - For each file emitted by the verilog emitter, certain prelude output will - be included before the main content. The flag `includeReplicatedOps` can - be used to disable the addition of the prelude text. All `OutputFileAttr`s - referring to the same file must use a consistent setting for this value. - This option defaults to `true`. - - Examples: - ```mlir - #hw.ouput_file<"/home/tester/t.sv"> - #hw.ouput_file<"t.sv", excludeFromFileList, includeReplicatedOps> - ``` - }]; - let mnemonic = "output_file"; - let parameters = (ins "::mlir::StringAttr":$filename, - "::mlir::BoolAttr":$excludeFromFilelist, - "::mlir::BoolAttr":$includeReplicatedOps); - let builders = [ - AttrBuilderWithInferredContext<(ins - "::mlir::StringAttr":$filename, - "::mlir::BoolAttr":$excludeFromFileList, - "::mlir::BoolAttr":$includeReplicatedOps), [{ - return get(filename.getContext(), filename, excludeFromFileList, - includeReplicatedOps); - }]>, - ]; - - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - /// Get an OutputFileAttr from a string filename, canonicalizing the - /// filename. - static OutputFileAttr getFromFilename(::mlir::MLIRContext *context, - const ::mlir::Twine &filename, - bool excludeFromFileList = false, - bool includeReplicatedOps = false); - - /// Get an OutputFileAttr from a string filename, resolving it relative to - /// `directory`. If `filename` is an absolute path, the given `directory` - /// will not be used. - static OutputFileAttr getFromDirectoryAndFilename( - ::mlir::MLIRContext *context, - const ::mlir::Twine &directory, - const ::mlir::Twine &filename, - bool excludeFromFileList = false, - bool includeReplicatedOps = false); - - /// Get an OutputFileAttr from a string directory name. The name will have - /// a trailing `/` added if it is not there, ensuring that this will be - /// an output directory. - static OutputFileAttr getAsDirectory(::mlir::MLIRContext *context, - const ::mlir::Twine &directory, - bool excludeFromFileList = false, - bool includeReplicatedOps = false); - - /// Returns true if this a directory. - bool isDirectory(); - }]; -} - -// An attribute to indicate which filelist an operation's file should be -// included in. -def FileListAttr : AttrDef { - let summary = "Ouput filelist attribute"; - let description = [{ - This attribute represents an output filelist for something which will be - printed. The `filename` string is the file which the filename of the - operation to be output to. - - When ExportVerilog runs, some of the files produced are lists of other files - which are produced. Each filelist exported contains entities' output file - with `FileListAttr` marked. - - - Examples: - ```mlir - #hw.ouput_filelist<"/home/tester/t.F"> - #hw.ouput_filelist<"t.f"> - ``` - }]; - let mnemonic = "output_filelist"; - let parameters = (ins "::mlir::StringAttr":$filename); - let builders = [ - AttrBuilderWithInferredContext<(ins - "::mlir::StringAttr":$filename), [{ - return get(filename.getContext(), filename); - }]>, - ]; - - let assemblyFormat = "`<` $filename `>`"; - - let extraClassDeclaration = [{ - /// Get an OutputFileAttr from a string filename, canonicalizing the - /// filename. - static FileListAttr getFromFilename(::mlir::MLIRContext *context, - const ::mlir::Twine &filename); - }]; - -} - -def ParamDeclAttr : AttrDef { - let summary = "Module or instance parameter definition"; - let description = [{ - An attribute describing a module parameter, or instance parameter - specification. - }]; - - /// The value of the attribute - in a module, this is the default - /// value (and may be missing). In an instance, this is a required field that - /// specifies the value being passed. The verilog emitter omits printing the - /// parameter for an instance when the applied value and the default value are - /// the same. - let parameters = (ins "::mlir::StringAttr":$name, - AttributeSelfTypeParameter<"">:$type, - "::mlir::Attribute":$value); - let mnemonic = "param.decl"; - - let hasCustomAssemblyFormat = 1; - - let builders = [ - AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$name, - "::mlir::Type":$type), - "auto *context = type.getContext();\n" - "return $_get(context, name, type, Attribute());">, - AttrBuilderWithInferredContext<(ins "::mlir::StringRef":$name, - "::mlir::Type":$type), - "return get(StringAttr::get(type.getContext(), name), type);">, - - AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$name, - "::mlir::TypedAttr":$value), - "auto *context = value.getContext();\n" - "return $_get(context, name, value.getType(), value);">, - AttrBuilderWithInferredContext<(ins "::mlir::StringRef":$name, - "::mlir::TypedAttr":$value), - "return get(StringAttr::get(value.getContext(), name), value);"> - ]; - - let extraClassDeclaration = [{ - static ParamDeclAttr getWithName(ParamDeclAttr param, - ::mlir::StringAttr name) { - return get(param.getContext(), name, param.getType(), param.getValue()); - } - }]; -} - -/// An array of ParamDeclAttr's that may or may not have a 'value' specified, -/// to be used on hw.module or hw.instance. The hw.instance verifier further -/// ensures that all the values are specified. -def ParamDeclArrayAttr - : TypedArrayAttrBase; - -/// This attribute models a reference to a named parameter within a module body. -/// The type of the ParamDeclRefAttr must always be the same as the type of the -/// parameter being referenced. -def ParamDeclRefAttr : AttrDef { - let summary = "Is a reference to a parameter value."; - let parameters = (ins "::mlir::StringAttr":$name, - AttributeSelfTypeParameter<"">:$type); - let mnemonic = "param.decl.ref"; - - let builders = [ - AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$name, - "::mlir::Type":$type), [{ - return get(name.getContext(), name, type); - }]> - ]; - - let hasCustomAssemblyFormat = 1; -} - -def ParamVerbatimAttr : AttrDef { - let summary = - "Represents text to emit directly to SystemVerilog for a parameter"; - let parameters = (ins "::mlir::StringAttr":$value, - AttributeSelfTypeParameter<"">:$type); - let mnemonic = "param.verbatim"; - let hasCustomAssemblyFormat = 1; - let builders = [ - AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$value), [{ - return get(value, NoneType::get(value.getContext())); - }]>, - AttrBuilderWithInferredContext< - (ins "::mlir::StringAttr":$value, "::mlir::Type":$type), [{ - return get(value.getContext(), value, type); - }]>, - ]; -} - -/// Parameter Expression Opcodes. -let cppNamespace = "circt::hw" in { - -/// Fully Associative Expression Opcodes. -def PEO_Add : I32EnumAttrCase<"Add", 0, "add">; -def PEO_Mul : I32EnumAttrCase<"Mul", 1, "mul">; -def PEO_And : I32EnumAttrCase<"And", 2, "and">; -def PEO_Or : I32EnumAttrCase<"Or", 3, "or">; -def PEO_Xor : I32EnumAttrCase<"Xor", 4, "xor">; - -// Binary Expression Opcodes. -def PEO_Shl : I32EnumAttrCase<"Shl" , 5, "shl">; -def PEO_ShrU : I32EnumAttrCase<"ShrU", 6, "shru">; -def PEO_ShrS : I32EnumAttrCase<"ShrS", 7, "shrs">; -def PEO_DivU : I32EnumAttrCase<"DivU", 8, "divu">; -def PEO_DivS : I32EnumAttrCase<"DivS", 9, "divs">; -def PEO_ModU : I32EnumAttrCase<"ModU",10, "modu">; -def PEO_ModS : I32EnumAttrCase<"ModS",11, "mods">; - -// Unary Expression Opcodes. -def PEO_CLog2 : I32EnumAttrCase<"CLog2", 12, "clog2">; - -// String manipulation Opcodes. -def PEO_StrConcat : I32EnumAttrCase<"StrConcat", 13, "str.concat">; - -def PEOAttr : I32EnumAttr<"PEO", "Parameter Expression Opcode", - [PEO_Add, PEO_Mul, PEO_And, PEO_Or, PEO_Xor, - PEO_Shl, PEO_ShrU, PEO_ShrS, - PEO_DivU, PEO_DivS, PEO_ModU, PEO_ModS, - PEO_CLog2, PEO_StrConcat]>; -} - -def ParamExprAttr : AttrDef { - let summary = "Parameter expression combining operands"; - let parameters = (ins "PEO":$opcode, - ArrayRefParameter<"::mlir::TypedAttr">:$operands, - AttributeSelfTypeParameter<"">:$type); - let mnemonic = "param.expr"; - - // Force all clients to go through our building logic so we can canonicalize - // during building. - let skipDefaultBuilders = 1; - - let extraClassDeclaration = [{ - /// Build a parameter expression. This automatically canonicalizes and - /// folds, so it may not necessarily return a ParamExprAttr. - static mlir::TypedAttr get(PEO opcode, - mlir::ArrayRef operands); - - /// Build a binary parameter expression for convenience. - static mlir::TypedAttr get(PEO opcode, mlir::TypedAttr lhs, - mlir::TypedAttr rhs) { - mlir::TypedAttr operands[] = { lhs, rhs }; - return get(opcode, operands); - } - }]; - - let hasCustomAssemblyFormat = 1; -} - -// An attribute to indicate an enumeration value. -def EnumFieldAttr : AttrDef { - let summary = "Enumeration field attribute"; - let description = [{ - This attribute represents a field of an enumeration. - - Examples: - ```mlir - #hw.enum.value> - ``` - }]; - let mnemonic = "enum.field"; - let parameters = (ins "::mlir::StringAttr":$field, "::mlir::TypeAttr":$type); - - // Force all clients to go through our custom builder so we can check - // whether the requested enum value is part of the provided enum type. - let skipDefaultBuilders = 1; - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - /// Builds a new EnumFieldAttr of the provided value. - /// This will fail if the value is not a member of the provided enum type. - static EnumFieldAttr get(::mlir::Location loc, ::mlir::StringAttr value, mlir::Type type); - }]; -} - -#endif // CIRCT_DIALECT_HW_HWATTRIBUTES_TD diff --git a/include/circt/Dialect/HW/HWAttributesNaming.td b/include/circt/Dialect/HW/HWAttributesNaming.td deleted file mode 100644 index c2bee100f6..0000000000 --- a/include/circt/Dialect/HW/HWAttributesNaming.td +++ /dev/null @@ -1,70 +0,0 @@ -//===- HWAttributesNaming.td - Attributes for HW dialect ---*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines HW dialect attributes used in other dialects. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWATTRIBUTESNAMING -#define CIRCT_DIALECT_HW_HWATTRIBUTESNAMING - -include "include/circt/Dialect/HW/HWDialect.td" -include "mlir/IR/AttrTypeBase.td" - -def InnerRefAttr : AttrDef { - let summary = "Refer to a name inside a module"; - let description = [{ - This works like a symbol reference, but to a name inside a module. - }]; - let mnemonic = "innerNameRef"; - let parameters = (ins "::mlir::FlatSymbolRefAttr":$moduleRef, - "::mlir::StringAttr":$name); - - let builders = [ - AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$module, - "::mlir::StringAttr":$name), [{ - return $_get( - module.getContext(), mlir::FlatSymbolRefAttr::get(module), name); - }]> - ]; - - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - /// Get the InnerRefAttr for an operation and add the sym on it. - static InnerRefAttr getFromOperation(mlir::Operation *op, - mlir::StringAttr symName, - mlir::StringAttr moduleName); - - /// Return the name of the referenced module. - mlir::StringAttr getModule() const { return getModuleRef().getAttr(); } - }]; -} - -def GlobalRefAttr : AttrDef { - let summary = "Refer to a non-local symbol"; - let description = [{ - This works like a symbol reference, but to a global symbol with a possible - unique instance path. - }]; - let mnemonic = "globalNameRef"; - let parameters = (ins "::mlir::FlatSymbolRefAttr":$glblSym); - let builders = [ - AttrBuilderWithInferredContext<(ins "::circt::hw::GlobalRefOp":$ref),[{ - return get(ref.getContext(), SymbolRefAttr::get(ref)); - }]>, - ]; - - let assemblyFormat = "`<` $glblSym `>`"; - - let extraClassDeclaration = [{ - static constexpr char DialectAttrName[] = "circt.globalRef"; - }]; -} - -#endif // CIRCT_DIALECT_HW_HWATTRIBUTESNAMING diff --git a/include/circt/Dialect/HW/HWDialect.h b/include/circt/Dialect/HW/HWDialect.h deleted file mode 100644 index 4be1503800..0000000000 --- a/include/circt/Dialect/HW/HWDialect.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- HWDialect.h - HW dialect declaration ---------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines an HW MLIR dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWDIALECT_H -#define CIRCT_DIALECT_HW_HWDIALECT_H - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project - -// Pull in the dialect definition. -#include "include/circt/Dialect/HW/HWDialect.h.inc" - -// Pull in all enum type definitions and utility function declarations. -#include "include/circt/Dialect/HW/HWEnums.h.inc" - -#endif // CIRCT_DIALECT_HW_HWDIALECT_H diff --git a/include/circt/Dialect/HW/HWDialect.td b/include/circt/Dialect/HW/HWDialect.td deleted file mode 100644 index 4d9a12965f..0000000000 --- a/include/circt/Dialect/HW/HWDialect.td +++ /dev/null @@ -1,49 +0,0 @@ -//===- HWDialect.td - HW dialect definition ----------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This contains the HWDialect definition to be included in other files. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWDIALECT -#define CIRCT_DIALECT_HW_HWDIALECT - -include "mlir/IR/OpBase.td" - -def HWDialect : Dialect { - let name = "hw"; - let cppNamespace = "::circt::hw"; - - let summary = "Types and operations for the hardware dialect"; - let description = [{ - This dialect defines the `hw` dialect, which is intended to be a generic - representation of HW outside of a particular use-case. - }]; - - let hasConstantMaterializer = 1; - let useDefaultTypePrinterParser = 1; - - // Opt-out of properties for now, must migrate by LLVM 19. #5273. - let usePropertiesForAttributes = 0; - - let extraClassDeclaration = [{ - /// Register all HW types. - void registerTypes(); - /// Register all HW attributes. - void registerAttributes(); - - Attribute parseAttribute(DialectAsmParser &p, Type type) const override; - void printAttribute(Attribute attr, DialectAsmPrinter &p) const override; - }]; -} - -// Base class for the operation in this dialect. -class HWOp traits = []> : - Op; - -#endif // CIRCT_DIALECT_HW_HWDIALECT diff --git a/include/circt/Dialect/HW/HWInstanceGraph.h b/include/circt/Dialect/HW/HWInstanceGraph.h deleted file mode 100644 index 9cc93055ce..0000000000 --- a/include/circt/Dialect/HW/HWInstanceGraph.h +++ /dev/null @@ -1,56 +0,0 @@ -//===- InstanceGraph.h - Instance graph -------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the HW InstanceGraph, which is similar to a CallGraph. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWINSTANCEGRAPH_H -#define CIRCT_DIALECT_HW_HWINSTANCEGRAPH_H - -#include "include/circt/Dialect/HW/HWOpInterfaces.h" -#include "include/circt/Support/InstanceGraph.h" - -namespace circt { -namespace hw { - -/// HW-specific instance graph with a virtual entry node linking to -/// all publicly visible modules. -class InstanceGraph : public igraph::InstanceGraph { - public: - InstanceGraph(Operation *operation); - - /// Return the entry node linking to all public modules. - igraph::InstanceGraphNode *getTopLevelNode() override { return &entry; } - - /// Adds a module, updating links to entry. - igraph::InstanceGraphNode *addHWModule(HWModuleLike module); - - /// Erases a module, updating links to entry. - void erase(igraph::InstanceGraphNode *node) override; - - private: - using igraph::InstanceGraph::addModule; - igraph::InstanceGraphNode entry; -}; - -} // namespace hw -} // namespace circt - -// Specialisation for the HW instance graph. -template <> -struct llvm::GraphTraits - : public llvm::GraphTraits {}; - -template <> -struct llvm::DOTGraphTraits - : public llvm::DOTGraphTraits { - using llvm::DOTGraphTraits::DOTGraphTraits; -}; - -#endif // CIRCT_DIALECT_HW_HWINSTANCEGRAPH_H diff --git a/include/circt/Dialect/HW/HWMiscOps.td b/include/circt/Dialect/HW/HWMiscOps.td deleted file mode 100644 index b9bbe28690..0000000000 --- a/include/circt/Dialect/HW/HWMiscOps.td +++ /dev/null @@ -1,192 +0,0 @@ -//===- HWMiscOps.td - Miscellaneous HW ops -----------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This defines miscellaneous generic HW ops, like ConstantOp and BitcastOp. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWMISCOPS_TD -#define CIRCT_DIALECT_HW_HWMISCOPS_TD - -include "include/circt/Dialect/HW/HWAttributes.td" -include "include/circt/Dialect/HW/HWDialect.td" -include "include/circt/Dialect/HW/HWOpInterfaces.td" -include "include/circt/Dialect/HW/HWTypes.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" - -def ConstantOp - : HWOp<"constant", [Pure, ConstantLike, FirstAttrDerivedResultType, - DeclareOpInterfaceMethods]> { - let summary = "Produce a constant value"; - let description = [{ - The constant operation produces a constant value of standard integer type - without a sign. - ``` - %result = hw.constant 42 : t1 - ``` - }]; - - let arguments = (ins APIntAttr:$value); - let results = (outs HWIntegerType:$result); - let hasCustomAssemblyFormat = 1; - - let builders = [ - /// Build a ConstantOp from an APInt, infering the result type from the - /// width of the APInt. - OpBuilder<(ins "const APInt &":$value)>, - - /// This builder allows construction of small signed integers like 0, 1, -1 - /// matching a specified MLIR IntegerType. This shouldn't be used for - /// general constant folding because it only works with values that can be - /// expressed in an int64_t. Use APInt's instead. - OpBuilder<(ins "Type":$type, "int64_t":$value)>, - - /// Build a ConstantOp from a prebuilt attribute. - OpBuilder<(ins "IntegerAttr":$value)> - ]; - let hasFolder = true; - let hasVerifier = 1; -} - -def WireOp : HWOp<"wire", [ - SameOperandsAndResultType, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Assign a name or symbol to an SSA edge"; - let description = [{ - An `hw.wire` is used to assign a human-readable name or a symbol for remote - references to an SSA edge. It takes a single operand and returns its value - unchanged as a result. The operation guarantees the following: - - - If the wire has a symbol, the value of its operand remains observable - under that symbol within the IR. - - - If the wire has a name, the name is treated as a hint. If the wire - persists until code generation the resulting wire will have this name, - with a potential suffix to ensure uniqueness. If the wire is canonicalized - away, its name is propagated to its input operand as a name hint. - - - The users of its result will always observe the operand through the - operation itself, meaning that optimizations cannot bypass the wire. This - ensures that if the wire's value is *forced*, for example through a - Verilog force statement, the forced value will affect all users of the - wire in the output. - - Example: - ``` - %1 = hw.wire %0 : i42 - %2 = hw.wire %0 sym @mySym : i42 - %3 = hw.wire %0 name "myWire" : i42 - %myWire = hw.wire %0 : i42 - ``` - }]; - - let arguments = (ins AnyType:$input, - OptionalAttr:$name, - OptionalAttr:$inner_sym); - let results = (outs AnyType:$result); - - let hasFolder = true; - let hasCanonicalizeMethod = 1; - let builders = [ - OpBuilder<(ins "mlir::Value":$input, - CArg<"const StringAttrOrRef &", "{}">:$name, - CArg<"hw::InnerSymAttr", "{}">:$innerSym), [{ - auto *context = odsBuilder.getContext(); - odsState.addOperands(input); - if (auto attr = name.get(context)) - odsState.addAttribute(getNameAttrName(odsState.name), attr); - if (innerSym) - odsState.addAttribute(getInnerSymAttrName(odsState.name), innerSym); - odsState.addTypes(input.getType()); - }]> - ]; - - let assemblyFormat = [{ - $input (`sym` $inner_sym^)? custom($name) attr-dict - `:` qualified(type($input)) - }]; -} - -def KnownBitWidthType : Type, - "Type wherein the bitwidth in hardware is known">; - -def BitcastOp: HWOp<"bitcast", [Pure]> { - let summary = [{ - Reinterpret one value to another value of the same size and - potentially different type. See the `hw` dialect rationale document for - more details. - }]; - - let arguments = (ins KnownBitWidthType:$input); - let results = (outs KnownBitWidthType:$result); - let hasCanonicalizeMethod = true; - let hasFolder = true; - let hasVerifier = 1; - - let assemblyFormat = "$input attr-dict `:` functional-type($input, $result)"; -} - -def ParamValueOp : HWOp<"param.value", - [FirstAttrDerivedResultType, Pure, - ConstantLike]> { - let summary = [{ - Return the value of a parameter expression as an SSA value that may be used - by other ops. - }]; - - let arguments = (ins AnyAttr:$value); - let results = (outs HWValueType:$result); - let assemblyFormat = "custom($value, qualified(type($result))) attr-dict"; - let hasVerifier = 1; - let hasFolder = true; -} - -def EnumConstantOp : HWOp<"enum.constant", [Pure, ConstantLike, - DeclareOpInterfaceMethods]> { - let summary = "Produce a constant enumeration value."; - let description = [{ - The enum.constant operation produces an enumeration value of the specified - enum value attribute. - ``` - %0 = hw.enum.constant A : !hw.enum - ``` - }]; - - let arguments = (ins EnumFieldAttr:$field); - let results = (outs EnumType:$result); - let hasCustomAssemblyFormat = 1; - let hasFolder = true; - let hasVerifier = true; - let builders = [ - OpBuilder<(ins "hw::EnumFieldAttr":$field)>, - ]; -} - -def EnumCmpOp : HWOp<"enum.cmp", [Pure]> { - let summary = "Compare two values of an enumeration"; - let description = [{ - This operation compares two values with the same canonical enumeration - type, returning 0 if they are different, and 1 if they are the same. - - Example: - ```mlir - %enumcmp = hw.enum.cmp %A, %B : !hw.enum, !hw.enum - ``` - }]; - let arguments = (ins EnumType:$lhs, EnumType:$rhs); - let results = (outs I1:$result); - let hasVerifier = true; - let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` qualified(type($lhs)) `,` qualified(type($rhs)) - }]; -} - -#endif // CIRCT_DIALECT_HW_HWMISCOPS_TD diff --git a/include/circt/Dialect/HW/HWModuleGraph.h b/include/circt/Dialect/HW/HWModuleGraph.h deleted file mode 100644 index ef72c57960..0000000000 --- a/include/circt/Dialect/HW/HWModuleGraph.h +++ /dev/null @@ -1,186 +0,0 @@ -//===- HWModuleGraph.h - HWModule graph -------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the HWModuleGraph. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWMODULEGRAPH_H -#define CIRCT_DIALECT_HW_HWMODULEGRAPH_H - -#include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/HW/HWInstanceGraph.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/Seq/SeqOps.h" -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/GraphTraits.h" // from @llvm-project -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/ADT/iterator.h" // from @llvm-project -#include "llvm/include/llvm/Support/DOTGraphTraits.h" // from @llvm-project -#include "llvm/include/llvm/Support/GraphWriter.h" // from @llvm-project - -namespace circt { -namespace hw { -namespace detail { - -// Using declaration to avoid polluting global namespace with CIRCT-specific -// graph traits for mlir::Operation. -using HWOperation = mlir::Operation; - -} // namespace detail -} // namespace hw -} // namespace circt - -template <> -struct llvm::GraphTraits { - using NodeType = circt::hw::detail::HWOperation; - using NodeRef = NodeType *; - - using ChildIteratorType = mlir::Operation::user_iterator; - static NodeRef getEntryNode(NodeRef op) { return op; } - static ChildIteratorType child_begin(NodeRef op) { return op->user_begin(); } - static ChildIteratorType child_end(NodeRef op) { return op->user_end(); } -}; - -template <> -struct llvm::GraphTraits - : public llvm::GraphTraits { - using GraphType = circt::hw::HWModuleOp; - - static NodeRef getEntryNode(GraphType mod) { - return &mod.getBodyBlock()->front(); - } - - using nodes_iterator = pointer_iterator; - static nodes_iterator nodes_begin(GraphType mod) { - return nodes_iterator{mod.getBodyBlock()->begin()}; - } - static nodes_iterator nodes_end(GraphType mod) { - return nodes_iterator{mod.getBodyBlock()->end()}; - } -}; - -template <> -struct llvm::DOTGraphTraits - : public llvm::DefaultDOTGraphTraits { - using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - - static std::string getNodeLabel(circt::hw::detail::HWOperation *node, - circt::hw::HWModuleOp) { - return llvm::TypeSwitch(node) - .Case([&](auto) { return "+"; }) - .Case([&](auto) { return "-"; }) - .Case([&](auto) { return "&"; }) - .Case([&](auto) { return "|"; }) - .Case([&](auto) { return "^"; }) - .Case([&](auto) { return "*"; }) - .Case([&](auto) { return "mux"; }) - .Case( - [&](auto) { return ">>"; }) - .Case([&](auto) { return "<<"; }) - .Case([&](auto op) { - switch (op.getPredicate()) { - case circt::comb::ICmpPredicate::eq: - case circt::comb::ICmpPredicate::ceq: - case circt::comb::ICmpPredicate::weq: - return "=="; - case circt::comb::ICmpPredicate::wne: - case circt::comb::ICmpPredicate::cne: - case circt::comb::ICmpPredicate::ne: - return "!="; - case circt::comb::ICmpPredicate::uge: - case circt::comb::ICmpPredicate::sge: - return ">="; - case circt::comb::ICmpPredicate::ugt: - case circt::comb::ICmpPredicate::sgt: - return ">"; - case circt::comb::ICmpPredicate::ule: - case circt::comb::ICmpPredicate::sle: - return "<="; - case circt::comb::ICmpPredicate::ult: - case circt::comb::ICmpPredicate::slt: - return "<"; - } - llvm_unreachable("unhandled ICmp predicate"); - }) - .Case( - [&](auto op) { return op.getName().str(); }) - .Case([&](auto op) { - llvm::SmallString<64> valueString; - op.getValue().toString(valueString, 10, false); - return valueString.str().str(); - }) - .Default([&](auto op) { return op->getName().getStringRef().str(); }); - } - - std::string getNodeAttributes(circt::hw::detail::HWOperation *node, - circt::hw::HWModuleOp) { - return llvm::TypeSwitch(node) - .Case( - [&](auto) { return "fillcolor=darkgoldenrod1,style=filled"; }) - .Case([&](auto) { - return "shape=invtrapezium,fillcolor=bisque,style=filled"; - }) - .Case( - [&](auto) { return "fillcolor=lightblue,style=filled"; }) - .Default([&](auto op) { - return llvm::TypeSwitch( - op->getDialect()) - .Case([&](auto) { - return "shape=oval,fillcolor=bisque,style=filled"; - }) - .template Case([&](auto) { - return "shape=folder,fillcolor=gainsboro,style=filled"; - }) - .Default([&](auto) { return ""; }); - }); - } - - static void addCustomGraphFeatures( - circt::hw::HWModuleOp mod, llvm::GraphWriter &g) { - // Add module input args. - auto &os = g.getOStream(); - os << "subgraph cluster_entry_args {\n"; - os << "label=\"Input arguments\";\n"; - auto iports = mod.getPortList(); - for (auto [info, arg] : - llvm::zip(iports.getInputs(), mod.getBodyBlock()->getArguments())) { - g.emitSimpleNode(reinterpret_cast(&arg), "", - info.getName().str()); - } - os << "}\n"; - for (auto [info, arg] : - llvm::zip(iports.getInputs(), mod.getBodyBlock()->getArguments())) { - for (auto *user : arg.getUsers()) { - g.emitEdge(reinterpret_cast(&arg), 0, user, -1, ""); - } - } - } - - template - static std::string getEdgeAttributes(circt::hw::detail::HWOperation *node, - Iterator it, circt::hw::HWModuleOp mod) { - mlir::OpOperand &operand = *it.getCurrent(); - mlir::Value v = operand.get(); - std::string str; - llvm::raw_string_ostream os(str); - auto verboseEdges = mod->getAttrOfType("dot_verboseEdges"); - if (verboseEdges.getValue()) { - os << "label=\"" << operand.getOperandNumber() << " (" << v.getType() - << ")\""; - } - - int64_t width = circt::hw::getBitWidth(v.getType()); - if (width > 1) os << " style=bold"; - - return os.str(); - } -}; - -#endif // CIRCT_DIALECT_HW_HWMODULEGRAPH_H diff --git a/include/circt/Dialect/HW/HWOpInterfaces.h b/include/circt/Dialect/HW/HWOpInterfaces.h deleted file mode 100644 index 0dc8bc357a..0000000000 --- a/include/circt/Dialect/HW/HWOpInterfaces.h +++ /dev/null @@ -1,297 +0,0 @@ -//===- HWOpInterfaces.h - Declare HW op interfaces --------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the operation interfaces for the HW dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWOPINTERFACES_H -#define CIRCT_DIALECT_HW_HWOPINTERFACES_H - -#include "include/circt/Dialect/HW/HWTypes.h" -#include "include/circt/Dialect/HW/InnerSymbolTable.h" -#include "include/circt/Support/InstanceGraphInterface.h" -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project - -namespace circt { -namespace hw { - -void populateHWModuleLikeTypeConversionPattern(StringRef moduleLikeOpName, - RewritePatternSet &patterns, - TypeConverter &converter); - -/// This holds the name, type, direction of a module's ports -struct PortInfo : public ModulePort { - /// This is the argument index or the result index depending on the direction. - /// "0" for an output means the first output, "0" for a in/inout means the - /// first argument. - size_t argNum = ~0U; - - /// The optional symbol for this port. - InnerSymAttr sym = {}; - DictionaryAttr attrs = {}; - LocationAttr loc = {}; - - StringRef getName() const { return name.getValue(); } - bool isInput() const { return dir == ModulePort::Direction::Input; } - bool isOutput() const { return dir == ModulePort::Direction::Output; } - bool isInOut() const { return dir == ModulePort::Direction::InOut; } - - /// Return a unique numeric identifier for this port. - ssize_t getId() const { return isOutput() ? argNum : (-1 - argNum); }; -}; - -raw_ostream &operator<<(raw_ostream &printer, PortInfo port); - -/// This holds a decoded list of input/inout and output ports for a module or -/// instance. -struct ModulePortInfo { - explicit ModulePortInfo(ArrayRef inputs, - ArrayRef outputs) { - ports.insert(ports.end(), inputs.begin(), inputs.end()); - ports.insert(ports.end(), outputs.begin(), outputs.end()); - sanitizeInOut(); - } - - explicit ModulePortInfo(ArrayRef mergedPorts) - : ports(mergedPorts.begin(), mergedPorts.end()) { - sanitizeInOut(); - } - - using iterator = SmallVector::iterator; - using const_iterator = SmallVector::const_iterator; - - iterator begin() { return ports.begin(); } - iterator end() { return ports.end(); } - const_iterator begin() const { return ports.begin(); } - const_iterator end() const { return ports.end(); } - - using PortDirectionRange = llvm::iterator_range< - llvm::filter_iterator>>; - - using ConstPortDirectionRange = llvm::iterator_range>>; - - PortDirectionRange getPortsOfDirection(bool input) { - std::function predicateFn; - if (input) { - predicateFn = [](const PortInfo &port) -> bool { - return port.dir == ModulePort::Direction::Input || - port.dir == ModulePort::Direction::InOut; - }; - } else { - predicateFn = [](const PortInfo &port) -> bool { - return port.dir == ModulePort::Direction::Output; - }; - } - return llvm::make_filter_range(ports, predicateFn); - } - - ConstPortDirectionRange getPortsOfDirection(bool input) const { - std::function predicateFn; - if (input) { - predicateFn = [](const PortInfo &port) -> bool { - return port.dir == ModulePort::Direction::Input || - port.dir == ModulePort::Direction::InOut; - }; - } else { - predicateFn = [](const PortInfo &port) -> bool { - return port.dir == ModulePort::Direction::Output; - }; - } - return llvm::make_filter_range(ports, predicateFn); - } - - PortDirectionRange getInputs() { return getPortsOfDirection(true); } - - PortDirectionRange getOutputs() { return getPortsOfDirection(false); } - - ConstPortDirectionRange getInputs() const { - return getPortsOfDirection(true); - } - - ConstPortDirectionRange getOutputs() const { - return getPortsOfDirection(false); - } - - size_t size() const { return ports.size(); } - size_t sizeInputs() const { - auto r = getInputs(); - return std::distance(r.begin(), r.end()); - } - size_t sizeOutputs() const { - auto r = getOutputs(); - return std::distance(r.begin(), r.end()); - } - - size_t portNumForInput(size_t idx) const { - size_t port = 0; - while (idx || ports[port].isOutput()) { - if (!ports[port].isOutput()) --idx; - ++port; - } - return port; - } - - size_t portNumForOutput(size_t idx) const { - size_t port = 0; - while (idx || !ports[port].isOutput()) { - if (ports[port].isOutput()) --idx; - ++port; - } - return port; - } - - PortInfo &at(size_t idx) { return ports[idx]; } - PortInfo &atInput(size_t idx) { return ports[portNumForInput(idx)]; } - PortInfo &atOutput(size_t idx) { return ports[portNumForOutput(idx)]; } - - const PortInfo &at(size_t idx) const { return ports[idx]; } - const PortInfo &atInput(size_t idx) const { - return ports[portNumForInput(idx)]; - } - const PortInfo &atOutput(size_t idx) const { - return ports[portNumForOutput(idx)]; - } - - void eraseInput(size_t idx) { - assert(idx < sizeInputs()); - ports.erase(ports.begin() + portNumForInput(idx)); - } - - private: - // convert input inout -> inout type - void sanitizeInOut() { - for (auto &p : ports) - if (auto inout = dyn_cast(p.type)) { - p.type = inout.getElementType(); - p.dir = ModulePort::Direction::InOut; - } - } - - /// This contains a list of all ports. Input first. - SmallVector ports; -}; - -// This provides capability for looking up port indices based on port names. -struct ModulePortLookupInfo { - FailureOr lookupPortIndex( - const llvm::DenseMap &portMap, - StringAttr name) const { - auto it = portMap.find(name); - if (it == portMap.end()) return failure(); - return it->second; - } - - public: - explicit ModulePortLookupInfo(MLIRContext *ctx, - const ModulePortInfo &portInfo) - : ctx(ctx) { - for (auto &in : portInfo.getInputs()) inputPortMap[in.name] = in.argNum; - - for (auto &out : portInfo.getOutputs()) - outputPortMap[out.name] = out.argNum; - } - - // Return the index of the input port with the specified name. - FailureOr getInputPortIndex(StringAttr name) const { - return lookupPortIndex(inputPortMap, name); - } - - // Return the index of the output port with the specified name. - FailureOr getOutputPortIndex(StringAttr name) const { - return lookupPortIndex(outputPortMap, name); - } - - FailureOr getInputPortIndex(StringRef name) const { - return getInputPortIndex(StringAttr::get(ctx, name)); - } - - FailureOr getOutputPortIndex(StringRef name) const { - return getOutputPortIndex(StringAttr::get(ctx, name)); - } - - private: - llvm::DenseMap inputPortMap; - llvm::DenseMap outputPortMap; - MLIRContext *ctx; -}; - -class InnerSymbolOpInterface; -/// Verification hook for verifying InnerSym Attribute. -LogicalResult verifyInnerSymAttr(InnerSymbolOpInterface op); - -namespace detail { -LogicalResult verifyInnerRefNamespace(Operation *op); -} // namespace detail - -/// Classify operations that are InnerRefNamespace-like, -/// until structure is in place to do this via Traits. -/// Useful for getParentOfType<>, or scheduling passes. -/// Prefer putting the trait on operations here or downstream. -struct InnerRefNamespaceLike { - /// Return if this operation is explicitly an IRN or appears compatible. - static bool classof(mlir::Operation *op); - /// Return if this operation is explicitly an IRN or appears compatible. - static bool classof(const mlir::RegisteredOperationName *opInfo); -}; - -} // namespace hw -} // namespace circt - -namespace mlir { -namespace OpTrait { - -/// This trait is for operations that define a scope for resolving InnerRef's, -/// and provides verification for InnerRef users (via InnerRefUserOpInterface). -template -class InnerRefNamespace : public TraitBase { - public: - static LogicalResult verifyRegionTrait(Operation *op) { - static_assert( - ConcreteType::template hasTrait<::mlir::OpTrait::SymbolTable>(), - "expected operation to be a SymbolTable"); - - if (op->getNumRegions() != 1) - return op->emitError("expected operation to have a single region"); - if (!op->getRegion(0).hasOneBlock()) - return op->emitError("expected operation to have a single block"); - - // Verify all InnerSymbolTable's and InnerRef users. - return ::circt::hw::detail::verifyInnerRefNamespace(op); - } -}; - -/// A trait for inner symbol table functionality on an operation. -template -class InnerSymbolTable : public TraitBase { - public: - static LogicalResult verifyRegionTrait(Operation *op) { - // Insist that ops with InnerSymbolTable's provide a Symbol, this is - // essential to how InnerRef's work. - static_assert( - ConcreteType::template hasTrait<::mlir::SymbolOpInterface::Trait>(), - "expected operation to define a Symbol"); - - // InnerSymbolTable's must be directly nested within an InnerRefNamespace. - auto *parent = op->getParentOp(); - if (!parent || !isa(parent)) - return op->emitError( - "InnerSymbolTable must have InnerRefNamespace parent"); - - return success(); - } -}; -} // namespace OpTrait -} // namespace mlir - -#include "include/circt/Dialect/HW/HWOpInterfaces.h.inc" - -#endif // CIRCT_DIALECT_HW_HWOPINTERFACES_H diff --git a/include/circt/Dialect/HW/HWOpInterfaces.td b/include/circt/Dialect/HW/HWOpInterfaces.td deleted file mode 100644 index 930fcb7500..0000000000 --- a/include/circt/Dialect/HW/HWOpInterfaces.td +++ /dev/null @@ -1,528 +0,0 @@ -//===- HWOpInterfaces.td - Operation Interfaces ------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This describes the HW operation interfaces. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWOPINTERFACES -#define CIRCT_DIALECT_HW_HWOPINTERFACES - -include "mlir/IR/SymbolInterfaces.td" -include "mlir/IR/OpBase.td" -include "include/circt/Support/InstanceGraphInterface.td" - -def PortList : OpInterface<"PortList", []> { - let cppNamespace = "circt::hw"; - let description = "Operations which produce a unified port list representation"; - let methods = [ - InterfaceMethod<"Get port list", - "::circt::hw::ModulePortInfo", "getPortList", (ins)>, - - InterfaceMethod<"Get the port a specific input", - "size_t", "getPortIdForInputId", (ins "size_t":$idx)>, - - InterfaceMethod<"Get the port a specific output", - "size_t", "getPortIdForOutputId", (ins "size_t":$idx)>, - - InterfaceMethod<"Get the number of ports", - "size_t", "getNumPorts", (ins)>, - InterfaceMethod<"Get the number of input ports", - "size_t", "getNumInputPorts", (ins)>, - InterfaceMethod<"Get the number of output ports", - "size_t", "getNumOutputPorts", (ins)>, - ]; -} - -def SymboledPortList : OpInterface<"SymboledPortList", [PortList]> { - let cppNamespace = "circt::hw"; - let description = "Operations which produce a unified port list representation which includes port symbols"; - let methods = [ - InterfaceMethod<"Get a port symbol attribute", - "::circt::hw::InnerSymAttr", "getPortSymbolAttr", (ins "size_t":$portIndex)>, - ]; -} - - -def HWModuleLike : OpInterface<"HWModuleLike", [ - Symbol, SymboledPortList, InstanceGraphModuleOpInterface]> { - let cppNamespace = "circt::hw"; - let description = "Provide common module information."; - - let methods = [ - InterfaceMethod<"Get the module type", - "::circt::hw::ModuleType", "getHWModuleType", (ins)>, - - InterfaceMethod<"Get the port Attributes", - "SmallVector", "getAllPortAttrs", (ins)>, - - InterfaceMethod<"Set the port Attributes", - "void", "setAllPortAttrs", (ins "ArrayRef":$attrs)>, - - InterfaceMethod<"Remove the port Attributes", - "void", "removeAllPortAttrs", (ins)>, - - InterfaceMethod<"Get the port Locations", - "SmallVector", "getAllPortLocs", (ins)>, - - InterfaceMethod<"Set the port Locations", - "void", "setAllPortLocs", (ins "ArrayRef":$locs)>, - - InterfaceMethod<"Set the module type (and port names)", - "void", "setHWModuleType", (ins "::circt::hw::ModuleType":$type)>, - - InterfaceMethod<"Set the port names", - "void", "setAllPortNames", (ins "ArrayRef":$names)>, - - ]; - - let extraSharedClassDeclaration = [{ - /// Attribute name for port symbols. - static StringRef getPortSymbolAttrName() { - return "hw.exportPort"; - } - - /// Return the region containing the body of this function. - Region &getModuleBody() { return $_op->getRegion(0); } - Block *getBodyBlock() { - Region &body = getModuleBody(); - if (body.empty()) - return nullptr; - return &body.front(); - } - - /// Returns the entry block argument at the given index. - BlockArgument getArgumentForInput(unsigned idx) { - return $_op.getModuleBody().getArgument(idx); - } - - /// Returns the entry block argument for the given port. May be null. - BlockArgument getArgumentForPort(unsigned idx) { - return $_op.getModuleBody().getArgument($_op.getHWModuleType().getInputIdForPortId(idx)); - } - - SmallVector getPortTypes() { - return $_op.getHWModuleType().getPortTypes(); - } - - SmallVector getInputTypes() { - return $_op.getHWModuleType().getInputTypes(); - } - - SmallVector getOutputTypes() { - return $_op.getHWModuleType().getOutputTypes(); - } - - /// Return the set of names on input and inout ports - SmallVector getInputNamesStr() { - return $_op.getHWModuleType().getInputNamesStr(); - } - - /// Return the set of names on output ports - SmallVector getOutputNamesStr() { - return $_op.getHWModuleType().getOutputNamesStr(); - } - - /// Return the set of names on input and inout ports - SmallVector getInputNames() { - return $_op.getHWModuleType().getInputNames(); - } - - /// Return the set of names on output ports - SmallVector getOutputNames() { - return $_op.getHWModuleType().getOutputNames(); - } - - void setInputNames(ArrayRef names) { - SmallVector newNames(names.begin(), names.end()); - auto resNames = $_op.getOutputNames(); - newNames.append(resNames.begin(), resNames.end()); - $_op.setAllPortNames(newNames); - } - - void setOutputNames(ArrayRef names) { - SmallVector newNames = $_op.getInputNames(); - newNames.append(names.begin(), names.end()); - $_op.setAllPortNames(newNames); - } - - // Get the name for the specified input or inout port - StringRef getPortName(size_t idx) { - return $_op.getHWModuleType().getPortName(idx); - } - - // Get the name for the specified input or inout port - StringRef getInputName(size_t idx) { - return $_op.getHWModuleType().getInputName(idx); - } - - // Get the name for the specified output port - StringRef getOutputName(size_t idx) { - return $_op.getHWModuleType().getOutputName(idx); - } - - StringAttr getInputNameAttr(size_t idx) { - return $_op.getHWModuleType().getInputNameAttr(idx); - } - - StringAttr getOutputNameAttr(size_t idx) { - return $_op.getHWModuleType().getOutputNameAttr(idx); - } - - Attribute getPortAttrs(size_t idx) { - return $_op.getAllPortAttrs()[idx]; - } - - SmallVector getAllInputAttrs() { - auto attrs = $_op.getAllPortAttrs(); - SmallVector retval; - auto modType = $_op.getHWModuleType(); - for (unsigned x = 0, e = $_op.getNumInputPorts(); x < e; ++x) - retval.push_back(attrs[modType.getPortIdForInputId(x)]); - return retval; - } - - SmallVector getAllOutputAttrs() { - auto attrs = $_op.getAllPortAttrs(); - SmallVector retval; - auto modType = $_op.getHWModuleType(); - for (unsigned x = 0, e = $_op.getNumOutputPorts(); x < e; ++x) - retval.push_back(attrs[modType.getPortIdForOutputId(x)]); - return retval; - } - - void setAllInputAttrs(ArrayRef attrs) { - SmallVector retval(attrs.begin(), attrs.end()); - auto resAttrs = $_op.getAllOutputAttrs(); - retval.append(resAttrs.begin(), resAttrs.end()); - $_op.setAllPortAttrs(retval); - } - - void setAllOutputAttrs(ArrayRef attrs) { - SmallVector retval = $_op.getAllInputAttrs(); - retval.append(attrs.begin(), attrs.end()); - $_op.setAllPortAttrs(retval); - } - - Attribute getInputAttrs(size_t idx) { - return $_op.getAllPortAttrs()[$_op.getPortIdForInputId(idx)]; - } - - Attribute getOutputAttrs(size_t idx) { - return $_op.getAllPortAttrs()[$_op.getPortIdForOutputId(idx)]; - } - - void setPortAttrs(size_t idx, DictionaryAttr attr) { - auto attrs = $_op.getAllPortAttrs(); - attrs[idx] = attr; - $_op.setAllPortAttrs(attrs); - } - - void setPortAttr(size_t idx, StringAttr name, Attribute value) { - auto attrs = $_op.getAllPortAttrs(); - NamedAttrList pattr(cast(attrs[idx])); - Attribute oldValue; - if (!value) - oldValue = pattr.erase(name); - else - oldValue = pattr.set(name, value); - if (oldValue != value) { - attrs[idx] = pattr.getDictionary($_op.getContext()); - $_op.setAllPortAttrs(attrs); - } - } - - void setPortAttrs(StringAttr attrName, ArrayRef newAttrs) { - auto attrs = $_op.getAllPortAttrs(); - auto ctxt = $_op.getContext(); - assert(newAttrs.size() == attrs.size()); - for (size_t idx = 0, e = attrs.size(); idx != e; ++idx) { - NamedAttrList pattr(cast(attrs[idx])); - auto newAttr = newAttrs[idx]; - if (newAttr) - pattr.set(attrName, newAttr); - else - pattr.erase(attrName); - attrs[idx] = pattr.getDictionary(ctxt); - } - $_op.setAllPortAttrs(attrs); - } - - Location getPortLoc(size_t idx) { - return $_op.getAllPortLocs()[idx]; - } - - void setPortLoc(size_t idx, Location loc) { - auto locs = $_op.getAllPortLocs(); - locs[idx] = loc; - return $_op.setAllPortLocs(locs); - } - - Location getInputLoc(size_t idx) { - return $_op.getAllPortLocs()[$_op.getPortIdForInputId(idx)]; - } - - Location getOutputLoc(size_t idx) { - return $_op.getAllPortLocs()[$_op.getPortIdForOutputId(idx)]; - } - - SmallVector getInputLocs() { - auto locs = $_op.getAllPortLocs(); - SmallVector retval; - for (unsigned x = 0, e = $_op.getNumInputPorts(); x < e; ++x) - retval.push_back(locs[$_op.getPortIdForInputId(x)]); - return retval; - } - - ArrayAttr getInputLocsAttr() { - auto locs = $_op.getAllPortLocs(); - SmallVector retval; - for (unsigned x = 0, e = $_op.getNumInputPorts(); x < e; ++x) - retval.push_back(locs[$_op.getPortIdForInputId(x)]); - return ArrayAttr::get($_op->getContext(), retval); - } - - void setInputLocs(ArrayRef inAttrs) { - assert(inAttrs.size() == $_op.getNumInputPorts()); - auto outAttrs = getOutputLocs(); - SmallVector attrs; - attrs.append(inAttrs.begin(), inAttrs.end()); - attrs.append(outAttrs.begin(), outAttrs.end()); - $_op.setAllPortLocs(attrs); - } - - SmallVector getOutputLocs() { - auto locs = $_op.getAllPortLocs(); - SmallVector retval; - for (unsigned x = 0, e = $_op.getNumOutputPorts(); x < e; ++x) - retval.push_back(locs[$_op.getPortIdForOutputId(x)]); - return retval; - } - - ArrayAttr getOutputLocsAttr() { - auto locs = $_op.getAllPortLocs(); - SmallVector retval; - for (unsigned x = 0, e = $_op.getNumOutputPorts(); x < e; ++x) - retval.push_back(locs[$_op.getPortIdForOutputId(x)]); - return ArrayAttr::get($_op->getContext(), retval); - } - - void setOutputLocs(ArrayRef outAttrs) { - assert(outAttrs.size() == $_op.getNumOutputPorts()); - auto inAttrs = getInputLocs(); - SmallVector attrs; - attrs.append(inAttrs.begin(), inAttrs.end()); - attrs.append(outAttrs.begin(), outAttrs.end()); - $_op.setAllPortLocs(attrs); - } - }]; - - let verify = [{ - static_assert( - ConcreteOp::template hasTrait<::mlir::SymbolOpInterface::Trait>(), - "expected operation to be a symbol"); - return success(); - }]; -} - -def HWMutableModuleLike : OpInterface<"HWMutableModuleLike", [HWModuleLike]> { - let cppNamespace = "circt::hw"; - let description = "Provide methods to mutate a module."; - - let methods = [ - - InterfaceMethod<"Get a handle to a utility class which provides by-name lookup of port indices. The returned object does _not_ update if the module is mutated.", - "::circt::hw::ModulePortLookupInfo", "getPortLookupInfo", (ins), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - return hw::ModulePortLookupInfo( - $_op->getContext(), - $_op.getPortList()); - }]>, - - /// Insert and remove input and output ports of this module. Does not modify - /// the block arguments of the module body. The insertion and removal - /// indices must be in ascending order. The indices refer to the port - /// positions before any insertion or removal occurs. Ports inserted at the - /// same index will appear in the module in the same order as they were - /// listed in the insertion arrays. - InterfaceMethod<"Insert and remove input and output ports", - "void", "modifyPorts", (ins - "ArrayRef>":$insertInputs, - "ArrayRef>":$insertOutputs, - "ArrayRef":$eraseInputs, "ArrayRef":$eraseOutputs), - /*methodBody=*/[{ - $_op.modifyPorts(insertInputs, insertOutputs, eraseInputs, eraseOutputs); - }]>, - - /// Insert ports into the module. Does not modify the block arguments of the - /// module body. - InterfaceMethod<"Insert ports into this module", - "void", "insertPorts", (ins - "ArrayRef>":$insertInputs, - "ArrayRef>":$insertOutputs), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - $_op.modifyPorts(insertInputs, insertOutputs, {}, {}); - }]>, - - /// Erase ports from the module. Does not modify the block arguments of the - /// module body. - InterfaceMethod<"Erase ports from this module", - "void", "erasePorts", (ins - "ArrayRef":$eraseInputs, - "ArrayRef":$eraseOutputs), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ - $_op.modifyPorts({}, {}, eraseInputs, eraseOutputs); - }]>, - - /// Appends output ports to the module with the specified names and rewrites - /// the output op to return the associated values. - InterfaceMethod<"Append output values to this module", - "void", "appendOutputs", (ins - "ArrayRef>":$outputs)> - ]; -} - - -def HWInstanceLike : OpInterface<"HWInstanceLike", [ - PortList, InstanceGraphInstanceOpInterface]> { - let cppNamespace = "circt::hw"; - let description = "Provide common module information."; -} - -def InnerRefNamespace : NativeOpTrait<"InnerRefNamespace">; - -def InnerSymbol : OpInterface<"InnerSymbolOpInterface"> { - let description = [{ - This interface describes an operation that may define an - `inner_sym`. An `inner_sym` operation resides - in arbitrarily-nested regions of a region that defines a - `InnerSymbolTable`. - Inner Symbols are different from normal symbols due to - MLIR symbol table resolution rules. Specifically normal - symbols are resolved by first going up to the closest - parent symbol table and resolving from there (recursing - down for complex symbol paths). In HW and SV, modules - define a symbol in a circuit or std.module symbol table. - For instances to be able to resolve the modules they - instantiate, the symbol use in an instance must resolve - in the top-level symbol table. If a module were a - symbol table, instances resolving a symbol would start from - their own module, never seeing other modules (since - resolution would start in the parent module of the - instance and be unable to go to the global scope). - The second problem arises from nesting. Symbols defining - ops must be immediate children of a symbol table. HW - and SV operations which define a inner_sym are grandchildren, - at least, of a symbol table and may be much further nested. - Lastly, ports need to define inner_sym, something not allowed - by normal symbols. - - Any operation implementing an InnerSymbol may have the inner symbol be - optional and all methods should be robuse to the attribute not being - defined. - }]; - - let cppNamespace = "::circt::hw"; - let methods = [ - InterfaceMethod<"Returns the name of the top-level inner symbol defined by this operation, if present.", - "::mlir::StringAttr", "getInnerNameAttr", (ins), [{}], - /*defaultImplementation=*/[{ - if (auto attr = - this->getOperation()->template getAttrOfType( - circt::hw::InnerSymbolTable::getInnerSymbolAttrName())) - return attr.getSymName(); - return {}; - }] - >, - InterfaceMethod<"Returns the name of the top-level inner symbol defined by this operation, if present.", - "::std::optional<::mlir::StringRef>", "getInnerName", (ins), [{}], - /*defaultImplementation=*/[{ - auto attr = this->getInnerNameAttr(); - return attr ? ::std::optional(attr.getValue()) : ::std::nullopt; - }] - >, - InterfaceMethod<"Sets the name of the top-level inner symbol defined by this operation to the specified string, dropping any symbols on fields.", - "void", "setInnerSymbol", (ins "::mlir::StringAttr":$name), [{}], - /*defaultImplementation=*/[{ - this->getOperation()->setAttr( - InnerSymbolTable::getInnerSymbolAttrName(), hw::InnerSymAttr::get(name)); - }] - >, - InterfaceMethod<"Sets the inner symbols defined by this operation.", - "void", "setInnerSymbolAttr", (ins "::circt::hw::InnerSymAttr":$sym), [{}], - /*defaultImplementation=*/[{ - if (sym && !sym.empty()) - this->getOperation()->setAttr( - InnerSymbolTable::getInnerSymbolAttrName(), sym); - else - this->getOperation()->removeAttr(InnerSymbolTable::getInnerSymbolAttrName()); - }] - >, - InterfaceMethod<"Returns an InnerRef to this operation's top-level inner symbol, which must be present.", - "::circt::hw::InnerRefAttr", "getInnerRef", (ins), [{}], - /*defaultImplementation=*/[{ - auto *op = this->getOperation(); - return hw::InnerRefAttr::get( - SymbolTable::getSymbolName( - op->template getParentWithTrait()), - InnerSymbolTable::getInnerSymbol(op)); - }] - >, - InterfaceMethod<"Returns the InnerSymAttr representing all inner symbols defined by this operation.", - "::circt::hw::InnerSymAttr", "getInnerSymAttr", (ins), [{}], - /*defaultImplementation=*/[{ - return this->getOperation()->template getAttrOfType( - circt::hw::InnerSymbolTable::getInnerSymbolAttrName()); - }] - >, - // Ask an operation if per-field symbols are allowed. - // Defaults to indicating they're allowed iff there's a defined target result, - // but let operations answer this differently if for some reason that makes sense. - StaticInterfaceMethod<"Returns whether per-field symbols are supported for this operation type.", - "bool", "supportsPerFieldSymbols", (ins), [{}], /*defaultImplementation=*/[{ - return ConcreteOp::getTargetResultIndex().has_value(); - }]>, - StaticInterfaceMethod<"Returns the index of the result the innner symbol targets, if applicable. Per-field symbols are resolved into this.", - "std::optional", "getTargetResultIndex">, - InterfaceMethod<"Returns the result the innner symbol targets, if applicable. Per-field symbols are resolved into this.", - "OpResult", "getTargetResult", (ins), [{}], /*defaultImplementation=*/[{ - auto idx = ConcreteOp::getTargetResultIndex(); - if (!idx) - return {}; - return $_op->getResult(*idx); - }]>, - ]; - - let verify = [{ - return verifyInnerSymAttr(cast(op)); - }]; -} - -def InnerSymbolTable : NativeOpTrait<"InnerSymbolTable">; - -def InnerRefUserOpInterface : OpInterface<"InnerRefUserOpInterface"> { - let description = [{ - This interface describes an operation that may use a `InnerRef`. This - interface allows for users of inner symbols to hook into verification and - other inner symbol related utilities that are either costly or otherwise - disallowed within a traditional operation. - }]; - let cppNamespace = "::circt::hw"; - - let methods = [ - InterfaceMethod<"Verify the inner ref uses held by this operation.", - "::mlir::LogicalResult", "verifyInnerRefs", - (ins "::circt::hw::InnerRefNamespace&":$ns) - >, - ]; -} - -#endif diff --git a/include/circt/Dialect/HW/HWOps.h b/include/circt/Dialect/HW/HWOps.h deleted file mode 100644 index f3d3033882..0000000000 --- a/include/circt/Dialect/HW/HWOps.h +++ /dev/null @@ -1,140 +0,0 @@ -//===- HWOps.h - Declare HW dialect operations ------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the operation classes for the HW dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_OPS_H -#define CIRCT_DIALECT_HW_OPS_H - -#include "include/circt/Dialect/HW/HWDialect.h" -#include "include/circt/Dialect/HW/HWOpInterfaces.h" -#include "include/circt/Dialect/HW/HWTypes.h" -#include "include/circt/Support/BuilderUtils.h" -#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project -#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project -#include "mlir/include/mlir/IR/RegionKindInterface.h" // from @llvm-project -#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project - -namespace circt { -namespace hw { - -class EnumFieldAttr; - -/// Flip a port direction. -ModulePort::Direction flip(ModulePort::Direction direction); - -/// TODO: Move all these functions to a hw::ModuleLike interface. - -/// Insert and remove ports of a module. The insertion and removal indices must -/// be in ascending order. The indices refer to the port positions before any -/// insertion or removal occurs. Ports inserted at the same index will appear in -/// the module in the same order as they were listed in the `insert*` array. -/// If 'body' is provided, additionally inserts/removes the corresponding -/// block arguments. -void modifyModulePorts(Operation *op, - ArrayRef> insertInputs, - ArrayRef> insertOutputs, - ArrayRef removeInputs, - ArrayRef removeOutputs, Block *body = nullptr); - -// Helpers for working with modules. - -/// Return true if isAnyModule or instance. -bool isAnyModuleOrInstance(Operation *module); - -/// Return the signature for the specified module as a function type. -FunctionType getModuleType(Operation *module); - -/// Returns the verilog module name attribute or symbol name of any module-like -/// operations. -StringAttr getVerilogModuleNameAttr(Operation *module); -inline StringRef getVerilogModuleName(Operation *module) { - return getVerilogModuleNameAttr(module).getValue(); -} - -// Index width should be exactly clog2 (size of array), or either 0 or 1 if the -// array is a singleton. -bool isValidIndexBitWidth(Value index, Value array); - -/// Return true if the specified operation is a combinational logic op. -bool isCombinational(Operation *op); - -/// Return true if the specified attribute tree is made up of nodes that are -/// valid in a parameter expression. -bool isValidParameterExpression(Attribute attr, Operation *module); - -/// Check parameter specified by `value` to see if it is valid within the scope -/// of the specified module `module`. If not, emit an error at the location of -/// `usingOp` and return failure, otherwise return success. -/// -/// If `disallowParamRefs` is true, then parameter references are not allowed. -LogicalResult checkParameterInContext(Attribute value, Operation *module, - Operation *usingOp, - bool disallowParamRefs = false); - -/// Check parameter specified by `value` to see if it is valid according to the -/// module's parameters. If not, emit an error to the diagnostic provided as an -/// argument to the lambda 'instanceError' and return failure, otherwise return -/// success. -/// -/// If `disallowParamRefs` is true, then parameter references are not allowed. -LogicalResult checkParameterInContext( - Attribute value, ArrayAttr moduleParameters, - const std::function)> - &instanceError, - bool disallowParamRefs = false); - -// Check whether an integer value is an offset from a base. -bool isOffset(Value base, Value index, uint64_t offset); - -// A class for providing access to the in- and output ports of a module through -// use of the HWModuleBuilder. -class HWModulePortAccessor { - public: - HWModulePortAccessor(Location loc, const ModulePortInfo &info, - Region &bodyRegion); - - // Returns the i'th/named input port of the module. - Value getInput(unsigned i); - Value getInput(StringRef name); - ValueRange getInputs() { return inputArgs; } - - // Assigns the i'th/named output port of the module. - void setOutput(unsigned i, Value v); - void setOutput(StringRef name, Value v); - - const ModulePortInfo &getPortList() const { return info; } - const llvm::SmallVector &getOutputOperands() const { - return outputOperands; - } - - private: - llvm::StringMap inputIdx, outputIdx; - llvm::SmallVector inputArgs; - llvm::SmallVector outputOperands; - ModulePortInfo info; -}; - -using HWModuleBuilder = - llvm::function_ref; - -} // namespace hw -} // namespace circt - -#define GET_OP_CLASSES -#include "include/circt/Dialect/HW/HW.h.inc" - -#endif // CIRCT_DIALECT_HW_OPS_H diff --git a/include/circt/Dialect/HW/HWPasses.h b/include/circt/Dialect/HW/HWPasses.h deleted file mode 100644 index 42d583b44c..0000000000 --- a/include/circt/Dialect/HW/HWPasses.h +++ /dev/null @@ -1,38 +0,0 @@ -//===- Passes.h - HW pass entry points --------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header file defines prototypes that expose pass constructors. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWPASSES_H -#define CIRCT_DIALECT_HW_HWPASSES_H - -#include -#include - -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/include/mlir/Pass/PassRegistry.h" // from @llvm-project - -namespace circt { -namespace hw { - -std::unique_ptr createPrintInstanceGraphPass(); -std::unique_ptr createHWSpecializePass(); -std::unique_ptr createPrintHWModuleGraphPass(); -std::unique_ptr createFlattenIOPass(); -std::unique_ptr createVerifyInnerRefNamespacePass(); - -/// Generate the code for registering passes. -#define GEN_PASS_REGISTRATION -#include "include/circt/Dialect/HW/Passes.h.inc" - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_HWPASSES_H diff --git a/include/circt/Dialect/HW/HWReductions.h b/include/circt/Dialect/HW/HWReductions.h deleted file mode 100644 index 0108eead24..0000000000 --- a/include/circt/Dialect/HW/HWReductions.h +++ /dev/null @@ -1,29 +0,0 @@ -//===- HWReductions.h - HW reduction interface declaration ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWREDUCTIONS_H -#define CIRCT_DIALECT_HW_HWREDUCTIONS_H - -#include "include/circt/Reduce/Reduction.h" - -namespace circt { -namespace hw { - -/// A dialect interface to provide reduction patterns to a reducer tool. -struct HWReducePatternDialectInterface : public ReducePatternDialectInterface { - using ReducePatternDialectInterface::ReducePatternDialectInterface; - void populateReducePatterns(circt::ReducePatternSet &patterns) const override; -}; - -/// Register the HW Reduction pattern dialect interface to the given registry. -void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry); - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_HWREDUCTIONS_H diff --git a/include/circt/Dialect/HW/HWStructure.td b/include/circt/Dialect/HW/HWStructure.td deleted file mode 100644 index 678f483f0e..0000000000 --- a/include/circt/Dialect/HW/HWStructure.td +++ /dev/null @@ -1,720 +0,0 @@ -//===- HWStructure.td - HW structure ops -------------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This describes the MLIR ops for structure. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWSTRUCTURE_TD -#define CIRCT_DIALECT_HW_HWSTRUCTURE_TD - -include "include/circt/Dialect/HW/HWAttributes.td" -include "include/circt/Dialect/HW/HWDialect.td" -include "include/circt/Dialect/HW/HWOpInterfaces.td" -include "include/circt/Dialect/HW/HWTypes.td" -include "mlir/Interfaces/FunctionInterfaces.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/OpBase.td" -include "mlir/IR/RegionKindInterface.td" -include "mlir/IR/SymbolInterfaces.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/SideEffectInterfaces.td" - -/// Base class factoring out some of the additional class declarations common to -/// the module-like operations. -class HWModuleOpBase traits = []> : - HWOp, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - Symbol, InnerSymbolTable, - OpAsmOpInterface, HasParent<"mlir::ModuleOp">]> { - /// Additional class declarations inside the module op. - code extraModuleClassDeclaration = ?; - - let extraClassDeclaration = extraModuleClassDeclaration # [{ - /// Insert and remove input and output ports of this module. Does not modify - /// the block arguments of the module body. The insertion and removal - /// indices must be in ascending order. The indices refer to the port - /// positions before any insertion or removal occurs. Ports inserted at the - /// same index will appear in the module in the same order as they were - /// listed in the insertion arrays. - void modifyPorts( - ArrayRef> insertInputs, - ArrayRef> insertOutputs, - ArrayRef eraseInputs, - ArrayRef eraseOutputs - ); - - void setPortSymbolAttr(size_t portIndex, ::circt::hw::InnerSymAttr sym); - - StringAttr getArgName(size_t index) { - return getHWModuleType().getInputNameAttr(index); - } - - }]; - - /// Additional class definitions inside the module op. - code extraModuleClassDefinition = [{}]; - - let extraClassDefinition = extraModuleClassDefinition # [{ - - ModuleType $cppClass::getHWModuleType() { - return getModuleType(); - } - - ::circt::hw::InnerSymAttr $cppClass::getPortSymbolAttr(size_t portIndex) { - DictionaryAttr portAttrs = cast_or_null(getPortAttrs(portIndex)); - if (portAttrs) - return portAttrs.getAs<::circt::hw::InnerSymAttr>( - getPortSymbolAttrName()); - return {}; - } - - void $cppClass::setPortSymbolAttr(size_t portIndex, ::circt::hw::InnerSymAttr sym) { - auto portSymAttr = StringAttr::get(getContext(), getPortSymbolAttrName()); - setPortAttr(portIndex, portSymAttr, sym); - } - - size_t $cppClass::getNumPorts() { - auto modty = getHWModuleType(); - return modty.getNumPorts(); - } - - size_t $cppClass::getNumInputPorts() { - auto modty = getHWModuleType(); - return modty.getNumInputs(); - } - - size_t $cppClass::getNumOutputPorts() { - auto modty = getHWModuleType(); - return modty.getNumOutputs(); - } - - size_t $cppClass::getPortIdForInputId(size_t idx) { - auto modty = getHWModuleType(); - return modty.getPortIdForInputId(idx); - } - - size_t $cppClass::getPortIdForOutputId(size_t idx) { - auto modty = getHWModuleType(); - return modty.getPortIdForOutputId(idx); - } - - ModulePortInfo $cppClass::getPortList() { - auto modTy = getHWModuleType(); - SmallVector inputs, outputs; - auto emptyDict = DictionaryAttr::get(getContext()); - auto argTypes = modTy.getInputTypes(); - auto argNames = modTy.getInputNames(); - auto argLocs = getArgLocs(); - auto argAttrs = getArgAttrs(); - for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { - auto type = argTypes[i]; - auto direction = ModulePort::Direction::Input; - if (auto inout = type.dyn_cast()) { - type = inout.getElementType(); - direction = ModulePort::Direction::InOut; - } - LocationAttr loc; - if (argLocs) - loc = argLocs[i].cast(); - DictionaryAttr attrs = emptyDict; - if (argAttrs && (*argAttrs)[i]) - attrs = cast((*argAttrs)[i]); - inputs.push_back({{cast(argNames[i]), type, direction}, - i, - extractSym(attrs), - attrs, - loc}); - } - - auto resultNames = modTy.getOutputNames(); - auto resultTypes = modTy.getOutputTypes(); - auto resultLocs = getResultLocs(); - auto resultAttrs = getResAttrs(); - for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) { - LocationAttr loc; - if (resultLocs) - loc = resultLocs[i].cast(); - DictionaryAttr attrs = emptyDict; - if (resultAttrs && (*resultAttrs)[i]) - attrs = cast((*resultAttrs)[i]); - outputs.push_back({{cast(resultNames[i]), resultTypes[i], - ModulePort::Direction::Output}, - i, - extractSym(attrs), - attrs, - loc}); - } - return ModulePortInfo(inputs, outputs); - } - - }]; - -} - -def HWTestModuleOp : HWOp<"testmodule", [Symbol, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - OpAsmOpInterface, - HasParent<"mlir::ModuleOp">, - IsolatedFromAbove, - SingleBlockImplicitTerminator<"OutputOp">]> { - let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$module_type, - OptionalAttr:$port_attrs, - OptionalAttr:$port_locs, - ParamDeclArrayAttr:$parameters, - OptionalAttr:$comment); - let results = (outs); - let regions = (region SizedRegion<1>:$body); - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - void getAsmBlockArgumentNames(mlir::Region ®ion, - mlir::OpAsmSetValueNameFn setNameFn); - - Block *getBodyBlock() { return &getRegion().front(); } - }]; - -} - -def HWModuleOp : HWModuleOpBase<"module", - [IsolatedFromAbove, RegionKindInterface, - SingleBlockImplicitTerminator<"OutputOp">]>{ - let summary = "HW Module"; - let description = [{ - The "hw.module" operation represents a Verilog module, including a given - name, a list of ports, a list of parameters, and a body that represents the - connections within the module. - }]; - let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$module_type, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs, - LocationArrayAttr:$argLocs, - LocationArrayAttr:$resultLocs, - ParamDeclArrayAttr:$parameters, - StrAttr:$comment); - let results = (outs); - let regions = (region SizedRegion<1>:$body); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "StringAttr":$name, "ArrayRef":$ports, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"ArrayRef", "{}">:$attributes, - CArg<"StringAttr", "{}">:$comment)>, - OpBuilder<(ins "StringAttr":$name, "const ModulePortInfo &":$ports, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"ArrayRef", "{}">:$attributes, - CArg<"StringAttr", "{}">:$comment, - CArg<"bool", "true">:$shouldEnsureTerminator)>, - OpBuilder<(ins "StringAttr":$name, "const ModulePortInfo &":$ports, - "HWModuleBuilder":$modBuilder, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"ArrayRef", "{}">:$attributes, - CArg<"StringAttr", "{}">:$comment)> - ]; - - let extraModuleClassDeclaration = [{ - - // Implement RegionKindInterface. - static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph;} - - /// Append an input with a given name and type to the port list. - /// If the name is not unique, a unique name is created and returned. - std::pair - appendInput(const Twine &name, Type ty) { - return insertInput(getNumInputPorts(), name, ty); - } - - std::pair - appendInput(StringAttr name, Type ty) { - return insertInput(getNumInputPorts(), name.getValue(), ty); - } - - /// Prepend an input with a given name and type to the port list. - /// If the name is not unique, a unique name is created and returned. - std::pair - prependInput(const Twine &name, Type ty) { - return insertInput(0, name, ty); - } - - std::pair - prependInput(StringAttr name, Type ty) { - return insertInput(0, name.getValue(), ty); - } - - /// Insert an input with a given name and type into the port list. - /// The input is added at the specified index. - std::pair - insertInput(unsigned index, StringAttr name, Type ty); - - std::pair - insertInput(unsigned index, const Twine &name, Type ty) { - ::mlir::StringAttr nameAttr = ::mlir::StringAttr::get(getContext(), name); - return insertInput(index, nameAttr, ty); - } - - /// Append an output with a given name and type to the port list. - /// If the name is not unique, a unique name is created. - void appendOutput(StringAttr name, Value value) { - return insertOutputs(getNumOutputPorts(), {{name, value}}); - } - - void appendOutput(const Twine &name, Value value) { - ::mlir::StringAttr nameAttr = ::mlir::StringAttr::get(getContext(), name); - return insertOutputs(getNumOutputPorts(), {{nameAttr, value}}); - } - - /// Prepend an output with a given name and type to the port list. - /// If the name is not unique, a unique name is created. - void prependOutput(StringAttr name, Value value) { - return insertOutputs(0, {{name, value}}); - } - - void prependOutput(const Twine &name, Value value) { - ::mlir::StringAttr nameAttr = ::mlir::StringAttr::get(getContext(), name); - return insertOutputs(0, {{nameAttr, value}}); - } - - /// Inserts a list of output ports into the port list at a specific - /// location, shifting all subsequent ports. Rewrites the output op - /// to return the associated values. - void insertOutputs(unsigned index, - ArrayRef> outputs); - - // Get the module's symbolic name as StringAttr. - StringAttr getNameAttr() { - return (*this)->getAttrOfType( - ::mlir::SymbolTable::getSymbolAttrName()); - } - - // Get the module's symbolic name. - StringRef getName() { - return getNameAttr().getValue(); - } - void getAsmBlockArgumentNames(mlir::Region ®ion, - mlir::OpAsmSetValueNameFn setNameFn); - - /// Verifies the body of the function. - LogicalResult verifyBody(); - }]; - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; -} - -def HWModuleExternOp : HWModuleOpBase<"module.extern"> { - let summary = "HW external Module"; - let description = [{ - The "hw.module.extern" operation represents an external reference to a - Verilog module, including a given name and a list of ports. - - The 'verilogName' attribute (when present) specifies the spelling of the - module name in Verilog we can use. TODO: This is a hack because we don't - have proper parameterization in the hw.dialect. We need a way to represent - parameterized types instead of just concrete types. - }]; - let arguments = (ins SymbolNameAttr:$sym_name, - TypeAttrOf:$module_type, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs, - LocationArrayAttr:$argLocs, - LocationArrayAttr:$resultLocs, - ParamDeclArrayAttr:$parameters, - OptionalAttr:$verilogName); - let results = (outs); - let regions = (region AnyRegion:$body); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "StringAttr":$name, "ArrayRef":$ports, - CArg<"StringRef", "StringRef()">:$verilogName, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"ArrayRef", "{}">:$attributes)>, - OpBuilder<(ins "StringAttr":$name, "const ModulePortInfo &":$ports, - CArg<"StringRef", "StringRef()">:$verilogName, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"ArrayRef", "{}">:$attributes)> - ]; - - let extraModuleClassDeclaration = [{ - - /// Return the name to use for the Verilog module that we're referencing - /// here. This is typically the symbol, but can be overridden with the - /// verilogName attribute. - StringRef getVerilogModuleName() { - return getVerilogModuleNameAttr().getValue(); - } - - /// Return the name to use for the Verilog module that we're referencing - /// here. This is typically the symbol, but can be overridden with the - /// verilogName attribute. - StringAttr getVerilogModuleNameAttr(); - - // Get the module's symbolic name as StringAttr. - StringAttr getNameAttr() { - return (*this)->getAttrOfType( - ::mlir::SymbolTable::getSymbolAttrName()); - } - - // Get the module's symbolic name. - StringRef getName() { - return getNameAttr().getValue(); - } - - void getAsmBlockArgumentNames(mlir::Region ®ion, - mlir::OpAsmSetValueNameFn setNameFn); - - }]; - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; -} - -def HWGeneratorSchemaOp : HWOp<"generator.schema", - [Symbol, HasParent<"mlir::ModuleOp">]> { - let summary = "HW Generator Schema declaration"; - let description = [{ - The "hw.generator.schema" operation declares a kind of generated module by - declaring the schema of meta-data required. - A generated module instance of a schema is independent of the external - method of producing it. It is assumed that for well known schema instances, - multiple external tools might exist which can process it. Generator nodes - list attributes required by hw.module.generated instances. - - For example: - generator.schema @MEMORY, "Simple-Memory", ["ports", "write_latency", "read_latency"] - module.generated @mymem, @MEMORY(ports) - -> (ports) {write_latency=1, read_latency=1, ports=["read","write"]} - }]; - - let arguments = (ins SymbolNameAttr:$sym_name, StrAttr:$descriptor, - StrArrayAttr:$requiredAttrs); - let results = (outs); - let assemblyFormat = "$sym_name `,` $descriptor `,` $requiredAttrs attr-dict"; -} - -def HWModuleGeneratedOp : HWModuleOpBase<"module.generated", [ - DeclareOpInterfaceMethods, - IsolatedFromAbove]> { - let summary = "HW Generated Module"; - let description = [{ - The "hw.module.generated" operation represents a reference to an external - module that will be produced by some external process. - This represents the name and list of ports to be generated. - - The 'verilogName' attribute (when present) specifies the spelling of the - module name in Verilog we can use. See hw.module for an explanation. - }]; - let arguments = (ins SymbolNameAttr:$sym_name, - FlatSymbolRefAttr:$generatorKind, - TypeAttrOf:$module_type, - OptionalAttr:$arg_attrs, - OptionalAttr:$res_attrs, - LocationArrayAttr:$argLocs, - LocationArrayAttr:$resultLocs, - ParamDeclArrayAttr:$parameters, - OptionalAttr:$verilogName); - let results = (outs); - let regions = (region AnyRegion:$body); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "FlatSymbolRefAttr":$genKind, "StringAttr":$name, - "ArrayRef":$ports, - CArg<"StringRef", "StringRef()">:$verilogName, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"ArrayRef", "{}">:$attributes)>, - OpBuilder<(ins "FlatSymbolRefAttr":$genKind, "StringAttr":$name, - "const ModulePortInfo &":$ports, - CArg<"StringRef", "StringRef()">:$verilogName, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"ArrayRef", "{}">:$attributes)> - ]; - - let extraModuleClassDeclaration = [{ - /// Return the name to use for the Verilog module that we're referencing - /// here. This is typically the symbol, but can be overridden with the - /// verilogName attribute. - StringRef getVerilogModuleName() { - return getVerilogModuleNameAttr().getValue(); - } - - /// Return the name to use for the Verilog module that we're referencing - /// here. This is typically the symbol, but can be overridden with the - /// verilogName attribute. - StringAttr getVerilogModuleNameAttr(); - - /// Lookup the generator kind for the symbol. This returns null on - /// invalid IR. - Operation *getGeneratorKindOp(); - - void getAsmBlockArgumentNames(mlir::Region ®ion, - mlir::OpAsmSetValueNameFn setNameFn); - - }]; - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; -} - -def InstanceOp : HWOp<"instance", [ - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { - let summary = "Create an instance of a module"; - let description = [{ - This represents an instance of a module. The inputs and results are - the referenced module's inputs and outputs. The `argNames` and - `resultNames` attributes must match the referenced module. - - Any parameters in the "old" format (slated to be removed) are stored in the - `oldParameters` dictionary. - }]; - - let arguments = (ins StrAttr:$instanceName, - FlatSymbolRefAttr:$moduleName, - Variadic:$inputs, - StrArrayAttr:$argNames, StrArrayAttr:$resultNames, - ParamDeclArrayAttr:$parameters, - OptionalAttr:$inner_sym); - let results = (outs Variadic:$results); - - let builders = [ - /// Create a instance that refers to a known module. - OpBuilder<(ins "Operation*":$module, "StringAttr":$name, - "ArrayRef":$inputs, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"InnerSymAttr", "{}">:$innerSym)>, - /// Create a instance that refers to a known module. - OpBuilder<(ins "Operation*":$module, "StringRef":$name, - "ArrayRef":$inputs, - CArg<"ArrayAttr", "{}">:$parameters, - CArg<"InnerSymAttr", "{}">:$innerSym), [{ - build($_builder, $_state, module, $_builder.getStringAttr(name), inputs, - parameters, innerSym); - }]>, - ]; - - let extraClassDeclaration = [{ - /// Return the name of the specified input port or null if it cannot be - /// determined. - StringAttr getArgumentName(size_t i); - - /// Return the name of the specified result or null if it cannot be - /// determined. - StringAttr getResultName(size_t i); - - /// Change the name of the specified input port. - void setArgumentName(size_t i, StringAttr name); - - /// Change the name of the specified output port. - void setResultName(size_t i, StringAttr name); - - /// Change the names of all the input ports. - void setInputNames(ArrayAttr names) { - setArgNamesAttr(names); - } - - /// Change the names of all the result ports. - void setOutputNames(ArrayAttr names) { - setResultNamesAttr(names); - } - - /// Lookup the module or extmodule for the symbol. This returns null on - /// invalid IR. - Operation *getReferencedModule(const HWSymbolCache *cache); - Operation *getReferencedModule(SymbolTable& tbl); - Operation *getReferencedModuleSlow(); - - /// Return the values for the port in port order. - /// Note: The module ports may not be input, output ordered. This computes - /// the port index to instance result/input Value mapping. - void getValues(SmallVectorImpl &values, const ModulePortInfo &mpi); - - //===------------------------------------------------------------------===// - // SymbolOpInterface Methods - //===------------------------------------------------------------------===// - - /// An InstanceOp may optionally define a symbol. - bool isOptionalSymbol() { return true; } - - }]; - - let hasCustomAssemblyFormat = 1; - let hasVerifier = 1; -} - -def OutputOp : HWOp<"output", [Terminator, HasParent<"HWModuleOp, HWTestModuleOp">, - Pure, ReturnLike]> { - let summary = "HW termination operation"; - let description = [{ - "hw.output" marks the end of a region in the HW dialect and the values - to put on the output ports. - }]; - - let arguments = (ins Variadic:$outputs); - - let builders = [ - OpBuilder<(ins), "build($_builder, $_state, std::nullopt);"> - ]; - - let assemblyFormat = "attr-dict ($outputs^ `:` qualified(type($outputs)))?"; - - let hasVerifier = 1; -} - -def GlobalRefOp : HWOp<"globalRef", [ - DeclareOpInterfaceMethods, - IsolatedFromAbove, Symbol]> { - let summary = "A global reference to uniquely identify each" - "instance of an operation"; - let description = [{ - This works like a symbol reference to an operation by specifying the - instance path to uniquely identify it globally. - It can be used to attach per instance metadata (non-local attributes). - This also lets components of the path point to a common entity. - }]; - - let arguments = (ins SymbolNameAttr:$sym_name, NameRefArrayAttr:$namepath); - - let assemblyFormat = [{ $sym_name $namepath attr-dict }]; -} - -def HierPathOp : HWOp<"hierpath", - [IsolatedFromAbove, Symbol, - DeclareOpInterfaceMethods]> { - let summary = "Hierarchical path specification"; - let description = [{ - The "hw.hierpath" operation represents a path through the hierarchy. - This is used to specify namable things for use in other operations, for - example in verbatim substitution. Non-local annotations also use these. - }]; - let arguments = (ins SymbolNameAttr:$sym_name, NameRefArrayAttr:$namepath); - let results = (outs); - let hasCustomAssemblyFormat = 1; - let extraClassDeclaration = [{ - /// Drop the module from the namepath. If its a InnerNameRef, then drop - /// the Module-Instance pair, else drop the final module from the namepath. - /// Return true if any update is made. - bool dropModule(StringAttr moduleToDrop); - - /// Inline the module in the namepath. - /// Update the symbol name for the inlined module instance, by prepending - /// the symbol name of the instance at which the inling was done. - /// Return true if any update is made. - bool inlineModule(StringAttr moduleToDrop); - - /// Replace the oldMod module with newMod module in the namepath of the NLA. - /// Return true if any update is made. - bool updateModule(StringAttr oldMod, StringAttr newMod); - - /// Replace the oldMod module with newMod module in the namepath of the NLA. - /// Since the module is being updated, the symbols inside the module should - /// also be renamed. Use the rename Map to update the corresponding - /// inner_sym names in the namepath. Return true if any update is made. - bool updateModuleAndInnerRef(StringAttr oldMod, StringAttr newMod, - const llvm::DenseMap &innerSymRenameMap); - - /// Truncate the namepath for this NLA, at atMod module. - /// If includeMod is false, drop atMod and beyond, else include it and drop - /// everything after it. - /// Return true if any update is made. - bool truncateAtModule(StringAttr atMod, bool includeMod = true); - - /// Return just the module part of the namepath at a specific index. - StringAttr modPart(unsigned i); - - /// Return the root module. - StringAttr root(); - - /// Return just the reference part of the namepath at a specific index. - /// This will return an empty attribute if this is the leaf and the leaf is - /// a module. - StringAttr refPart(unsigned i); - - /// Return the leaf reference. This returns an empty attribute if the leaf - /// reference is a module. - StringAttr ref(); - - /// Return the leaf Module. - StringAttr leafMod(); - - /// Returns true, if the NLA path contains the module. - bool hasModule(StringAttr modName); - - /// Returns true, if the NLA path contains the InnerSym {modName, symName}. - bool hasInnerSym(StringAttr modName, StringAttr symName) const; - - /// Returns true if this NLA targets a module or instance of a module (as - /// opposed to an instance's port or something inside an instance). - bool isModule(); - - /// Returns true if this NLA targets something inside a module (as opposed - /// to a module or an instance of a module); - bool isComponent(); - }]; -} - -// Edge behavior for trigger blocks. Currently these map 1:1 to SV event -// control kinds. - -/// AtPosEdge triggers on a rise from 0 to 1/X/Z, or X/Z to 1. -def AtPosEdge: I32EnumAttrCase<"AtPosEdge", 0, "posedge">; -/// AtNegEdge triggers on a drop from 1 to 0/X/Z, or X/Z to 0. -def AtNegEdge: I32EnumAttrCase<"AtNegEdge", 1, "negedge">; -/// AtEdge(v) is syntactic sugar for "AtPosEdge(v) or AtNegEdge(v)". -def AtEdge : I32EnumAttrCase<"AtEdge", 2, "edge">; - -def EventControlAttr : I32EnumAttr<"EventControl", "edge control trigger", - [AtPosEdge, AtNegEdge, AtEdge]> { - let cppNamespace = "circt::hw"; -} - -def TriggeredOp : HWOp<"triggered", [ - IsolatedFromAbove, SingleBlock, NoTerminator]> { - let summary = "A procedural region with a trigger condition"; - let description = [{ - A procedural region that can be triggered by an event. The trigger - condition is a 1-bit value that is activated based on some event control - attribute. - The operation is `IsolatedFromAbove`, and thus requires values passed into - the trigger region to be explicitly passed in through the `inputs` list. - }]; - - let regions = (region SizedRegion<1>:$body); - let arguments = (ins EventControlAttr:$event, I1:$trigger, Variadic:$inputs); - let results = (outs); - - let assemblyFormat = [{ - $event $trigger `(` $inputs `)` `:` type($inputs) $body attr-dict - }]; - - let extraClassDeclaration = [{ - Block *getBodyBlock() { return &getBody().front(); } - - // Return the input arguments inside the trigger region. - ArrayRef getInnerInputs() { - return getBodyBlock()->getArguments(); - } - }]; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "EventControlAttr":$event, "Value":$trigger, "ValueRange":$inputs)> - ]; -} - -#endif // CIRCT_DIALECT_HW_HWSTRUCTURE_TD diff --git a/include/circt/Dialect/HW/HWSymCache.h b/include/circt/Dialect/HW/HWSymCache.h deleted file mode 100644 index b152ea96d5..0000000000 --- a/include/circt/Dialect/HW/HWSymCache.h +++ /dev/null @@ -1,118 +0,0 @@ -//===- HWSymCache.h - Declare Symbol Cache ---------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares a Symbol Cache specialized for HW instances. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_SYMCACHE_H -#define CIRCT_DIALECT_HW_SYMCACHE_H - -#include "include/circt/Dialect/HW/HWAttributes.h" -#include "include/circt/Support/SymCache.h" - -namespace circt { -namespace hw { - -/// This stores lookup tables to make manipulating and working with the IR more -/// efficient. There are two phases to this object: the "building" phase in -/// which it is "write only" and then the "using" phase which is read-only (and -/// thus can be used by multiple threads). The "freeze" method transitions -/// between the two states. -class HWSymbolCache : public SymbolCacheBase { - public: - class Item { - public: - Item(mlir::Operation *op) : op(op), port(~0ULL) {} - Item(mlir::Operation *op, size_t port) : op(op), port(port) {} - bool hasPort() const { return port != ~0ULL; } - size_t getPort() const { return port; } - mlir::Operation *getOp() const { return op; } - - private: - mlir::Operation *op; - size_t port; - }; - - // Add inner names, which might be ports - void addDefinition(mlir::StringAttr modSymbol, mlir::StringAttr name, - mlir::Operation *op, size_t port = ~0ULL) { - auto key = InnerRefAttr::get(modSymbol, name); - symbolCache.try_emplace(key, op, port); - } - - void addDefinition(mlir::Attribute key, mlir::Operation *op) override { - assert(!isFrozen && "cannot mutate a frozen cache"); - symbolCache.try_emplace(key, op); - } - - // Pull in getDefinition(mlir::FlatSymbolRefAttr symbol) - using SymbolCacheBase::getDefinition; - mlir::Operation *getDefinition(mlir::Attribute attr) const override { - assert(isFrozen && "cannot read from this cache until it is frozen"); - auto it = symbolCache.find(attr); - if (it == symbolCache.end()) return nullptr; - assert(!it->second.hasPort() && "Module names should never be ports"); - return it->second.getOp(); - } - - HWSymbolCache::Item getInnerDefinition(mlir::StringAttr modSymbol, - mlir::StringAttr name) const { - return lookupInner(InnerRefAttr::get(modSymbol, name)); - } - - HWSymbolCache::Item getInnerDefinition(InnerRefAttr inner) const { - return lookupInner(inner); - } - - /// Mark the cache as frozen, which allows it to be shared across threads. - void freeze() { isFrozen = true; } - - private: - Item lookupInner(InnerRefAttr attr) const { - assert(isFrozen && "cannot read from this cache until it is frozen"); - auto it = symbolCache.find(attr); - return it == symbolCache.end() ? Item{nullptr, ~0ULL} : it->second; - } - - bool isFrozen = false; - - /// This stores a lookup table from symbol attribute to the item - /// that defines it. - llvm::DenseMap symbolCache; - - private: - // Iterator support. Map from Item's to their inner operations. - using Iterator = decltype(symbolCache)::iterator; - struct HwSymbolCacheIteratorImpl : public CacheIteratorImpl { - HwSymbolCacheIteratorImpl(Iterator it) : it(it) {} - CacheItem operator*() override { - return {it->getFirst(), it->getSecond().getOp()}; - } - void operator++() override { it++; } - bool operator==(CacheIteratorImpl *other) override { - return it == static_cast(other)->it; - } - Iterator it; - }; - - public: - SymbolCacheBase::Iterator begin() override { - return SymbolCacheBase::Iterator( - std::make_unique(symbolCache.begin())); - } - SymbolCacheBase::Iterator end() override { - return SymbolCacheBase::Iterator( - std::make_unique(symbolCache.end())); - } -}; - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_SYMCACHE_H diff --git a/include/circt/Dialect/HW/HWTypeDecls.td b/include/circt/Dialect/HW/HWTypeDecls.td deleted file mode 100644 index 8a1db6b36d..0000000000 --- a/include/circt/Dialect/HW/HWTypeDecls.td +++ /dev/null @@ -1,63 +0,0 @@ -//===- HWTypeDecls.td - HW data type declaration ops and types ------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Support for type declarations in the HW type system. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWTYPEDECLS_TD -#define CIRCT_DIALECT_HW_HWTYPEDECLS_TD - -include "include/circt/Dialect/HW/HWDialect.td" -include "mlir/IR/SymbolInterfaces.td" - -//===----------------------------------------------------------------------===// -// Declaration operations -//===----------------------------------------------------------------------===// - -def TypeScopeOp : HWOp<"type_scope", - [Symbol, SymbolTable, SingleBlock, NoTerminator, NoRegionArguments]> { - let summary = "Type declaration wrapper."; - let description = [{ - An operation whose one body block contains type declarations. This op - provides a scope for type declarations at the top level of an MLIR module. - It is a symbol that may be looked up within the module, as well as a symbol - table itself, so type declarations may be looked up. - }]; - - let regions = (region SizedRegion<1>:$body); - let arguments = (ins SymbolNameAttr:$sym_name); - - let assemblyFormat = "$sym_name $body attr-dict"; - - let extraClassDeclaration = [{ - Block *getBodyBlock() { return &getBody().front(); } - }]; -} - -def TypedeclOp : HWOp<"typedecl", [Symbol, HasParent<"TypeScopeOp">]> { - let summary = "Type declaration."; - let description = "Associate a symbolic name with a type."; - - let arguments = (ins - SymbolNameAttr:$sym_name, - TypeAttr:$type, - OptionalAttr:$verilogName - ); - - let assemblyFormat = "$sym_name (`,` $verilogName^)? `:` $type attr-dict"; - - let extraClassDeclaration = [{ - StringRef getPreferredName(); - - // Returns the type alias type which this typedecl op defines. - Type getAliasType(); - }]; -} - -#endif // CIRCT_DIALECT_HW_HWTYPEDECLS_TD diff --git a/include/circt/Dialect/HW/HWTypeInterfaces.h b/include/circt/Dialect/HW/HWTypeInterfaces.h deleted file mode 100644 index 70159338cb..0000000000 --- a/include/circt/Dialect/HW/HWTypeInterfaces.h +++ /dev/null @@ -1,44 +0,0 @@ -//===- HWTypeInterfaces.h - Declare HW type interfaces ----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares type interfaces for the HW Dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWTYPEINTERFACES_H -#define CIRCT_DIALECT_HW_HWTYPEINTERFACES_H - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project - -namespace circt { -namespace hw { -namespace FieldIdImpl { -uint64_t getMaxFieldID(Type); - -std::pair<::mlir::Type, uint64_t> getSubTypeByFieldID(Type, uint64_t fieldID); - -::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID); - -std::pair projectToChildFieldID(Type, uint64_t fieldID, - uint64_t index); - -std::pair getIndexAndSubfieldID(Type type, - uint64_t fieldID); - -uint64_t getFieldID(Type type, uint64_t index); - -uint64_t getIndexForFieldID(Type type, uint64_t fieldID); - -} // namespace FieldIdImpl -} // namespace hw -} // namespace circt - -#include "include/circt/Dialect/HW/HWTypeInterfaces.h.inc" - -#endif // CIRCT_DIALECT_HW_HWTYPEINTERFACES_H diff --git a/include/circt/Dialect/HW/HWTypeInterfaces.td b/include/circt/Dialect/HW/HWTypeInterfaces.td deleted file mode 100644 index b71b6e5ed0..0000000000 --- a/include/circt/Dialect/HW/HWTypeInterfaces.td +++ /dev/null @@ -1,81 +0,0 @@ -//===- HWTypeInterfaces.td - HW Type Interfaces ------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This describes the type interfaces of the HW dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWTYPEINTERFACES_TD -#define CIRCT_DIALECT_HW_HWTYPEINTERFACES_TD - -include "mlir/IR/OpBase.td" - -def FieldIDTypeInterface : TypeInterface<"FieldIDTypeInterface"> { - let cppNamespace = "circt::hw"; - let description = [{ - Common methods for types which can be indexed by a FieldID. - FieldID is a depth-first numbering of the elements of a type. For example: - ``` - struct a /* 0 */ { - int b; /* 1 */ - struct c /* 2 */ { - int d; /* 3 */ - } - } - - int e; /* 0 */ - ``` - }]; - - let methods = [ - InterfaceMethod<"Get the maximum field ID for this type", - "uint64_t", "getMaxFieldID">, - - InterfaceMethod<[{ - Get the sub-type of a type for a field ID, and the subfield's ID. Strip - off a single layer of this type and return the sub-type and a field ID - targeting the same field, but rebased on the sub-type. - - The resultant type *may* not be a FieldIDTypeInterface if the resulting - fieldID is zero. This means that leaf types may be ground without - implementing an interface. An empty aggregate will also appear as a - zero. - }], "std::pair<::mlir::Type, uint64_t>", - "getSubTypeByFieldID", (ins "uint64_t":$fieldID)>, - - InterfaceMethod<[{ - Returns the effective field id when treating the index field as the - root of the type. Essentially maps a fieldID to a fieldID after a - subfield op. Returns the new id and whether the id is in the given - child. - }], "std::pair", "projectToChildFieldID", - (ins "uint64_t":$fieldID, "uint64_t":$index)>, - - InterfaceMethod<[{ - Returns the index (e.g. struct or vector element) for a given FieldID. - This returns the containing index in the case that the fieldID points to a - child field of a field. - }], "uint64_t", "getIndexForFieldID", (ins "uint64_t":$fieldID)>, - - InterfaceMethod<[{ - Return the fieldID of a given index (e.g. struct or vector element). - Field IDs start at 1, and are assigned - to each field in a recursive depth-first walk of all - elements. A field ID of 0 is used to reference the type itself. - }], "uint64_t", "getFieldID", (ins "uint64_t":$fieldID)>, - - InterfaceMethod<[{ - Find the index of the element that contains the given fieldID. - As well, rebase the fieldID to the element. - }], "std::pair", "getIndexAndSubfieldID", - (ins "uint64_t":$fieldID)>, - - ]; -} - -#endif // CIRCT_DIALECT_HW_HWTYPEINTERFACES_TD diff --git a/include/circt/Dialect/HW/HWTypes.h b/include/circt/Dialect/HW/HWTypes.h deleted file mode 100644 index c137bf13fc..0000000000 --- a/include/circt/Dialect/HW/HWTypes.h +++ /dev/null @@ -1,165 +0,0 @@ -//===- HWTypes.h - Types for the HW dialect ---------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Types for the HW dialect are mostly in tablegen. This file should contain -// C++ types used in MLIR type parameters. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_TYPES_H -#define CIRCT_DIALECT_HW_TYPES_H - -#include "include/circt/Dialect/HW/HWDialect.h" -#include "include/circt/Dialect/HW/HWTypeInterfaces.h" -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project - -namespace circt { -namespace hw { - -struct ModulePort { - enum Direction { Input, Output, InOut }; - mlir::StringAttr name; - mlir::Type type; - Direction dir; -}; - -class HWSymbolCache; -class ParamDeclAttr; -class TypedeclOp; -class ModuleType; - -namespace detail { - -ModuleType fnToMod(Operation *op, ArrayRef inputNames, - ArrayRef outputNames); -ModuleType fnToMod(FunctionType fn, ArrayRef inputNames, - ArrayRef outputNames); - -/// Struct defining a field. Used in structs. -struct FieldInfo { - mlir::StringAttr name; - mlir::Type type; -}; - -/// Struct defining a field with an offset. Used in unions. -struct OffsetFieldInfo { - StringAttr name; - Type type; - size_t offset; -}; -} // namespace detail -} // namespace hw -} // namespace circt - -#define GET_TYPEDEF_CLASSES -#include "include/circt/Dialect/HW/HWTypes.h.inc" - -namespace circt { -namespace hw { - -// Returns the canonical type of a HW type (inner type of a type alias). -mlir::Type getCanonicalType(mlir::Type type); - -/// Return true if the specified type is a value HW Integer type. This checks -/// that it is a signless standard dialect type. -bool isHWIntegerType(mlir::Type type); - -/// Return true if the specified type is a HW Enum type. -bool isHWEnumType(mlir::Type type); - -/// Return true if the specified type can be used as an HW value type, that is -/// the set of types that can be composed together to represent synthesized, -/// hardware but not marker types like InOutType or unknown types from other -/// dialects. -bool isHWValueType(mlir::Type type); - -/// Return the hardware bit width of a type. Does not reflect any encoding, -/// padding, or storage scheme, just the bit (and wire width) of a -/// statically-size type. Reflects the number of wires needed to transmit a -/// value of this type. Returns -1 if the type is not known or cannot be -/// statically computed. -int64_t getBitWidth(mlir::Type type); - -/// Return true if the specified type contains known marker types like -/// InOutType. Unlike isHWValueType, this is not conservative, it only returns -/// false on known InOut types, rather than any unknown types. -bool hasHWInOutType(mlir::Type type); - -template -bool type_isa(Type type) { - // First check if the type is the requested type. - if (type.isa()) return true; - - // Then check if it is a type alias wrapping the requested type. - if (auto alias = type.dyn_cast()) - return type_isa(alias.getInnerType()); - - return false; -} - -// type_isa for a nullable argument. -template -bool type_isa_and_nonnull(Type type) { // NOLINT(readability-identifier-naming) - if (!type) return false; - return type_isa(type); -} - -template -BaseTy type_cast(Type type) { - assert(type_isa(type) && "type must convert to requested type"); - - // If the type is the requested type, return it. - if (type.isa()) return type.cast(); - - // Otherwise, it must be a type alias wrapping the requested type. - return type_cast(type.cast().getInnerType()); -} - -template -BaseTy type_dyn_cast(Type type) { - if (!type_isa(type)) return BaseTy(); - - return type_cast(type); -} - -/// Utility type that wraps a type that may be one of several possible Types. -/// This is similar to std::variant but is implemented for mlir::Type, and it -/// understands how to handle type aliases. -template -class TypeVariant - : public ::mlir::Type::TypeBase, mlir::Type, - mlir::TypeStorage> { - using mlir::Type::TypeBase, mlir::Type, - mlir::TypeStorage>::Base::Base; - - public: - // Support LLVM isa/cast/dyn_cast to one of the possible types. - static bool classof(Type other) { return type_isa(other); } -}; - -template -class TypeAliasOr - : public ::mlir::Type::TypeBase, mlir::Type, - mlir::TypeStorage> { - using mlir::Type::TypeBase, mlir::Type, - mlir::TypeStorage>::Base::Base; - - public: - // Support LLVM isa/cast/dyn_cast to BaseTy. - static bool classof(Type other) { return type_isa(other); } - - // Support C++ implicit conversions to BaseTy. - operator BaseTy() const { return type_cast(*this); } -}; - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_TYPES_H diff --git a/include/circt/Dialect/HW/HWTypes.td b/include/circt/Dialect/HW/HWTypes.td deleted file mode 100644 index ef11ea15b8..0000000000 --- a/include/circt/Dialect/HW/HWTypes.td +++ /dev/null @@ -1,179 +0,0 @@ -//===- HWTypes.td - HW data type definitions ---------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Basic data types for the HW dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWTYPES -#define CIRCT_DIALECT_HW_HWTYPES - -include "include/circt/Dialect/HW/HWDialect.td" -include "mlir/IR/AttrTypeBase.td" - -//===----------------------------------------------------------------------===// -// Type predicates -//===----------------------------------------------------------------------===// - -// Type constraint that indicates that an operand/result may only be a valid, -// known, non-directional type. -def HWIntegerType : DialectType, - "a signless integer bitvector", - "::circt::hw::TypeVariant<::mlir::IntegerType, ::circt::hw::IntType>">; - -// Type constraint that indicates that an operand/result may only be a valid, -// known, non-directional type. -def HWValueType : DialectType, "a known primitive element">; - -// Type constraint that indicates that an operand/result may only be a valid -// non-directional type. -def HWNonInOutType : DialectType, "a type without inout">; - -def InOutType : DialectType($_self)">, - "InOutType", "InOutType">; - -// A handle to refer to circt::hw::ArrayType in ODS. -def ArrayType : DialectType($_self)">, - "an ArrayType", "::circt::hw::TypeAliasOr">; - -// A handle to refer to circt::hw::StructType in ODS. -def StructType : DialectType($_self)">, - "a StructType", "::circt::hw::TypeAliasOr">; - -// A handle to refer to circt::hw::UnionType in ODS. -def UnionType : DialectType($_self)">, - "a UnionType", "::circt::hw::TypeAliasOr">; - -// A handle to refer to circt::hw::EnumType in ODS. -def EnumType : DialectType($_self)">, - "a EnumType", "::circt::hw::TypeAliasOr">; - -def HWAggregateType : DialectType($_self)}]>, - "an ArrayType or StructType", - [{::circt::hw::TypeVariant< - ::circt::hw::ArrayType, - ::circt::hw::UnpackedArrayType, - ::circt::hw::StructType>}]>; - -// A handle to refer to circt::hw::ModuleType in ODS. -def ModuleType : DialectType($_self)">, - "a module", "::circt::hw::ModuleType">; - -// A handle to refer to circt::hw::StringType in ODS. -def HWStringType : - DialectType($_self)">, - "a HW string", "::circt::hw::StringType">, - BuildableType<"::circt::hw::StringType::get($_builder.getContext())">; - -/// A flat symbol reference or a reference to a name within a module. -def NameRefAttr : Attr< - CPred<"$_self.isa<::mlir::FlatSymbolRefAttr, ::circt::hw::InnerRefAttr>()">, - "name reference attribute">{ - let returnType = "::mlir::Attribute"; - let convertFromStorage = "$_self"; - let valueType = NoneType; -} - -// Like a FlatSymbolRefArrayAttr, but can also refer to names inside modules. -def NameRefArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getArrayAttr($0)"; -} - -def InnerSymProperties : AttrDef { - let mnemonic = "innerSymProps"; - let parameters = (ins - "::mlir::StringAttr":$name, - DefaultValuedParameter<"uint64_t", "0">:$fieldID, - DefaultValuedParameter<"::mlir::StringAttr", "public">:$sym_visibility - ); - let builders = [ - AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$sym),[{ - return get(sym.getContext(), sym, 0, - mlir::StringAttr::get(sym.getContext(), "public") ); - }]> - ]; - let hasCustomAssemblyFormat = 1; - // The assembly format is as follows: - // "`<` `@` $name `,` $fieldID `,` $sym_visibility `>`"; - let genVerifyDecl = 1; -} - - -def InnerSymAttr : AttrDef { - let summary = "Inner symbol definition"; - let description = [{ - Defines the properties of an inner_sym attribute. It specifies the symbol - name and symbol visibility for each field ID. For any ground types, - there are no subfields and the field ID is 0. For aggregate types, a - unique field ID is assigned to each field by visiting them in a - depth-first pre-order. The custom assembly format ensures that for ground - types, only `@` is printed. - }]; - let mnemonic = "innerSym"; - let parameters = (ins ArrayRefParameter<"InnerSymPropertiesAttr">:$props); - let builders = [ - AttrBuilderWithInferredContext<(ins "::mlir::StringAttr":$sym),[{ - return get(sym.getContext(), - {InnerSymPropertiesAttr::get(sym.getContext(), sym, 0, - mlir::StringAttr::get(sym.getContext(), "public"))}); - }]>, - // Create an empty array, represents an invalid InnerSym. - AttrBuilder<(ins),[{ - return get($_ctxt, {}); - }]> - ]; - let extraClassDeclaration = [{ - /// Get the inner sym name for fieldID, if it exists. - mlir::StringAttr getSymIfExists(uint64_t fieldID) const; - - /// Get the inner sym name for fieldID=0, if it exists. - mlir::StringAttr getSymName() const { return getSymIfExists(0); } - - /// Get the number of inner symbols defined. - size_t size() const { return getProps().size(); } - - /// Check if this is an empty array, no sym names stored. - bool empty() const { return getProps().empty(); } - - /// Return an InnerSymAttr with the inner symbol for the specified fieldID removed. - InnerSymAttr erase(uint64_t fieldID) const; - - using iterator = mlir::ArrayRef::iterator; - /// Iterator begin for all the InnerSymProperties. - iterator begin() const { return getProps().begin(); } - - /// Iterator end for all the InnerSymProperties. - iterator end() const { return getProps().end(); } - - /// Invoke the func, for all sym names. Return success(), - /// if the callback function never returns failure(). - mlir::LogicalResult walkSymbols(llvm::function_ref< - mlir::LogicalResult (::mlir::StringAttr)>) const; - }]; - - let hasCustomAssemblyFormat = 1; - // Example format: - // firrtl.wire sym [<@x,1,private>, <@w,2,public>, <@syh,4,public>] -} - -#endif // CIRCT_DIALECT_HW_HWTYPES diff --git a/include/circt/Dialect/HW/HWTypesImpl.td b/include/circt/Dialect/HW/HWTypesImpl.td deleted file mode 100644 index d83c9f33cd..0000000000 --- a/include/circt/Dialect/HW/HWTypesImpl.td +++ /dev/null @@ -1,280 +0,0 @@ -//===- HWTypesImpl.td - HW data type definitions -----------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Basic data type implementations for the HW dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWTYPESIMPL_TD -#define CIRCT_DIALECT_HW_HWTYPESIMPL_TD - -include "include/circt/Dialect/HW/HWDialect.td" -include "include/circt/Dialect/HW/HWTypeInterfaces.td" -include "mlir/IR/AttrTypeBase.td" - -// Base class for other typedefs. Provides dialact-specific defaults. -class HWType traits = []> - : TypeDef { } - -//===----------------------------------------------------------------------===// -// Type declarations -//===----------------------------------------------------------------------===// - -// A parameterized integer type. Declares the hw::IntType in C++. -def IntTypeImpl : HWType<"Int"> { - let summary = "parameterized-width integer"; - let description = [{ - Parameterized integer types are equivalent to the MLIR standard integer - type: it is signless, and may be any width integer. This type represents - the case when the width is a parameter in the HW dialect sense. - }]; - - let mnemonic = "int"; - let parameters = (ins "::mlir::TypedAttr":$width); - - let hasCustomAssemblyFormat = 1; - - let skipDefaultBuilders = 1; - - let extraClassDeclaration = [{ - /// Get an integer type for the specified width. Note that this may return - /// a builtin integer type if the width is a known-constant value. - static Type get(::mlir::TypedAttr width); - }]; -} - -// A simple fixed size array. Declares the hw::ArrayType in C++. -def ArrayTypeImpl : HWType<"Array", [DeclareTypeInterfaceMethods]> { - let summary = "fixed-sized array"; - let description = [{ - Fixed sized HW arrays are roughly similar to C arrays. On the wire (vs. - in a memory), arrays are always packed. Memory layout is not defined as - it does not need to be since in silicon there is not implicit memory - sharing. - }]; - - let mnemonic = "array"; - let parameters = (ins "::mlir::Type":$elementType, "::mlir::Attribute":$sizeAttr); - let genVerifyDecl = 1; - - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - static ArrayType get(Type elementType, size_t size) { - auto *ctx = elementType.getContext(); - auto intType = ::mlir::IntegerType::get(ctx, 64); - auto sizeAttr = ::mlir::IntegerAttr::get(intType, size); - return ArrayType::get(ctx, elementType, sizeAttr); - } - size_t getNumElements() const; - }]; -} - -// An 'unpacked' array of fixed size. -def UnpackedArrayType : HWType<"UnpackedArray", [DeclareTypeInterfaceMethods]> { - let summary = "SystemVerilog 'unpacked' fixed-sized array"; - let description = [{ - Unpacked arrays are a more flexible array representation than packed arrays, - and are typically used to model memories. See SystemVerilog Spec 7.4.2. - }]; - - let mnemonic = "uarray"; - let parameters = (ins "::mlir::Type":$elementType, "::mlir::Attribute":$sizeAttr); - let genVerifyDecl = 1; - - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - static UnpackedArrayType get(Type elementType, size_t size) { - auto *ctx = elementType.getContext(); - auto intType = ::mlir::IntegerType::get(ctx, 64); - auto sizeAttr = ::mlir::IntegerAttr::get(intType, size); - return UnpackedArrayType::get(ctx, elementType, sizeAttr); - } - size_t getNumElements() const; - }]; -} - -def InOutTypeImpl : HWType<"InOut"> { - let summary = "inout type"; - let description = [{ - InOut type is used for model operations and values that have "connection" - semantics, instead of typical dataflow behavior. This is used for wires - and inout ports in Verilog. - }]; - - let mnemonic = "inout"; - let parameters = (ins "::mlir::Type":$elementType); - let genVerifyDecl = 1; - - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - static InOutType get(Type elementType) { - return get(elementType.getContext(), elementType); - } - }]; -} - -// A packed struct. Declares the hw::StructType in C++. -def StructTypeImpl : HWType<"Struct", [DeclareTypeInterfaceMethods]> { - let summary = "HW struct type"; - let description = [{ - Represents a structure of name, value pairs. - !hw.struct - }]; - let mnemonic = "struct"; - - let hasCustomAssemblyFormat = 1; - - let parameters = ( - ins ArrayRefParameter< - "::circt::hw::StructType::FieldInfo", "struct fields">: $elements - ); - - let extraClassDeclaration = [{ - using FieldInfo = ::circt::hw::detail::FieldInfo; - mlir::Type getFieldType(mlir::StringRef fieldName); - void getInnerTypes(mlir::SmallVectorImpl&); - std::optional getFieldIndex(mlir::StringRef fieldName); - std::optional getFieldIndex(mlir::StringAttr fieldName); - }]; -} - -// An enum type. Declares the hw::EnumType in C++. -def EnumTypeImpl : HWType<"Enum"> { - let summary = "HW Enum type"; - let description = [{ - Represents an enumeration of values. Enums are interpreted as integers with - a synthesis-defined encoding. - !hw.enum - }]; - let mnemonic = "enum"; - let parameters = ( - ins "mlir::ArrayAttr":$fields - ); - - let extraClassDeclaration = [{ - /// Returns true if the requested field is part of this enum - bool contains(mlir::StringRef field); - - /// Returns the number of bits used by the enum - size_t getBitWidth(); - - /// Returns the index of the requested field, or a nullopt if the field is - // not part of this enum. - std::optional indexOf(mlir::StringRef field); - }]; - - let hasCustomAssemblyFormat = 1; -} - -// An untagged union. Declares the hw::UnionType in C++. -def UnionTypeImpl : HWType<"Union"> { - let summary = "An untagged union of types"; - let parameters = ( - ins ArrayRefParameter< - "::circt::hw::UnionType::FieldInfo", "union fields">: $elements - ); - let mnemonic = "union"; - - let hasCustomAssemblyFormat = 1; - - let extraClassDeclaration = [{ - using FieldInfo = ::circt::hw::detail::OffsetFieldInfo; - - FieldInfo getFieldInfo(::mlir::StringRef fieldName); - - ::mlir::Type getFieldType(::mlir::StringRef fieldName); - }]; -} - -def TypeAliasType : HWType<"TypeAlias"> { - let summary = "An symbolic reference to a type declaration"; - let description = [{ - A TypeAlias is parameterized by a SymbolRefAttr, which points to a - TypedeclOp. The root reference should refer to a TypeScope within the same - outer ModuleOp, and the leaf reference should refer to a type within that - TypeScope. A TypeAlias is further parameterized by the inner type, which is - needed to be known at the time the type is parsed. - - Upon construction, a TypeAlias stores the symbol reference and type, and - canonicalizes the type to resolve any nested type aliases. The canonical - type is also cached to avoid recomputing it when needed. - }]; - - let mnemonic = "typealias"; - - let parameters = (ins - "mlir::SymbolRefAttr":$ref, - "mlir::Type":$innerType, - "mlir::Type":$canonicalType - ); - - let hasCustomAssemblyFormat = 1; - - let builders = [ - TypeBuilderWithInferredContext<(ins - "mlir::SymbolRefAttr":$ref, "mlir::Type":$innerType)> - ]; - - let extraClassDeclaration = [{ - /// Return the Typedecl referenced by this TypeAlias, given the module to - /// look in. This returns null when the IR is malformed. - TypedeclOp getTypeDecl(const HWSymbolCache &cache); - }]; -} - -def ModuleTypeImpl : HWType<"Module"> { - let summary = "Module Type"; - let description = [{ - Module types have ports. - }]; - let parameters = (ins ArrayRefParameter<"::circt::hw::ModulePort", "port list">:$ports); - let hasCustomAssemblyFormat = 1; - let genVerifyDecl = 1; - let mnemonic = "modty"; - - let extraClassDeclaration = [{ - // Many of these are transitional and will be removed when modules and instances - // have moved over to this type. - size_t getNumPorts(); - size_t getNumInputs(); - size_t getNumOutputs(); - SmallVector getPortTypes(); - SmallVector getInputTypes(); - SmallVector getOutputTypes(); - Type getPortType(size_t); - Type getInputType(size_t); - Type getOutputType(size_t); - SmallVector getInputNamesStr(); - SmallVector getOutputNamesStr(); - SmallVector getInputNames(); - SmallVector getOutputNames(); - StringAttr getPortNameAttr(size_t); - StringRef getPortName(size_t); - StringAttr getInputNameAttr(size_t); - StringRef getInputName(size_t); - StringAttr getOutputNameAttr(size_t); - StringRef getOutputName(size_t); - FunctionType getFuncType(); - bool isOutput(size_t); - size_t getInputIdForPortId(size_t); - size_t getOutputIdForPortId(size_t); - size_t getPortIdForInputId(size_t); - size_t getPortIdForOutputId(size_t); - }]; -} - -def HWStringTypeImpl : HWType<"String"> { - let summary = "String type"; - let description = "Defines a string type for the hw-centric dialects"; - let mnemonic = "string"; -} - -#endif // CIRCT_DIALECT_HW_HWTYPESIMPL_TD diff --git a/include/circt/Dialect/HW/HWVisitors.h b/include/circt/Dialect/HW/HWVisitors.h deleted file mode 100644 index 2c00d181b5..0000000000 --- a/include/circt/Dialect/HW/HWVisitors.h +++ /dev/null @@ -1,140 +0,0 @@ -//===- HWVisitors.h - HW Dialect Visitors ---------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines visitors that make it easier to work with HW IR. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_HWVISITORS_H -#define CIRCT_DIALECT_HW_HWVISITORS_H - -#include "include/circt/Dialect/HW/HWOps.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project - -namespace circt { -namespace hw { - -/// This helps visit TypeOp nodes. -template -class TypeOpVisitor { - public: - ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args) { - auto *thisCast = static_cast(this); - return TypeSwitch(op) - .template Case([&](auto expr) -> ResultType { - return thisCast->visitTypeOp(expr, args...); - }) - .Default([&](auto expr) -> ResultType { - return thisCast->visitInvalidTypeOp(op, args...); - }); - } - - /// This callback is invoked on any non-expression operations. - ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args) { - op->emitOpError("unknown HW combinational node"); - abort(); - } - - /// This callback is invoked on any combinational operations that are not - /// handled by the concrete visitor. - ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args) { - return ResultType(); - } - -#define HANDLE(OPTYPE, OPKIND) \ - ResultType visitTypeOp(OPTYPE op, ExtraArgs... args) { \ - return static_cast(this)->visit##OPKIND##TypeOp(op, \ - args...); \ - } - - HANDLE(ConstantOp, Unhandled); - HANDLE(AggregateConstantOp, Unhandled); - HANDLE(BitcastOp, Unhandled); - HANDLE(ParamValueOp, Unhandled); - HANDLE(StructCreateOp, Unhandled); - HANDLE(StructExtractOp, Unhandled); - HANDLE(StructInjectOp, Unhandled); - HANDLE(UnionCreateOp, Unhandled); - HANDLE(UnionExtractOp, Unhandled); - HANDLE(ArraySliceOp, Unhandled); - HANDLE(ArrayGetOp, Unhandled); - HANDLE(ArrayCreateOp, Unhandled); - HANDLE(ArrayConcatOp, Unhandled); - HANDLE(EnumCmpOp, Unhandled); - HANDLE(EnumConstantOp, Unhandled); -#undef HANDLE -}; - -/// This helps visit TypeOp nodes. -template -class StmtVisitor { - public: - ResultType dispatchStmtVisitor(Operation *op, ExtraArgs... args) { - auto *thisCast = static_cast(this); - return TypeSwitch(op) - .template Case( - [&](auto expr) -> ResultType { - return thisCast->visitStmt(expr, args...); - }) - .Default([&](auto expr) -> ResultType { - return thisCast->visitInvalidStmt(op, args...); - }); - } - - /// This callback is invoked on any non-expression operations. - ResultType visitInvalidStmt(Operation *op, ExtraArgs... args) { - op->emitOpError("unknown hw statement"); - abort(); - } - - /// This callback is invoked on any combinational operations that are not - /// handled by the concrete visitor. - ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args) { - return ResultType(); - } - - /// This fallback is invoked on any binary node that isn't explicitly handled. - /// The default implementation delegates to the 'unhandled' fallback. - ResultType visitBinaryTypeOp(Operation *op, ExtraArgs... args) { - return static_cast(this)->visitUnhandledTypeOp(op, args...); - } - - ResultType visitUnaryTypeOp(Operation *op, ExtraArgs... args) { - return static_cast(this)->visitUnhandledTypeOp(op, args...); - } - -#define HANDLE(OPTYPE, OPKIND) \ - ResultType visitStmt(OPTYPE op, ExtraArgs... args) { \ - return static_cast(this)->visit##OPKIND##Stmt(op, \ - args...); \ - } - - // Basic nodes. - HANDLE(OutputOp, Unhandled); - HANDLE(InstanceOp, Unhandled); - HANDLE(TypeScopeOp, Unhandled); - HANDLE(TypedeclOp, Unhandled); -#undef HANDLE -}; - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_HWVISITORS_H diff --git a/include/circt/Dialect/HW/InnerSymbolNamespace.h b/include/circt/Dialect/HW/InnerSymbolNamespace.h deleted file mode 100644 index 814fb8022c..0000000000 --- a/include/circt/Dialect/HW/InnerSymbolNamespace.h +++ /dev/null @@ -1,51 +0,0 @@ -//===- InnerSymbolNamespace.h - Inner Symbol Table Namespace ----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the InnerSymbolNamespace, which tracks the names -// used by inner symbols within an InnerSymbolTable. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_INNERSYMBOLNAMESPACE_H -#define CIRCT_DIALECT_HW_INNERSYMBOLNAMESPACE_H - -#include "include/circt/Dialect/HW/InnerSymbolTable.h" -#include "include/circt/Support/Namespace.h" - -namespace circt { -namespace hw { - -struct InnerSymbolNamespace : Namespace { - InnerSymbolNamespace() = default; - InnerSymbolNamespace(Operation *module) { add(module); } - - /// Populate the namespace from a module-like operation. This namespace will - /// be composed of the `inner_sym`s of the module's ports and declarations. - void add(Operation *module) { - hw::InnerSymbolTable::walkSymbols( - module, [&](StringAttr name, const InnerSymTarget &target) { - nextIndex.insert({name.getValue(), 0}); - }); - } -}; - -struct InnerSymbolNamespaceCollection { - InnerSymbolNamespace &get(Operation *op) { - return collection.try_emplace(op, op).first->second; - } - - InnerSymbolNamespace &operator[](Operation *op) { return get(op); } - - private: - DenseMap collection; -}; - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_INNERSYMBOLNAMESPACE_H diff --git a/include/circt/Dialect/HW/InnerSymbolTable.h b/include/circt/Dialect/HW/InnerSymbolTable.h deleted file mode 100644 index ef752c6351..0000000000 --- a/include/circt/Dialect/HW/InnerSymbolTable.h +++ /dev/null @@ -1,264 +0,0 @@ -//===- InnerSymbolTable.h - Inner Symbol Table -----------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares the InnerSymbolTable and related classes, used for -// managing and tracking "inner symbols". -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_INNERSYMBOLTABLE_H -#define CIRCT_DIALECT_HW_INNERSYMBOLTABLE_H - -#include "include/circt/Dialect/HW/HWAttributes.h" -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project - -namespace circt { -namespace hw { - -/// The target of an inner symbol, the entity the symbol is a handle for. -class InnerSymTarget { - public: - /// Default constructor, invalid. - InnerSymTarget() { assert(!*this); } - - /// Target an operation. - explicit InnerSymTarget(Operation *op) : InnerSymTarget(op, 0) {} - - /// Target an operation and a field (=0 means the op itself). - InnerSymTarget(Operation *op, size_t fieldID) - : op(op), portIdx(invalidPort), fieldID(fieldID) {} - - /// Target a port, and optionally a field (=0 means the port itself). - InnerSymTarget(size_t portIdx, Operation *op, size_t fieldID = 0) - : op(op), portIdx(portIdx), fieldID(fieldID) {} - - InnerSymTarget(const InnerSymTarget &) = default; - InnerSymTarget(InnerSymTarget &&) = default; - - // Accessors: - - /// Return the target's fieldID. - auto getField() const { return fieldID; } - - /// Return the target's base operation. For ports, this is the module. - Operation *getOp() const { return op; } - - /// Return the target's port, if valid. Check "isPort()". - auto getPort() const { - assert(isPort()); - return portIdx; - } - - // Classification: - - /// Return if this targets a field (nonzero fieldID). - bool isField() const { return fieldID != 0; } - - /// Return if this targets a port. - bool isPort() const { return portIdx != invalidPort; } - - /// Returns if this targets an operation only (not port or field). - bool isOpOnly() const { return !isPort() && !isField(); } - - /// Return a target to the specified field within the given base. - /// FieldID is relative to the specified base target. - static InnerSymTarget getTargetForSubfield(const InnerSymTarget &base, - size_t fieldID) { - if (base.isPort()) - return InnerSymTarget(base.portIdx, base.op, base.fieldID + fieldID); - return InnerSymTarget(base.op, base.fieldID + fieldID); - } - - private: - auto asTuple() const { return std::tie(op, portIdx, fieldID); } - Operation *op = nullptr; - size_t portIdx = 0; - size_t fieldID = 0; - static constexpr size_t invalidPort = ~size_t{0}; - - public: - // Operators are defined below. - - // Comparison operators: - bool operator==(const InnerSymTarget &rhs) const { - return asTuple() == rhs.asTuple(); - } - - // Assignment operators: - InnerSymTarget &operator=(InnerSymTarget &&) = default; - InnerSymTarget &operator=(const InnerSymTarget &) = default; - - /// Check if this target is valid. - operator bool() const { return op; } -}; - -/// A table of inner symbols and their resolutions. -class InnerSymbolTable { - public: - /// Build an inner symbol table for the given operation. The operation must - /// have the InnerSymbolTable trait. - explicit InnerSymbolTable(Operation *op); - - /// Non-copyable - InnerSymbolTable(const InnerSymbolTable &) = delete; - InnerSymbolTable &operator=(const InnerSymbolTable &) = delete; - - // Moveable - InnerSymbolTable(InnerSymbolTable &&) = default; - InnerSymbolTable &operator=(InnerSymbolTable &&) = default; - - /// Look up a symbol with the specified name, returning empty InnerSymTarget - /// if no such name exists. Names never include the @ on them. - InnerSymTarget lookup(StringRef name) const; - InnerSymTarget lookup(StringAttr name) const; - - /// Look up a symbol with the specified name, returning null if no such - /// name exists or doesn't target just an operation. - Operation *lookupOp(StringRef name) const; - template - T lookupOp(StringRef name) const { - return dyn_cast_or_null(lookupOp(name)); - } - - /// Look up a symbol with the specified name, returning null if no such - /// name exists or doesn't target just an operation. - Operation *lookupOp(StringAttr name) const; - template - T lookupOp(StringAttr name) const { - return dyn_cast_or_null(lookupOp(name)); - } - - /// Get InnerSymbol for an operation. - static StringAttr getInnerSymbol(Operation *op); - - /// Get InnerSymbol for a target. - static StringAttr getInnerSymbol(const InnerSymTarget &target); - - /// Return the name of the attribute used for inner symbol names. - static StringRef getInnerSymbolAttrName() { return "inner_sym"; } - - /// Construct an InnerSymbolTable, checking for verification failure. - /// Emits diagnostics describing encountered issues. - static FailureOr get(Operation *op); - - using InnerSymCallbackFn = - llvm::function_ref; - - /// Walk the given IST operation and invoke the callback for all encountered - /// inner symbols. - /// This variant allows callbacks that return LogicalResult OR void, - /// and wraps the underlying implementation. - template > - static RetTy walkSymbols(Operation *op, FuncTy &&callback) { - if constexpr (std::is_void_v) - return (void)walkSymbols( - op, InnerSymCallbackFn( - [&](StringAttr name, const InnerSymTarget &target) { - std::invoke(std::forward(callback), name, target); - return success(); - })); - else - return walkSymbols( - op, InnerSymCallbackFn([&](StringAttr name, - const InnerSymTarget &target) { - return std::invoke(std::forward(callback), name, target); - })); - } - - /// Walk the given IST operation and invoke the callback for all encountered - /// inner symbols. - /// This variant is the underlying implementation. - /// If callback returns failure, the walk is aborted and failure is returned. - /// A successful walk with no failures returns success. - static LogicalResult walkSymbols(Operation *op, InnerSymCallbackFn callback); - - private: - using TableTy = DenseMap; - /// Construct an inner symbol table for the given operation, - /// with pre-populated table contents. - explicit InnerSymbolTable(Operation *op, TableTy &&table) - : innerSymTblOp(op), symbolTable(table){}; - - /// This is the operation this table is constructed for, which must have the - /// InnerSymbolTable trait. - Operation *innerSymTblOp; - - /// This maps inner symbol names to their targets. - TableTy symbolTable; -}; - -/// This class represents a collection of InnerSymbolTable's. -class InnerSymbolTableCollection { - public: - /// Get or create the InnerSymbolTable for the specified operation. - InnerSymbolTable &getInnerSymbolTable(Operation *op); - - /// Populate tables in parallel for all InnerSymbolTable operations in the - /// given InnerRefNamespace operation, verifying each and returning - /// the verification result. - LogicalResult populateAndVerifyTables(Operation *innerRefNSOp); - - explicit InnerSymbolTableCollection() = default; - explicit InnerSymbolTableCollection(Operation *innerRefNSOp) { - // Caller is not interested in verification, no way to report it upwards. - auto result = populateAndVerifyTables(innerRefNSOp); - (void)result; - assert(succeeded(result)); - } - InnerSymbolTableCollection(const InnerSymbolTableCollection &) = delete; - InnerSymbolTableCollection &operator=(const InnerSymbolTableCollection &) = - delete; - - private: - /// This maps Operations to their InnnerSymbolTable's. - DenseMap> symbolTables; -}; - -/// This class represents the namespace in which InnerRef's can be resolved. -struct InnerRefNamespace { - SymbolTable &symTable; - InnerSymbolTableCollection &innerSymTables; - - /// Resolve the InnerRef to its target within this namespace, returning empty - /// target if no such name exists. - InnerSymTarget lookup(hw::InnerRefAttr inner); - - /// Resolve the InnerRef to its target within this namespace, returning - /// empty target if no such name exists or it's not an operation. - /// Template type can be used to limit results to specified op type. - Operation *lookupOp(hw::InnerRefAttr inner); - template - T lookupOp(hw::InnerRefAttr inner) { - return dyn_cast_or_null(lookupOp(inner)); - } -}; - -/// Printing InnerSymTarget's. -template -OS &operator<<(OS &os, const InnerSymTarget &target) { - if (!target) return os << ""; - - if (target.isField()) os << "field " << target.getField() << " of "; - - if (target.isPort()) - os << "port " << target.getPort() << " on @" - << SymbolTable::getSymbolName(target.getOp()).getValue() << ""; - else - os << "op " << *target.getOp() << ""; - - return os; -} - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_FIRRTL_INNERSYMBOLTABLE_H diff --git a/include/circt/Dialect/HW/InstanceImplementation.h b/include/circt/Dialect/HW/InstanceImplementation.h deleted file mode 100644 index 6265ce0179..0000000000 --- a/include/circt/Dialect/HW/InstanceImplementation.h +++ /dev/null @@ -1,104 +0,0 @@ -//===- InstanceImplementation.h - Instance-like Op utilities ----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides utility functions for implementing instance-like -// operations, in particular, parsing, and printing common to instance-like -// operations. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_INSTANCEIMPLEMENTATION_H -#define CIRCT_DIALECT_HW_INSTANCEIMPLEMENTATION_H - -#include - -#include "include/circt/Support/LLVM.h" - -namespace circt { -namespace hw { -// Forward declarations. -class HWSymbolCache; - -namespace instance_like_impl { - -/// Whenever the nested function returns true, a note referring to the -/// referenced module is attached to the error. -using EmitErrorFn = - std::function)>; - -/// Return a pointer to the referenced module operation. -Operation *getReferencedModule(const HWSymbolCache *cache, - Operation *instanceOp, - mlir::FlatSymbolRefAttr moduleName); - -/// Verify that the instance refers to a valid HW module. -LogicalResult verifyReferencedModule(Operation *instanceOp, - SymbolTableCollection &symbolTable, - mlir::FlatSymbolRefAttr moduleName, - Operation *&module); - -/// Stores a resolved version of each type in @param types wherein any parameter -/// reference has been evaluated based on the set of provided @param parameters -/// in @param resolvedTypes -LogicalResult resolveParametricTypes(Location loc, ArrayAttr parameters, - ArrayRef types, - SmallVectorImpl &resolvedTypes, - const EmitErrorFn &emitError); - -/// Verify that the list of inputs of the instance and the module match in terms -/// of length, names, and types. -LogicalResult verifyInputs(ArrayAttr argNames, ArrayAttr moduleArgNames, - TypeRange inputTypes, - ArrayRef moduleInputTypes, - const EmitErrorFn &emitError); - -/// Verify that the list of outputs of the instance and the module match in -/// terms of length, names, and types. -LogicalResult verifyOutputs(ArrayAttr resultNames, ArrayAttr moduleResultNames, - TypeRange resultTypes, - ArrayRef moduleResultTypes, - const EmitErrorFn &emitError); - -/// Verify that the parameter lists of the instance and the module match in -/// terms of length, names, and types. -LogicalResult verifyParameters(ArrayAttr parameters, ArrayAttr moduleParameters, - const EmitErrorFn &emitError); - -/// Combines verifyReferencedModule, verifyInputs, verifyOutputs, and -/// verifyParameters. It is only allowed to call this function when the instance -/// refers to a HW module. The @param parameters attribute may be null in which -/// case not parameters are verified. -LogicalResult verifyInstanceOfHWModule( - Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, - TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, - ArrayAttr parameters, SymbolTableCollection &symbolTable); - -/// Check that all the parameter values specified to the instance are -/// structurally valid. -LogicalResult verifyParameterStructure(ArrayAttr parameters, - ArrayAttr moduleParameters, - const EmitErrorFn &emitError); - -/// Return the name at the specified index of the ArrayAttr or null if it cannot -/// be determined. -StringAttr getName(ArrayAttr names, size_t idx); - -/// Change the name at the specified index of the @param oldNames ArrayAttr to -/// @param name -ArrayAttr updateName(ArrayAttr oldNames, size_t i, StringAttr name); - -/// Suggest a name for each result value based on the saved result names -/// attribute. -void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, - ArrayAttr resultNames, ValueRange results); - -} // namespace instance_like_impl -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_INSTANCEIMPLEMENTATION_H diff --git a/include/circt/Dialect/HW/ModuleImplementation.h b/include/circt/Dialect/HW/ModuleImplementation.h deleted file mode 100644 index f47c55abe5..0000000000 --- a/include/circt/Dialect/HW/ModuleImplementation.h +++ /dev/null @@ -1,56 +0,0 @@ -//===- ModuleImplementation.h - Module-like Op utilities --------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides utility functions for implementing module-like -// operations, in particular, parsing, and printing common to module-like -// operations. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_MODULEIMPLEMENTATION_H -#define CIRCT_DIALECT_HW_MODULEIMPLEMENTATION_H - -#include "include/circt/Dialect/HW/HWTypes.h" -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project - -namespace circt { -namespace hw { - -namespace module_like_impl { - -struct PortParse : OpAsmParser::Argument { - ModulePort::Direction direction; -}; - -/// This is a variant of mlir::parseFunctionSignature that allows names on -/// result arguments. -ParseResult parseModuleFunctionSignature( - OpAsmParser &parser, bool &isVariadic, - SmallVectorImpl &args, - SmallVectorImpl &argNames, SmallVectorImpl &argLocs, - SmallVectorImpl &resultNames, - SmallVectorImpl &resultAttrs, - SmallVectorImpl &resultLocs, TypeAttr &type); - -/// Print a module signature with named results. -void printModuleSignature(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes, bool &needArgNamesAttr); - -/// New Style parsing -ParseResult parseModuleSignature(OpAsmParser &parser, - SmallVectorImpl &args, - TypeAttr &modType); -void printModuleSignatureNew(OpAsmPrinter &p, Operation *op); - -} // namespace module_like_impl -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_MODULEIMPLEMENTATION_H diff --git a/include/circt/Dialect/HW/Passes.td b/include/circt/Dialect/HW/Passes.td deleted file mode 100644 index 15d95000f8..0000000000 --- a/include/circt/Dialect/HW/Passes.td +++ /dev/null @@ -1,61 +0,0 @@ -//===-- Passes.td - HW pass definition file ----------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines the passes that work on the HW dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_PASSES_TD -#define CIRCT_DIALECT_HW_PASSES_TD - -include "mlir/Pass/PassBase.td" - -def PrintInstanceGraph : Pass<"hw-print-instance-graph", "mlir::ModuleOp"> { - let summary = "Print a DOT graph of the module hierarchy."; - let constructor = "circt::hw::createPrintInstanceGraphPass()"; -} - -def PrintHWModuleGraph : Pass<"hw-print-module-graph", "mlir::ModuleOp"> { - let summary = "Print a DOT graph of the HWModule's within a top-level module."; - let constructor = "circt::hw::createPrintHWModuleGraphPass()"; - - let options = [ - Option<"verboseEdges", "verbose-edges", "bool", "false", - "Print information on SSA edges (types, operand #, ...)">, - ]; -} - -def FlattenIO : Pass<"hw-flatten-io", "mlir::ModuleOp"> { - let summary = "Flattens hw::Structure typed in- and output ports."; - let constructor = "circt::hw::createFlattenIOPass()"; - - let options = [ - Option<"recursive", "recursive", "bool", "false", - "Recursively flatten nested structs.">, - ]; -} - -def HWSpecialize : Pass<"hw-specialize", "mlir::ModuleOp"> { - let summary = "Specializes instances of parametric hw.modules"; - let constructor = "circt::hw::createHWSpecializePass()"; - let description = [{ - Any `hw.instance` operation instantiating a parametric `hw.module` will - trigger a specialization procedure which resolves all parametric types and - values within the module based on the set of provided parameters to the - `hw.instance` operation. This specialized module is created as a new - `hw.module` and the referring `hw.instance` operation is rewritten to - instantiate the newly specialized module. - }]; -} - -def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> { - let summary = "Verify InnerRefNamespaceLike operations, if not self-verifying."; - let constructor = "circt::hw::createVerifyInnerRefNamespacePass()"; -} - -#endif // CIRCT_DIALECT_HW_PASSES_TD diff --git a/include/circt/Dialect/HW/PortConverter.h b/include/circt/Dialect/HW/PortConverter.h deleted file mode 100644 index f3bb412d12..0000000000 --- a/include/circt/Dialect/HW/PortConverter.h +++ /dev/null @@ -1,182 +0,0 @@ -//===- PortConverter.h - Module I/O rewriting utility -----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The PortConverter is a utility class for rewriting arguments of a -// HWMutableModuleLike operation. -// It is intended to be a generic utility that can facilitate replacement of -// a given module in- or output to an arbitrary set of new inputs and outputs -// (i.e. 1 port -> N in, M out ports). Typical usecases is where an in (or -// output) of a module represents some higher-level abstraction that will be -// implemented by a set of lower-level in- and outputs ports + supporting -// operations within a module. It also attempts to do so in an optimal way, by -// e.g. being able to collect multiple port modifications of a module, and -// perform them all at once. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_DIALECT_HW_PORTCONVERTER_H -#define CIRCT_DIALECT_HW_PORTCONVERTER_H - -#include "include/circt/Dialect/HW/HWInstanceGraph.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Support/BackedgeBuilder.h" -#include "include/circt/Support/LLVM.h" - -namespace circt { -namespace hw { - -class PortConversionBuilder; -class PortConversion; - -class PortConverterImpl { - public: - /// Run port conversion. - LogicalResult run(); - Block *getBody() const { return body; } - hw::HWMutableModuleLike getModule() const { return mod; } - - /// These two methods take care of allocating new ports in the correct place - /// based on the position of 'origPort'. The new port is based on the original - /// name and suffix. The specification for the new port is given by `newPort` - /// and is recorded internally. Any changes to 'newPort' after calling this - /// will not be reflected in the modules new port list. Will also add the new - /// input to the block arguments of the body of the module. - Value createNewInput(hw::PortInfo origPort, const Twine &suffix, Type type, - hw::PortInfo &newPort); - /// Same as above. 'output' is the value fed into the new port and is required - /// if 'body' is non-null. Important note: cannot be a backedge which gets - /// replaced since this isn't attached to an op until later in the pass. - void createNewOutput(hw::PortInfo origPort, const Twine &suffix, Type type, - Value output, hw::PortInfo &newPort); - - protected: - PortConverterImpl(igraph::InstanceGraphNode *moduleNode); - - std::unique_ptr ssb; - - private: - /// Updates an instance of the module. This is called after the module has - /// been updated. It will update the instance to match the new port - void updateInstance(hw::InstanceOp); - - // If the module has a block and it wants to be modified, this'll be - // non-null. - Block *body = nullptr; - - igraph::InstanceGraphNode *moduleNode; - hw::HWMutableModuleLike mod; - OpBuilder b; - - // Keep around a reference to the specific port conversion classes to - // facilitate updating the instance ops. Indexed by the original port - // location. - SmallVector> loweredInputs; - SmallVector> loweredOutputs; - - // Tracking information to modify the module. Populated by the - // 'createNew(Input|Output)' methods. Will be cleared once port changes have - // materialized. Default length is 0 to save memory in case we'll be keeping - // this around for later use. - SmallVector, 0> newInputs; - SmallVector, 0> newOutputs; - - // Maintain a handle to the terminator of the body, if any. This will get - // continuously updated during port conversion whenever a new output is added - // to the module. - Operation *terminator = nullptr; -}; - -/// Base class for the port conversion of a particular port. Abstracts the -/// details of a particular port conversion from the port layout. Subclasses -/// keep around port mapping information to use when updating instances. -class PortConversion { - public: - PortConversion(PortConverterImpl &converter, hw::PortInfo origPort) - : converter(converter), body(converter.getBody()), origPort(origPort) {} - virtual ~PortConversion() = default; - - // An optional initialization step that can be overridden by subclasses. - // This allows subclasses to perform a failable post-construction - // initialization step. - virtual LogicalResult init() { return success(); } - - // Lower the specified port into a wire-level signaling protocol. The two - // virtual methods 'build*Signals' should be overridden by subclasses. They - // should use the 'create*' methods in 'PortConverter' to create the - // necessary ports. - void lowerPort() { - if (origPort.dir == hw::ModulePort::Direction::Output) - buildOutputSignals(); - else - buildInputSignals(); - } - - /// Update an instance port to the new port information. - virtual void mapInputSignals(OpBuilder &b, Operation *inst, Value instValue, - SmallVectorImpl &newOperands, - ArrayRef newResults) = 0; - virtual void mapOutputSignals(OpBuilder &b, Operation *inst, Value instValue, - SmallVectorImpl &newOperands, - ArrayRef newResults) = 0; - - MLIRContext *getContext() { return getModule()->getContext(); } - bool isUntouched() const { return isUntouchedFlag; } - - protected: - // Build the input and output signals for the port. This pertains to modifying - // the module itself. - virtual void buildInputSignals() = 0; - virtual void buildOutputSignals() = 0; - - PortConverterImpl &converter; - Block *body; - hw::PortInfo origPort; - - hw::HWMutableModuleLike getModule() { return converter.getModule(); } - - // We don't need full LLVM-style RTTI support for PortConversion (would - // require some mechanism of registering user-provided PortConversion-derived - // classes), we only need to dynamically tell whether any given PortConversion - // is the UntouchedPortConversion. - bool isUntouchedFlag = false; -}; // namespace hw - -// A PortConversionBuilder will, given an input type, build the appropriate -// port conversion for that type. -class PortConversionBuilder { - public: - PortConversionBuilder(PortConverterImpl &converter) : converter(converter) {} - virtual ~PortConversionBuilder() = default; - - // Builds the appropriate port conversion for the port. Users should - // override this method with their own llvm::TypeSwitch-based dispatch code, - // and by default call this method when no port conversion applies. - virtual FailureOr> build(hw::PortInfo port); - - PortConverterImpl &converter; -}; - -// A PortConverter wraps a single HWMutableModuleLike operation, and is -// initialized from an instance graph node. The port converter is templated -// on a PortConversionBuilder, which is used to build the appropriate -// port conversion for each port type. -template -class PortConverter : public PortConverterImpl { - public: - template - PortConverter(hw::InstanceGraph &graph, hw::HWMutableModuleLike mod, - Args &&...args) - : PortConverterImpl(graph.lookup(cast(*mod))) { - ssb = std::make_unique(*this, args...); - } -}; - -} // namespace hw -} // namespace circt - -#endif // CIRCT_DIALECT_HW_PORTCONVERTER_H diff --git a/include/circt/Support/APInt.h b/include/circt/Support/APInt.h deleted file mode 100644 index 0aae1eb12e..0000000000 --- a/include/circt/Support/APInt.h +++ /dev/null @@ -1,30 +0,0 @@ -//===- APInt.h - CIRCT Lowering Options -------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for working around limitations of upstream LLVM APInts. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_APINT_H -#define CIRCT_SUPPORT_APINT_H - -#include "include/circt/Support/LLVM.h" - -namespace circt { - -/// A safe version of APInt::sext that will NOT assert on zero-width -/// signed APSInts. Instead of asserting, this will zero extend. -APInt sextZeroWidth(APInt value, unsigned width); - -/// A safe version of APSInt::extOrTrunc that will NOT assert on zero-width -/// signed APSInts. Instead of asserting, this will zero extend. -APSInt extOrTruncZeroWidth(APSInt value, unsigned width); - -} // namespace circt - -#endif // CIRCT_SUPPORT_APINT_H diff --git a/include/circt/Support/BUILD b/include/circt/Support/BUILD index cfbfa0eb14..6209665878 100644 --- a/include/circt/Support/BUILD +++ b/include/circt/Support/BUILD @@ -1,5 +1,3 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") - package( default_applicable_licenses = ["@heir//:license"], default_visibility = ["//visibility:public"], @@ -10,43 +8,3 @@ exports_files( "*.h", ]), ) - -td_library( - name = "td_files", - srcs = glob([ - "*.td", - ]), - includes = ["include"], - deps = [ - "@llvm-project//mlir:BuiltinDialectTdFiles", - "@llvm-project//mlir:ControlFlowInterfacesTdFiles", - "@llvm-project//mlir:FunctionInterfacesTdFiles", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -gentbl_cc_library( - name = "interfaces_inc_gen", - includes = ["include"], - tbl_outs = [ - ( - [ - "-gen-op-interface-decls", - ], - "InstanceGraphInterface.h.inc", - ), - ( - [ - "-gen-op-interface-defs", - ], - "InstanceGraphInterface.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "InstanceGraphInterface.td", - deps = [ - ":td_files", - ], -) diff --git a/include/circt/Support/BackedgeBuilder.h b/include/circt/Support/BackedgeBuilder.h deleted file mode 100644 index ec6cee0ef9..0000000000 --- a/include/circt/Support/BackedgeBuilder.h +++ /dev/null @@ -1,100 +0,0 @@ -//===- BackedgeBuilder.h - Support for building backedges -------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Backedges are operations/values which have to exist as operands before -// they are produced in a result. Since it isn't clear how to build backedges -// in MLIR, these helper classes set up a canonical way to do so. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_BACKEDGEBUILDER_H -#define CIRCT_SUPPORT_BACKEDGEBUILDER_H - -#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project -#include "mlir/include/mlir/IR/Location.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project - -namespace mlir { -class OpBuilder; -class PatternRewriter; -class Operation; -} // namespace mlir - -namespace circt { - -class Backedge; - -/// Instantiate one of these and use it to build typed backedges. Backedges -/// which get used as operands must be assigned to with the actual value before -/// this class is destructed, usually at the end of a scope. It will check that -/// invariant then erase all the backedge ops during destruction. -/// -/// Example use: -/// ``` -/// circt::BackedgeBuilder back(rewriter, loc); -/// circt::Backedge ready = back.get(rewriter.getI1Type()); -/// // Use `ready` as a `Value`. -/// auto addOp = rewriter.create(loc, ready); -/// // When the actual value is available, -/// ready.set(anotherOp.getResult(0)); -/// ``` -class BackedgeBuilder { - friend class Backedge; - - public: - /// To build a backedge op and manipulate it, we need a `PatternRewriter` and - /// a `Location`. Store them during construct of this instance and use them - /// when building. - BackedgeBuilder(mlir::OpBuilder &builder, mlir::Location loc); - BackedgeBuilder(mlir::PatternRewriter &rewriter, mlir::Location loc); - ~BackedgeBuilder(); - - /// Create a typed backedge. If no location is provided, the one passed to the - /// constructor will be used. - Backedge get(mlir::Type resultType, mlir::LocationAttr optionalLoc = {}); - - /// Clear the backedges, erasing any remaining cursor ops. Returns `failure` - /// and emits diagnostic messages if a backedge is still active. - mlir::LogicalResult clearOrEmitError(); - - /// Abandon the backedges, suppressing any diagnostics if they are still - /// active upon destruction of the backedge builder. Also, any currently - /// existing cursor ops will be abandoned. - void abandon(); - - private: - mlir::OpBuilder &builder; - mlir::PatternRewriter *rewriter; - mlir::Location loc; - llvm::SmallVector edges; -}; - -/// `Backedge` is a wrapper class around a `Value`. When assigned another -/// `Value`, it replaces all uses of itself with the new `Value` then become a -/// wrapper around the new `Value`. -class Backedge { - friend class BackedgeBuilder; - - /// `Backedge` is constructed exclusively by `BackedgeBuilder`. - Backedge(mlir::Operation *op); - - public: - Backedge() {} - - explicit operator bool() const { return !!value; } - operator mlir::Value() const { return value; } - void setValue(mlir::Value); - - private: - mlir::Value value; - bool set = false; -}; - -} // namespace circt - -#endif // CIRCT_SUPPORT_BACKEDGEBUILDER_H diff --git a/include/circt/Support/BuilderUtils.h b/include/circt/Support/BuilderUtils.h deleted file mode 100644 index dc1ba20322..0000000000 --- a/include/circt/Support/BuilderUtils.h +++ /dev/null @@ -1,49 +0,0 @@ -//===- BuilderUtils.h - Operation builder utilities -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_BUILDERUTILS_H -#define CIRCT_SUPPORT_BUILDERUTILS_H - -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project - -// Note: The following must be added to include the LLVM types used. -#include "llvm/include/llvm/ADT/PointerUnion.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project - -namespace circt { - -/// A helper union that can represent a `StringAttr`, `StringRef`, or `Twine`. -/// It is intended to be used as arguments to an op's `build` function. This -/// allows a single builder to accept any flavor value for a string attribute. -/// The `get` function can then be used to obtain a `StringAttr` from any of the -/// possible variants `StringAttrOrRef` can take. -class StringAttrOrRef { - using Value = llvm::PointerUnion; - Value value; - - public: - StringAttrOrRef() : value() {} - StringAttrOrRef(StringAttr attr) : value(attr) {} - StringAttrOrRef(const StringRef &str) - : value(const_cast(&str)) {} - StringAttrOrRef(const Twine &twine) : value(const_cast(&twine)) {} - - /// Return the represented string as a `StringAttr`. - StringAttr get(MLIRContext *context) const { - return TypeSwitch(value) - .Case([&](auto value) { return value; }) - .Case( - [&](auto value) { return StringAttr::get(context, *value); }) - .Default([](auto) { return StringAttr{}; }); - } -}; - -} // namespace circt - -#endif // CIRCT_SUPPORT_BUILDERUTILS_H diff --git a/include/circt/Support/CMakeLists.txt b/include/circt/Support/CMakeLists.txt deleted file mode 100644 index e63fd5422f..0000000000 --- a/include/circt/Support/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_mlir_interface(InstanceGraphInterface) diff --git a/include/circt/Support/CustomDirectiveImpl.h b/include/circt/Support/CustomDirectiveImpl.h deleted file mode 100644 index 317ce2c969..0000000000 --- a/include/circt/Support/CustomDirectiveImpl.h +++ /dev/null @@ -1,91 +0,0 @@ -//===- CustomDirectiveImpl.h - Custom TableGen directives -------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides common custom directives for table-gen assembly formats. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_CUSTOMDIRECTIVEIMPL_H -#define CIRCT_SUPPORT_CUSTOMDIRECTIVEIMPL_H - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project - -namespace circt { - -//===----------------------------------------------------------------------===// -// ImplicitSSAName Custom Directive -//===----------------------------------------------------------------------===// - -/// Parse an implicit SSA name string attribute. If the name is not provided in -/// the input text, its value is inferred from the SSA name of the operation's -/// first result. -/// -/// implicit-name ::= (`name` str-attr)? -ParseResult parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr); - -/// Parse an attribute dictionary and ensure that it contains a `name` field by -/// inferring its value from the SSA name of the operation's first result if -/// necessary. -ParseResult parseImplicitSSAName(OpAsmParser &parser, NamedAttrList &attrs); - -/// Ensure that `attrs` contains a `name` attribute by inferring its value from -/// the SSA name of the operation's first result if necessary. Returns true if a -/// name was inferred, false if `attrs` already contained a `name`. -bool inferImplicitSSAName(OpAsmParser &parser, NamedAttrList &attrs); - -/// Print an implicit SSA name string attribute. If the given string attribute -/// does not match the SSA name of the operation's first result, the name is -/// explicitly printed. Prints a leading space in front of `name` if any name is -/// present. -/// -/// implicit-name ::= (`name` str-attr)? -void printImplicitSSAName(OpAsmPrinter &p, Operation *op, StringAttr attr); - -/// Print an attribute dictionary and elide the `name` field if its value -/// matches the SSA name of the operation's first result. -void printImplicitSSAName(OpAsmPrinter &p, Operation *op, DictionaryAttr attrs, - ArrayRef extraElides = {}); - -/// Check if the `name` attribute in `attrs` matches the SSA name of the -/// operation's first result. If it does, add `name` to `elides`. This is -/// helpful during printing of attribute dictionaries in order to determine if -/// the inclusion of the `name` field would be redundant. -void elideImplicitSSAName(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs, - SmallVectorImpl &elides); - -/// Print/parse binary operands type only when types are different. -/// optional-bin-op-types := type($lhs) (, type($rhs))? -void printOptionalBinaryOpTypes(OpAsmPrinter &p, Operation *op, Type lhs, - Type rhs); -ParseResult parseOptionalBinaryOpTypes(OpAsmParser &parser, Type &lhs, - Type &rhs); - -//===----------------------------------------------------------------------===// -// KeywordBool Custom Directive -//===----------------------------------------------------------------------===// - -/// Parse a boolean as one of two keywords. The `trueKeyword` will result in a -/// true boolean; the `falseKeyword` will result in a false boolean. -/// -/// labeled-bool ::= (true-label | false-label) -ParseResult parseKeywordBool(OpAsmParser &parser, BoolAttr &attr, - StringRef trueKeyword, StringRef falseKeyword); - -/// Print a boolean as one of two keywords. If the boolean is true, the -/// `trueKeyword` is used; if it is false, the `falseKeyword` is used. -/// -/// labeled-bool ::= (true-label | false-label) -void printKeywordBool(OpAsmPrinter &printer, Operation *op, BoolAttr attr, - StringRef trueKeyword, StringRef falseKeyword); - -} // namespace circt - -#endif // CIRCT_SUPPORT_CUSTOMDIRECTIVEIMPL_H diff --git a/include/circt/Support/FieldRef.h b/include/circt/Support/FieldRef.h deleted file mode 100644 index 52204e80dc..0000000000 --- a/include/circt/Support/FieldRef.h +++ /dev/null @@ -1,116 +0,0 @@ -//===- FieldRef.h - Field References ---------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header file defines FieldRefs and helpers for them. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_FIELDREF_H -#define CIRCT_SUPPORT_FIELDREF_H - -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/DenseMapInfo.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project - -namespace circt { - -/// This class represents a reference to a specific field or element of an -/// aggregate value. Typically, the user will assign a unique field ID to each -/// field in an aggregate type by visiting them in a depth-first pre-order. -/// -/// This can be used as the key in a hashtable to store field specific -/// information. -class FieldRef { - public: - /// Get a null FieldRef. - FieldRef() {} - - /// Get a FieldRef location for the specified value. - FieldRef(Value value, unsigned id) : value(value), id(id) {} - - /// Get the Value which created this location. - Value getValue() const { return value; } - - /// Get the operation which defines this field. If the field is a block - /// argument it will return the operation which owns the block. - Operation *getDefiningOp() const; - - /// Get the operation which defines this field and cast it to the OpTy. - /// Returns null if the defining operation is of a different type. - template - OpTy getDefiningOp() const { - return llvm::dyn_cast(getDefiningOp()); - } - - template - bool isa() const { - auto *op = getDefiningOp(); - assert(op && "isa<> used on a null type."); - return ::llvm::isa(op); - } - - /// Get the field ID of this FieldRef, which is a unique identifier mapped to - /// a specific field in a bundle. - unsigned getFieldID() const { return id; } - - /// Get a reference to a subfield. - FieldRef getSubField(unsigned subFieldID) const { - return FieldRef(value, id + subFieldID); - } - - /// Get the location associated with the value of this field ref. - Location getLoc() const { return getValue().getLoc(); } - - bool operator==(const FieldRef &other) const { - return value == other.value && id == other.id; - } - - bool operator<(const FieldRef &other) const { - if (value.getImpl() < other.value.getImpl()) return true; - if (value.getImpl() > other.value.getImpl()) return false; - return id < other.id; - } - - operator bool() const { return bool(value); } - - private: - /// A pointer to the value which created this. - Value value; - - /// A unique field ID. - unsigned id = 0; -}; - -/// Get a hash code for a FieldRef. -inline ::llvm::hash_code hash_value(const FieldRef &fieldRef) { - return llvm::hash_combine(fieldRef.getValue(), fieldRef.getFieldID()); -} - -} // namespace circt - -namespace llvm { -/// Allow using FieldRef with DenseMaps. This hash is based on the Value -/// identity and field ID. -template <> -struct DenseMapInfo { - static inline circt::FieldRef getEmptyKey() { - return circt::FieldRef(DenseMapInfo::getEmptyKey(), 0); - } - static inline circt::FieldRef getTombstoneKey() { - return circt::FieldRef(DenseMapInfo::getTombstoneKey(), 0); - } - static unsigned getHashValue(const circt::FieldRef &val) { - return circt::hash_value(val); - } - static bool isEqual(const circt::FieldRef &lhs, const circt::FieldRef &rhs) { - return lhs == rhs; - } -}; -} // namespace llvm - -#endif // CIRCT_SUPPORT_FIELDREF_H diff --git a/include/circt/Support/FoldUtils.h b/include/circt/Support/FoldUtils.h deleted file mode 100644 index eb40d17d4b..0000000000 --- a/include/circt/Support/FoldUtils.h +++ /dev/null @@ -1,38 +0,0 @@ -//===- FoldUtils.h - Common folder and canonicalizer utilities --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_FOLDUTILS_H -#define CIRCT_SUPPORT_FOLDUTILS_H - -#include "llvm/include/llvm/ADT/APInt.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project - -namespace circt { - -/// Determine the integer value of a constant operand. -static inline std::optional getConstantInt(Attribute operand) { - if (!operand) return {}; - if (auto attr = dyn_cast(operand)) return attr.getValue(); - return {}; -} - -/// Determine whether a constant operand is a zero value. -static inline bool isConstantZero(Attribute operand) { - if (auto cst = getConstantInt(operand)) return cst->isZero(); - return false; -} - -/// Determine whether a constant operand is a one value. -static inline bool isConstantOne(Attribute operand) { - if (auto cst = getConstantInt(operand)) return cst->isOne(); - return false; -} - -} // namespace circt - -#endif // CIRCT_SUPPORT_FOLDUTILS_H diff --git a/include/circt/Support/InstanceGraph.h b/include/circt/Support/InstanceGraph.h deleted file mode 100644 index 56dcd450d2..0000000000 --- a/include/circt/Support/InstanceGraph.h +++ /dev/null @@ -1,453 +0,0 @@ -//===- InstanceGraph.h - Instance graph -------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines a generic instance graph for module- and instance-likes. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_INSTANCEGRAPH_H -#define CIRCT_SUPPORT_INSTANCEGRAPH_H - -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/GraphTraits.h" // from @llvm-project -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/iterator.h" // from @llvm-project -#include "llvm/include/llvm/Support/DOTGraphTraits.h" // from @llvm-project -#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project - -/// The InstanceGraph op interface, see InstanceGraphInterface.td for more -/// details. -#include "include/circt/Support/InstanceGraphInterface.h" - -namespace circt { -namespace igraph { - -namespace detail { -/// This just maps a iterator of references to an iterator of addresses. -template -struct AddressIterator - : public llvm::mapped_iterator { - // This using statement is to get around a bug in MSVC. Without it, it - // tries to look up "It" as a member type of the parent class. - using Iterator = It; - static typename Iterator::value_type *addrOf( - typename Iterator::value_type &v) noexcept { - return std::addressof(v); - } - /* implicit */ AddressIterator(Iterator iterator) - : llvm::mapped_iterator(iterator, - addrOf) {} -}; -} // namespace detail - -class InstanceGraphNode; - -/// This is an edge in the InstanceGraph. This tracks a specific instantiation -/// of a module. -class InstanceRecord - : public llvm::ilist_node_with_parent { - public: - /// Get the instance-like op that this is tracking. - template - auto getInstance() { - if constexpr (std::is_same::value) - return instance; - return dyn_cast_or_null(instance.getOperation()); - } - - /// Get the module where the instantiation lives. - InstanceGraphNode *getParent() const { return parent; } - - /// Get the module which the instance-like is instantiating. - InstanceGraphNode *getTarget() const { return target; } - - /// Erase this instance record, removing it from the parent module and the - /// target's use-list. - void erase(); - - private: - friend class InstanceGraph; - friend class InstanceGraphNode; - - InstanceRecord(InstanceGraphNode *parent, InstanceOpInterface instance, - InstanceGraphNode *target) - : parent(parent), instance(instance), target(target) {} - InstanceRecord(const InstanceRecord &) = delete; - - /// This is the module where the instantiation lives. - InstanceGraphNode *parent; - - /// The InstanceLike that this is tracking. - InstanceOpInterface instance; - - /// This is the module which the instance-like is instantiating. - InstanceGraphNode *target; - /// Intrusive linked list for other uses. - InstanceRecord *nextUse = nullptr; - InstanceRecord *prevUse = nullptr; -}; - -/// This is a Node in the InstanceGraph. Each node represents a Module in a -/// Circuit. Both external modules and regular modules can be represented by -/// this class. It is possible to efficiently iterate all modules instantiated -/// by this module, as well as all instantiations of this module. -class InstanceGraphNode : public llvm::ilist_node { - using InstanceList = llvm::iplist; - - public: - InstanceGraphNode() : module(nullptr) {} - - /// Get the module that this node is tracking. - template - auto getModule() { - if constexpr (std::is_same::value) - return module; - return cast(module.getOperation()); - } - - /// Iterate the instance records in this module. - using iterator = detail::AddressIterator; - iterator begin() { return instances.begin(); } - iterator end() { return instances.end(); } - - /// Return true if there are no more instances of this module. - bool noUses() { return !firstUse; } - - /// Return true if this module has exactly one use. - bool hasOneUse() { return llvm::hasSingleElement(uses()); } - - /// Get the number of direct instantiations of this module. - size_t getNumUses() { return std::distance(usesBegin(), usesEnd()); } - - /// Iterator for module uses. - struct UseIterator - : public llvm::iterator_facade_base< - UseIterator, std::forward_iterator_tag, InstanceRecord *> { - UseIterator() : current(nullptr) {} - UseIterator(InstanceGraphNode *node) : current(node->firstUse) {} - InstanceRecord *operator*() const { return current; } - using llvm::iterator_facade_base::operator++; - UseIterator &operator++() { - assert(current && "incrementing past end"); - current = current->nextUse; - return *this; - } - bool operator==(const UseIterator &other) const { - return current == other.current; - } - - private: - InstanceRecord *current; - }; - - /// Iterate the instance records which instantiate this module. - UseIterator usesBegin() { return {this}; } - UseIterator usesEnd() { return {}; } - llvm::iterator_range uses() { - return llvm::make_range(usesBegin(), usesEnd()); - } - - /// Record a new instance op in the body of this module. Returns a newly - /// allocated InstanceRecord which will be owned by this node. - InstanceRecord *addInstance(InstanceOpInterface instance, - InstanceGraphNode *target); - - private: - friend class InstanceRecord; - - InstanceGraphNode(const InstanceGraphNode &) = delete; - - /// Record that a module instantiates this module. - void recordUse(InstanceRecord *record); - - /// The module. - ModuleOpInterface module; - - /// List of instance operations in this module. This member owns the - /// InstanceRecords, which may be pointed to by other InstanceGraphNode's use - /// lists. - InstanceList instances; - - /// List of instances which instantiate this module. - InstanceRecord *firstUse = nullptr; - - // Provide access to the constructor. - friend class InstanceGraph; -}; - -/// This graph tracks modules and where they are instantiated. This is intended -/// to be used as a cached analysis on circuits. This class can be used -/// to walk the modules efficiently in a bottom-up or top-down order. -/// -/// To use this class, retrieve a cached copy from the analysis manager: -/// auto &instanceGraph = getAnalysis(getOperation()); -class InstanceGraph { - /// This is the list of InstanceGraphNodes in the graph. - using NodeList = llvm::iplist; - - public: - /// Create a new module graph of a circuit. Must be called on the parent - /// operation of ModuleOpInterface ops. - InstanceGraph(Operation *parent); - InstanceGraph(const InstanceGraph &) = delete; - virtual ~InstanceGraph() = default; - - /// Look up an InstanceGraphNode for a module. - InstanceGraphNode *lookup(ModuleOpInterface op); - - /// Lookup an module by name. - InstanceGraphNode *lookup(StringAttr name); - - /// Lookup an InstanceGraphNode for a module. - InstanceGraphNode *operator[](ModuleOpInterface op) { return lookup(op); } - - /// Look up the referenced module from an InstanceOp. This will use a - /// hashtable lookup to find the module, where - /// InstanceOp.getReferencedModule() will be a linear search through the IR. - template - auto getReferencedModule(InstanceOpInterface op) { - return cast(getReferencedModuleImpl(op).getOperation()); - } - - /// Check if child is instantiated by a parent. - bool isAncestor(ModuleOpInterface child, ModuleOpInterface parent); - - /// Get the node corresponding to the top-level module of a circuit. - virtual InstanceGraphNode *getTopLevelNode() { return nullptr; } - - /// Get the nodes corresponding to the inferred top-level modules of a - /// circuit. - FailureOr> getInferredTopLevelNodes(); - - /// Return the parent under which all nodes are nested. - Operation *getParent() { return parent; } - - /// Returns pointer to member of operation list. - static NodeList InstanceGraph::*getSublistAccess(Operation *) { - return &InstanceGraph::nodes; - } - - /// Iterate through all modules. - using iterator = detail::AddressIterator; - iterator begin() { return nodes.begin(); } - iterator end() { return nodes.end(); } - - //===------------------------------------------------------------------------- - // Methods to keep an InstanceGraph up to date. - // - // These methods are not thread safe. Make sure that modifications are - // properly synchronized or performed in a serial context. When the - // InstanceGraph is used as an analysis, this is only safe when the pass is - // on a CircuitOp or a ModuleOp. - - /// Add a newly created module to the instance graph. - virtual InstanceGraphNode *addModule(ModuleOpInterface module); - - /// Remove this module from the instance graph. This will also remove all - /// InstanceRecords in this module. All instances of this module must have - /// been removed from the graph. - virtual void erase(InstanceGraphNode *node); - - /// Replaces an instance of a module with another instance. The target module - /// of both InstanceOps must be the same. - virtual void replaceInstance(InstanceOpInterface inst, - InstanceOpInterface newInst); - - protected: - ModuleOpInterface getReferencedModuleImpl(InstanceOpInterface op); - - /// Get the node corresponding to the module. If the node has does not exist - /// yet, it will be created. - InstanceGraphNode *getOrAddNode(StringAttr name); - - /// The node under which all modules are nested. - Operation *parent; - - /// The storage for graph nodes, with deterministic iteration. - NodeList nodes; - - /// This maps each operation to its graph node. - llvm::DenseMap nodeMap; - - /// A caching of the inferred top level module(s). - llvm::SmallVector inferredTopLevelNodes; -}; - -struct InstancePathCache; - -/** - * An instance path composed of a series of instances. - */ -class InstancePath final { - public: - InstancePath() = default; - - InstanceOpInterface top() const { - assert(!empty() && "instance path is empty"); - return path[0]; - } - - InstanceOpInterface leaf() const { - assert(!empty() && "instance path is empty"); - return path.back(); - } - - InstancePath dropFront() const { return InstancePath(path.drop_front()); } - - InstanceOpInterface operator[](size_t idx) const { return path[idx]; } - ArrayRef::iterator begin() const { return path.begin(); } - ArrayRef::iterator end() const { return path.end(); } - size_t size() const { return path.size(); } - bool empty() const { return path.empty(); } - - /// Print the path to any stream-like object. - template - void print(T &into) const { - into << "$root"; - for (auto inst : path) - into << "/" << inst.getInstanceName() << ":" - << inst.getReferencedModuleName(); - } - - private: - // Only the path cache is allowed to create paths. - friend struct InstancePathCache; - InstancePath(ArrayRef path) : path(path) {} - - ArrayRef path; -}; - -template -static T &operator<<(T &os, const InstancePath &path) { - return path.print(os); -} - -/// A data structure that caches and provides absolute paths to module instances -/// in the IR. -struct InstancePathCache { - /// The instance graph of the IR. - InstanceGraph &instanceGraph; - - explicit InstancePathCache(InstanceGraph &instanceGraph) - : instanceGraph(instanceGraph) {} - ArrayRef getAbsolutePaths(ModuleOpInterface op); - - /// Replace an InstanceOp. This is required to keep the cache updated. - void replaceInstance(InstanceOpInterface oldOp, InstanceOpInterface newOp); - - /// Append an instance to a path. - InstancePath appendInstance(InstancePath path, InstanceOpInterface inst); - - /// Prepend an instance to a path. - InstancePath prependInstance(InstanceOpInterface inst, InstancePath path); - - private: - /// An allocator for individual instance paths and entire path lists. - llvm::BumpPtrAllocator allocator; - - /// Cached absolute instance paths. - DenseMap> absolutePathsCache; -}; - -} // namespace igraph -} // namespace circt - -// Graph traits for modules. -template <> -struct llvm::GraphTraits { - using NodeType = circt::igraph::InstanceGraphNode; - using NodeRef = NodeType *; - - // Helper for getting the module referenced by the instance op. - static NodeRef getChild(const circt::igraph::InstanceRecord *record) { - return record->getTarget(); - } - - using ChildIteratorType = - llvm::mapped_iterator; - - static NodeRef getEntryNode(NodeRef node) { return node; } - static ChildIteratorType child_begin(NodeRef node) { - return {node->begin(), &getChild}; - } - static ChildIteratorType child_end(NodeRef node) { - return {node->end(), &getChild}; - } -}; - -// Provide graph traits for iterating the modules in inverse order. -template <> -struct llvm::GraphTraits> { - using NodeType = circt::igraph::InstanceGraphNode; - using NodeRef = NodeType *; - - // Helper for getting the module containing the instance op. - static NodeRef getParent(const circt::igraph::InstanceRecord *record) { - return record->getParent(); - } - - using ChildIteratorType = - llvm::mapped_iterator; - - static NodeRef getEntryNode(Inverse inverse) { - return inverse.Graph; - } - static ChildIteratorType child_begin(NodeRef node) { - return {node->usesBegin(), &getParent}; - } - static ChildIteratorType child_end(NodeRef node) { - return {node->usesEnd(), &getParent}; - } -}; - -// Graph traits for the common instance graph. -template <> -struct llvm::GraphTraits - : public llvm::GraphTraits { - using nodes_iterator = circt::igraph::InstanceGraph::iterator; - - static NodeRef getEntryNode(circt::igraph::InstanceGraph *graph) { - return graph->getTopLevelNode(); - } - // NOLINTNEXTLINE(readability-identifier-naming) - static nodes_iterator nodes_begin(circt::igraph::InstanceGraph *graph) { - return graph->begin(); - } - // NOLINTNEXTLINE(readability-identifier-naming) - static nodes_iterator nodes_end(circt::igraph::InstanceGraph *graph) { - return graph->end(); - } -}; - -// Graph traits for DOT labeling. -template <> -struct llvm::DOTGraphTraits - : public llvm::DefaultDOTGraphTraits { - using DefaultDOTGraphTraits::DefaultDOTGraphTraits; - - static std::string getNodeLabel(circt::igraph::InstanceGraphNode *node, - circt::igraph::InstanceGraph *) { - // The name of the graph node is the module name. - return node->getModule().getModuleName().str(); - } - - template - static std::string getEdgeAttributes( - const circt::igraph::InstanceGraphNode *node, Iterator it, - circt::igraph::InstanceGraph *) { - // Set an edge label that is the name of the instance. - auto *instanceRecord = *it.getCurrent(); - auto instanceOp = instanceRecord->getInstance(); - return ("label=" + instanceOp.getInstanceName()).str(); - } -}; - -#endif // CIRCT_SUPPORT_INSTANCEGRAPH_H diff --git a/include/circt/Support/InstanceGraphInterface.h b/include/circt/Support/InstanceGraphInterface.h deleted file mode 100644 index 77e065760d..0000000000 --- a/include/circt/Support/InstanceGraphInterface.h +++ /dev/null @@ -1,23 +0,0 @@ -//===- InstanceGraphInterface.h - Instance graph interface ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares stuff related to the instance graph interface. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_INSTANCEGRAPHINTERFACE_H -#define CIRCT_SUPPORT_INSTANCEGRAPHINTERFACE_H - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project - -/// The InstanceGraph op interface, see InstanceGraphInterface.td for more -/// details. -#include "include/circt/Support/InstanceGraphInterface.h.inc" - -#endif // CIRCT_SUPPORT_INSTANCEGRAPHINTERFACE_H diff --git a/include/circt/Support/InstanceGraphInterface.td b/include/circt/Support/InstanceGraphInterface.td deleted file mode 100644 index 261dd15709..0000000000 --- a/include/circt/Support/InstanceGraphInterface.td +++ /dev/null @@ -1,74 +0,0 @@ -//===- InstanceGraphInterface.td - Interface for instance graphs --------*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains interfaces and other utilities for interacting with the -// generic CIRCT instance graph. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_INSTANCEGRAPH_INSTANCEGRAPHINTERFACE_TD -#define CIRCT_SUPPORT_INSTANCEGRAPH_INSTANCEGRAPHINTERFACE_TD - -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/OpBase.td" - -def InstanceGraphInstanceOpInterface : OpInterface<"InstanceOpInterface"> { - let description = [{ - This interface provides hooks for an instance-like operation. - }]; - let cppNamespace = "::circt::igraph"; - - let methods = [ - InterfaceMethod<"Get the name of the instance", - "::llvm::StringRef", "getInstanceName", (ins)>, - - InterfaceMethod<"Get the name of the instance", - "::mlir::StringAttr", "getInstanceNameAttr", (ins)>, - - InterfaceMethod<"Get the name of the instantiated module", - "::llvm::StringRef", "getReferencedModuleName", (ins), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return $_op.getModuleName(); }]>, - - InterfaceMethod<"Get the name of the instantiated module", - "::mlir::StringAttr", "getReferencedModuleNameAttr", (ins), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return $_op.getModuleNameAttr().getAttr(); }]>, - - InterfaceMethod<[{ - Get the referenced module (slow, unsafe). This function directly accesses - the parent operation to lookup a symbol, which is unsafe in many contexts. - }], - "::mlir::Operation *", "getReferencedModuleSlow", (ins)>, - - InterfaceMethod<"Get the referenced module via a symbol table.", - "::mlir::Operation *", "getReferencedModule", (ins "SymbolTable&":$symtbl)>, - ]; -} - -def InstanceGraphModuleOpInterface : OpInterface<"ModuleOpInterface"> { - let description = [{ - This interface provides hooks for a module-like operation. - }]; - let cppNamespace = "::circt::igraph"; - - let methods = [ - InterfaceMethod<"Get the module name", - "::llvm::StringRef", "getModuleName", (ins), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return $_op.getModuleNameAttr().getValue(); }]>, - - InterfaceMethod<"Get the module name", - "::mlir::StringAttr", "getModuleNameAttr", (ins), - /*methodBody=*/[{}], - /*defaultImplementation=*/[{ return $_op.getNameAttr(); }]>, - ]; - -} - -#endif // CIRCT_SUPPORT_INSTANCEGRAPH_INSTANCEGRAPHINTERFACE_TD diff --git a/include/circt/Support/JSON.h b/include/circt/Support/JSON.h deleted file mode 100644 index 181fd4dc81..0000000000 --- a/include/circt/Support/JSON.h +++ /dev/null @@ -1,30 +0,0 @@ -//===- Json.h - Json Utilities ----------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for JSON-to-Attribute conversion. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_JSON_H -#define CIRCT_SUPPORT_JSON_H - -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/Support/JSON.h" // from @llvm-project - -namespace circt { - -/// Convert a simple attribute to JSON. -LogicalResult convertAttributeToJSON(llvm::json::OStream &json, Attribute attr); - -/// Convert arbitrary JSON to an MLIR Attribute. -Attribute convertJSONToAttribute(MLIRContext *context, llvm::json::Value &value, - llvm::json::Path p); - -} // namespace circt - -#endif // CIRCT_SUPPORT_JSON_H diff --git a/include/circt/Support/LoweringOptions.h b/include/circt/Support/LoweringOptions.h deleted file mode 100644 index b59fab49d1..0000000000 --- a/include/circt/Support/LoweringOptions.h +++ /dev/null @@ -1,169 +0,0 @@ -//===- LoweringOptions.h - CIRCT Lowering Options ---------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Options for controlling the lowering process and verilog exporting. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_LOWERINGOPTIONS_H -#define CIRCT_SUPPORT_LOWERINGOPTIONS_H - -#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project -#include "llvm/include/llvm/ADT/Twine.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project - -namespace mlir { -class ModuleOp; -} - -namespace circt { - -/// Options which control the emission from CIRCT to Verilog. -struct LoweringOptions { - /// Error callback type used to indicate errors parsing the options string. - using ErrorHandlerT = llvm::function_ref; - - /// Create a LoweringOptions with the default values. - LoweringOptions() = default; - - /// Create a LoweringOptions and read in options from a string, - /// overriding only the set options in the string. - LoweringOptions(llvm::StringRef options, ErrorHandlerT errorHandler); - - /// Create a LoweringOptions with values loaded from an MLIR ModuleOp. This - /// loads a string attribute with the key `circt.loweringOptions`. If there is - /// an error parsing the attribute this will print an error using the - /// ModuleOp. - LoweringOptions(mlir::ModuleOp module); - - /// Return the value of the `circt.loweringOptions` in the specified module - /// if present, or a null attribute if not. - static mlir::StringAttr getAttributeFrom(mlir::ModuleOp module); - - /// Read in options from a string, overriding only the set options in the - /// string. - void parse(llvm::StringRef options, ErrorHandlerT callback); - - /// Returns a string representation of the options. - std::string toString() const; - - /// Write the verilog emitter options to a module's attributes. - void setAsAttribute(mlir::ModuleOp module); - - /// Load any emitter options from the module. If there is an error validating - /// the attribute, this will print an error using the ModuleOp. - void parseFromAttribute(mlir::ModuleOp module); - - /// If true, emits `sv.alwayscomb` as Verilog `always @(*)` statements. - /// Otherwise, print them as `always_comb`. - bool noAlwaysComb = false; - - /// If true, expressions are allowed in the sensitivity list of `always` - /// statements, otherwise they are forced to be simple wires. Some EDA - /// tools rely on these being simple wires. - bool allowExprInEventControl = false; - - /// If true, eliminate packed arrays for tools that don't support them (e.g. - /// Yosys). - bool disallowPackedArrays = false; - - /// If true, eliminate packed struct assignments in favor of a wire + - /// assignments to the individual fields. - bool disallowPackedStructAssignments = false; - - /// If true, do not emit SystemVerilog locally scoped "automatic" or logic - /// declarations - emit top level wire and reg's instead. - bool disallowLocalVariables = false; - - /// If true, verification statements like `assert`, `assume`, and `cover` will - /// always be emitted with a label. If the statement has no label in the IR, a - /// generic one will be created. Some EDA tools require verification - /// statements to be labeled. - bool enforceVerifLabels = false; - - /// This is the maximum number of terms in an expression before that - /// expression spills a wire. - enum { DEFAULT_TERM_LIMIT = 256 }; - unsigned maximumNumberOfTermsPerExpression = DEFAULT_TERM_LIMIT; - - /// This is the target width of lines in an emitted Verilog source file in - /// columns. - enum { DEFAULT_LINE_LENGTH = 90 }; - unsigned emittedLineLength = DEFAULT_LINE_LENGTH; - - /// Add an explicit bitcast for avoiding bitwidth mismatch LINT errors. - bool explicitBitcast = false; - - /// If true, replicated ops are emitted to a header file. - bool emitReplicatedOpsToHeader = false; - - /// This option controls emitted location information style. - enum LocationInfoStyle { - Plain, // Default. - WrapInAtSquareBracket, // Wrap location info in @[..]. - None, // No location info comment. - } locationInfoStyle = Plain; - - /// If true, every port is declared separately - /// (each includes direction and type (e.g., `input [3:0]`)). - /// When false (default), ports share declarations when possible. - bool disallowPortDeclSharing = false; - - /// Print debug info. - bool printDebugInfo = false; - - /// If true, every mux expression is spilled to a wire. - bool disallowMuxInlining = false; - - /// This controls extra wire spilling performed in PrepareForEmission to - /// improve readablitiy and debuggability. - enum WireSpillingHeuristic : unsigned { - SpillLargeTermsWithNamehints = 1, // Spill wires for expressions with - // namehints if the term size is greater - // than `wireSpillingNamehintTermLimit`. - }; - - unsigned wireSpillingHeuristicSet = 0; - - bool isWireSpillingHeuristicEnabled(WireSpillingHeuristic heurisic) const { - return static_cast(wireSpillingHeuristicSet & heurisic); - } - - enum { DEFAULT_NAMEHINT_TERM_LIMIT = 3 }; - unsigned wireSpillingNamehintTermLimit = DEFAULT_NAMEHINT_TERM_LIMIT; - - /// If true, every expression passed to an instance port is driven by a wire. - /// Some lint tools dislike expressions being inlined into input ports so this - /// option avoids such warnings. - bool disallowExpressionInliningInPorts = false; - - /// If true, every expression used as an array index is driven by a wire, and - /// the wire is marked as `(* keep = "true" *)`. Certain versions of Vivado - /// produce incorrect synthesis results for certain arithmetic ops inlined - /// into the array index. - bool mitigateVivadoArrayIndexConstPropBug = false; - - /// If true, emit `wire` in port lists rather than nothing. Used in cases - /// where `default_nettype is not set to wire. - bool emitWireInPorts = false; - - /// If true, emit a comment wherever an instance wasn't printed, because - /// it's emitted elsewhere as a bind. - bool emitBindComments = false; - - /// If true, do not emit a version comment at the top of each verilog file. - bool omitVersionComment = false; - - /// If true, then unique names that collide with keywords case insensitively. - /// This is used to avoid stricter lint warnings which, e.g., treat "REG" as a - /// Verilog keyword. - bool caseInsensitiveKeywords = false; -}; -} // namespace circt - -#endif // CIRCT_SUPPORT_LOWERINGOPTIONS_H diff --git a/include/circt/Support/LoweringOptionsParser.h b/include/circt/Support/LoweringOptionsParser.h deleted file mode 100644 index 16e7f2ad46..0000000000 --- a/include/circt/Support/LoweringOptionsParser.h +++ /dev/null @@ -1,57 +0,0 @@ -//===- LoweringOptionsParser.h - CIRCT Lowering Option Parser ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Parser for lowering options. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_LOWERINGOPTIONSPARSER_H -#define CIRCT_SUPPORT_LOWERINGOPTIONSPARSER_H - -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project -#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project - -namespace circt { - -/// Commandline parser for LoweringOptions. Delegates to the parser -/// defined by LoweringOptions. -struct LoweringOptionsParser : public llvm::cl::parser { - LoweringOptionsParser(llvm::cl::Option &option) - : llvm::cl::parser(option) {} - - bool parse(llvm::cl::Option &option, StringRef argName, StringRef argValue, - LoweringOptions &value) { - bool failed = false; - value.parse(argValue, [&](Twine error) { failed = option.error(error); }); - return failed; - } -}; - -struct LoweringOptionsOption - : llvm::cl::opt { - LoweringOptionsOption(llvm::cl::OptionCategory &cat) - : llvm::cl::opt{ - "lowering-options", - llvm::cl::desc( - "Style options. Valid flags include: " - "noAlwaysComb, exprInEventControl, disallowPackedArrays, " - "disallowLocalVariables, verifLabels, emittedLineLength=, " - "maximumNumberOfTermsPerExpression=, " - "explicitBitcast, emitReplicatedOpsToHeader, " - "locationInfoStyle={plain,wrapInAtSquareBracket,none}, " - "disallowPortDeclSharing, printDebugInfo, " - "disallowExpressionInliningInPorts, disallowMuxInlining, " - "emitWireInPort, emitBindComments, omitVersionComment, " - "caseInsensitiveKeywords"), - llvm::cl::cat(cat), llvm::cl::value_desc("option")} {} -}; - -} // namespace circt - -#endif // CIRCT_SUPPORT_LOWERINGOPTIONSPARSER_H diff --git a/include/circt/Support/Namespace.h b/include/circt/Support/Namespace.h deleted file mode 100644 index 53d2a69e8c..0000000000 --- a/include/circt/Support/Namespace.h +++ /dev/null @@ -1,134 +0,0 @@ -//===- Namespace.h - Utilities for generating names -------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides utilities for generating new names that do not conflict -// with existing names. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_NAMESPACE_H -#define CIRCT_SUPPORT_NAMESPACE_H - -#include "include/circt/Support/LLVM.h" -#include "include/circt/Support/SymCache.h" -#include "llvm/include/llvm/ADT/SmallString.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringSet.h" // from @llvm-project -#include "llvm/include/llvm/ADT/Twine.h" // from @llvm-project - -namespace circt { - -/// A namespace that is used to store existing names and generate new names in -/// some scope within the IR. This exists to work around limitations of -/// SymbolTables. This acts as a base class providing facilities common to all -/// namespaces implementations. -class Namespace { - public: - Namespace() {} - Namespace(const Namespace &other) = default; - Namespace(Namespace &&other) : nextIndex(std::move(other.nextIndex)) {} - - Namespace &operator=(const Namespace &other) = default; - Namespace &operator=(Namespace &&other) { - nextIndex = std::move(other.nextIndex); - return *this; - } - - /// SymbolCache initializer; initialize from every key that is convertible to - /// a StringAttr in the SymbolCache. - void add(SymbolCache &symCache) { - for (auto &&[attr, _] : symCache) - if (auto strAttr = attr.dyn_cast()) - nextIndex.insert({strAttr.getValue(), 0}); - } - - /// Empty the namespace. - void clear() { nextIndex.clear(); } - - /// Return a unique name, derived from the input `name`, and add the new name - /// to the internal namespace. There are two possible outcomes for the - /// returned name: - /// - /// 1. The original name is returned. - /// 2. The name is given a `_` suffix where `` is a number starting from - /// `0` and incrementing by one each time (`_0`, ...). - StringRef newName(const Twine &name) { - // Special case the situation where there is no name collision to avoid - // messing with the SmallString allocation below. - llvm::SmallString<64> tryName; - auto inserted = nextIndex.insert({name.toStringRef(tryName), 0}); - if (inserted.second) return inserted.first->getKey(); - - // Try different suffixes until we get a collision-free one. - if (tryName.empty()) - name.toVector(tryName); // toStringRef may leave tryName unfilled - - // Indexes less than nextIndex[tryName] are lready used, so skip them. - // Indexes larger than nextIndex[tryName] may be used in another name. - size_t &i = nextIndex[tryName]; - tryName.push_back('_'); - size_t baseLength = tryName.size(); - do { - tryName.resize(baseLength); - Twine(i++).toVector(tryName); // append integer to tryName - inserted = nextIndex.insert({tryName, 0}); - } while (!inserted.second); - - return inserted.first->getKey(); - } - - /// Return a unique name, derived from the input `name` and ensure the - /// returned name has the input `suffix`. Also add the new name to the - /// internal namespace. - /// There are two possible outcomes for the returned name: - /// 1. The original name + `_` is returned. - /// 2. The name is given a suffix `__` where `` is a number - /// starting from `0` and incrementing by one each time. - StringRef newName(const Twine &name, const Twine &suffix) { - // Special case the situation where there is no name collision to avoid - // messing with the SmallString allocation below. - llvm::SmallString<64> tryName; - auto inserted = nextIndex.insert( - {name.concat("_").concat(suffix).toStringRef(tryName), 0}); - if (inserted.second) return inserted.first->getKey(); - - // Try different suffixes until we get a collision-free one. - tryName.clear(); - name.toVector(tryName); // toStringRef may leave tryName unfilled - tryName.push_back('_'); - size_t baseLength = tryName.size(); - - // Get the initial number to start from. Since `:` is not a valid character - // in a verilog identifier, we use it separate the name and suffix. - // Next number for name+suffix is stored with key `name_:suffix`. - tryName.push_back(':'); - suffix.toVector(tryName); - - // Indexes less than nextIndex[tryName] are already used, so skip them. - // Indexes larger than nextIndex[tryName] may be used in another name. - size_t &i = nextIndex[tryName]; - do { - tryName.resize(baseLength); - Twine(i++).toVector(tryName); // append integer to tryName - tryName.push_back('_'); - suffix.toVector(tryName); - inserted = nextIndex.insert({tryName, 0}); - } while (!inserted.second); - - return inserted.first->getKey(); - } - - protected: - // The "next index" that will be tried when trying to unique a string within a - // namespace. It follows that all values less than the "next index" value are - // already used. - llvm::StringMap nextIndex; -}; - -} // namespace circt - -#endif // CIRCT_SUPPORT_NAMESPACE_H diff --git a/include/circt/Support/ParsingUtils.h b/include/circt/Support/ParsingUtils.h deleted file mode 100644 index b62b680637..0000000000 --- a/include/circt/Support/ParsingUtils.h +++ /dev/null @@ -1,57 +0,0 @@ -//===- ParsingUtils.h - CIRCT parsing common functions ----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities to help with parsing. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_PARSINGUTILS_H -#define CIRCT_SUPPORT_PARSINGUTILS_H - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project - -namespace circt { -namespace parsing_util { - -/// Get a name from an SSA value string, if said value name is not a -/// number. -static inline StringAttr getNameFromSSA(MLIRContext *context, StringRef name) { - if (!name.empty()) { - // Ignore numeric names like %42 - assert(name.size() > 1 && name[0] == '%' && "Unknown MLIR name"); - if (isdigit(name[1])) - name = StringRef(); - else - name = name.drop_front(); - } - return StringAttr::get(context, name); -} - -//===----------------------------------------------------------------------===// -// Initializer lists -//===----------------------------------------------------------------------===// - -/// Parses an initializer. -/// An initializer list is a list of operands, types and names on the format: -/// (%arg = %input : type, ...) -ParseResult parseInitializerList( - mlir::OpAsmParser &parser, - llvm::SmallVector &inputArguments, - llvm::SmallVector &inputOperands, - llvm::SmallVector &inputTypes, ArrayAttr &inputNames); - -// Prints an initializer list. -void printInitializerList(OpAsmPrinter &p, ValueRange ins, - ArrayRef args); - -} // namespace parsing_util -} // namespace circt - -#endif // CIRCT_SUPPORT_PARSINGUTILS_H diff --git a/include/circt/Support/Passes.h b/include/circt/Support/Passes.h deleted file mode 100644 index a76f1153c3..0000000000 --- a/include/circt/Support/Passes.h +++ /dev/null @@ -1,64 +0,0 @@ -//===- Passes.h - Helpers for pipeline instrumentation ----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_PASSES_H -#define CIRCT_SUPPORT_PASSES_H - -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/Support/Chrono.h" // from @llvm-project -#include "llvm/include/llvm/Support/Format.h" // from @llvm-project -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/include/mlir/Pass/PassInstrumentation.h" // from @llvm-project - -namespace circt { -// This class prints logs before and after of pass executions when its pass -// operation is in `LoggedOpTypes`. Note that `runBeforePass` and `runAfterPass` -// are not thread safe so `LoggedOpTypes` must be a set of operations whose -// passes are ran sequentially (e.g. mlir::ModuleOp, firrtl::CircuitOp). -template -class VerbosePassInstrumentation : public mlir::PassInstrumentation { - // This stores start time points of passes. - using TimePoint = llvm::sys::TimePoint<>; - llvm::SmallVector timePoints; - int level = 0; - const char *toolName; - - public: - VerbosePassInstrumentation(const char *toolName) : toolName(toolName){}; - void runBeforePass(Pass *pass, Operation *op) override { - if (isa(op)) { - timePoints.push_back(TimePoint::clock::now()); - auto &os = llvm::errs(); - os << llvm::format("[%s] ", toolName); - os.indent(2 * level++); - os << "Running \""; - pass->printAsTextualPipeline(llvm::errs()); - os << "\"\n"; - } - } - - void runAfterPass(Pass *pass, Operation *op) override { - using namespace std::chrono; - if (isa(op)) { - auto &os = llvm::errs(); - auto elapsed = duration(TimePoint::clock::now() - - timePoints.pop_back_val()) / - seconds(1); - os << llvm::format("[%s] ", toolName); - os.indent(2 * --level); - os << "-- Done in " << llvm::format("%.3f", elapsed) << " sec\n"; - } - } -}; - -/// Create a simple canonicalizer pass. -std::unique_ptr createSimpleCanonicalizerPass(); - -} // namespace circt - -#endif // CIRCT_SUPPORT_PASSES_H diff --git a/include/circt/Support/Path.h b/include/circt/Support/Path.h deleted file mode 100644 index 2de9d9ba5e..0000000000 --- a/include/circt/Support/Path.h +++ /dev/null @@ -1,30 +0,0 @@ -//===- Path.h - Path Utilities ----------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for file system path handling, supplementing the ones from -// llvm::sys::path. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_PATH_H -#define CIRCT_SUPPORT_PATH_H - -#include "include/circt/Support/LLVM.h" - -namespace circt { - -/// Append a path to an existing path, replacing it if the other path is -/// absolute. This mimicks the behaviour of `foo/bar` and `/foo/bar` being used -/// in a working directory `/home`, resulting in `/home/foo/bar` and `/foo/bar`, -/// respectively. -void appendPossiblyAbsolutePath(llvm::SmallVectorImpl &base, - const llvm::Twine &suffix); - -} // namespace circt - -#endif // CIRCT_SUPPORT_PATH_H diff --git a/include/circt/Support/PrettyPrinter.h b/include/circt/Support/PrettyPrinter.h deleted file mode 100644 index 70cf8f2289..0000000000 --- a/include/circt/Support/PrettyPrinter.h +++ /dev/null @@ -1,325 +0,0 @@ -//===- PrettyPrinter.h - Pretty printing ----------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This implements a pretty-printer. -// "PrettyPrinting", Derek C. Oppen, 1980. -// https://dx.doi.org/10.1145/357114.357115 -// -// This was selected as it is linear in number of tokens O(n) and requires -// memory O(linewidth). -// -// See PrettyPrinter.cpp for more information. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_PRETTYPRINTER_H -#define CIRCT_SUPPORT_PRETTYPRINTER_H - -#include -#include -#include - -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project -#include "llvm/include/llvm/Support/ErrorHandling.h" // from @llvm-project -#include "llvm/include/llvm/Support/SaveAndRestore.h" // from @llvm-project - -namespace circt { -namespace pretty { - -//===----------------------------------------------------------------------===// -// Tokens -//===----------------------------------------------------------------------===// - -/// Style of breaking within a group: -/// - Consistent: all fits or all breaks. -/// - Inconsistent: best fit, break where needed. -/// - Never: force no breaking including nested groups. -enum class Breaks { Consistent, Inconsistent, Never }; - -/// Style of indent when starting a group: -/// - Visual: offset is relative to current column. -/// - Block: offset is relative to current base indentation. -enum class IndentStyle { Visual, Block }; - -class Token { - public: - enum class Kind { String, Break, Begin, End, Callback }; - - struct TokenInfo { - Kind kind; // Common initial sequence. - }; - struct StringInfo : public TokenInfo { - uint32_t len; - const char *str; - }; - struct BreakInfo : public TokenInfo { - uint32_t spaces; // How many spaces to emit when not broken. - int32_t offset; // How many spaces to emit when broken. - bool neverbreak; // If set, behaves like break except this always 'fits'. - }; - struct BeginInfo : public TokenInfo { - int32_t offset; // Adjust base indentation by this amount. - Breaks breaks; - IndentStyle style; - }; - struct EndInfo : public TokenInfo { - // Nothing - }; - // This can be used to associate a callback with the print event on the - // tokens stream. Since tokens strictly follow FIFO order on a queue, each - // CallbackToken can uniquely identify the data it is associated with, when it - // is added and popped from the queue. - struct CallbackInfo : public TokenInfo {}; - - private: - union { - TokenInfo info; - StringInfo stringInfo; - BreakInfo breakInfo; - BeginInfo beginInfo; - EndInfo endInfo; - CallbackInfo callbackInfo; - } data; - - protected: - template - static auto &getInfoImpl(T &t) { - if constexpr (k == Kind::String) return t.data.stringInfo; - if constexpr (k == Kind::Break) return t.data.breakInfo; - if constexpr (k == Kind::Begin) return t.data.beginInfo; - if constexpr (k == Kind::End) return t.data.endInfo; - if constexpr (k == Kind::Callback) return t.data.callbackInfo; - llvm_unreachable("unhandled token kind"); - } - - Token(Kind k) { data.info.kind = k; } - - public: - Kind getKind() const { return data.info.kind; } -}; - -/// Helper class to CRTP-derive common functions. -template -struct TokenBase : public Token { - static bool classof(const Token *t) { return t->getKind() == DerivedKind; } - - protected: - TokenBase() : Token(DerivedKind) {} - - using InfoType = std::remove_reference_t), Token &>>; - - InfoType &getInfoMut() { return Token::getInfoImpl(*this); } - - const InfoType &getInfo() const { - return Token::getInfoImpl(*this); - } - - template - void initialize(Args &&...args) { - getInfoMut() = InfoType{{DerivedKind}, args...}; - } -}; - -/// Token types. - -struct StringToken : public TokenBase { - StringToken(llvm::StringRef text) { - assert(text.size() == (uint32_t)text.size()); - initialize((uint32_t)text.size(), text.data()); - } - StringRef text() const { return StringRef(getInfo().str, getInfo().len); } -}; - -struct BreakToken : public TokenBase { - BreakToken(uint32_t spaces = 1, int32_t offset = 0, bool neverbreak = false) { - initialize(spaces, offset, neverbreak); - } - uint32_t spaces() const { return getInfo().spaces; } - int32_t offset() const { return getInfo().offset; } - bool neverbreak() const { return getInfo().neverbreak; } -}; - -struct BeginToken : public TokenBase { - BeginToken(int32_t offset = 2, Breaks breaks = Breaks::Inconsistent, - IndentStyle style = IndentStyle::Visual) { - initialize(offset, breaks, style); - } - int32_t offset() const { return getInfo().offset; } - Breaks breaks() const { return getInfo().breaks; } - IndentStyle style() const { return getInfo().style; } -}; - -struct EndToken : public TokenBase {}; - -struct CallbackToken : public TokenBase { - CallbackToken() = default; -}; - -//===----------------------------------------------------------------------===// -// PrettyPrinter -//===----------------------------------------------------------------------===// - -class PrettyPrinter { - public: - /// Listener to Token storage events. - struct Listener { - virtual ~Listener(); - /// No tokens referencing external memory are present. - virtual void clear(){}; - /// Listener for print event. - virtual void print(){}; - }; - - /// PrettyPrinter for specified stream. - /// - margin: line width. - /// - baseIndent: always indent at least this much (starting 'indent' value). - /// - currentColumn: current column, used to calculate space remaining. - /// - maxStartingIndent: max column indentation starts at, must be >= margin. - PrettyPrinter(llvm::raw_ostream &os, uint32_t margin, uint32_t baseIndent = 0, - uint32_t currentColumn = 0, - uint32_t maxStartingIndent = kInfinity / 4, - Listener *listener = nullptr) - : space(margin - std::max(currentColumn, baseIndent)), - defaultFrame{baseIndent, PrintBreaks::Inconsistent}, - indent(baseIndent), - margin(margin), - maxStartingIndent(std::max(maxStartingIndent, margin)), - os(os), - listener(listener) { - assert(maxStartingIndent < kInfinity / 2); - assert(maxStartingIndent > baseIndent); - assert(margin > currentColumn); - // Ensure first print advances to at least baseIndent. - pendingIndentation = - baseIndent > currentColumn ? baseIndent - currentColumn : 0; - } - ~PrettyPrinter() { eof(); } - - /// Add token for printing. In Oppen, this is "scan". - void add(Token t); - - /// Add a range of tokens. - template - void addTokens(R &&tokens) { - // Don't invoke listener until range processed, we own it now. - { - llvm::SaveAndRestore save(donotClear, true); - for (Token &t : tokens) add(t); - } - // Invoke it now if appropriate. - if (scanStack.empty()) clear(); - } - - void eof(); - - void setListener(Listener *newListener) { listener = newListener; }; - auto *getListener() const { return listener; } - - static constexpr uint32_t kInfinity = (1U << 15) - 1; - - private: - /// Format token with tracked size. - struct FormattedToken { - Token token; /// underlying token - int32_t size; /// calculate size when positive. - }; - - /// Breaking style for a printStack entry. - /// This is "Breaks" values with extra for "Fits". - /// Breaks::Never is "AlwaysFits" here. - enum class PrintBreaks { Consistent, Inconsistent, AlwaysFits, Fits }; - - /// Printing information for active scope, stored in printStack. - struct PrintEntry { - uint32_t offset; - PrintBreaks breaks; - }; - - /// Print out tokens we know sizes for, and drop from token buffer. - void advanceLeft(); - - /// Break encountered, set sizes of begin/breaks in scanStack we now know. - void checkStack(); - - /// Check if there's enough tokens to hit width, if so print. - /// If scan size is wider than line, it's infinity. - void checkStream(); - - /// Print a token, maintaining printStack for context. - void print(const FormattedToken &f); - - /// Clear token buffer, scanStack must be empty. - void clear(); - - /// Reset leftTotal and tokenOffset, rebase size data and scanStack indices. - void rebaseIfNeeded(); - - /// Get current printing frame. - auto &getPrintFrame() { - return printStack.empty() ? defaultFrame : printStack.back(); - } - - /// Characters left on this line. - int32_t space; - - /// Sizes: printed, enqueued - int32_t leftTotal; - int32_t rightTotal; - - /// Unprinted tokens, combination of 'token' and 'size' in Oppen. - std::deque tokens; - /// index of first token, for resolving scanStack entries. - uint32_t tokenOffset = 0; - - /// Stack of begin/break tokens, adjust by tokenOffset to index into tokens. - std::deque scanStack; - - /// Stack of printing contexts (indentation + breaking behavior). - SmallVector printStack; - - /// Printing context when stack is empty. - const PrintEntry defaultFrame; - - /// Number of "AlwaysFits" on print stack. - uint32_t alwaysFits = 0; - - /// Current indentation level - uint32_t indent; - - /// Whitespace to print before next, tracked to avoid trailing whitespace. - uint32_t pendingIndentation; - - /// Target line width. - const uint32_t margin; - - /// Maximum starting indentation level (default=kInfinity/4). - /// Useful to continue indentation past margin while still providing a limit - /// to avoid pathological output and for consumption by tools with limits. - const uint32_t maxStartingIndent; - - /// Output stream. - llvm::raw_ostream &os; - - /// Hook for Token storage events. - Listener *listener = nullptr; - - /// Flag to identify a state when the clear cannot be called. - bool donotClear = false; - - /// Threshold for walking scan state and "rebasing" totals/offsets. - static constexpr decltype(leftTotal) rebaseThreshold = - 1UL << (std::numeric_limits::digits - 3); -}; - -} // end namespace pretty -} // end namespace circt - -#endif // CIRCT_SUPPORT_PRETTYPRINTER_H diff --git a/include/circt/Support/PrettyPrinterHelpers.h b/include/circt/Support/PrettyPrinterHelpers.h deleted file mode 100644 index 49f45105ff..0000000000 --- a/include/circt/Support/PrettyPrinterHelpers.h +++ /dev/null @@ -1,374 +0,0 @@ -//===- PrettyPrinterHelpers.h - Pretty printing helpers -------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Helper classes for using PrettyPrinter. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_PRETTYPRINTERHELPERS_H -#define CIRCT_SUPPORT_PRETTYPRINTERHELPERS_H - -#include - -#include "include/circt/Support/PrettyPrinter.h" -#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/ScopeExit.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallString.h" // from @llvm-project -#include "llvm/include/llvm/Support/Allocator.h" // from @llvm-project -#include "llvm/include/llvm/Support/StringSaver.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project - -namespace circt { -namespace pretty { - -//===----------------------------------------------------------------------===// -// PrettyPrinter standin that buffers tokens until flushed. -//===----------------------------------------------------------------------===// - -/// Buffer tokens for clients that need to adjust things. -struct BufferingPP { - using BufferVec = SmallVectorImpl; - BufferVec &tokens; - bool hasEOF = false; - - BufferingPP(BufferVec &tokens) : tokens(tokens) {} - - void add(Token t) { - assert(!hasEOF); - tokens.push_back(t); - } - - /// Add a range of tokens. - template - void addTokens(R &&newTokens) { - assert(!hasEOF); - llvm::append_range(tokens, newTokens); - } - - /// Buffer a final EOF, no tokens allowed after this. - void eof() { - assert(!hasEOF); - hasEOF = true; - } - - /// Flush buffered tokens to the specified pretty printer. - /// Emit the EOF is one was added. - void flush(PrettyPrinter &pp) { - pp.addTokens(tokens); - tokens.clear(); - if (hasEOF) { - pp.eof(); - hasEOF = false; - } - } -}; - -//===----------------------------------------------------------------------===// -// Convenience Token builders. -//===----------------------------------------------------------------------===// - -namespace detail { -void emitNBSP(unsigned n, llvm::function_ref add); -} // end namespace detail - -/// Add convenience methods for generating pretty-printing tokens. -template -class TokenBuilder { - PPTy &pp; - - public: - TokenBuilder(PPTy &pp) : pp(pp) {} - - //===- Add tokens -------------------------------------------------------===// - - /// Add new token. - template - typename std::enable_if_t> add(Args &&...args) { - pp.add(T(std::forward(args)...)); - } - void addToken(Token t) { pp.add(t); } - - /// End of a stream. - void eof() { pp.eof(); } - - //===- Strings ----------------------------------------------------------===// - - /// Add a literal (with external storage). - void literal(StringRef str) { add(str); } - - /// Add a non-breaking space. - void nbsp() { literal(" "); } - - /// Add multiple non-breaking spaces as a single token. - void nbsp(unsigned n) { - detail::emitNBSP(n, [&](Token t) { addToken(t); }); - } - - //===- Breaks -----------------------------------------------------------===// - - /// Add a 'neverbreak' break. Always 'fits'. - void neverbreak() { add(0, 0, true); } - - /// Add a newline (break too wide to fit, always breaks). - void newline() { add(PrettyPrinter::kInfinity); } - - /// Add breakable spaces. - void spaces(uint32_t n) { add(n); } - - /// Add a breakable space. - void space() { spaces(1); } - - /// Add a break that is zero-wide if not broken. - void zerobreak() { add(0); } - - //===- Groups -----------------------------------------------------------===// - - /// Start a IndentStyle::Block group with specified offset. - void bbox(int32_t offset = 0, Breaks breaks = Breaks::Consistent) { - add(offset, breaks, IndentStyle::Block); - } - - /// Start a consistent group with specified offset. - void cbox(int32_t offset = 0, IndentStyle style = IndentStyle::Visual) { - add(offset, Breaks::Consistent, style); - } - - /// Start an inconsistent group with specified offset. - void ibox(int32_t offset = 0, IndentStyle style = IndentStyle::Visual) { - add(offset, Breaks::Inconsistent, style); - } - - /// Start a group that cannot break, including nested groups. - /// Use sparingly. - void neverbox() { add(0, Breaks::Never); } - - /// End a group. - void end() { add(); } -}; - -/// PrettyPrinter::Listener that saves strings while live. -/// Once they're no longer referenced, memory is reset. -/// Allows differentiating between strings to save and external strings. -class TokenStringSaver : public PrettyPrinter::Listener { - llvm::BumpPtrAllocator alloc; - llvm::StringSaver strings; - - public: - TokenStringSaver() : strings(alloc) {} - - /// Add string, save in storage. - [[nodiscard]] StringRef save(StringRef str) { return strings.save(str); } - - /// PrettyPrinter::Listener::clear -- indicates no external refs. - void clear() override; -}; - -/// Note: Callable class must implement a callable with signature: -/// void (Data) -template -class PrintEventAndStorageListener : public TokenStringSaver { - /// List of all the unique data associated with each callback token. - /// The fact that tokens on a stream can never be printed out of order, - /// ensures that CallbackTokens are always added and invoked in FIFO order, - /// hence no need to record an index into the Data list. - std::queue dataQ; - /// The storage for the callback, as a function object. - CallableTy &callable; - - public: - PrintEventAndStorageListener(CallableTy &c) : callable(c) {} - - /// PrettyPrinter::Listener::print -- indicates all the preceding tokens on - /// the stream have been printed. - /// This is invoked when the CallbackToken is printed. - void print() override { - std::invoke(callable, dataQ.front()); - dataQ.pop(); - } - /// Get a token with the obj data. - CallbackToken getToken(DataTy obj) { - // Insert data onto the list. - dataQ.push(obj); - return CallbackToken(); - } -}; - -//===----------------------------------------------------------------------===// -// Streaming support. -//===----------------------------------------------------------------------===// - -/// Send one of these to TokenStream to add the corresponding token. -/// See TokenBuilder for details of each. -enum class PP { - bbox2, - cbox0, - cbox2, - end, - eof, - ibox0, - ibox2, - nbsp, - neverbox, - neverbreak, - newline, - space, - zerobreak, -}; - -/// String wrapper to indicate string has external storage. -struct PPExtString { - StringRef str; - explicit PPExtString(StringRef str) : str(str) {} -}; - -/// String wrapper to indicate string needs to be saved. -struct PPSaveString { - StringRef str; - explicit PPSaveString(StringRef str) : str(str) {} -}; - -/// Wrap a PrettyPrinter with TokenBuilder features as well as operator<<'s. -/// String behavior: -/// Strings streamed as `const char *` are assumed to have external storage, -/// and StringRef's are saved until no longer needed. -/// Use PPExtString() and PPSaveString() wrappers to specify/override behavior. -template -class TokenStream : public TokenBuilder { - using Base = TokenBuilder; - TokenStringSaver &saver; - - public: - /// Create a TokenStream using the specified PrettyPrinter and StringSaver - /// storage. Strings are saved in `saver`, which is generally the listener in - /// the PrettyPrinter, but may not be (e.g., using BufferingPP). - TokenStream(PPTy &pp, TokenStringSaver &saver) : Base(pp), saver(saver) {} - - /// Add a string literal (external storage). - TokenStream &operator<<(const char *s) { - Base::literal(s); - return *this; - } - /// Add a string token (saved to storage). - TokenStream &operator<<(StringRef s) { - Base::template add(saver.save(s)); - return *this; - } - - /// String has external storage. - TokenStream &operator<<(const PPExtString &str) { - Base::literal(str.str); - return *this; - } - - /// String must be saved. - TokenStream &operator<<(const PPSaveString &str) { - Base::template add(saver.save(str.str)); - return *this; - } - - /// Convenience for inline streaming of builder methods. - TokenStream &operator<<(PP s) { - switch (s) { - case PP::bbox2: - Base::bbox(2); - break; - case PP::cbox0: - Base::cbox(0); - break; - case PP::cbox2: - Base::cbox(2); - break; - case PP::end: - Base::end(); - break; - case PP::eof: - Base::eof(); - break; - case PP::ibox0: - Base::ibox(0); - break; - case PP::ibox2: - Base::ibox(2); - break; - case PP::nbsp: - Base::nbsp(); - break; - case PP::neverbox: - Base::neverbox(); - break; - case PP::neverbreak: - Base::neverbreak(); - break; - case PP::newline: - Base::newline(); - break; - case PP::space: - Base::space(); - break; - case PP::zerobreak: - Base::zerobreak(); - break; - } - return *this; - } - - /// Stream support for user-created Token's. - TokenStream &operator<<(Token t) { - Base::addToken(t); - return *this; - } - - /// General-purpose "format this" helper, for types not supported by - /// operator<< yet. - template - TokenStream &addAsString(T &&t) { - invokeWithStringOS([&](auto &os) { os << std::forward(t); }); - return *this; - } - - /// Helper to invoke code with a llvm::raw_ostream argument for compatibility. - /// All data is gathered into a single string token. - template - auto invokeWithStringOS(Callable &&c) { - SmallString ss; - llvm::raw_svector_ostream ssos(ss); - auto flush = llvm::make_scope_exit([&]() { - if (!ss.empty()) *this << ss; - }); - return std::invoke(std::forward(c), ssos); - } - - /// Write escaped versions of the string, saved in storage. - TokenStream &writeEscaped(StringRef str, bool useHexEscapes = false) { - return writeQuotedEscaped(str, useHexEscapes, "", ""); - } - TokenStream &writeQuotedEscaped(StringRef str, bool useHexEscapes = false, - StringRef left = "\"", - StringRef right = "\"") { - // Add as a single StringToken. - invokeWithStringOS([&](auto &os) { - os << left; - os.write_escaped(str, useHexEscapes); - os << right; - }); - return *this; - } - - /// Open a box, invoke the lambda, and close it after. - template - auto scopedBox(T &&t, Callable &&c, Token close = EndToken()) { - *this << std::forward(t); - auto done = llvm::make_scope_exit([&]() { *this << close; }); - return std::invoke(std::forward(c)); - } -}; - -} // end namespace pretty -} // end namespace circt - -#endif // CIRCT_SUPPORT_PRETTYPRINTERHELPERS_H diff --git a/include/circt/Support/SymCache.h b/include/circt/Support/SymCache.h deleted file mode 100644 index 3be7798f4c..0000000000 --- a/include/circt/Support/SymCache.h +++ /dev/null @@ -1,132 +0,0 @@ -//===- SymCache.h - Declare Symbol Cache ------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file declares a Symbol Cache. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_SYMCACHE_H -#define CIRCT_SUPPORT_SYMCACHE_H - -#include "llvm/include/llvm/ADT/iterator.h" // from @llvm-project -#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project -#include "mlir/include/mlir/IR/SymbolTable.h" // from @llvm-project - -namespace circt { - -/// Base symbol cache class to allow for cache lookup through a pointer to some -/// abstract cache. A symbol cache stores lookup tables to make manipulating and -/// working with the IR more efficient. -class SymbolCacheBase { - public: - virtual ~SymbolCacheBase(); - - /// Defines 'op' as associated with the 'symbol' in the cache. - virtual void addDefinition(mlir::Attribute symbol, mlir::Operation *op) = 0; - - /// Adds the symbol-defining 'op' to the cache. - void addSymbol(mlir::SymbolOpInterface op) { - addDefinition(op.getNameAttr(), op); - } - - /// Populate the symbol cache with all symbol-defining operations within the - /// 'top' operation. - void addDefinitions(mlir::Operation *top); - - /// Lookup a definition for 'symbol' in the cache. - virtual mlir::Operation *getDefinition(mlir::Attribute symbol) const = 0; - - /// Lookup a definition for 'symbol' in the cache. - mlir::Operation *getDefinition(mlir::FlatSymbolRefAttr symbol) const { - return getDefinition(symbol.getAttr()); - } - - /// Iterator support through a pointer to some abstract cache. - /// The implementing cache must provide an iterator that carries values on the - /// form of . - using CacheItem = std::pair; - struct CacheIteratorImpl { - virtual ~CacheIteratorImpl() {} - virtual void operator++() = 0; - virtual CacheItem operator*() = 0; - virtual bool operator==(CacheIteratorImpl *other) = 0; - }; - - struct Iterator - : public llvm::iterator_facade_base { - Iterator(std::unique_ptr &&impl) - : impl(std::move(impl)) {} - CacheItem operator*() const { return **impl; } - using llvm::iterator_facade_base::operator++; - bool operator==(const Iterator &other) const { - return *impl == other.impl.get(); - } - void operator++() { impl->operator++(); } - - private: - std::unique_ptr impl; - }; - virtual Iterator begin() = 0; - virtual Iterator end() = 0; -}; - -/// Default symbol cache implementation; stores associations between names -/// (StringAttr's) to mlir::Operation's. -/// Adding/getting definitions from the symbol cache is not -/// thread safe. If this is required, synchronizing cache acccess should be -/// ensured by the caller. -class SymbolCache : public SymbolCacheBase { - public: - /// In the building phase, add symbols. - void addDefinition(mlir::Attribute key, mlir::Operation *op) override { - symbolCache.try_emplace(key, op); - } - - // Pull in getDefinition(mlir::FlatSymbolRefAttr symbol) - using SymbolCacheBase::getDefinition; - mlir::Operation *getDefinition(mlir::Attribute attr) const override { - auto it = symbolCache.find(attr); - if (it == symbolCache.end()) return nullptr; - return it->second; - } - - protected: - /// This stores a lookup table from symbol attribute to the operation - /// that defines it. - llvm::DenseMap symbolCache; - - private: - /// Iterator support: A simple mapping between decltype(symbolCache)::iterator - /// to SymbolCacheBase::Iterator. - using Iterator = decltype(symbolCache)::iterator; - struct SymbolCacheIteratorImpl : public CacheIteratorImpl { - SymbolCacheIteratorImpl(Iterator it) : it(it) {} - CacheItem operator*() override { return {it->getFirst(), it->getSecond()}; } - void operator++() override { it++; } - bool operator==(CacheIteratorImpl *other) override { - return it == static_cast(other)->it; - } - Iterator it; - }; - - public: - SymbolCacheBase::Iterator begin() override { - return SymbolCacheBase::Iterator( - std::make_unique(symbolCache.begin())); - } - SymbolCacheBase::Iterator end() override { - return SymbolCacheBase::Iterator( - std::make_unique(symbolCache.end())); - } -}; - -} // namespace circt - -#endif // CIRCT_SUPPORT_SYMCACHE_H diff --git a/include/circt/Support/ValueMapper.h b/include/circt/Support/ValueMapper.h deleted file mode 100644 index 3b253bd7cf..0000000000 --- a/include/circt/Support/ValueMapper.h +++ /dev/null @@ -1,63 +0,0 @@ -//===- ValueMapper.h - Support for mapping SSA values -----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides support for mapping SSA values between two domains. -// Provided a BackedgeBuilder, the ValueMapper supports mappings between -// GraphRegions, creating Backedges in cases of 'get'ing mapped values which are -// yet to be 'set'. -// -//===----------------------------------------------------------------------===// - -#ifndef CIRCT_SUPPORT_VALUEMAPPER_H -#define CIRCT_SUPPORT_VALUEMAPPER_H - -#include -#include - -#include "include/circt/Support/BackedgeBuilder.h" -#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project -#include "mlir/include/mlir/IR/Location.h" // from @llvm-project -#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project - -namespace circt { - -/// The ValueMapper class facilitates the definition and connection of SSA -/// def-use chains between two location - a 'from' location (defining -/// use-def chains) and a 'to' location (where new operations are created based -/// on the 'from' location).´ -class ValueMapper { - public: - using TypeTransformer = llvm::function_ref; - static mlir::Type identity(mlir::Type t) { return t; }; - explicit ValueMapper(BackedgeBuilder *bb = nullptr) : bb(bb) {} - - // Get the mapped value of value 'from'. If no mapping has been registered, a - // new backedge is created. The type of the mapped value may optionally be - // modified through the 'typeTransformer'. - mlir::Value get(mlir::Value from, - TypeTransformer typeTransformer = ValueMapper::identity); - llvm::SmallVector get( - mlir::ValueRange from, - TypeTransformer typeTransformer = ValueMapper::identity); - - // Set the mapped value of 'from' to 'to'. If 'from' is already mapped to a - // backedge, replaces that backedge with 'to'. If 'replace' is not set, and a - // (non-backedge) mapping already exists, an assert is thrown. - void set(mlir::Value from, mlir::Value to, bool replace = false); - void set(mlir::ValueRange from, mlir::ValueRange to, bool replace = false); - - private: - BackedgeBuilder *bb = nullptr; - llvm::DenseMap> mapping; -}; - -} // namespace circt - -#endif // CIRCT_SUPPORT_VALUEMAPPER_H diff --git a/include/circt/Support/Version.h b/include/circt/Support/Version.h deleted file mode 100644 index 206f0688e1..0000000000 --- a/include/circt/Support/Version.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef CIRCT_SUPPORT_VERSION_H -#define CIRCT_SUPPORT_VERSION_H - -#include - -namespace circt { -const char *getCirctVersion(); -const char *getCirctVersionComment(); - -/// A generic bug report message for CIRCT-related projects -constexpr const char *circtBugReportMsg = - "PLEASE submit a bug report to https://github.com/llvm/circt and include " - "the crash backtrace.\n"; -} // namespace circt - -#endif // CIRCT_SUPPORT_VERSION_H diff --git a/lib/circt/Dialect/Comb/BUILD b/lib/circt/Dialect/Comb/BUILD index bd5fb37ed1..a5c574392d 100644 --- a/lib/circt/Dialect/Comb/BUILD +++ b/lib/circt/Dialect/Comb/BUILD @@ -5,20 +5,19 @@ package( cc_library( name = "Dialect", - srcs = glob([ - "*.cpp", - ]), + srcs = [ + "CombDialect.cpp", + "CombOps.cpp", + ], hdrs = [ "@heir//include/circt/Dialect/Comb:CombDialect.h", "@heir//include/circt/Dialect/Comb:CombOps.h", - "@heir//include/circt/Dialect/Comb:CombVisitors.h", ], deps = [ "@heir//include/circt/Dialect/Comb:dialect_inc_gen", "@heir//include/circt/Dialect/Comb:enum_inc_gen", "@heir//include/circt/Dialect/Comb:ops_inc_gen", "@heir//include/circt/Dialect/Comb:type_inc_gen", - "@heir//lib/circt/Dialect/HW:Dialect", "@heir//lib/circt/Support", "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", diff --git a/lib/circt/Dialect/Comb/CMakeLists.txt b/lib/circt/Dialect/Comb/CMakeLists.txt deleted file mode 100644 index 8034eff9b7..0000000000 --- a/lib/circt/Dialect/Comb/CMakeLists.txt +++ /dev/null @@ -1,26 +0,0 @@ -add_circt_dialect_library(CIRCTComb - CombFolds.cpp - CombOps.cpp - CombAnalysis.cpp - CombDialect.cpp - - ADDITIONAL_HEADER_DIRS - ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/Comb - - DEPENDS - CIRCTHW - MLIRCombIncGen - MLIRCombEnumsIncGen - - LINK_COMPONENTS - Support - - LINK_LIBS PUBLIC - CIRCTHW - MLIRIR - MLIRInferTypeOpInterface - ) - -add_dependencies(circt-headers MLIRCombIncGen MLIRCombEnumsIncGen) - -add_subdirectory(Transforms) diff --git a/lib/circt/Dialect/Comb/CombAnalysis.cpp b/lib/circt/Dialect/Comb/CombAnalysis.cpp deleted file mode 100644 index 328a9ac294..0000000000 --- a/lib/circt/Dialect/Comb/CombAnalysis.cpp +++ /dev/null @@ -1,87 +0,0 @@ -//===- CombAnalysis.cpp - Analysis Helpers for Comb+HW operations ---------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "llvm/include/llvm/Support/KnownBits.h" // from @llvm-project - -using namespace circt; -using namespace comb; - -/// Given an integer SSA value, check to see if we know anything about the -/// result of the computation. For example, we know that "and with a constant" -/// always returns zeros for the zero bits in a constant. -/// -/// Expression trees can be very large, so we need ot make sure to cap our -/// recursion, this is controlled by `depth`. -static KnownBits computeKnownBits(Value v, unsigned depth) { - Operation *op = v.getDefiningOp(); - if (!op || depth == 5) return KnownBits(v.getType().getIntOrFloatBitWidth()); - - // A constant has all bits known! - if (auto constant = dyn_cast(op)) - return KnownBits::makeConstant(constant.getValue()); - - // `concat(x, y, z)` has whatever is known about the operands concat'd. - if (auto concatOp = dyn_cast(op)) { - auto result = computeKnownBits(concatOp.getOperand(0), depth + 1); - for (size_t i = 1, e = concatOp.getNumOperands(); i != e; ++i) { - auto otherBits = computeKnownBits(concatOp.getOperand(i), depth + 1); - unsigned width = otherBits.getBitWidth(); - unsigned newWidth = result.getBitWidth() + width; - result.Zero = - (result.Zero.zext(newWidth) << width) | otherBits.Zero.zext(newWidth); - result.One = - (result.One.zext(newWidth) << width) | otherBits.One.zext(newWidth); - } - return result; - } - - // `and(x, y, z)` has whatever is known about the operands intersected. - if (auto andOp = dyn_cast(op)) { - auto result = computeKnownBits(andOp.getOperand(0), depth + 1); - for (size_t i = 1, e = andOp.getNumOperands(); i != e; ++i) - result &= computeKnownBits(andOp.getOperand(i), depth + 1); - return result; - } - - // `or(x, y, z)` has whatever is known about the operands unioned. - if (auto orOp = dyn_cast(op)) { - auto result = computeKnownBits(orOp.getOperand(0), depth + 1); - for (size_t i = 1, e = orOp.getNumOperands(); i != e; ++i) - result |= computeKnownBits(orOp.getOperand(i), depth + 1); - return result; - } - - // `xor(x, cst)` inverts known bits and passes through unmodified ones. - if (auto xorOp = dyn_cast(op)) { - auto result = computeKnownBits(xorOp.getOperand(0), depth + 1); - for (size_t i = 1, e = xorOp.getNumOperands(); i != e; ++i) { - // If we don't know anything, we don't need to evaluate more subexprs. - if (result.isUnknown()) return result; - result ^= computeKnownBits(xorOp.getOperand(i), depth + 1); - } - return result; - } - - // `mux(cond, x, y)` is the intersection of the known bits of `x` and `y`. - if (auto muxOp = dyn_cast(op)) { - auto lhs = computeKnownBits(muxOp.getTrueValue(), depth + 1); - auto rhs = computeKnownBits(muxOp.getFalseValue(), depth + 1); - return lhs.intersectWith(rhs); - } - - return KnownBits(v.getType().getIntOrFloatBitWidth()); -} - -/// Given an integer SSA value, check to see if we know anything about the -/// result of the computation. For example, we know that "and with a -/// constant" always returns zeros for the zero bits in a constant. -KnownBits comb::computeKnownBits(Value value) { - return ::computeKnownBits(value, 0); -} diff --git a/lib/circt/Dialect/Comb/CombDialect.cpp b/lib/circt/Dialect/Comb/CombDialect.cpp index c1291aca3e..b0e24848b6 100644 --- a/lib/circt/Dialect/Comb/CombDialect.cpp +++ b/lib/circt/Dialect/Comb/CombDialect.cpp @@ -13,7 +13,6 @@ #include "include/circt/Dialect/Comb/CombDialect.h" #include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/HW/HWOps.h" #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project @@ -40,22 +39,23 @@ void CombDialect::initialize() { /// like, i.e. single result, zero operands, non side-effecting, etc. On /// success, this hook should return the value generated to represent the /// constant value. Otherwise, it should return null on failure. -Operation *CombDialect::materializeConstant(OpBuilder &builder, Attribute value, - Type type, Location loc) { - // Integer constants. - if (auto intType = type.dyn_cast()) - if (auto attrValue = value.dyn_cast()) - return builder.create(loc, type, attrValue); - - // Parameter expressions materialize into hw.param.value. - auto parentOp = builder.getBlock()->getParentOp(); - auto curModule = dyn_cast(parentOp); - if (!curModule) curModule = parentOp->getParentOfType(); - if (curModule && isValidParameterExpression(value, curModule)) - return builder.create(loc, type, value); - - return nullptr; -} +// Operation *CombDialect::materializeConstant(OpBuilder &builder, Attribute +// value, +// Type type, Location loc) { +// // Integer constants. +// if (auto intType = type.dyn_cast()) +// if (auto attrValue = value.dyn_cast()) +// return builder.create(loc, type, attrValue); + +// // Parameter expressions materialize into hw.param.value. +// auto parentOp = builder.getBlock()->getParentOp(); +// auto curModule = dyn_cast(parentOp); +// if (!curModule) curModule = parentOp->getParentOfType(); +// if (curModule && isValidParameterExpression(value, curModule)) +// return builder.create(loc, type, value); + +// return nullptr; +// } // Provide implementations for the enums we use. #include "include/circt/Dialect/Comb/CombDialect.cpp.inc" diff --git a/lib/circt/Dialect/Comb/CombFolds.cpp b/lib/circt/Dialect/Comb/CombFolds.cpp deleted file mode 100644 index bd33d6602f..0000000000 --- a/lib/circt/Dialect/Comb/CombFolds.cpp +++ /dev/null @@ -1,3026 +0,0 @@ -//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/HW/HWAttributes.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "llvm/include/llvm/ADT/SetVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallBitVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/KnownBits.h" // from @llvm-project -#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project - -using namespace mlir; -using namespace circt; -using namespace comb; -using namespace matchers; - -/// Create a new instance of a generic operation that only has value operands, -/// and has a single result value whose type matches the first operand. -/// -/// This should not be used to create instances of ops with attributes or with -/// more complicated type signatures. -static Value createGenericOp(Location loc, OperationName name, - ArrayRef operands, OpBuilder &builder) { - OperationState state(loc, name); - state.addOperands(operands); - state.addTypes(operands[0].getType()); - return builder.create(state)->getResult(0); -} - -static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) { - return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()), - value); -} - -/// Flatten concat and mux operands into a vector. -static void getConcatOperands(Value v, SmallVectorImpl &result) { - if (auto concat = v.getDefiningOp()) { - for (auto op : concat.getOperands()) getConcatOperands(op, result); - } else if (auto repl = v.getDefiningOp()) { - for (size_t i = 0, e = repl.getMultiple(); i != e; ++i) - getConcatOperands(repl.getOperand(), result); - } else { - result.push_back(v); - } -} - -/// A wrapper of `PatternRewriter::replaceOp` to propagate "sv.namehint" -/// attribute. If a replaced op has a "sv.namehint" attribute, this function -/// propagates the name to the new value. -static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, - Value newValue) { - if (auto *newOp = newValue.getDefiningOp()) { - auto name = op->getAttrOfType("sv.namehint"); - if (name && !newOp->hasAttr("sv.namehint")) - rewriter.updateRootInPlace(newOp, - [&] { newOp->setAttr("sv.namehint", name); }); - } - rewriter.replaceOp(op, newValue); -} - -/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate -/// "sv.namehint" attribute. If a replaced op has a "sv.namehint" attribute, -/// this function propagates the name to the new value. -template -static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, - Operation *op, Args &&...args) { - auto name = op->getAttrOfType("sv.namehint"); - auto newOp = - rewriter.replaceOpWithNewOp(op, std::forward(args)...); - if (name && !newOp->hasAttr("sv.namehint")) - rewriter.updateRootInPlace(newOp, - [&] { newOp->setAttr("sv.namehint", name); }); - - return newOp; -} - -// Return true if the op has SV attributes. Note that we cannot use a helper -// function `hasSVAttributes` defined under SV dialect because of a cyclic -// dependency. -static bool hasSVAttributes(Operation *op) { - return op->hasAttr("sv.attributes"); -} - -namespace { -template -struct ComplementMatcher { - SubType lhs; - ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {} - bool match(Operation *op) { - auto xorOp = dyn_cast(op); - return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0)); - } -}; -} // end anonymous namespace - -template -static inline ComplementMatcher m_Complement(const SubType &subExpr) { - return ComplementMatcher(subExpr); -} - -/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined -/// as an Op. Returns true if successful, and false otherwise. -/// -/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true -/// -static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) { - auto inputs = op->getOperands(); - - for (size_t i = 0, size = inputs.size(); i != size; ++i) { - Operation *flattenOp = inputs[i].getDefiningOp(); - if (!flattenOp || flattenOp->getName() != op->getName()) continue; - - // Check for loops - if (flattenOp == op) continue; - - // Don't duplicate logic when it has multiple uses. - if (!inputs[i].hasOneUse()) { - // We can fold a multi-use binary operation into this one if this allows a - // constant to fold though. For example, fold - // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2) - // since the constants will both fold and we end up with the equiv cost. - // - // We don't do this for add/mul because the hardware won't be shared - // between the two ops if duplicated. - if (flattenOp->getNumOperands() != 2 || !isa(op) || - !flattenOp->getOperand(1).getDefiningOp() || - !inputs.back().getDefiningOp()) - continue; - } - - // Otherwise, flatten away. - auto flattenOpInputs = flattenOp->getOperands(); - - SmallVector newOperands; - newOperands.reserve(size + flattenOpInputs.size()); - - auto flattenOpIndex = inputs.begin() + i; - newOperands.append(inputs.begin(), flattenOpIndex); - newOperands.append(flattenOpInputs.begin(), flattenOpInputs.end()); - newOperands.append(flattenOpIndex + 1, inputs.end()); - - Value result = - createGenericOp(op->getLoc(), op->getName(), newOperands, rewriter); - - // If the original operation and flatten operand have bin flags, propagte - // the flag to new one. - if (op->hasAttrOfType("twoState") && - flattenOp->hasAttrOfType("twoState")) - result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr()); - - replaceOpAndCopyName(rewriter, op, result); - return true; - } - return false; -} - -// Given a range of uses of an operation, find the lowest and highest bits -// inclusive that are ever referenced. The range of uses must not be empty. -static std::pair getLowestBitAndHighestBitRequired( - Operation *op, bool narrowTrailingBits, size_t originalOpWidth) { - auto users = op->getUsers(); - assert(!users.empty() && - "getLowestBitAndHighestBitRequired cannot operate on " - "a empty list of uses."); - - // when we don't want to narrowTrailingBits (namely in arithmetic - // operations), forcing lowestBitRequired = 0 - size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0; - size_t highestBitRequired = 0; - - for (auto *user : users) { - if (auto extractOp = dyn_cast(user)) { - size_t lowBit = extractOp.getLowBit(); - size_t highBit = - extractOp.getType().cast().getWidth() + lowBit - 1; - highestBitRequired = std::max(highestBitRequired, highBit); - lowestBitRequired = std::min(lowestBitRequired, lowBit); - continue; - } - - highestBitRequired = originalOpWidth - 1; - lowestBitRequired = 0; - break; - } - - return {lowestBitRequired, highestBitRequired}; -} - -template -static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, - PatternRewriter &rewriter) { - IntegerType opType = - op.getResult().getType().template dyn_cast(); - if (!opType) return false; - - auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits, - opType.getWidth()); - if (range.second + 1 == opType.getWidth() && range.first == 0) return false; - - SmallVector args; - auto newType = rewriter.getIntegerType(range.second - range.first + 1); - for (auto inop : op.getOperands()) { - // deal with muxes here - if (inop.getType() != op.getType()) - args.push_back(inop); - else - args.push_back(rewriter.createOrFold(inop.getLoc(), newType, - inop, range.first)); - } - Value newop = rewriter.createOrFold(op.getLoc(), newType, args); - newop.getDefiningOp()->setDialectAttrs(op->getDialectAttrs()); - if (range.first) - newop = rewriter.createOrFold( - op.getLoc(), newop, - rewriter.create(op.getLoc(), - APInt::getZero(range.first))); - if (range.second + 1 < opType.getWidth()) - newop = rewriter.createOrFold( - op.getLoc(), - rewriter.create( - op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)), - newop); - rewriter.replaceOp(op, newop); - return true; -} - -//===----------------------------------------------------------------------===// -// Unary Operations -//===----------------------------------------------------------------------===// - -OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) { - // Replicate one time -> noop. - if (getType().cast().getWidth() == - getInput().getType().getIntOrFloatBitWidth()) - return getInput(); - - // Constant fold. - if (auto input = adaptor.getInput().dyn_cast_or_null()) { - if (input.getValue().getBitWidth() == 1) { - if (input.getValue().isZero()) - return getIntAttr( - APInt::getZero(getType().cast().getWidth()), - getContext()); - return getIntAttr( - APInt::getAllOnes(getType().cast().getWidth()), - getContext()); - } - - APInt result = APInt::getZeroWidth(); - for (auto i = getMultiple(); i != 0; --i) - result = result.concat(input.getValue()); - return getIntAttr(result, getContext()); - } - - return {}; -} - -OpFoldResult ParityOp::fold(FoldAdaptor adaptor) { - // Constant fold. - if (auto input = adaptor.getInput().dyn_cast_or_null()) - return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext()); - - return {}; -} - -//===----------------------------------------------------------------------===// -// Binary Operations -//===----------------------------------------------------------------------===// - -/// Performs constant folding `calculate` with element-wise behavior on the two -/// attributes in `operands` and returns the result if possible. -static Attribute constFoldBinaryOp(ArrayRef operands, - hw::PEO paramOpcode) { - assert(operands.size() == 2 && "binary op takes two operands"); - if (!operands[0] || !operands[1]) return {}; - - // Fold constants with ParamExprAttr::get which handles simple constants as - // well as parameter expressions. - return hw::ParamExprAttr::get(paramOpcode, operands[0].cast(), - operands[1].cast()); -} - -OpFoldResult ShlOp::fold(FoldAdaptor adaptor) { - if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { - unsigned shift = rhs.getValue().getZExtValue(); - unsigned width = getType().getIntOrFloatBitWidth(); - if (shift == 0) return getOperand(0); - if (width <= shift) return getIntAttr(APInt::getZero(width), getContext()); - } - - return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl); -} - -LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) { - // ShlOp(x, cst) -> Concat(Extract(x), zeros) - APInt value; - if (!matchPattern(op.getRhs(), m_ConstantInt(&value))) return failure(); - - unsigned width = op.getLhs().getType().cast().getWidth(); - unsigned shift = value.getZExtValue(); - - // This case is handled by fold. - if (width <= shift || shift == 0) return failure(); - - auto zeros = - rewriter.create(op.getLoc(), APInt::getZero(shift)); - - // Remove the high bits which would be removed by the Shl. - auto extract = - rewriter.create(op.getLoc(), op.getLhs(), 0, width - shift); - - replaceOpWithNewOpAndCopyName(rewriter, op, extract, zeros); - return success(); -} - -OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) { - if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { - unsigned shift = rhs.getValue().getZExtValue(); - if (shift == 0) return getOperand(0); - - unsigned width = getType().getIntOrFloatBitWidth(); - if (width <= shift) return getIntAttr(APInt::getZero(width), getContext()); - } - return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU); -} - -LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) { - // ShrUOp(x, cst) -> Concat(zeros, Extract(x)) - APInt value; - if (!matchPattern(op.getRhs(), m_ConstantInt(&value))) return failure(); - - unsigned width = op.getLhs().getType().cast().getWidth(); - unsigned shift = value.getZExtValue(); - - // This case is handled by fold. - if (width <= shift || shift == 0) return failure(); - - auto zeros = - rewriter.create(op.getLoc(), APInt::getZero(shift)); - - // Remove the low bits which would be removed by the Shr. - auto extract = rewriter.create(op.getLoc(), op.getLhs(), shift, - width - shift); - - replaceOpWithNewOpAndCopyName(rewriter, op, zeros, extract); - return success(); -} - -OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) { - if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { - if (rhs.getValue().getZExtValue() == 0) return getOperand(0); - } - return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS); -} - -LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) { - // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x)) - APInt value; - if (!matchPattern(op.getRhs(), m_ConstantInt(&value))) return failure(); - - unsigned width = op.getLhs().getType().cast().getWidth(); - unsigned shift = value.getZExtValue(); - - auto topbit = - rewriter.createOrFold(op.getLoc(), op.getLhs(), width - 1, 1); - auto sext = rewriter.createOrFold(op.getLoc(), topbit, shift); - - if (width <= shift) { - replaceOpAndCopyName(rewriter, op, {sext}); - return success(); - } - - auto extract = rewriter.create(op.getLoc(), op.getLhs(), shift, - width - shift); - - replaceOpWithNewOpAndCopyName(rewriter, op, sext, extract); - return success(); -} - -//===----------------------------------------------------------------------===// -// Other Operations -//===----------------------------------------------------------------------===// - -OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { - // If we are extracting the entire input, then return it. - if (getInput().getType() == getType()) return getInput(); - - // Constant fold. - if (auto input = adaptor.getInput().dyn_cast_or_null()) { - unsigned dstWidth = getType().cast().getWidth(); - return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth), - getContext()); - } - return {}; -} - -// Transforms extract(lo, cat(a, b, c, d, e)) into -// cat(extract(lo1, b), c, extract(lo2, d)). -// innerCat must be the argument of the provided ExtractOp. -static LogicalResult extractConcatToConcatExtract(ExtractOp op, - ConcatOp innerCat, - PatternRewriter &rewriter) { - auto reversedConcatArgs = llvm::reverse(innerCat.getInputs()); - size_t beginOfFirstRelevantElement = 0; - auto it = reversedConcatArgs.begin(); - size_t lowBit = op.getLowBit(); - - // This loop finds the first concatArg that is covered by the ExtractOp - for (; it != reversedConcatArgs.end(); it++) { - assert(beginOfFirstRelevantElement <= lowBit && - "incorrectly moved past an element that lowBit has coverage over"); - auto operand = *it; - - size_t operandWidth = operand.getType().getIntOrFloatBitWidth(); - if (lowBit < beginOfFirstRelevantElement + operandWidth) { - // A bit other than the first bit will be used in this element. - // ...... ........ ... - // ^---lowBit - // ^---beginOfFirstRelevantElement - // - // Edge-case close to the end of the range. - // ...... ........ ... - // ^---(position + operandWidth) - // ^---lowBit - // ^---beginOfFirstRelevantElement - // - // Edge-case close to the beginning of the rang - // ...... ........ ... - // ^---lowBit - // ^---beginOfFirstRelevantElement - // - break; - } - - // extraction discards this element. - // ...... ........ ... - // | ^---lowBit - // ^---beginOfFirstRelevantElement - beginOfFirstRelevantElement += operandWidth; - } - assert(it != reversedConcatArgs.end() && - "incorrectly failed to find an element which contains coverage of " - "lowBit"); - - SmallVector reverseConcatArgs; - size_t widthRemaining = op.getType().cast().getWidth(); - size_t extractLo = lowBit - beginOfFirstRelevantElement; - - // Transform individual arguments of innerCat(..., a, b, c,) into - // [ extract(a), b, extract(c) ], skipping an extract operation where - // possible. - for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) { - auto concatArg = *it; - size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth(); - size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo); - - if (widthToConsume == operandWidth && extractLo == 0) { - reverseConcatArgs.push_back(concatArg); - } else { - auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume); - reverseConcatArgs.push_back( - rewriter.create(op.getLoc(), resultType, *it, extractLo)); - } - - widthRemaining -= widthToConsume; - - // Beyond the first element, all elements are extracted from position 0. - extractLo = 0; - } - - if (reverseConcatArgs.size() == 1) { - replaceOpAndCopyName(rewriter, op, reverseConcatArgs[0]); - } else { - replaceOpWithNewOpAndCopyName( - rewriter, op, SmallVector(llvm::reverse(reverseConcatArgs))); - } - return success(); -} - -// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c). -static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, - PatternRewriter &rewriter) { - auto extractResultWidth = op.getType().cast().getWidth(); - auto replicateEltWidth = - replicate.getOperand().getType().getIntOrFloatBitWidth(); - - // If the extract starts at the base of an element and is an even multiple, - // we can replace the extract with a smaller replicate. - if (op.getLowBit() % replicateEltWidth == 0 && - extractResultWidth % replicateEltWidth == 0) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - replicate.getOperand()); - return true; - } - - // If the extract is completely contained in one element, extract from the - // element. - if (op.getLowBit() % replicateEltWidth + extractResultWidth <= - replicateEltWidth) { - replaceOpWithNewOpAndCopyName( - rewriter, op, op.getType(), replicate.getOperand(), - op.getLowBit() % replicateEltWidth); - return true; - } - - // We don't currently handle the case of extracting from non-whole elements, - // e.g. `extract (replicate 2-bit-thing, N), 1`. - return false; -} - -LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) { - auto *inputOp = op.getInput().getDefiningOp(); - - // This turns out to be incredibly expensive. Disable until performance is - // addressed. -#if 0 - // If the extracted bits are all known, then return the result. - auto knownBits = computeKnownBits(op.getInput()) - .extractBits(op.getType().cast().getWidth(), - op.getLowBit()); - if (knownBits.isConstant()) { - replaceOpWithNewOpAndCopyName(rewriter, op, - knownBits.getConstant()); - return success(); - } -#endif - - // extract(olo, extract(ilo, x)) = extract(olo + ilo, x) - if (auto innerExtract = dyn_cast_or_null(inputOp)) { - replaceOpWithNewOpAndCopyName( - rewriter, op, op.getType(), innerExtract.getInput(), - innerExtract.getLowBit() + op.getLowBit()); - return success(); - } - - // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d)) - if (auto innerCat = dyn_cast_or_null(inputOp)) - return extractConcatToConcatExtract(op, innerCat, rewriter); - - // extract(lo, replicate(a)) - if (auto replicate = dyn_cast_or_null(inputOp)) - if (extractFromReplicate(op, replicate, rewriter)) return success(); - - // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the - // and/or/xor are not modifying the extracted bits. - if (inputOp && inputOp->getNumOperands() == 2 && - isa(inputOp)) { - if (auto cstRHS = inputOp->getOperand(1).getDefiningOp()) { - auto extractedCst = cstRHS.getValue().extractBits( - op.getType().cast().getWidth(), op.getLowBit()); - if (isa(inputOp) && extractedCst.isZero()) { - replaceOpWithNewOpAndCopyName( - rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit()); - return success(); - } - - // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one - // extract to represent the result. Turning it into a pile of extracts is - // always fine by our cost model, but we don't want to explode things into - // a ton of bits because it will bloat the IR and generated Verilog. - if (isa(inputOp)) { - // For our cost model, we only do this if the bit pattern is a - // contiguous series of ones. - unsigned lz = extractedCst.countLeadingZeros(); - unsigned tz = extractedCst.countTrailingZeros(); - unsigned pop = extractedCst.popcount(); - if (extractedCst.getBitWidth() - lz - tz == pop) { - auto resultTy = rewriter.getIntegerType(pop); - SmallVector resultElts; - if (lz) - resultElts.push_back(rewriter.create( - op.getLoc(), APInt::getZero(lz))); - resultElts.push_back(rewriter.createOrFold( - op.getLoc(), resultTy, inputOp->getOperand(0), - op.getLowBit() + tz)); - if (tz) - resultElts.push_back(rewriter.create( - op.getLoc(), APInt::getZero(tz))); - replaceOpWithNewOpAndCopyName(rewriter, op, resultElts); - return success(); - } - } - } - } - - // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is - // extracted. - if (op.getType().cast().getWidth() == 1 && inputOp) - if (auto shlOp = dyn_cast(inputOp)) - if (auto lhsCst = shlOp.getOperand(0).getDefiningOp()) - if (lhsCst.getValue().isOne()) { - auto newCst = rewriter.create( - shlOp.getLoc(), - APInt(lhsCst.getValue().getBitWidth(), op.getLowBit())); - replaceOpWithNewOpAndCopyName(rewriter, op, ICmpPredicate::eq, - shlOp->getOperand(1), newCst, - false); - return success(); - } - - return failure(); -} - -//===----------------------------------------------------------------------===// -// Associative Variadic operations -//===----------------------------------------------------------------------===// - -// Reduce all operands to a single value (either integer constant or parameter -// expression) if all the operands are constants. -static Attribute constFoldAssociativeOp(ArrayRef operands, - hw::PEO paramOpcode) { - assert(operands.size() > 1 && "caller should handle one-operand case"); - // We can only fold anything in the case where all operands are known to be - // constants. Check the least common one first for an early out. - if (!operands[1] || !operands[0]) return {}; - - // This will fold to a simple constant if all operands are constant. - if (llvm::all_of(operands.drop_front(2), - [&](Attribute in) { return !!in; })) { - SmallVector typedOperands; - typedOperands.reserve(operands.size()); - for (auto operand : operands) { - if (auto typedOperand = operand.dyn_cast()) - typedOperands.push_back(typedOperand); - else - break; - } - if (typedOperands.size() == operands.size()) - return hw::ParamExprAttr::get(paramOpcode, typedOperands); - } - - return {}; -} - -/// When we find a logical operation (and, or, xor) with a constant e.g. -/// `X & 42`, we want to push the constant into the computation of X if it leads -/// to simplification. -/// -/// This function handles the case where the logical operation has a concat -/// operand. We check to see if we can simplify the concat, e.g. when it has -/// constant operands. -/// -/// This returns true when a simplification happens. -static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, - size_t concatIdx, const APInt &cst, - PatternRewriter &rewriter) { - auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp(); - assert((isa(logicalOp) && concatOp)); - - // Check to see if any operands can be simplified by pushing the logical op - // into all parts of the concat. - bool canSimplify = - llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool { - auto *operandOp = operand.getDefiningOp(); - if (!operandOp) return false; - - // If the concat has a constant operand then we can transform this. - if (isa(operandOp)) return true; - // If the concat has the same logical operation and that operation has - // a constant operation than we can fold it into that suboperation. - return operandOp->getName() == logicalOp->getName() && - operandOp->hasOneUse() && operandOp->getNumOperands() != 0 && - operandOp->getOperands().back().getDefiningOp(); - }); - - if (!canSimplify) return false; - - // Create a new instance of the logical operation. We have to do this the - // hard way since we're generic across a family of different ops. - auto createLogicalOp = [&](ArrayRef operands) -> Value { - return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands, - rewriter); - }; - - // Ok, let's do the transformation. We do this by slicing up the constant - // for each unit of the concat and duplicate the operation into the - // sub-operand. - SmallVector newConcatOperands; - newConcatOperands.reserve(concatOp->getNumOperands()); - - // Work from MSB to LSB. - size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth(); - for (Value operand : concatOp->getOperands()) { - size_t operandWidth = operand.getType().getIntOrFloatBitWidth(); - nextOperandBit -= operandWidth; - // Take a slice of the constant. - auto eltCst = rewriter.create( - logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth)); - - newConcatOperands.push_back(createLogicalOp({operand, eltCst})); - } - - // Create the concat, and the rest of the logical op if we need it. - Value newResult = - rewriter.create(concatOp.getLoc(), newConcatOperands); - - // If we had a variadic logical op on the top level, then recreate it with the - // new concat and without the constant operand. - if (logicalOp->getNumOperands() > 2) { - auto origOperands = logicalOp->getOperands(); - SmallVector operands; - // Take any stuff before the concat. - operands.append(origOperands.begin(), origOperands.begin() + concatIdx); - // Take any stuff after the concat but before the constant. - operands.append(origOperands.begin() + concatIdx + 1, - origOperands.begin() + (origOperands.size() - 1)); - // Include the new concat. - operands.push_back(newResult); - newResult = createLogicalOp(operands); - } - - replaceOpAndCopyName(rewriter, logicalOp, newResult); - return true; -} - -OpFoldResult AndOp::fold(FoldAdaptor adaptor) { - APInt value = APInt::getAllOnes(getType().cast().getWidth()); - - auto inputs = adaptor.getInputs(); - - // and(x, 01, 10) -> 00 -- annulment. - for (auto operand : inputs) { - if (!operand) continue; - value &= operand.cast().getValue(); - if (value.isZero()) return getIntAttr(value, getContext()); - } - - // and(x, -1) -> x. - if (inputs.size() == 2 && inputs[1] && - inputs[1].cast().getValue().isAllOnes()) - return getInputs()[0]; - - // and(x, x, x) -> x. This also handles and(x) -> x. - if (llvm::all_of(getInputs(), - [&](auto in) { return in == this->getInputs()[0]; })) - return getInputs()[0]; - - // and(..., x, ..., ~x, ...) -> 0 - for (Value arg : getInputs()) { - Value subExpr; - if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) { - for (Value arg2 : getInputs()) - if (arg2 == subExpr) - return getIntAttr( - APInt::getZero(getType().cast().getWidth()), - getContext()); - } - } - - // Constant fold - return constFoldAssociativeOp(inputs, hw::PEO::And); -} - -/// Returns a single common operand that all inputs of the operation `op` can -/// be traced back to, or an empty `Value` if no such operand exists. -/// -/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a` -/// (assuming the bit-width of `a` is `n`). -template -static Value getCommonOperand(Op op) { - if (!op.getType().isInteger(1)) return Value(); - - auto inputs = op.getInputs(); - size_t size = inputs.size(); - - auto sourceOp = inputs[0].template getDefiningOp(); - if (!sourceOp) return Value(); - Value source = sourceOp.getOperand(); - - // Fast path: the input size is not equal to the width of the source. - if (size != source.getType().getIntOrFloatBitWidth()) return Value(); - - // Tracks the bits that were encountered. - llvm::BitVector bits(size); - bits.set(sourceOp.getLowBit()); - - for (size_t i = 1; i != size; ++i) { - auto extractOp = inputs[i].template getDefiningOp(); - if (!extractOp || extractOp.getOperand() != source) return Value(); - bits.set(extractOp.getLowBit()); - } - - return bits.all() ? source : Value(); -} - -/// Canonicalize an idempotent operation `op` so that only one input of any kind -/// occurs. -/// -/// Example: `and(x, y, x, z)` -> `and(x, y, z)` -template -static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) { - auto inputs = op.getInputs(); - llvm::SmallSetVector uniqueInputs; - - for (const auto input : inputs) uniqueInputs.insert(input); - - if (uniqueInputs.size() < inputs.size()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - uniqueInputs.getArrayRef()); - return true; - } - - return false; -} - -LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) { - auto inputs = op.getInputs(); - auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands, `fold` should handle this"); - - // and(..., x, ..., x) -> and(..., x, ...) -- idempotent - // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above. - if (size > 2 && canonicalizeIdempotentInputs(op, rewriter)) return success(); - - // Patterns for and with a constant on RHS. - APInt value; - if (matchPattern(inputs.back(), m_ConstantInt(&value))) { - // and(..., '1) -> and(...) -- identity - if (value.isAllOnes()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - inputs.drop_back(), false); - return success(); - } - - // TODO: Combine multiple constants together even if they aren't at the - // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant - // folding - APInt value2; - if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { - auto cst = rewriter.create(op.getLoc(), value & value2); - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - newOperands.push_back(cst); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands, false); - return success(); - } - - // Handle 'and' with a single bit constant on the RHS. - if (size == 2 && value.isPowerOf2()) { - // If the LHS is a replicate from a single bit, we can 'concat' it - // into place. e.g.: - // `replicate(x) & 4` -> `concat(zeros, x, zeros)` - // TODO: Generalize this for non-single-bit operands. - if (auto replicate = inputs[0].getDefiningOp()) { - auto replicateOperand = replicate.getOperand(); - if (replicateOperand.getType().isInteger(1)) { - unsigned resultWidth = op.getType().getIntOrFloatBitWidth(); - auto trailingZeros = value.countTrailingZeros(); - - // Don't add zero bit constants unnecessarily. - SmallVector concatOperands; - if (trailingZeros != resultWidth - 1) { - auto highZeros = rewriter.create( - op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1)); - concatOperands.push_back(highZeros); - } - concatOperands.push_back(replicateOperand); - if (trailingZeros != 0) { - auto lowZeros = rewriter.create( - op.getLoc(), APInt::getZero(trailingZeros)); - concatOperands.push_back(lowZeros); - } - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - concatOperands); - return success(); - } - } - } - - // If this is an and from an extract op, try shrinking the extract. - if (auto extractOp = inputs[0].getDefiningOp()) { - if (size == 2 && - // We can shrink it if the mask has leading or trailing zeros. - (value.countLeadingZeros() || value.countTrailingZeros())) { - unsigned lz = value.countLeadingZeros(); - unsigned tz = value.countTrailingZeros(); - - // Start by extracting the smaller number of bits. - auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz); - Value smallElt = rewriter.createOrFold( - extractOp.getLoc(), smallTy, extractOp->getOperand(0), - extractOp.getLowBit() + tz); - // Apply the 'and' mask if needed. - APInt smallMask = value.extractBits(smallTy.getWidth(), tz); - if (!smallMask.isAllOnes()) { - auto loc = inputs.back().getLoc(); - smallElt = rewriter.createOrFold( - loc, smallElt, rewriter.create(loc, smallMask), - false); - } - - // The final replacement will be a concat of the leading/trailing zeros - // along with the smaller extracted value. - SmallVector resultElts; - if (lz) - resultElts.push_back( - rewriter.create(op.getLoc(), APInt::getZero(lz))); - resultElts.push_back(smallElt); - if (tz) - resultElts.push_back( - rewriter.create(op.getLoc(), APInt::getZero(tz))); - replaceOpWithNewOpAndCopyName(rewriter, op, resultElts); - return success(); - } - } - - // and(concat(x, cst1), a, b, c, cst2) - // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')). - // We do this for even more multi-use concats since they are "just wiring". - for (size_t i = 0; i < size - 1; ++i) { - if (auto concat = inputs[i].getDefiningOp()) - if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter)) - return success(); - } - } - - // and(x, and(...)) -> and(x, ...) -- flatten - if (tryFlatteningOperands(op, rewriter)) return success(); - - // extracts only of and(...) -> and(extract()...) - if (narrowOperationWidth(op, true, rewriter)) return success(); - - // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1) - if (auto source = getCommonOperand(op)) { - auto cmpAgainst = - rewriter.create(op.getLoc(), APInt::getAllOnes(size)); - replaceOpWithNewOpAndCopyName(rewriter, op, ICmpPredicate::eq, - source, cmpAgainst); - return success(); - } - - /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement - return failure(); -} - -OpFoldResult OrOp::fold(FoldAdaptor adaptor) { - auto value = APInt::getZero(getType().cast().getWidth()); - auto inputs = adaptor.getInputs(); - // or(x, 10, 01) -> 11 - for (auto operand : inputs) { - if (!operand) continue; - value |= operand.cast().getValue(); - if (value.isAllOnes()) return getIntAttr(value, getContext()); - } - - // or(x, 0) -> x - if (inputs.size() == 2 && inputs[1] && - inputs[1].cast().getValue().isZero()) - return getInputs()[0]; - - // or(x, x, x) -> x. This also handles or(x) -> x - if (llvm::all_of(getInputs(), - [&](auto in) { return in == this->getInputs()[0]; })) - return getInputs()[0]; - - // or(..., x, ..., ~x, ...) -> -1 - for (Value arg : getInputs()) { - Value subExpr; - if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) { - for (Value arg2 : getInputs()) - if (arg2 == subExpr) - return getIntAttr( - APInt::getAllOnes(getType().cast().getWidth()), - getContext()); - } - } - - // Constant fold - return constFoldAssociativeOp(inputs, hw::PEO::Or); -} - -/// Simplify concat ops in an or op when a constant operand is present in either -/// concat. -/// -/// This will invert an or(concat, concat) into concat(or, or, ...), which can -/// often be further simplified due to the smaller or ops being easier to fold. -/// -/// For example: -/// -/// or(..., concat(x, 0), concat(0, y)) -/// ==> or(..., concat(x, 0, y)), when x and y don't overlap. -/// -/// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1)) -/// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)), -/// or(extract(cst1, 3..1), extract(cst2, 2..0)), -/// or(extract(cst1, 0..0), y: i1)) -static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, - size_t concatIdx2, - PatternRewriter &rewriter) { - assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2"); - - auto inputs = op.getInputs(); - auto concat1 = inputs[concatIdx1].getDefiningOp(); - auto concat2 = inputs[concatIdx2].getDefiningOp(); - - assert(concat1 && concat2 && "expected indexes to point to ConcatOps"); - - // We can simplify as long as a constant is present in either concat. - bool hasConstantOp1 = - llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool { - return operand.getDefiningOp(); - }); - if (!hasConstantOp1) { - bool hasConstantOp2 = - llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool { - return operand.getDefiningOp(); - }); - if (!hasConstantOp2) return false; - } - - SmallVector newConcatOperands; - - // Simultaneously iterate over the operands of both concat ops, from MSB to - // LSB, pushing out or's of overlapping ranges of the operands. When operands - // span different bit ranges, we extract only the maximum overlap. - auto operands1 = concat1->getOperands(); - auto operands2 = concat2->getOperands(); - // Number of bits already consumed from operands 1 and 2, respectively. - unsigned consumedWidth1 = 0; - unsigned consumedWidth2 = 0; - for (auto it1 = operands1.begin(), end1 = operands1.end(), - it2 = operands2.begin(), end2 = operands2.end(); - it1 != end1 && it2 != end2;) { - auto operand1 = *it1; - auto operand2 = *it2; - - unsigned remainingWidth1 = - hw::getBitWidth(operand1.getType()) - consumedWidth1; - unsigned remainingWidth2 = - hw::getBitWidth(operand2.getType()) - consumedWidth2; - unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2); - auto narrowedType = rewriter.getIntegerType(widthToConsume); - - auto extract1 = rewriter.createOrFold( - op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume); - auto extract2 = rewriter.createOrFold( - op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume); - - newConcatOperands.push_back( - rewriter.createOrFold(op.getLoc(), extract1, extract2, false)); - - consumedWidth1 += widthToConsume; - consumedWidth2 += widthToConsume; - - if (widthToConsume == remainingWidth1) { - ++it1; - consumedWidth1 = 0; - } - if (widthToConsume == remainingWidth2) { - ++it2; - consumedWidth2 = 0; - } - } - - ConcatOp newOp = rewriter.create(op.getLoc(), newConcatOperands); - - // Copy the old operands except for concatIdx1 and concatIdx2, and append the - // new ConcatOp to the end. - SmallVector newOrOperands; - newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1); - newOrOperands.append(inputs.begin() + concatIdx1 + 1, - inputs.begin() + concatIdx2); - newOrOperands.append(inputs.begin() + concatIdx2 + 1, - inputs.begin() + inputs.size()); - newOrOperands.push_back(newOp); - - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOrOperands); - return true; -} - -LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { - auto inputs = op.getInputs(); - auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands"); - - // or(..., x, ..., x, ...) -> or(..., x) -- idempotent - // Trivial or(x), or(x, x) cases are handled by [OrOp::fold]. - if (size > 2 && canonicalizeIdempotentInputs(op, rewriter)) return success(); - - // Patterns for and with a constant on RHS. - APInt value; - if (matchPattern(inputs.back(), m_ConstantInt(&value))) { - // or(..., '0) -> or(...) -- identity - if (value.isZero()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - inputs.drop_back()); - return success(); - } - - // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding - APInt value2; - if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { - auto cst = rewriter.create(op.getLoc(), value | value2); - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - newOperands.push_back(cst); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands); - return success(); - } - - // or(concat(x, cst1), a, b, c, cst2) - // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')). - // We do this for even more multi-use concats since they are "just wiring". - for (size_t i = 0; i < size - 1; ++i) { - if (auto concat = inputs[i].getDefiningOp()) - if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter)) - return success(); - } - } - - // or(x, or(...)) -> or(x, ...) -- flatten - if (tryFlatteningOperands(op, rewriter)) return success(); - - // or(..., concat(x, cst1), concat(cst2, y) - // ==> or(..., concat(x, cst3, y)), when x and y don't overlap. - for (size_t i = 0; i < size - 1; ++i) { - if (auto concat = inputs[i].getDefiningOp()) - for (size_t j = i + 1; j < size; ++j) - if (auto concat = inputs[j].getDefiningOp()) - if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter)) - return success(); - } - - // extracts only of or(...) -> or(extract()...) - if (narrowOperationWidth(op, true, rewriter)) return success(); - - // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0) - if (auto source = getCommonOperand(op)) { - auto cmpAgainst = - rewriter.create(op.getLoc(), APInt::getZero(size)); - replaceOpWithNewOpAndCopyName(rewriter, op, ICmpPredicate::ne, - source, cmpAgainst); - return success(); - } - - // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2, - // .., c_n), a, 0) - if (auto firstMux = op.getOperand(0).getDefiningOp()) { - APInt value; - if (op.getTwoState() && firstMux.getTwoState() && - matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) && - value.isZero()) { - SmallVector conditions{firstMux.getCond()}; - auto check = [&](Value v) { - auto mux = v.getDefiningOp(); - if (!mux) return false; - conditions.push_back(mux.getCond()); - return mux.getTwoState() && - firstMux.getTrueValue() == mux.getTrueValue() && - firstMux.getFalseValue() == mux.getFalseValue(); - }; - if (llvm::all_of(op.getOperands().drop_front(), check)) { - auto cond = rewriter.create(op.getLoc(), conditions, true); - replaceOpWithNewOpAndCopyName( - rewriter, op, cond, firstMux.getTrueValue(), - firstMux.getFalseValue(), true); - return success(); - } - } - } - - /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement - return failure(); -} - -OpFoldResult XorOp::fold(FoldAdaptor adaptor) { - auto size = getInputs().size(); - auto inputs = adaptor.getInputs(); - - // xor(x) -> x -- noop - if (size == 1) return getInputs()[0]; - - // xor(x, x) -> 0 -- idempotent - if (size == 2 && getInputs()[0] == getInputs()[1]) - return IntegerAttr::get(getType(), 0); - - // xor(x, 0) -> x - if (inputs.size() == 2 && inputs[1] && - inputs[1].cast().getValue().isZero()) - return getInputs()[0]; - - // xor(xor(x,1),1) -> x - // but not self loop - if (isBinaryNot()) { - Value subExpr; - if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) && - subExpr != getResult()) - return subExpr; - } - - // Constant fold - return constFoldAssociativeOp(inputs, hw::PEO::Xor); -} - -// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user. -static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, - PatternRewriter &rewriter) { - auto icmp = op.getOperand(icmpOperand).getDefiningOp(); - auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate()); - - Value result = - rewriter.create(icmp.getLoc(), negatedPred, icmp.getOperand(0), - icmp.getOperand(1), icmp.getTwoState()); - - // If the xor had other operands, rebuild it. - if (op.getNumOperands() > 2) { - SmallVector newOperands(op.getOperands()); - newOperands.pop_back(); - newOperands.erase(newOperands.begin() + icmpOperand); - newOperands.push_back(result); - result = rewriter.create(op.getLoc(), newOperands, op.getTwoState()); - } - - replaceOpAndCopyName(rewriter, op, result); -} - -LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) { - auto inputs = op.getInputs(); - auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands"); - - // xor(..., x, x) -> xor (...) -- idempotent - if (inputs[size - 1] == inputs[size - 2]) { - assert(size > 2 && - "expected idempotent case for 2 elements handled already."); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - inputs.drop_back(/*n=*/2), false); - return success(); - } - - // Patterns for xor with a constant on RHS. - APInt value; - if (matchPattern(inputs.back(), m_ConstantInt(&value))) { - // xor(..., 0) -> xor(...) -- identity - if (value.isZero()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - inputs.drop_back(), false); - return success(); - } - - // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2. - APInt value2; - if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { - auto cst = rewriter.create(op.getLoc(), value ^ value2); - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - newOperands.push_back(cst); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands, false); - return success(); - } - - bool isSingleBit = value.getBitWidth() == 1; - - // Check for subexpressions that we can simplify. - for (size_t i = 0; i < size - 1; ++i) { - Value operand = inputs[i]; - - // xor(concat(x, cst1), a, b, c, cst2) - // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')). - // We do this for even more multi-use concats since they are "just - // wiring". - if (auto concat = operand.getDefiningOp()) - if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter)) - return success(); - - // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user. - if (isSingleBit && operand.hasOneUse()) { - assert(value == 1 && "single bit constant has to be one if not zero"); - if (auto icmp = operand.getDefiningOp()) - return canonicalizeXorIcmpTrue(op, i, rewriter), success(); - } - } - } - - // xor(x, xor(...)) -> xor(x, ...) -- flatten - if (tryFlatteningOperands(op, rewriter)) return success(); - - // extracts only of xor(...) -> xor(extract()...) - if (narrowOperationWidth(op, true, rewriter)) return success(); - - // xor(a[0], a[1], ..., a[n]) -> parity(a) - if (auto source = getCommonOperand(op)) { - replaceOpWithNewOpAndCopyName(rewriter, op, source); - return success(); - } - - return failure(); -} - -OpFoldResult SubOp::fold(FoldAdaptor adaptor) { - // sub(x - x) -> 0 - if (getRhs() == getLhs()) - return getIntAttr( - APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()), - getContext()); - - if (adaptor.getRhs()) { - // If both are constants, we can unconditionally fold. - if (adaptor.getLhs()) { - // Constant fold (c1 - c2) => (c1 + -1*c2). - auto negOne = getIntAttr( - APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()), - getContext()); - auto rhsNeg = hw::ParamExprAttr::get( - hw::PEO::Mul, adaptor.getRhs().cast(), negOne); - return hw::ParamExprAttr::get(hw::PEO::Add, - adaptor.getLhs().cast(), rhsNeg); - } - - // sub(x - 0) -> x - if (auto rhsC = adaptor.getRhs().dyn_cast()) { - if (rhsC.getValue().isZero()) return getLhs(); - } - } - - return {}; -} - -LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) { - // sub(x, cst) -> add(x, -cst) - APInt value; - if (matchPattern(op.getRhs(), m_ConstantInt(&value))) { - auto negCst = rewriter.create(op.getLoc(), -value); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getLhs(), negCst, - false); - return success(); - } - - // extracts only of sub(...) -> sub(extract()...) - if (narrowOperationWidth(op, false, rewriter)) return success(); - - return failure(); -} - -OpFoldResult AddOp::fold(FoldAdaptor adaptor) { - auto size = getInputs().size(); - - // add(x) -> x -- noop - if (size == 1u) return getInputs()[0]; - - // Constant fold constant operands. - return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add); -} - -LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) { - auto inputs = op.getInputs(); - auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands"); - - APInt value, value2; - - // add(..., 0) -> add(...) -- identity - if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - inputs.drop_back(), false); - return success(); - } - - // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding - if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) && - matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { - auto cst = rewriter.create(op.getLoc(), value + value2); - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - newOperands.push_back(cst); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands, false); - return success(); - } - - // add(..., x, x) -> add(..., shl(x, 1)) - if (inputs[size - 1] == inputs[size - 2]) { - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - - auto one = rewriter.create(op.getLoc(), op.getType(), 1); - auto shiftLeftOp = - rewriter.create(op.getLoc(), inputs.back(), one, false); - - newOperands.push_back(shiftLeftOp); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands, false); - return success(); - } - - auto shlOp = inputs[size - 1].getDefiningOp(); - // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1)) - if (shlOp && shlOp.getLhs() == inputs[size - 2] && - matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) { - APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false); - auto rhs = - rewriter.create(op.getLoc(), (one << value) + one); - - std::array factors = {shlOp.getLhs(), rhs}; - auto mulOp = rewriter.create(op.getLoc(), factors, false); - - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - newOperands.push_back(mulOp); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands, false); - return success(); - } - - auto mulOp = inputs[size - 1].getDefiningOp(); - // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1)) - if (mulOp && mulOp.getInputs().size() == 2 && - mulOp.getInputs()[0] == inputs[size - 2] && - matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) { - APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false); - auto rhs = rewriter.create(op.getLoc(), value + one); - std::array factors = {mulOp.getInputs()[0], rhs}; - auto newMulOp = rewriter.create(op.getLoc(), factors, false); - - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - newOperands.push_back(newMulOp); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands, false); - return success(); - } - - // add(x, add(...)) -> add(x, ...) -- flatten - if (tryFlatteningOperands(op, rewriter)) return success(); - - // extracts only of add(...) -> add(extract()...) - if (narrowOperationWidth(op, false, rewriter)) return success(); - - // add(add(x, c1), c2) -> add(x, c1 + c2) - auto addOp = inputs[0].getDefiningOp(); - if (addOp && addOp.getInputs().size() == 2 && - matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) && - inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) { - auto rhs = rewriter.create(op.getLoc(), value + value2); - replaceOpWithNewOpAndCopyName( - rewriter, op, op.getType(), ArrayRef{addOp.getInputs()[0], rhs}, - /*twoState=*/op.getTwoState() && addOp.getTwoState()); - return success(); - } - - return failure(); -} - -OpFoldResult MulOp::fold(FoldAdaptor adaptor) { - auto size = getInputs().size(); - auto inputs = adaptor.getInputs(); - - // mul(x) -> x -- noop - if (size == 1u) return getInputs()[0]; - - auto width = getType().cast().getWidth(); - APInt value(/*numBits=*/width, 1, /*isSigned=*/false); - - // mul(x, 0, 1) -> 0 -- annulment - for (auto operand : inputs) { - if (!operand) continue; - value *= operand.cast().getValue(); - if (value.isZero()) return getIntAttr(value, getContext()); - } - - // Constant fold - return constFoldAssociativeOp(inputs, hw::PEO::Mul); -} - -LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { - auto inputs = op.getInputs(); - auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands"); - - APInt value, value2; - - // mul(x, c) -> shl(x, log2(c)), where c is a power of two. - if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) && - value.isPowerOf2()) { - auto shift = rewriter.create(op.getLoc(), op.getType(), - value.exactLogBase2()); - auto shlOp = - rewriter.create(op.getLoc(), inputs[0], shift, false); - - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - ArrayRef(shlOp), false); - return success(); - } - - // mul(..., 1) -> mul(...) -- identity - if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - inputs.drop_back()); - return success(); - } - - // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding - if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) && - matchPattern(inputs[size - 2], m_ConstantInt(&value2))) { - auto cst = rewriter.create(op.getLoc(), value * value2); - SmallVector newOperands(inputs.drop_back(/*n=*/2)); - newOperands.push_back(cst); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands); - return success(); - } - - // mul(a, mul(...)) -> mul(a, ...) -- flatten - if (tryFlatteningOperands(op, rewriter)) return success(); - - // extracts only of mul(...) -> mul(extract()...) - if (narrowOperationWidth(op, false, rewriter)) return success(); - - return failure(); -} - -template -static OpFoldResult foldDiv(Op op, ArrayRef constants) { - if (auto rhsValue = constants[1].dyn_cast_or_null()) { - // divu(x, 1) -> x, divs(x, 1) -> x - if (rhsValue.getValue() == 1) return op.getLhs(); - - // If the divisor is zero, do not fold for now. - if (rhsValue.getValue().isZero()) return {}; - } - - return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU); -} - -OpFoldResult DivUOp::fold(FoldAdaptor adaptor) { - return foldDiv(*this, adaptor.getOperands()); -} - -OpFoldResult DivSOp::fold(FoldAdaptor adaptor) { - return foldDiv(*this, adaptor.getOperands()); -} - -template -static OpFoldResult foldMod(Op op, ArrayRef constants) { - if (auto rhsValue = constants[1].dyn_cast_or_null()) { - // modu(x, 1) -> 0, mods(x, 1) -> 0 - if (rhsValue.getValue() == 1) - return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()), - op.getContext()); - - // If the divisor is zero, do not fold for now. - if (rhsValue.getValue().isZero()) return {}; - } - - if (auto lhsValue = constants[0].dyn_cast_or_null()) { - // modu(0, x) -> 0, mods(0, x) -> 0 - if (lhsValue.getValue().isZero()) - return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()), - op.getContext()); - } - - return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU); -} - -OpFoldResult ModUOp::fold(FoldAdaptor adaptor) { - return foldMod(*this, adaptor.getOperands()); -} - -OpFoldResult ModSOp::fold(FoldAdaptor adaptor) { - return foldMod(*this, adaptor.getOperands()); -} -//===----------------------------------------------------------------------===// -// ConcatOp -//===----------------------------------------------------------------------===// - -// Constant folding -OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { - if (getNumOperands() == 1) return getOperand(0); - - // If all the operands are constant, we can fold. - for (auto attr : adaptor.getInputs()) - if (!attr || !attr.isa()) return {}; - - // If we got here, we can constant fold. - unsigned resultWidth = getType().getIntOrFloatBitWidth(); - APInt result(resultWidth, 0); - - unsigned nextInsertion = resultWidth; - // Insert each chunk into the result. - for (auto attr : adaptor.getInputs()) { - auto chunk = attr.cast().getValue(); - nextInsertion -= chunk.getBitWidth(); - result.insertBits(chunk, nextInsertion); - } - - return getIntAttr(result, getContext()); -} - -LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) { - auto inputs = op.getInputs(); - auto size = inputs.size(); - assert(size > 1 && "expected 2 or more operands"); - - // This function is used when we flatten neighboring operands of a - // (variadic) concat into a new vesion of the concat. first/last indices - // are inclusive. - auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex, - ValueRange replacements) -> LogicalResult { - SmallVector newOperands; - newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex); - newOperands.append(replacements.begin(), replacements.end()); - newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end()); - if (newOperands.size() == 1) - replaceOpAndCopyName(rewriter, op, newOperands[0]); - else - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOperands); - return success(); - }; - - Value commonOperand = inputs[0]; - for (size_t i = 0; i != size; ++i) { - // Check to see if all operands are the same. - if (inputs[i] != commonOperand) commonOperand = Value(); - - // If an operand to the concat is itself a concat, then we can fold them - // together. - if (auto subConcat = inputs[i].getDefiningOp()) - return flattenConcat(i, i, subConcat->getOperands()); - - // Check for canonicalization due to neighboring operands. - if (i != 0) { - // Merge neighboring constants. - if (auto cst = inputs[i].getDefiningOp()) { - if (auto prevCst = inputs[i - 1].getDefiningOp()) { - unsigned prevWidth = prevCst.getValue().getBitWidth(); - unsigned thisWidth = cst.getValue().getBitWidth(); - auto resultCst = cst.getValue().zext(prevWidth + thisWidth); - resultCst |= prevCst.getValue().zext(prevWidth + thisWidth) - << thisWidth; - Value replacement = - rewriter.create(op.getLoc(), resultCst); - return flattenConcat(i - 1, i, replacement); - } - } - - // If the two operands are the same, turn them into a replicate. - if (inputs[i] == inputs[i - 1]) { - Value replacement = - rewriter.createOrFold(op.getLoc(), inputs[i], 2); - return flattenConcat(i - 1, i, replacement); - } - - // If this input is a replicate, see if we can fold it with the previous - // one. - if (auto repl = inputs[i].getDefiningOp()) { - // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ... - if (repl.getOperand() == inputs[i - 1]) { - Value replacement = rewriter.createOrFold( - op.getLoc(), repl.getOperand(), repl.getMultiple() + 1); - return flattenConcat(i - 1, i, replacement); - } - // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ... - if (auto prevRepl = inputs[i - 1].getDefiningOp()) { - if (prevRepl.getOperand() == repl.getOperand()) { - Value replacement = rewriter.createOrFold( - op.getLoc(), repl.getOperand(), - repl.getMultiple() + prevRepl.getMultiple()); - return flattenConcat(i - 1, i, replacement); - } - } - } - - // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ... - if (auto repl = inputs[i - 1].getDefiningOp()) { - if (repl.getOperand() == inputs[i]) { - Value replacement = rewriter.createOrFold( - op.getLoc(), inputs[i], repl.getMultiple() + 1); - return flattenConcat(i - 1, i, replacement); - } - } - - // Merge neighboring extracts of neighboring inputs, e.g. - // {A[3], A[2]} -> A[3:2] - if (auto extract = inputs[i].getDefiningOp()) { - if (auto prevExtract = inputs[i - 1].getDefiningOp()) { - if (extract.getInput() == prevExtract.getInput()) { - auto thisWidth = extract.getType().cast().getWidth(); - if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) { - auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth(); - auto resType = rewriter.getIntegerType(thisWidth + prevWidth); - Value replacement = rewriter.create( - op.getLoc(), resType, extract.getInput(), - extract.getLowBit()); - return flattenConcat(i - 1, i, replacement); - } - } - } - } - // Merge neighboring array extracts of neighboring inputs, e.g. - // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2]) - - // This represents a slice of an array. - struct ArraySlice { - Value input; - Value index; - size_t width; - static std::optional get(Value value) { - assert(value.getType().isa() && "expected integer type"); - if (auto arrayGet = value.getDefiningOp()) - return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1}; - // array slice op is wrapped with bitcast. - if (auto bitcast = value.getDefiningOp()) - if (auto arraySlice = - bitcast.getInput().getDefiningOp()) - return ArraySlice{ - arraySlice.getInput(), arraySlice.getLowIndex(), - hw::type_cast(arraySlice.getType()) - .getNumElements()}; - return std::nullopt; - } - }; - if (auto extractOpt = ArraySlice::get(inputs[i])) { - if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) { - // Check that two array slices are mergable. - if (prevExtractOpt->index.getType() == extractOpt->index.getType() && - prevExtractOpt->input == extractOpt->input && - hw::isOffset(extractOpt->index, prevExtractOpt->index, - extractOpt->width)) { - auto resType = hw::ArrayType::get( - hw::type_cast(prevExtractOpt->input.getType()) - .getElementType(), - extractOpt->width + prevExtractOpt->width); - auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType)); - Value replacement = rewriter.create( - op.getLoc(), resIntType, - rewriter.create(op.getLoc(), resType, - prevExtractOpt->input, - extractOpt->index)); - return flattenConcat(i - 1, i, replacement); - } - } - } - } - } - - // If all operands were the same, then this is a replicate. - if (commonOperand) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - commonOperand); - return success(); - } - - return failure(); -} - -//===----------------------------------------------------------------------===// -// MuxOp -//===----------------------------------------------------------------------===// - -OpFoldResult MuxOp::fold(FoldAdaptor adaptor) { - // mux (c, b, b) -> b - if (getTrueValue() == getFalseValue()) return getTrueValue(); - - // mux(0, a, b) -> b - // mux(1, a, b) -> a - if (auto pred = adaptor.getCond().dyn_cast_or_null()) { - if (pred.getValue().isZero()) return getFalseValue(); - return getTrueValue(); - } - - // mux(cond, 1, 0) -> cond - if (auto tv = adaptor.getTrueValue().dyn_cast_or_null()) - if (auto fv = adaptor.getFalseValue().dyn_cast_or_null()) - if (tv.getValue().isOne() && fv.getValue().isZero() && - hw::getBitWidth(getType()) == 1) - return getCond(); - - return {}; -} - -/// Check to see if the condition to the specified mux is an equality -/// comparison `indexValue` and one or more constants. If so, put the -/// constants in the constants vector and return true, otherwise return false. -/// -/// This is part of foldMuxChain. -/// -static bool getMuxChainCondConstant( - Value cond, Value indexValue, bool isInverted, - std::function constantFn) { - // Handle `idx == 42` and `idx != 42`. - if (auto cmp = cond.getDefiningOp()) { - // TODO: We could handle things like "x < 2" as two entries. - auto requiredPredicate = - (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne); - if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) { - if (auto cst = cmp.getRhs().getDefiningOp()) { - constantFn(cst); - return true; - } - } - return false; - } - - // Handle mux(`idx == 1 || idx == 3`, value, muxchain). - if (auto orOp = cond.getDefiningOp()) { - if (!isInverted) return false; - for (auto operand : orOp.getOperands()) - if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn)) - return false; - return true; - } - - // Handle mux(`idx != 1 && idx != 3`, muxchain, value). - if (auto andOp = cond.getDefiningOp()) { - if (isInverted) return false; - for (auto operand : andOp.getOperands()) - if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn)) - return false; - return true; - } - - return false; -} - -/// Given a mux, check to see if the "on true" value (or "on false" value if -/// isFalseSide=true) is a mux tree with the same condition. This allows us -/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into -/// `array_get (array_create(A, B, C), VAL)` which is far more compact and -/// allows synthesis tools to do more interesting optimizations. -/// -/// This returns false if we cannot form the mux tree (or do not want to) and -/// returns true if the mux was replaced. -static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, - PatternRewriter &rewriter) { - // Get the index value being compared. Later we check to see if it is - // compared to a constant with the right predicate. - auto rootCmp = rootMux.getCond().getDefiningOp(); - if (!rootCmp) return false; - Value indexValue = rootCmp.getLhs(); - - // Return the value to use if the equality match succeeds. - auto getCaseValue = [&](MuxOp mux) -> Value { - return mux.getOperand(1 + unsigned(!isFalseSide)); - }; - - // Return the value to use if the equality match fails. This is the next - // mux in the sequence or the "otherwise" value. - auto getTreeValue = [&](MuxOp mux) -> Value { - return mux.getOperand(1 + unsigned(isFalseSide)); - }; - - // Start scanning the mux tree to see what we've got. Keep track of the - // constant comparison value and the SSA value to use when equal to it. - SmallVector locationsFound; - SmallVector, 4> valuesFound; - - /// Extract constants and values into `valuesFound` and return true if this is - /// part of the mux tree, otherwise return false. - auto collectConstantValues = [&](MuxOp mux) -> bool { - return getMuxChainCondConstant( - mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) { - valuesFound.push_back({cst, getCaseValue(mux)}); - locationsFound.push_back(mux.getCond().getLoc()); - locationsFound.push_back(mux->getLoc()); - }); - }; - - // Make sure the root is a correct comparison with a constant. - if (!collectConstantValues(rootMux)) return false; - - // Make sure that we're not looking at the intermediate node in a mux tree. - if (rootMux->hasOneUse()) { - if (auto userMux = dyn_cast(*rootMux->user_begin())) { - if (getTreeValue(userMux) == rootMux.getResult() && - getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide, - [&](hw::ConstantOp cst) {})) - return false; - } - } - - // Scan up the tree linearly. - auto nextTreeValue = getTreeValue(rootMux); - while (1) { - auto nextMux = nextTreeValue.getDefiningOp(); - if (!nextMux || !nextMux->hasOneUse()) break; - if (!collectConstantValues(nextMux)) break; - nextTreeValue = getTreeValue(nextMux); - } - - // We need to have more than three values to create an array. This is an - // arbitrary threshold which is saying that one or two muxes together is ok, - // but three should be folded. - if (valuesFound.size() < 3) return false; - - // If the array is greater that 9 bits, it will take over 512 elements and - // it will be too large for a single expression. - auto indexWidth = indexValue.getType().cast().getWidth(); - if (indexWidth >= 9) return false; - - // Next we need to see if the values are dense-ish. We don't want to have - // a tremendous number of replicated entries in the array. Some sparsity is - // ok though, so we require the table to be at least 5/8 utilized. - uint64_t tableSize = 1ULL << indexWidth; - if (valuesFound.size() < (tableSize * 5) / 8) - return false; // Not dense enough. - - // Ok, we're going to do the transformation, start by building the table - // filled with the "otherwise" value. - SmallVector table(tableSize, nextTreeValue); - - // Fill in entries in the table from the leaf to the root of the expression. - // This ensures that any duplicate matches end up with the ultimate value, - // which is the one closer to the root. - for (auto &elt : llvm::reverse(valuesFound)) { - uint64_t idx = elt.first.getValue().getZExtValue(); - assert(idx < table.size() && "constant should be same bitwidth as index"); - table[idx] = elt.second; - } - - // The hw.array_create operation has the operand list in unintuitive order - // with a[0] stored as the last element, not the first. - std::reverse(table.begin(), table.end()); - - // Build the array_create and the array_get. - auto fusedLoc = rewriter.getFusedLoc(locationsFound); - auto array = rewriter.create(fusedLoc, table); - replaceOpWithNewOpAndCopyName(rewriter, rootMux, array, - indexValue); - return true; -} - -/// Given a fully associative variadic operation like (a+b+c+d), break the -/// expression into two parts, one without the specified operand (e.g. -/// `tmp = a+b+d`) and one that combines that into the full expression (e.g. -/// `tmp+c`), and return the inner expression. -/// -/// NOTE: This mutates the operation in place if it only has a single user, -/// which assumes that user will be removed. -/// -static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, - size_t operandNo, - PatternRewriter &rewriter) { - assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops"); - assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #"); - - // If this expression already has two operands (the common case) no splitting - // is necessary. - if (fullyAssoc->getNumOperands() == 2) - return fullyAssoc->getOperand(operandNo ^ 1); - - // If the operation has a single use, mutate it in place. - if (fullyAssoc->hasOneUse()) { - fullyAssoc->eraseOperand(operandNo); - return fullyAssoc->getResult(0); - } - - // Form the new operation with the operands that remain. - SmallVector operands; - operands.append(fullyAssoc->getOperands().begin(), - fullyAssoc->getOperands().begin() + operandNo); - operands.append(fullyAssoc->getOperands().begin() + operandNo + 1, - fullyAssoc->getOperands().end()); - Value opWithoutExcluded = createGenericOp( - fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter); - Value excluded = fullyAssoc->getOperand(operandNo); - - Value fullResult = - createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(), - ArrayRef{opWithoutExcluded, excluded}, rewriter); - replaceOpAndCopyName(rewriter, fullyAssoc, fullResult); - return opWithoutExcluded; -} - -/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and -/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand -/// is true. Return true on successful transformation, false if not. -/// -/// These are various forms of "predicated ops" that can be handled with a -/// replicate/and combination. -static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, - PatternRewriter &rewriter) { - // Check to see the operand in question is an operation. If it is a port, - // we can't simplify it. - Operation *subExpr = - (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp(); - if (!subExpr || subExpr->getNumOperands() < 2) return false; - - // If this isn't an operation we can handle, don't spend energy on it. - if (!isa(subExpr)) return false; - - // Check to see if the common value occurs in the operand list for the - // subexpression op. If so, then we can simplify it. - Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue(); - size_t opNo = 0, e = subExpr->getNumOperands(); - while (opNo != e && subExpr->getOperand(opNo) != commonValue) ++opNo; - if (opNo == e) return false; - - // If we got a hit, then go ahead and simplify it! - Value cond = op.getCond(); - - // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)` - // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)` - // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)` - // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)` - if (auto subMux = dyn_cast(subExpr)) { - Value otherValue; - Value subCond = subMux.getCond(); - - // Invert th subCond if needed and dig out the 'b' value. - if (subMux.getTrueValue() == commonValue) - otherValue = subMux.getFalseValue(); - else if (subMux.getFalseValue() == commonValue) { - otherValue = subMux.getTrueValue(); - subCond = createOrFoldNot(op.getLoc(), subCond, rewriter); - } else { - // We can't fold `mux(cond, a, mux(a, x, y))`. - return false; - } - - // Invert the outer cond if needed, and combine the mux conditions. - if (!isTrueOperand) cond = createOrFoldNot(op.getLoc(), cond, rewriter); - cond = rewriter.createOrFold(op.getLoc(), cond, subCond, false); - replaceOpWithNewOpAndCopyName(rewriter, op, cond, commonValue, - otherValue, op.getTwoState()); - return true; - } - - // Invert the condition if needed. Or/Xor invert when dealing with - // TrueOperand, And inverts for False operand. - bool isaAndOp = isa(subExpr); - if (isTrueOperand ^ isaAndOp) - cond = createOrFoldNot(op.getLoc(), cond, rewriter); - - auto extendedCond = - rewriter.createOrFold(op.getLoc(), op.getType(), cond); - - // Cache this information before subExpr is erased by extraction below. - bool isaXorOp = isa(subExpr); - bool isaOrOp = isa(subExpr); - - // Handle the fully associative ops, start by pulling out the subexpression - // from a many operand version of the op. - auto restOfAssoc = - extractOperandFromFullyAssociative(subExpr, opNo, rewriter); - - // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a` - // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a` - if (isaOrOp || isaXorOp) { - auto masked = rewriter.createOrFold(op.getLoc(), extendedCond, - restOfAssoc, false); - if (isaXorOp) - replaceOpWithNewOpAndCopyName(rewriter, op, masked, commonValue, - false); - else - replaceOpWithNewOpAndCopyName(rewriter, op, masked, commonValue, - false); - return true; - } - - // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a` - assert(isaAndOp && "unexpected operation here"); - auto masked = rewriter.createOrFold(op.getLoc(), extendedCond, - restOfAssoc, false); - replaceOpWithNewOpAndCopyName(rewriter, op, masked, commonValue, - false); - return true; -} - -/// This function is invoke when we find a mux with true/false operations that -/// have the same opcode. Check to see if we can strength reduce the mux by -/// applying it to less data by applying this transformation: -/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))` -static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, - Operation *falseOp, - PatternRewriter &rewriter) { - // Right now we only apply to concat. - // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice - if (!isa(trueOp)) return false; - - // Decode the operands, looking through recursive concats and replicates. - SmallVector trueOperands, falseOperands; - getConcatOperands(trueOp->getResult(0), trueOperands); - getConcatOperands(falseOp->getResult(0), falseOperands); - - size_t numTrueOperands = trueOperands.size(); - size_t numFalseOperands = falseOperands.size(); - - if (!numTrueOperands || !numFalseOperands || - (trueOperands.front() != falseOperands.front() && - trueOperands.back() != falseOperands.back())) - return false; - - // Pull all leading shared operands out into their own op if any are common. - if (trueOperands.front() == falseOperands.front()) { - SmallVector operands; - size_t i; - for (i = 0; i < numTrueOperands; ++i) { - Value trueOperand = trueOperands[i]; - if (trueOperand == falseOperands[i]) - operands.push_back(trueOperand); - else - break; - } - if (i == numTrueOperands) { - // Selecting between distinct, but lexically identical, concats. - replaceOpAndCopyName(rewriter, mux, trueOp->getResult(0)); - return true; - } - - Value sharedMSB; - if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); })) - sharedMSB = rewriter.createOrFold( - mux->getLoc(), operands.front(), operands.size()); - else - sharedMSB = rewriter.createOrFold(mux->getLoc(), operands); - operands.clear(); - - // Get a concat of the LSB's on each side. - operands.append(trueOperands.begin() + i, trueOperands.end()); - Value trueLSB = rewriter.createOrFold(trueOp->getLoc(), operands); - operands.clear(); - operands.append(falseOperands.begin() + i, falseOperands.end()); - Value falseLSB = - rewriter.createOrFold(falseOp->getLoc(), operands); - // Merge the LSBs with a new mux and concat the MSB with the LSB to be - // done. - Value lsb = rewriter.createOrFold( - mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState()); - replaceOpWithNewOpAndCopyName(rewriter, mux, sharedMSB, lsb); - return true; - } - - // If trailing operands match, try to commonize them. - if (trueOperands.back() == falseOperands.back()) { - SmallVector operands; - size_t i; - for (i = 0;; ++i) { - Value trueOperand = trueOperands[numTrueOperands - i - 1]; - if (trueOperand == falseOperands[numFalseOperands - i - 1]) - operands.push_back(trueOperand); - else - break; - } - std::reverse(operands.begin(), operands.end()); - Value sharedLSB = rewriter.createOrFold(mux->getLoc(), operands); - operands.clear(); - - // Get a concat of the MSB's on each side. - operands.append(trueOperands.begin(), trueOperands.end() - i); - Value trueMSB = rewriter.createOrFold(trueOp->getLoc(), operands); - operands.clear(); - operands.append(falseOperands.begin(), falseOperands.end() - i); - Value falseMSB = - rewriter.createOrFold(falseOp->getLoc(), operands); - // Merge the MSBs with a new mux and concat the MSB with the LSB to be done. - Value msb = rewriter.createOrFold( - mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState()); - replaceOpWithNewOpAndCopyName(rewriter, mux, msb, sharedLSB); - return true; - } - - return false; -} - -// If both arguments of the mux are arrays with the same elements, sink the -// mux and return a uniform array initializing all elements to it. -static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) { - auto trueVec = op.getTrueValue().getDefiningOp(); - auto falseVec = op.getFalseValue().getDefiningOp(); - if (!trueVec || !falseVec) return false; - if (!trueVec.isUniform() || !falseVec.isUniform()) return false; - - auto mux = rewriter.create( - op.getLoc(), op.getCond(), trueVec.getUniformElement(), - falseVec.getUniformElement(), op.getTwoState()); - - SmallVector values(trueVec.getInputs().size(), mux); - rewriter.replaceOpWithNewOp(op, values); - return true; -} - -namespace { -struct MuxRewriter : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MuxOp op, - PatternRewriter &rewriter) const override; -}; - -LogicalResult MuxRewriter::matchAndRewrite(MuxOp op, - PatternRewriter &rewriter) const { - // If the op has a SV attribute, don't optimize it. - if (hasSVAttributes(op)) return failure(); - APInt value; - - if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) { - if (value.getBitWidth() == 1) { - // mux(a, 0, b) -> and(~a, b) for single-bit values. - if (value.isZero()) { - auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter); - replaceOpWithNewOpAndCopyName(rewriter, op, notCond, - op.getFalseValue(), false); - return success(); - } - - // mux(a, 1, b) -> or(a, b) for single-bit values. - replaceOpWithNewOpAndCopyName(rewriter, op, op.getCond(), - op.getFalseValue(), false); - return success(); - } - - // Check for mux of two constants. There are many ways to simplify them. - APInt value2; - if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) { - // When both inputs are constants and differ by only one bit, we can - // simplify by splitting the mux into up to three contiguous chunks: one - // for the differing bit and up to two for the bits that are the same. - // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0) - APInt xorValue = value ^ value2; - if (xorValue.isPowerOf2()) { - unsigned leadingZeros = xorValue.countLeadingZeros(); - unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1; - SmallVector operands; - - // Concat operands go from MSB to LSB, so we handle chunks in reverse - // order of bit indexes. - // For the chunks that are identical (i.e. correspond to 0s in - // xorValue), we can extract directly from either input value, and we - // arbitrarily pick the trueValue(). - - if (leadingZeros > 0) - operands.push_back(rewriter.createOrFold( - op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros)); - - // Handle the differing bit, which should simplify into either cond or - // ~cond. - auto v1 = rewriter.createOrFold( - op.getLoc(), op.getTrueValue(), trailingZeros, 1); - auto v2 = rewriter.createOrFold( - op.getLoc(), op.getFalseValue(), trailingZeros, 1); - operands.push_back(rewriter.createOrFold( - op.getLoc(), op.getCond(), v1, v2, false)); - - if (trailingZeros > 0) - operands.push_back(rewriter.createOrFold( - op.getLoc(), op.getTrueValue(), 0, trailingZeros)); - - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - operands); - return success(); - } - - // If the true value is all ones and the false is all zeros then we have a - // replicate pattern. - if (value.isAllOnes() && value2.isZero()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - op.getCond()); - return success(); - } - } - } - - if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) && - value.getBitWidth() == 1) { - // mux(a, b, 0) -> and(a, b) for single-bit values. - if (value.isZero()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getCond(), - op.getTrueValue(), false); - return success(); - } - - // mux(a, b, 1) -> or(~a, b) for single-bit values. - // falseValue() is known to be a single-bit 1, which we can use for - // the 1 in the representation of ~ using xor. - auto notCond = rewriter.createOrFold(op.getLoc(), op.getCond(), - op.getFalseValue(), false); - replaceOpWithNewOpAndCopyName(rewriter, op, notCond, - op.getTrueValue(), false); - return success(); - } - - // mux(!a, b, c) -> mux(a, c, b) - Value subExpr; - Operation *condOp = op.getCond().getDefiningOp(); - if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) && - op.getTwoState()) { - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), subExpr, - op.getFalseValue(), op.getTrueValue(), - true); - return success(); - } - - // Same but with Demorgan's law. - // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x) - // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x) - if (condOp && condOp->hasOneUse()) { - SmallVector invertedOperands; - - /// Scan all the operands to see if they are complemented. If so, build a - /// vector of them and return true, otherwise return false. - auto getInvertedOperands = [&]() -> bool { - for (Value operand : condOp->getOperands()) { - if (matchPattern(operand, m_Complement(m_Any(&subExpr)))) - invertedOperands.push_back(subExpr); - else - return false; - } - return true; - }; - - if (isa(condOp) && getInvertedOperands()) { - auto newOr = - rewriter.createOrFold(op.getLoc(), invertedOperands, false); - replaceOpWithNewOpAndCopyName(rewriter, op, newOr, - op.getFalseValue(), - op.getTrueValue(), op.getTwoState()); - return success(); - } - if (isa(condOp) && getInvertedOperands()) { - auto newAnd = - rewriter.createOrFold(op.getLoc(), invertedOperands, false); - replaceOpWithNewOpAndCopyName(rewriter, op, newAnd, - op.getFalseValue(), - op.getTrueValue(), op.getTwoState()); - return success(); - } - } - - if (auto falseMux = - dyn_cast_or_null(op.getFalseValue().getDefiningOp())) { - // mux(selector, x, mux(selector, y, z) = mux(selector, x, z) - if (op.getCond() == falseMux.getCond()) { - replaceOpWithNewOpAndCopyName( - rewriter, op, op.getCond(), op.getTrueValue(), - falseMux.getFalseValue(), op.getTwoStateAttr()); - return success(); - } - - // Check to see if we can fold a mux tree into an array_create/get pair. - if (foldMuxChain(op, /*isFalse*/ true, rewriter)) return success(); - } - - if (auto trueMux = - dyn_cast_or_null(op.getTrueValue().getDefiningOp())) { - // mux(selector, mux(selector, a, b), c) = mux(selector, a, c) - if (op.getCond() == trueMux.getCond()) { - replaceOpWithNewOpAndCopyName( - rewriter, op, op.getCond(), trueMux.getTrueValue(), - op.getFalseValue(), op.getTwoStateAttr()); - return success(); - } - - // Check to see if we can fold a mux tree into an array_create/get pair. - if (foldMuxChain(op, /*isFalseSide*/ false, rewriter)) return success(); - } - - // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c)) - if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), - falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); - trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && - trueMux.getTrueValue() == falseMux.getTrueValue()) { - auto subMux = rewriter.create( - rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}), - op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue()); - replaceOpWithNewOpAndCopyName(rewriter, op, trueMux.getCond(), - trueMux.getTrueValue(), subMux, - op.getTwoStateAttr()); - return success(); - } - - // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b) - if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), - falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); - trueMux && falseMux && trueMux.getCond() == falseMux.getCond() && - trueMux.getFalseValue() == falseMux.getFalseValue()) { - auto subMux = rewriter.create( - rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}), - op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue()); - replaceOpWithNewOpAndCopyName(rewriter, op, trueMux.getCond(), - subMux, trueMux.getFalseValue(), - op.getTwoStateAttr()); - return success(); - } - - // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b) - if (auto trueMux = dyn_cast_or_null(op.getTrueValue().getDefiningOp()), - falseMux = dyn_cast_or_null(op.getFalseValue().getDefiningOp()); - trueMux && falseMux && - trueMux.getTrueValue() == falseMux.getTrueValue() && - trueMux.getFalseValue() == falseMux.getFalseValue()) { - auto subMux = rewriter.create( - rewriter.getFusedLoc( - {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}), - op.getCond(), trueMux.getCond(), falseMux.getCond()); - replaceOpWithNewOpAndCopyName( - rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(), - op.getTwoStateAttr()); - return success(); - } - - // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a - if (foldCommonMuxValue(op, false, rewriter)) return success(); - // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a - if (foldCommonMuxValue(op, true, rewriter)) return success(); - - // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))` - if (Operation *trueOp = op.getTrueValue().getDefiningOp()) - if (Operation *falseOp = op.getFalseValue().getDefiningOp()) - if (trueOp->getName() == falseOp->getName()) - if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter)) - return success(); - - // extracts only of mux(...) -> mux(extract()...) - if (narrowOperationWidth(op, true, rewriter)) return success(); - - // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2)) - if (foldMuxOfUniformArrays(op, rewriter)) return success(); - - return failure(); -} - -static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) { - // Do not fold uniform or singleton arrays to avoid duplicating muxes. - if (op.getInputs().empty() || op.isUniform()) return false; - auto inputs = op.getInputs(); - if (inputs.size() <= 1) return false; - - // Check the operands to the array create. Ensure all of them are the - // same op with the same number of operands. - auto first = inputs[0].getDefiningOp(); - if (!first || hasSVAttributes(first)) return false; - - // Check whether all operands are muxes with the same condition. - for (size_t i = 1, n = inputs.size(); i < n; ++i) { - auto input = inputs[i].getDefiningOp(); - if (!input || first.getCond() != input.getCond()) return false; - } - - // Collect the true and the false branches into arrays. - SmallVector trues{first.getTrueValue()}; - SmallVector falses{first.getFalseValue()}; - SmallVector locs{first->getLoc()}; - bool isTwoState = true; - for (size_t i = 1, n = inputs.size(); i < n; ++i) { - auto input = inputs[i].getDefiningOp(); - trues.push_back(input.getTrueValue()); - falses.push_back(input.getFalseValue()); - locs.push_back(input->getLoc()); - if (!input.getTwoState()) isTwoState = false; - } - - // Define the location of the array create as the aggregate of all muxes. - auto loc = FusedLoc::get(op.getContext(), locs); - - // Replace the create with an aggregate operation. Push the create op - // into the operands of the aggregate operation. - auto arrayTy = op.getType(); - auto trueValues = rewriter.create(loc, arrayTy, trues); - auto falseValues = rewriter.create(loc, arrayTy, falses); - rewriter.replaceOpWithNewOp(op, arrayTy, first.getCond(), - trueValues, falseValues, isTwoState); - return true; -} - -struct ArrayRewriter : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(hw::ArrayCreateOp op, - PatternRewriter &rewriter) const override { - if (foldArrayOfMuxes(op, rewriter)) return success(); - return failure(); - } -}; - -} // namespace - -void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.insert(context); -} - -//===----------------------------------------------------------------------===// -// ICmpOp -//===----------------------------------------------------------------------===// - -// Calculate the result of a comparison when the LHS and RHS are both -// constants. -static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, - const APInt &rhs) { - switch (predicate) { - case ICmpPredicate::eq: - return lhs.eq(rhs); - case ICmpPredicate::ne: - return lhs.ne(rhs); - case ICmpPredicate::slt: - return lhs.slt(rhs); - case ICmpPredicate::sle: - return lhs.sle(rhs); - case ICmpPredicate::sgt: - return lhs.sgt(rhs); - case ICmpPredicate::sge: - return lhs.sge(rhs); - case ICmpPredicate::ult: - return lhs.ult(rhs); - case ICmpPredicate::ule: - return lhs.ule(rhs); - case ICmpPredicate::ugt: - return lhs.ugt(rhs); - case ICmpPredicate::uge: - return lhs.uge(rhs); - case ICmpPredicate::ceq: - return lhs.eq(rhs); - case ICmpPredicate::cne: - return lhs.ne(rhs); - case ICmpPredicate::weq: - return lhs.eq(rhs); - case ICmpPredicate::wne: - return lhs.ne(rhs); - } - llvm_unreachable("unknown comparison predicate"); -} - -// Returns the result of applying the predicate when the LHS and RHS are the -// exact same value. -static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) { - switch (predicate) { - case ICmpPredicate::eq: - case ICmpPredicate::sle: - case ICmpPredicate::sge: - case ICmpPredicate::ule: - case ICmpPredicate::uge: - case ICmpPredicate::ceq: - case ICmpPredicate::weq: - return true; - case ICmpPredicate::ne: - case ICmpPredicate::slt: - case ICmpPredicate::sgt: - case ICmpPredicate::ult: - case ICmpPredicate::ugt: - case ICmpPredicate::cne: - case ICmpPredicate::wne: - return false; - } - llvm_unreachable("unknown comparison predicate"); -} - -OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) { - // gt a, a -> false - // gte a, a -> true - if (getLhs() == getRhs()) { - auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return IntegerAttr::get(getType(), val); - } - - // gt 1, 2 -> false - if (auto lhs = adaptor.getLhs().dyn_cast_or_null()) { - if (auto rhs = adaptor.getRhs().dyn_cast_or_null()) { - auto val = - applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return IntegerAttr::get(getType(), val); - } - } - return {}; -} - -// Given a range of operands, computes the number of matching prefix and -// suffix elements. This does not perform cross-element matching. -template -static size_t computeCommonPrefixLength(const Range &a, const Range &b) { - size_t commonPrefixLength = 0; - auto ia = a.begin(); - auto ib = b.begin(); - - for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) { - if (*ia != *ib) { - break; - } - } - - return commonPrefixLength; -} - -static size_t getTotalWidth(ArrayRef operands) { - size_t totalWidth = 0; - for (auto operand : operands) { - // getIntOrFloatBitWidth should never raise, since all arguments to - // ConcatOp are integers. - ssize_t width = operand.getType().getIntOrFloatBitWidth(); - assert(width >= 0); - totalWidth += width; - } - return totalWidth; -} - -/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise -/// comparison on common prefix and suffixes. Returns success() if a rewriting -/// happens. This handles both concat and replicate. -static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, - Operation *rhs, - PatternRewriter &rewriter) { - // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and - // all elements have non-zero length. Both these invariants are verified - // by the ConcatOp verifier. - SmallVector lhsOperands, rhsOperands; - getConcatOperands(lhs->getResult(0), lhsOperands); - getConcatOperands(rhs->getResult(0), rhsOperands); - ArrayRef lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands; - - auto formCatOrReplicate = [&](Location loc, - ArrayRef operands) -> Value { - assert(!operands.empty()); - Value sameElement = operands[0]; - for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i) - if (sameElement != operands[i]) sameElement = Value(); - if (sameElement) - return rewriter.createOrFold(loc, sameElement, - operands.size()); - return rewriter.createOrFold(loc, operands); - }; - - auto replaceWith = [&](ICmpPredicate predicate, Value lhs, - Value rhs) -> LogicalResult { - replaceOpWithNewOpAndCopyName(rewriter, op, predicate, lhs, rhs, - op.getTwoState()); - return success(); - }; - - size_t commonPrefixLength = - computeCommonPrefixLength(lhsOperands, rhsOperands); - if (commonPrefixLength == lhsOperands.size()) { - // cat(a, b, c) == cat(a, b, c) -> 1 - bool result = applyCmpPredicateToEqualOperands(op.getPredicate()); - replaceOpWithNewOpAndCopyName(rewriter, op, - APInt(1, result)); - return success(); - } - - size_t commonSuffixLength = computeCommonPrefixLength( - llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef)); - - size_t commonPrefixTotalWidth = - getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength)); - size_t commonSuffixTotalWidth = - getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength)); - auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength) - .drop_back(commonSuffixLength); - auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength) - .drop_back(commonSuffixLength); - - auto replaceWithoutReplicatingSignBit = [&]() { - auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly); - auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly); - return replaceWith(op.getPredicate(), newLhs, newRhs); - }; - - auto replaceWithReplicatingSignBit = [&]() { - auto firstNonEmptyValue = lhsOperands[0]; - auto firstNonEmptyElemWidth = - firstNonEmptyValue.getType().getIntOrFloatBitWidth(); - Value signBit = rewriter.createOrFold( - op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1); - - auto newLhs = rewriter.create(lhs->getLoc(), signBit, lhsOnly); - auto newRhs = rewriter.create(rhs->getLoc(), signBit, rhsOnly); - return replaceWith(op.getPredicate(), newLhs, newRhs); - }; - - if (ICmpOp::isPredicateSigned(op.getPredicate())) { - // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y)) - if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0) - return replaceWithoutReplicatingSignBit(); - - // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x), - // cat(sgn(b), ..y)) Note that we cannot perform this optimization if - // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign - // bit. Doing the rewrite can result in an infinite loop. - if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0) - return replaceWithReplicatingSignBit(); - - } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) { - // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y)) - return replaceWithoutReplicatingSignBit(); - } - - return failure(); -} - -/// Given an equality comparison with a constant value and some operand that has -/// known bits, simplify the comparison to check only the unknown bits of the -/// input. -/// -/// One simple example of this is that `concat(0, stuff) == 0` can be simplified -/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to -/// `extract x[1:0] == 0` -static void combineEqualityICmpWithKnownBitsAndConstant( - ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, - PatternRewriter &rewriter) { - // If any of the known bits disagree with any of the comparison bits, then - // we can constant fold this comparison right away. - APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One; - if ((bitsKnown & rhsCst) != bitAnalysis.One) { - // If we discover a mismatch then we know an "eq" comparison is false - // and a "ne" comparison is true! - bool result = cmpOp.getPredicate() == ICmpPredicate::ne; - replaceOpWithNewOpAndCopyName(rewriter, cmpOp, - APInt(1, result)); - return; - } - - // Check to see if we can prove the result entirely of the comparison (in - // which we bail out early), otherwise build a list of values to concat and a - // smaller constant to compare against. - SmallVector newConcatOperands; - auto newConstant = APInt::getZeroWidth(); - - // Ok, some (maybe all) bits are known and some others may be unknown. - // Extract out segments of the operand and compare against the - // corresponding bits. - unsigned knownMSB = bitsKnown.countLeadingOnes(); - - Value operand = cmpOp.getLhs(); - - // Ok, some bits are known but others are not. Extract out sequences of - // bits that are unknown and compare just those bits. We work from MSB to - // LSB. - while (knownMSB != bitsKnown.getBitWidth()) { - // Drop any high bits that are known. - if (knownMSB) - bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB); - - // Find the span of unknown bits, and extract it. - unsigned unknownBits = bitsKnown.countLeadingZeros(); - unsigned lowBit = bitsKnown.getBitWidth() - unknownBits; - auto spanOperand = rewriter.createOrFold( - operand.getLoc(), operand, /*lowBit=*/lowBit, - /*bitWidth=*/unknownBits); - auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits); - - // Add this info to the concat we're generating. - newConcatOperands.push_back(spanOperand); - // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values - // newConstant = newConstant.concat(spanConstant); - if (newConstant.getBitWidth() != 0) - newConstant = newConstant.concat(spanConstant); - else - newConstant = spanConstant; - - // Drop the unknown bits in prep for the next chunk. - unsigned newWidth = bitsKnown.getBitWidth() - unknownBits; - bitsKnown = bitsKnown.trunc(newWidth); - knownMSB = bitsKnown.countLeadingOnes(); - } - - // If all the operands to the concat are foldable then we have an identity - // situation where all the sub-elements equal each other. This implies that - // the overall result is foldable. - if (newConcatOperands.empty()) { - bool result = cmpOp.getPredicate() == ICmpPredicate::eq; - replaceOpWithNewOpAndCopyName(rewriter, cmpOp, - APInt(1, result)); - return; - } - - // If we have a single operand remaining, use it, otherwise form a concat. - Value concatResult = - rewriter.createOrFold(operand.getLoc(), newConcatOperands); - - // Form the comparison against the smaller constant. - auto newConstantOp = rewriter.create( - cmpOp.getOperand(1).getLoc(), newConstant); - - replaceOpWithNewOpAndCopyName(rewriter, cmpOp, cmpOp.getPredicate(), - concatResult, newConstantOp, - cmpOp.getTwoState()); -} - -// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2). -static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, - const APInt &rhs, - PatternRewriter &rewriter) { - auto ip = rewriter.saveInsertionPoint(); - rewriter.setInsertionPoint(xorOp); - - auto xorRHS = xorOp.getOperands().back().getDefiningOp(); - auto newRHS = rewriter.create(xorRHS->getLoc(), - xorRHS.getValue() ^ rhs); - Value newLHS; - switch (xorOp.getNumOperands()) { - case 1: - // This isn't common but is defined so we need to handle it. - newLHS = rewriter.create( - xorOp.getLoc(), APInt::getZero(rhs.getBitWidth())); - break; - case 2: - // The binary case is the most common. - newLHS = xorOp.getOperand(0); - break; - default: - // The general case forces us to form a new xor with the remaining - // operands. - SmallVector newOperands(xorOp.getOperands()); - newOperands.pop_back(); - newLHS = rewriter.create(xorOp.getLoc(), newOperands, false); - break; - } - - bool xorMultipleUses = !xorOp->hasOneUse(); - - // If the xor has multiple uses (not just the compare, then we need/want to - // replace them as well. - if (xorMultipleUses) - replaceOpWithNewOpAndCopyName(rewriter, xorOp, newLHS, xorRHS, - false); - - // Replace the comparison. - rewriter.restoreInsertionPoint(ip); - replaceOpWithNewOpAndCopyName(rewriter, cmpOp, cmpOp.getPredicate(), - newLHS, newRHS, false); -} - -LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) { - APInt lhs, rhs; - - // icmp 1, x -> icmp x, 1 - if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) { - assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) && - "Should be folded"); - replaceOpWithNewOpAndCopyName( - rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()), - op.getRhs(), op.getLhs(), op.getTwoState()); - return success(); - } - - // Canonicalize with RHS constant - if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) { - auto getConstant = [&](APInt constant) -> Value { - return rewriter.create(op.getLoc(), std::move(constant)); - }; - - auto replaceWith = [&](ICmpPredicate predicate, Value lhs, - Value rhs) -> LogicalResult { - replaceOpWithNewOpAndCopyName(rewriter, op, predicate, lhs, rhs, - op.getTwoState()); - return success(); - }; - - auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult { - replaceOpWithNewOpAndCopyName(rewriter, op, - APInt(1, constant)); - return success(); - }; - - switch (op.getPredicate()) { - case ICmpPredicate::slt: - // x < max -> x != max - if (rhs.isMaxSignedValue()) - return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); - // x < min -> false - if (rhs.isMinSignedValue()) return replaceWithConstantI1(0); - // x < min+1 -> x == min - if ((rhs - 1).isMinSignedValue()) - return replaceWith(ICmpPredicate::eq, op.getLhs(), - getConstant(rhs - 1)); - break; - case ICmpPredicate::sgt: - // x > min -> x != min - if (rhs.isMinSignedValue()) - return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); - // x > max -> false - if (rhs.isMaxSignedValue()) return replaceWithConstantI1(0); - // x > max-1 -> x == max - if ((rhs + 1).isMaxSignedValue()) - return replaceWith(ICmpPredicate::eq, op.getLhs(), - getConstant(rhs + 1)); - break; - case ICmpPredicate::ult: - // x < max -> x != max - if (rhs.isAllOnes()) - return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); - // x < min -> false - if (rhs.isZero()) return replaceWithConstantI1(0); - // x < min+1 -> x == min - if ((rhs - 1).isZero()) - return replaceWith(ICmpPredicate::eq, op.getLhs(), - getConstant(rhs - 1)); - - // x < 0xE0 -> extract(x, 5..7) != 0b111 - if (rhs.countLeadingOnes() + rhs.countTrailingZeros() == - rhs.getBitWidth()) { - auto numOnes = rhs.countLeadingOnes(); - auto smaller = rewriter.create( - op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes); - return replaceWith(ICmpPredicate::ne, smaller, - getConstant(APInt::getAllOnes(numOnes))); - } - - break; - case ICmpPredicate::ugt: - // x > min -> x != min - if (rhs.isZero()) - return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs()); - // x > max -> false - if (rhs.isAllOnes()) return replaceWithConstantI1(0); - // x > max-1 -> x == max - if ((rhs + 1).isAllOnes()) - return replaceWith(ICmpPredicate::eq, op.getLhs(), - getConstant(rhs + 1)); - - // x > 0x07 -> extract(x, 3..7) != 0b00000 - if ((rhs + 1).isPowerOf2()) { - auto numOnes = rhs.countTrailingOnes(); - auto newWidth = rhs.getBitWidth() - numOnes; - auto smaller = rewriter.create(op.getLoc(), op.getLhs(), - numOnes, newWidth); - return replaceWith(ICmpPredicate::ne, smaller, - getConstant(APInt::getZero(newWidth))); - } - - break; - case ICmpPredicate::sle: - // x <= max -> true - if (rhs.isMaxSignedValue()) return replaceWithConstantI1(1); - // x <= c -> x < (c+1) - return replaceWith(ICmpPredicate::slt, op.getLhs(), - getConstant(rhs + 1)); - case ICmpPredicate::sge: - // x >= min -> true - if (rhs.isMinSignedValue()) return replaceWithConstantI1(1); - // x >= c -> x > (c-1) - return replaceWith(ICmpPredicate::sgt, op.getLhs(), - getConstant(rhs - 1)); - case ICmpPredicate::ule: - // x <= max -> true - if (rhs.isAllOnes()) return replaceWithConstantI1(1); - // x <= c -> x < (c+1) - return replaceWith(ICmpPredicate::ult, op.getLhs(), - getConstant(rhs + 1)); - case ICmpPredicate::uge: - // x >= min -> true - if (rhs.isZero()) return replaceWithConstantI1(1); - // x >= c -> x > (c-1) - return replaceWith(ICmpPredicate::ugt, op.getLhs(), - getConstant(rhs - 1)); - case ICmpPredicate::eq: - if (rhs.getBitWidth() == 1) { - if (rhs.isZero()) { - // x == 0 -> x ^ 1 - replaceOpWithNewOpAndCopyName(rewriter, op, op.getLhs(), - getConstant(APInt(1, 1)), - op.getTwoState()); - return success(); - } - if (rhs.isAllOnes()) { - // x == 1 -> x - replaceOpAndCopyName(rewriter, op, op.getLhs()); - return success(); - } - } - break; - case ICmpPredicate::ne: - if (rhs.getBitWidth() == 1) { - if (rhs.isZero()) { - // x != 0 -> x - replaceOpAndCopyName(rewriter, op, op.getLhs()); - return success(); - } - if (rhs.isAllOnes()) { - // x != 1 -> x ^ 1 - replaceOpWithNewOpAndCopyName(rewriter, op, op.getLhs(), - getConstant(APInt(1, 1)), - op.getTwoState()); - return success(); - } - } - break; - case ICmpPredicate::ceq: - case ICmpPredicate::cne: - case ICmpPredicate::weq: - case ICmpPredicate::wne: - break; - } - - // We have some specific optimizations for comparison with a constant that - // are only supported for equality comparisons. - if (op.getPredicate() == ICmpPredicate::eq || - op.getPredicate() == ICmpPredicate::ne) { - // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts - // with a smaller constant. We only support equality comparisons for - // this. - auto knownBits = computeKnownBits(op.getLhs()); - if (!knownBits.isUnknown()) - return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs, - rewriter), - success(); - - // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), - // cst1^cst2). - if (auto xorOp = op.getLhs().getDefiningOp()) - if (xorOp.getOperands().back().getDefiningOp()) - return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter), - success(); - - // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or - // all one. - if (auto replicateOp = op.getLhs().getDefiningOp()) - if (rhs.isAllOnes() || rhs.isZero()) { - auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth(); - auto cst = rewriter.create( - op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width) - : APInt::getZero(width)); - replaceOpWithNewOpAndCopyName(rewriter, op, op.getPredicate(), - replicateOp.getInput(), cst, - op.getTwoState()); - return success(); - } - } - } - - // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a, - // b), cat(c, d)). contains special handling for sign bit in signed - // compressions. - if (Operation *opLHS = op.getLhs().getDefiningOp()) - if (Operation *opRHS = op.getRhs().getDefiningOp()) - if (isa(opLHS) && - isa(opRHS)) { - if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter))) - return success(); - } - - return failure(); -} diff --git a/lib/circt/Dialect/Comb/CombOps.cpp b/lib/circt/Dialect/Comb/CombOps.cpp index 92ba7aa21e..bc7ed6fa1e 100644 --- a/lib/circt/Dialect/Comb/CombOps.cpp +++ b/lib/circt/Dialect/Comb/CombOps.cpp @@ -12,7 +12,6 @@ #include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/HW/HWOps.h" #include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project #include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project @@ -23,38 +22,39 @@ using namespace comb; /// Create a sign extension operation from a value of integer type to an equal /// or larger integer type. -Value comb::createOrFoldSExt(Location loc, Value value, Type destTy, - OpBuilder &builder) { - IntegerType valueType = value.getType().dyn_cast(); - assert(valueType && destTy.isa() && - valueType.getWidth() <= destTy.getIntOrFloatBitWidth() && - valueType.getWidth() != 0 && "invalid sext operands"); - // If already the right size, we are done. - if (valueType == destTy) return value; - - // sext is concat with a replicate of the sign bits and the bottom part. - auto signBit = - builder.createOrFold(loc, value, valueType.getWidth() - 1, 1); - auto signBits = builder.createOrFold( - loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth()); - return builder.createOrFold(loc, signBits, value); -} - -Value comb::createOrFoldSExt(Value value, Type destTy, - ImplicitLocOpBuilder &builder) { - return createOrFoldSExt(builder.getLoc(), value, destTy, builder); -} - -Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder, - bool twoState) { - auto allOnes = builder.create(loc, value.getType(), -1); - return builder.createOrFold(loc, value, allOnes, twoState); -} - -Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, - bool twoState) { - return createOrFoldNot(builder.getLoc(), value, builder, twoState); -} +// Value comb::createOrFoldSExt(Location loc, Value value, Type destTy, +// OpBuilder &builder) { +// IntegerType valueType = value.getType().dyn_cast(); +// assert(valueType && destTy.isa() && +// valueType.getWidth() <= destTy.getIntOrFloatBitWidth() && +// valueType.getWidth() != 0 && "invalid sext operands"); +// // If already the right size, we are done. +// if (valueType == destTy) return value; + +// // sext is concat with a replicate of the sign bits and the bottom part. +// auto signBit = +// builder.createOrFold(loc, value, valueType.getWidth() - 1, +// 1); +// auto signBits = builder.createOrFold( +// loc, signBit, destTy.getIntOrFloatBitWidth() - valueType.getWidth()); +// return builder.createOrFold(loc, signBits, value); +// } + +// Value comb::createOrFoldSExt(Value value, Type destTy, +// ImplicitLocOpBuilder &builder) { +// return createOrFoldSExt(builder.getLoc(), value, destTy, builder); +// } + +// Value comb::createOrFoldNot(Location loc, Value value, OpBuilder &builder, +// bool twoState) { +// auto allOnes = builder.create(loc, value.getType(), -1); +// return builder.createOrFold(loc, value, allOnes, twoState); +// } + +// Value comb::createOrFoldNot(Value value, ImplicitLocOpBuilder &builder, +// bool twoState) { +// return createOrFoldNot(builder.getLoc(), value, builder, twoState); +// } //===----------------------------------------------------------------------===// // ICmpOp @@ -152,27 +152,27 @@ ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) { llvm_unreachable("unknown comparison predicate"); } -/// Return true if this is an equality test with -1, which is a "reduction -/// and" operation in Verilog. -bool ICmpOp::isEqualAllOnes() { - if (getPredicate() != ICmpPredicate::eq) return false; +// /// Return true if this is an equality test with -1, which is a "reduction +// /// and" operation in Verilog. +// bool ICmpOp::isEqualAllOnes() { +// if (getPredicate() != ICmpPredicate::eq) return false; - if (auto op1 = - dyn_cast_or_null(getOperand(1).getDefiningOp())) - return op1.getValue().isAllOnes(); - return false; -} +// if (auto op1 = +// dyn_cast_or_null(getOperand(1).getDefiningOp())) +// return op1.getValue().isAllOnes(); +// return false; +// } -/// Return true if this is a not equal test with 0, which is a "reduction -/// or" operation in Verilog. -bool ICmpOp::isNotEqualZero() { - if (getPredicate() != ICmpPredicate::ne) return false; +// /// Return true if this is a not equal test with 0, which is a "reduction +// /// or" operation in Verilog. +// bool ICmpOp::isNotEqualZero() { +// if (getPredicate() != ICmpPredicate::ne) return false; - if (auto op1 = - dyn_cast_or_null(getOperand(1).getDefiningOp())) - return op1.getValue().isZero(); - return false; -} +// if (auto op1 = +// dyn_cast_or_null(getOperand(1).getDefiningOp())) +// return op1.getValue().isZero(); +// return false; +// } //===----------------------------------------------------------------------===// // Unary Operations @@ -219,12 +219,12 @@ LogicalResult XorOp::verify() { return verifyUTBinOp(*this); } /// Return true if this is a two operand xor with an all ones constant as its /// RHS operand. -bool XorOp::isBinaryNot() { - if (getNumOperands() != 2) return false; - if (auto cst = getOperand(1).getDefiningOp()) - if (cst.getValue().isAllOnes()) return true; - return false; -} +// bool XorOp::isBinaryNot() { +// if (getNumOperands() != 2) return false; +// if (auto cst = getOperand(1).getDefiningOp()) +// if (cst.getValue().isAllOnes()) return true; +// return false; +// } //===----------------------------------------------------------------------===// // ConcatOp diff --git a/lib/circt/Dialect/Comb/Transforms/CMakeLists.txt b/lib/circt/Dialect/Comb/Transforms/CMakeLists.txt deleted file mode 100644 index 9043ccf5de..0000000000 --- a/lib/circt/Dialect/Comb/Transforms/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_circt_dialect_library(CIRCTCombTransforms - LowerComb.cpp - - DEPENDS - CIRCTCombTransformsIncGen - - LINK_LIBS PUBLIC - CIRCTHW - CIRCTSV - CIRCTComb - CIRCTSupport - MLIRIR - MLIRPass - MLIRTransformUtils -) diff --git a/lib/circt/Dialect/Comb/Transforms/LowerComb.cpp b/lib/circt/Dialect/Comb/Transforms/LowerComb.cpp deleted file mode 100644 index a22aae313a..0000000000 --- a/lib/circt/Dialect/Comb/Transforms/LowerComb.cpp +++ /dev/null @@ -1,88 +0,0 @@ -//===- LowerComb.cpp - Lower some ops in comb -------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PassDetails.h" -#include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/Comb/CombPasses.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - -using namespace circt; -using namespace circt::comb; - -namespace circt { -namespace comb { -#define GEN_PASS_DEF_LOWERCOMB -#include "include/circt/Dialect/Comb/Passes.h.inc" -} // namespace comb -} // namespace circt - -namespace { -/// Lower truth tables to mux trees. -struct TruthTableToMuxTree : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - private: - /// Get a mux tree for `inputs` corresponding to the given truth table. Do - /// this recursively by dividing the table in half for each input. - // NOLINTNEXTLINE(misc-no-recursion) - Value getMux(Location loc, OpBuilder &b, Value t, Value f, - ArrayRef table, Operation::operand_range inputs) const { - assert(table.size() == (1ull << inputs.size())); - if (table.size() == 1) return table.front() ? t : f; - - size_t half = table.size() / 2; - Value if1 = - getMux(loc, b, t, f, table.drop_front(half), inputs.drop_front()); - Value if0 = - getMux(loc, b, t, f, table.drop_back(half), inputs.drop_front()); - return b.create(loc, inputs.front(), if1, if0, false); - } - - public: - LogicalResult matchAndRewrite(TruthTableOp op, OpAdaptor adaptor, - ConversionPatternRewriter &b) const override { - Location loc = op.getLoc(); - SmallVector table( - llvm::map_range(op.getLookupTableAttr().getAsValueRange(), - [](const APInt &a) { return !a.isZero(); })); - Value t = b.create(loc, b.getIntegerAttr(b.getI1Type(), 1)); - Value f = b.create(loc, b.getIntegerAttr(b.getI1Type(), 0)); - - Value tree = getMux(loc, b, t, f, table, op.getInputs()); - b.updateRootInPlace(tree.getDefiningOp(), [&]() { - tree.getDefiningOp()->setDialectAttrs(op->getDialectAttrs()); - }); - b.replaceOp(op, tree); - return success(); - } -}; -} // namespace - -namespace { -class LowerCombPass : public impl::LowerCombBase { - public: - using LowerCombBase::LowerCombBase; - - void runOnOperation() override; -}; -} // namespace - -void LowerCombPass::runOnOperation() { - ModuleOp module = getOperation(); - - ConversionTarget target(getContext()); - RewritePatternSet patterns(&getContext()); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - target.addIllegalOp(); - - patterns.add(patterns.getContext()); - - if (failed(applyPartialConversion(module, target, std::move(patterns)))) - return signalPassFailure(); -} diff --git a/lib/circt/Dialect/Comb/Transforms/PassDetails.h b/lib/circt/Dialect/Comb/Transforms/PassDetails.h deleted file mode 100644 index 32345aa242..0000000000 --- a/lib/circt/Dialect/Comb/Transforms/PassDetails.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- PassDetails.h - Comb pass class details ------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -// clang-tidy seems to expect the absolute path in the header guard on some -// systems, so just disable it. -// NOLINTNEXTLINE(llvm-header-guard) -#ifndef DIALECT_COMB_TRANSFORMS_PASSDETAILS_H -#define DIALECT_COMB_TRANSFORMS_PASSDETAILS_H - -#include "include/circt/Dialect/HW/HWOps.h" -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project - -namespace circt { -namespace comb { - -#define GEN_PASS_CLASSES -#include "include/circt/Dialect/Comb/Passes.h.inc" - -} // namespace comb -} // namespace circt - -#endif // DIALECT_COMB_TRANSFORMS_PASSDETAILS_H diff --git a/lib/circt/Dialect/HW/BUILD b/lib/circt/Dialect/HW/BUILD deleted file mode 100644 index 30727b0e18..0000000000 --- a/lib/circt/Dialect/HW/BUILD +++ /dev/null @@ -1,55 +0,0 @@ -package( - default_applicable_licenses = ["@heir//:license"], - default_visibility = ["//visibility:public"], -) - -cc_library( - name = "Dialect", - srcs = glob( - [ - "*.cpp", - ], - exclude = [ - "HWReductions.cpp", - ], - ), - hdrs = [ - "@heir//include/circt/Dialect/Comb:CombDialect.h", - "@heir//include/circt/Dialect/Comb:CombOps.h", - "@heir//include/circt/Dialect/HW:ConversionPatterns.h", - "@heir//include/circt/Dialect/HW:CustomDirectiveImpl.h", - "@heir//include/circt/Dialect/HW:HWAttributes.h", - "@heir//include/circt/Dialect/HW:HWDialect.h", - "@heir//include/circt/Dialect/HW:HWInstanceGraph.h", - "@heir//include/circt/Dialect/HW:HWOpInterfaces.h", - "@heir//include/circt/Dialect/HW:HWOps.h", - "@heir//include/circt/Dialect/HW:HWSymCache.h", - "@heir//include/circt/Dialect/HW:HWTypeInterfaces.h", - "@heir//include/circt/Dialect/HW:HWTypes.h", - "@heir//include/circt/Dialect/HW:HWVisitors.h", - "@heir//include/circt/Dialect/HW:InnerSymbolTable.h", - "@heir//include/circt/Dialect/HW:InstanceImplementation.h", - "@heir//include/circt/Dialect/HW:ModuleImplementation.h", - "@heir//include/circt/Dialect/HW:PortConverter.h", - "@llvm-project//mlir:include/mlir/Transforms/InliningUtils.h", - ], - deps = [ - "@heir//include/circt/Dialect/Comb:dialect_inc_gen", - "@heir//include/circt/Dialect/Comb:enum_inc_gen", - "@heir//include/circt/Dialect/Comb:ops_inc_gen", - "@heir//include/circt/Dialect/HW:attributes_inc_gen", - "@heir//include/circt/Dialect/HW:dialect_inc_gen", - "@heir//include/circt/Dialect/HW:enum_inc_gen", - "@heir//include/circt/Dialect/HW:op_interfaces_inc_gen", - "@heir//include/circt/Dialect/HW:ops_inc_gen", - "@heir//include/circt/Dialect/HW:type_interfaces_inc_gen", - "@heir//include/circt/Dialect/HW:types_inc_gen", - "@heir//lib/circt/Support", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ControlFlowInterfaces", - "@llvm-project//mlir:FunctionInterfaces", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:InferTypeOpInterface", - "@llvm-project//mlir:Transforms", - ], -) diff --git a/lib/circt/Dialect/HW/CMakeLists.txt b/lib/circt/Dialect/HW/CMakeLists.txt deleted file mode 100644 index 83cc1559a4..0000000000 --- a/lib/circt/Dialect/HW/CMakeLists.txt +++ /dev/null @@ -1,52 +0,0 @@ -set(CIRCT_HW_Sources - ConversionPatterns.cpp - CustomDirectiveImpl.cpp - HWAttributes.cpp - HWDialect.cpp - HWInstanceGraph.cpp - HWModuleOpInterface.cpp - HWOpInterfaces.cpp - HWOps.cpp - HWTypeInterfaces.cpp - HWTypes.cpp - InstanceImplementation.cpp - ModuleImplementation.cpp - InnerSymbolTable.cpp - PortConverter.cpp -) - -set(LLVM_OPTIONAL_SOURCES - ${CIRCT_HW_Sources} - HWReductions.cpp -) - -add_circt_dialect_library(CIRCTHW - ${CIRCT_HW_Sources} - - ADDITIONAL_HEADER_DIRS - ${CIRCT_MAIN_INCLUDE_DIR}/circt/Dialect/HW - - DEPENDS - MLIRHWIncGen - MLIRHWAttrIncGen - MLIRHWEnumsIncGen - - LINK_COMPONENTS - Support - - LINK_LIBS PUBLIC - CIRCTSupport - MLIRIR - MLIRInferTypeOpInterface -) - -add_circt_library(CIRCTHWReductions - HWReductions.cpp - - LINK_LIBS PUBLIC - CIRCTReduceLib - CIRCTHW - MLIRIR -) - -add_subdirectory(Transforms) diff --git a/lib/circt/Dialect/HW/ConversionPatterns.cpp b/lib/circt/Dialect/HW/ConversionPatterns.cpp deleted file mode 100644 index 4a0db49e0a..0000000000 --- a/lib/circt/Dialect/HW/ConversionPatterns.cpp +++ /dev/null @@ -1,103 +0,0 @@ -//===- ConversionPatterns.cpp - Common Conversion patterns ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/ConversionPatterns.h" - -#include "include/circt/Dialect/HW/HWTypes.h" - -using namespace circt; - -// Converts a function type wrt. the given type converter. -static FunctionType convertFunctionType(const TypeConverter &typeConverter, - FunctionType type) { - // Convert the original function types. - llvm::SmallVector res, arg; - llvm::transform(type.getResults(), std::back_inserter(res), - [&](Type t) { return typeConverter.convertType(t); }); - llvm::transform(type.getInputs(), std::back_inserter(arg), - [&](Type t) { return typeConverter.convertType(t); }); - - return FunctionType::get(type.getContext(), arg, res); -} - -// Converts a function type wrt. the given type converter. -static hw::ModuleType convertModuleType(const TypeConverter &typeConverter, - hw::ModuleType type) { - // Convert the original function types. - SmallVector ports(type.getPorts()); - for (auto &p : ports) p.type = typeConverter.convertType(p.type); - return hw::ModuleType::get(type.getContext(), ports); -} - -LogicalResult TypeConversionPattern::matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - // Convert the TypeAttrs. - llvm::SmallVector newAttrs; - newAttrs.reserve(op->getAttrs().size()); - for (auto attr : op->getAttrs()) { - if (auto typeAttr = attr.getValue().dyn_cast()) { - auto innerType = typeAttr.getValue(); - // TypeConvert::convertType doesn't handle function types, so we need to - // handle them manually. - if (auto funcType = innerType.dyn_cast()) - innerType = convertFunctionType(*getTypeConverter(), funcType); - else if (auto modType = innerType.dyn_cast()) - innerType = convertModuleType(*getTypeConverter(), modType); - else - innerType = getTypeConverter()->convertType(innerType); - newAttrs.emplace_back(attr.getName(), TypeAttr::get(innerType)); - } else { - newAttrs.push_back(attr); - } - } - - // Convert the result types. - llvm::SmallVector newResults; - if (failed( - getTypeConverter()->convertTypes(op->getResultTypes(), newResults))) - return rewriter.notifyMatchFailure(op->getLoc(), "type conversion failed"); - - // Build the state for the edited clone. - OperationState state(op->getLoc(), op->getName().getStringRef(), operands, - newResults, newAttrs, op->getSuccessors()); - for (size_t i = 0, e = op->getNumRegions(); i < e; ++i) state.addRegion(); - - // Must create the op before running any modifications on the regions so that - // we don't crash with '-debug' and so we have something to 'root update'. - Operation *newOp = rewriter.create(state); - - // Move the regions over, converting the signatures as we go. - rewriter.startRootUpdate(newOp); - for (size_t i = 0, e = op->getNumRegions(); i < e; ++i) { - Region ®ion = op->getRegion(i); - Region *newRegion = &newOp->getRegion(i); - - // TypeConverter::SignatureConversion drops argument locations, so we need - // to manually copy them over (a verifier in e.g. HWModule checks this). - llvm::SmallVector argLocs; - for (auto arg : region.getArguments()) argLocs.push_back(arg.getLoc()); - - // Move the region and convert the region args. - rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin()); - TypeConverter::SignatureConversion result(newRegion->getNumArguments()); - if (failed(getTypeConverter()->convertSignatureArgs( - newRegion->getArgumentTypes(), result))) - return rewriter.notifyMatchFailure(op->getLoc(), - "type conversion failed"); - rewriter.applySignatureConversion(newRegion, result, getTypeConverter()); - - // Apply the argument locations. - for (auto [arg, loc] : llvm::zip(newRegion->getArguments(), argLocs)) - arg.setLoc(loc); - } - rewriter.finalizeRootUpdate(newOp); - - rewriter.replaceOp(op, newOp->getResults()); - return success(); -} diff --git a/lib/circt/Dialect/HW/CustomDirectiveImpl.cpp b/lib/circt/Dialect/HW/CustomDirectiveImpl.cpp deleted file mode 100644 index b782516ec5..0000000000 --- a/lib/circt/Dialect/HW/CustomDirectiveImpl.cpp +++ /dev/null @@ -1,136 +0,0 @@ -//===- CustomDirectiveImpl.cpp - Custom directive definitions -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/CustomDirectiveImpl.h" - -#include "include/circt/Dialect/HW/HWAttributes.h" - -using namespace circt; -using namespace circt::hw; - -ParseResult circt::parseInputPortList( - OpAsmParser &parser, - SmallVectorImpl &inputs, - SmallVectorImpl &inputTypes, ArrayAttr &inputNames) { - SmallVector argNames; - auto parseInputPort = [&]() -> ParseResult { - std::string portName; - if (parser.parseKeywordOrString(&portName)) return failure(); - argNames.push_back(StringAttr::get(parser.getContext(), portName)); - inputs.push_back({}); - inputTypes.push_back({}); - return failure(parser.parseColon() || parser.parseOperand(inputs.back()) || - parser.parseColon() || parser.parseType(inputTypes.back())); - }; - llvm::SMLoc inputsOperandsLoc; - if (parser.getCurrentLocation(&inputsOperandsLoc) || - parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, - parseInputPort)) - return failure(); - - inputNames = ArrayAttr::get(parser.getContext(), argNames); - - return success(); -} - -void circt::printInputPortList(OpAsmPrinter &p, Operation *op, - OperandRange inputs, TypeRange inputTypes, - ArrayAttr inputNames) { - p << "("; - llvm::interleaveComma(llvm::zip(inputs, inputNames), p, - [&](std::tuple input) { - Value val = std::get<0>(input); - p.printKeywordOrString( - std::get<1>(input).cast().getValue()); - p << ": " << val << ": " << val.getType(); - }); - p << ")"; -} - -ParseResult circt::parseOutputPortList(OpAsmParser &parser, - SmallVectorImpl &resultTypes, - ArrayAttr &resultNames) { - SmallVector names; - auto parseResultPort = [&]() -> ParseResult { - std::string portName; - if (parser.parseKeywordOrString(&portName)) return failure(); - names.push_back(StringAttr::get(parser.getContext(), portName)); - resultTypes.push_back({}); - return parser.parseColonType(resultTypes.back()); - }; - llvm::SMLoc inputsOperandsLoc; - if (parser.getCurrentLocation(&inputsOperandsLoc) || - parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, - parseResultPort)) - return failure(); - - resultNames = ArrayAttr::get(parser.getContext(), names); - - return success(); -} - -void circt::printOutputPortList(OpAsmPrinter &p, Operation *op, - TypeRange resultTypes, ArrayAttr resultNames) { - p << "("; - llvm::interleaveComma( - llvm::zip(resultTypes, resultNames), p, - [&](std::tuple result) { - p.printKeywordOrString( - std::get<1>(result).cast().getValue()); - p << ": " << std::get<0>(result); - }); - p << ")"; -} - -ParseResult circt::parseOptionalParameterList(OpAsmParser &parser, - ArrayAttr ¶meters) { - SmallVector params; - - auto parseParameter = [&]() { - std::string name; - Type type; - Attribute value; - - if (parser.parseKeywordOrString(&name) || parser.parseColonType(type)) - return failure(); - - // Parse the default value if present. - if (succeeded(parser.parseOptionalEqual())) { - if (parser.parseAttribute(value, type)) return failure(); - } - - auto &builder = parser.getBuilder(); - params.push_back(ParamDeclAttr::get( - builder.getContext(), builder.getStringAttr(name), type, value)); - return success(); - }; - - if (failed(parser.parseCommaSeparatedList( - OpAsmParser::Delimiter::OptionalLessGreater, parseParameter))) - return failure(); - - parameters = ArrayAttr::get(parser.getContext(), params); - - return success(); -} - -void circt::printOptionalParameterList(OpAsmPrinter &p, Operation *op, - ArrayAttr parameters) { - if (parameters.empty()) return; - - p << '<'; - llvm::interleaveComma(parameters, p, [&](Attribute param) { - auto paramAttr = param.cast(); - p << paramAttr.getName().getValue() << ": " << paramAttr.getType(); - if (auto value = paramAttr.getValue()) { - p << " = "; - p.printAttributeWithoutType(value); - } - }); - p << '>'; -} diff --git a/lib/circt/Dialect/HW/HWAttributes.cpp b/lib/circt/Dialect/HW/HWAttributes.cpp deleted file mode 100644 index 0130339788..0000000000 --- a/lib/circt/Dialect/HW/HWAttributes.cpp +++ /dev/null @@ -1,1032 +0,0 @@ -//===- HWAttributes.cpp - Implement HW attributes -------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWAttributes.h" - -#include "include/circt/Dialect/HW/HWDialect.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/HW/HWTypes.h" -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/SmallString.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/FileSystem.h" // from @llvm-project -#include "llvm/include/llvm/Support/Path.h" // from @llvm-project -#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project - -using namespace circt; -using namespace circt::hw; -using mlir::TypedAttr; - -// Internal method used for .mlir file parsing, defined below. -static Attribute parseParamExprWithOpcode(StringRef opcode, DialectAsmParser &p, - Type type); - -//===----------------------------------------------------------------------===// -// ODS Boilerplate -//===----------------------------------------------------------------------===// - -#define GET_ATTRDEF_CLASSES -#include "include/circt/Dialect/HW/HWAttributes.cpp.inc" - -void HWDialect::registerAttributes() { - addAttributes< -#define GET_ATTRDEF_LIST -#include "include/circt/Dialect/HW/HWAttributes.cpp.inc" - >(); -} - -Attribute HWDialect::parseAttribute(DialectAsmParser &p, Type type) const { - StringRef attrName; - Attribute attr; - auto parseResult = generatedAttributeParser(p, &attrName, type, attr); - if (parseResult.has_value()) return attr; - - // Parse "#hw.param.expr.add" as ParamExprAttr. - if (attrName.startswith(ParamExprAttr::getMnemonic())) { - auto string = attrName.drop_front(ParamExprAttr::getMnemonic().size()); - if (string.front() == '.') - return parseParamExprWithOpcode(string.drop_front(), p, type); - } - - p.emitError(p.getNameLoc(), "Unexpected hw attribute '" + attrName + "'"); - return {}; -} - -void HWDialect::printAttribute(Attribute attr, DialectAsmPrinter &p) const { - if (succeeded(generatedAttributePrinter(attr, p))) return; - llvm_unreachable("Unexpected attribute"); -} - -//===----------------------------------------------------------------------===// -// OutputFileAttr -//===----------------------------------------------------------------------===// - -static std::string canonicalizeFilename(const Twine &directory, - const Twine &filename) { - // Convert the filename to a native style path. - SmallString<128> nativeFilename; - llvm::sys::path::native(filename, nativeFilename); - - // If the filename is an absolute path, ignore the directory. - // e.g. `directory/` + `/etc/filename` -> `/etc/filename`. - if (llvm::sys::path::is_absolute(nativeFilename)) - return std::string(nativeFilename); - - // Convert the directory to a native style path. - SmallString<128> nativeDirectory; - llvm::sys::path::native(directory, nativeDirectory); - - // If the filename component is empty, then ensure that the path ends in a - // separator and return it. - // e.g. `directory` + `` -> `directory/`. - auto separator = llvm::sys::path::get_separator(); - if (nativeFilename.empty() && !nativeDirectory.endswith(separator)) { - nativeDirectory += separator; - return std::string(nativeDirectory); - } - - // Append the directory and filename together. - // e.g. `/tmp/` + `out/filename` -> `/tmp/out/filename`. - SmallString<128> fullPath; - llvm::sys::path::append(fullPath, nativeDirectory, nativeFilename); - return std::string(fullPath); -} - -OutputFileAttr OutputFileAttr::getFromFilename(MLIRContext *context, - const Twine &filename, - bool excludeFromFileList, - bool includeReplicatedOps) { - return OutputFileAttr::getFromDirectoryAndFilename( - context, "", filename, excludeFromFileList, includeReplicatedOps); -} - -OutputFileAttr OutputFileAttr::getFromDirectoryAndFilename( - MLIRContext *context, const Twine &directory, const Twine &filename, - bool excludeFromFileList, bool includeReplicatedOps) { - auto canonicalized = canonicalizeFilename(directory, filename); - return OutputFileAttr::get(StringAttr::get(context, canonicalized), - BoolAttr::get(context, excludeFromFileList), - BoolAttr::get(context, includeReplicatedOps)); -} - -OutputFileAttr OutputFileAttr::getAsDirectory(MLIRContext *context, - const Twine &directory, - bool excludeFromFileList, - bool includeReplicatedOps) { - return getFromDirectoryAndFilename(context, directory, "", - excludeFromFileList, includeReplicatedOps); -} - -bool OutputFileAttr::isDirectory() { - return getFilename().getValue().endswith(llvm::sys::path::get_separator()); -} - -/// Option ::= 'excludeFromFileList' | 'includeReplicatedOp' -/// OutputFileAttr ::= 'output_file<' directory ',' name (',' Option)* '>' -Attribute OutputFileAttr::parse(AsmParser &p, Type type) { - StringAttr filename; - if (p.parseLess() || p.parseAttribute(filename)) - return Attribute(); - - // Parse the additional keyword attributes. Its easier to let people specify - // these more than once than to detect the problem and do something about it. - bool excludeFromFileList = false; - bool includeReplicatedOps = false; - while (true) { - if (p.parseOptionalComma()) break; - if (!p.parseOptionalKeyword("excludeFromFileList")) - excludeFromFileList = true; - else if (!p.parseKeyword("includeReplicatedOps", - "or 'excludeFromFileList'")) - includeReplicatedOps = true; - else - return Attribute(); - } - - if (p.parseGreater()) return Attribute(); - - return OutputFileAttr::getFromFilename(p.getContext(), filename.getValue(), - excludeFromFileList, - includeReplicatedOps); -} - -void OutputFileAttr::print(AsmPrinter &p) const { - p << "<" << getFilename(); - if (getExcludeFromFilelist().getValue()) p << ", excludeFromFileList"; - if (getIncludeReplicatedOps().getValue()) p << ", includeReplicatedOps"; - p << ">"; -} - -//===----------------------------------------------------------------------===// -// FileListAttr -//===----------------------------------------------------------------------===// - -FileListAttr FileListAttr::getFromFilename(MLIRContext *context, - const Twine &filename) { - auto canonicalized = canonicalizeFilename("", filename); - return FileListAttr::get(StringAttr::get(context, canonicalized)); -} - -//===----------------------------------------------------------------------===// -// EnumFieldAttr -//===----------------------------------------------------------------------===// - -Attribute EnumFieldAttr::parse(AsmParser &p, Type) { - StringRef field; - Type type; - if (p.parseLess() || p.parseKeyword(&field) || p.parseComma() || - p.parseType(type) || p.parseGreater()) - return Attribute(); - return EnumFieldAttr::get(p.getEncodedSourceLoc(p.getCurrentLocation()), - StringAttr::get(p.getContext(), field), type); -} - -void EnumFieldAttr::print(AsmPrinter &p) const { - p << "<" << getField().getValue() << ", "; - p.printType(getType().getValue()); - p << ">"; -} - -EnumFieldAttr EnumFieldAttr::get(Location loc, StringAttr value, - mlir::Type type) { - if (!hw::isHWEnumType(type)) emitError(loc) << "expected enum type"; - - // Check whether the provided value is a member of the enum type. - EnumType enumType = getCanonicalType(type).cast(); - if (!enumType.contains(value.getValue())) { - emitError(loc) << "enum value '" << value.getValue() - << "' is not a member of enum type " << enumType; - return nullptr; - } - - return Base::get(value.getContext(), value, TypeAttr::get(type)); -} - -//===----------------------------------------------------------------------===// -// InnerRefAttr -//===----------------------------------------------------------------------===// - -Attribute InnerRefAttr::parse(AsmParser &p, Type type) { - SymbolRefAttr attr; - if (p.parseLess() || p.parseAttribute(attr) || - p.parseGreater()) - return Attribute(); - if (attr.getNestedReferences().size() != 1) return Attribute(); - return InnerRefAttr::get(attr.getRootReference(), attr.getLeafReference()); -} - -void InnerRefAttr::print(AsmPrinter &p) const { - p << "<"; - p.printSymbolName(getModule().getValue()); - p << "::"; - p.printSymbolName(getName().getValue()); - p << ">"; -} - -//===----------------------------------------------------------------------===// -// InnerSymAttr and InnerSymPropertiesAttr -//===----------------------------------------------------------------------===// - -Attribute InnerSymPropertiesAttr::parse(AsmParser &parser, Type type) { - StringAttr name; - NamedAttrList dummyList; - int64_t fieldId = 0; - if (parser.parseLess() || parser.parseSymbolName(name, "name", dummyList) || - parser.parseComma() || parser.parseInteger(fieldId) || - parser.parseComma()) - return Attribute(); - - StringRef visibility; - auto loc = parser.getCurrentLocation(); - if (parser.parseOptionalKeyword(&visibility, - {"public", "private", "nested"})) { - parser.emitError(loc, "expected 'public', 'private', or 'nested'"); - return Attribute(); - } - auto visibilityAttr = parser.getBuilder().getStringAttr(visibility); - - if (parser.parseGreater()) return Attribute(); - - return parser.getChecked(parser.getContext(), name, - fieldId, visibilityAttr); -} - -void InnerSymPropertiesAttr::print(AsmPrinter &odsPrinter) const { - odsPrinter << "<@" << getName().getValue() << "," << getFieldID() << "," - << getSymVisibility().getValue() << ">"; -} - -LogicalResult InnerSymPropertiesAttr::verify( - ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - ::mlir::StringAttr name, uint64_t fieldID, - ::mlir::StringAttr symVisibility) { - if (!name || name.getValue().empty()) - return emitError() << "inner symbol cannot have empty name"; - return success(); -} - -StringAttr InnerSymAttr::getSymIfExists(uint64_t fieldId) const { - const auto *it = - llvm::find_if(getImpl()->props, [&](const InnerSymPropertiesAttr &p) { - return p.getFieldID() == fieldId; - }); - if (it != getProps().end()) return it->getName(); - return {}; -} - -InnerSymAttr InnerSymAttr::erase(uint64_t fieldID) const { - SmallVector syms(getProps()); - const auto *it = llvm::find_if(syms, [fieldID](InnerSymPropertiesAttr p) { - return p.getFieldID() == fieldID; - }); - assert(it != syms.end()); - syms.erase(it); - return InnerSymAttr::get(getContext(), syms); -} - -LogicalResult InnerSymAttr::walkSymbols( - llvm::function_ref callback) const { - for (auto p : getImpl()->props) - if (callback(p.getName()).failed()) return failure(); - return success(); -} - -Attribute InnerSymAttr::parse(AsmParser &parser, Type type) { - StringAttr sym; - NamedAttrList dummyList; - SmallVector names; - if (!parser.parseOptionalSymbolName(sym, "dummy", dummyList)) { - auto prop = parser.getChecked( - parser.getContext(), sym, 0, - StringAttr::get(parser.getContext(), "public")); - if (!prop) return {}; - names.push_back(prop); - } else if (parser.parseCommaSeparatedList( - OpAsmParser::Delimiter::Square, [&]() -> ParseResult { - InnerSymPropertiesAttr prop; - if (parser.parseCustomAttributeWithFallback( - prop, mlir::Type{}, "dummy", dummyList)) - return failure(); - - names.push_back(prop); - - return success(); - })) - return Attribute(); - - std::sort(names.begin(), names.end(), - [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) { - return a.getFieldID() < b.getFieldID(); - }); - - return InnerSymAttr::get(parser.getContext(), names); -} - -void InnerSymAttr::print(AsmPrinter &odsPrinter) const { - auto props = getProps(); - if (props.size() == 1 && - props[0].getSymVisibility().getValue().equals("public") && - props[0].getFieldID() == 0) { - odsPrinter << "@" << props[0].getName().getValue(); - return; - } - auto names = props.vec(); - - std::sort(names.begin(), names.end(), - [&](InnerSymPropertiesAttr a, InnerSymPropertiesAttr b) { - return a.getFieldID() < b.getFieldID(); - }); - odsPrinter << "["; - llvm::interleaveComma(names, odsPrinter, [&](InnerSymPropertiesAttr attr) { - attr.print(odsPrinter); - }); - odsPrinter << "]"; -} - -//===----------------------------------------------------------------------===// -// ParamDeclAttr -//===----------------------------------------------------------------------===// - -Attribute ParamDeclAttr::parse(AsmParser &p, Type trailing) { - std::string name; - Type type; - Attribute value; - // < "FOO" : i32 > : i32 - // < "FOO" : i32 = 0 > : i32 - // < "FOO" : none > - if (p.parseLess() || p.parseString(&name) || p.parseColonType(type)) - return Attribute(); - - if (succeeded(p.parseOptionalEqual())) { - if (p.parseAttribute(value, type)) return Attribute(); - } - - if (p.parseGreater()) return Attribute(); - - if (value) - return ParamDeclAttr::get(p.getContext(), - p.getBuilder().getStringAttr(name), type, value); - return ParamDeclAttr::get(name, type); -} - -void ParamDeclAttr::print(AsmPrinter &p) const { - p << "<" << getName() << ": " << getType(); - if (getValue()) { - p << " = "; - p.printAttributeWithoutType(getValue()); - } - p << ">"; -} - -//===----------------------------------------------------------------------===// -// ParamDeclRefAttr -//===----------------------------------------------------------------------===// - -Attribute ParamDeclRefAttr::parse(AsmParser &p, Type type) { - StringAttr name; - if (p.parseLess() || p.parseAttribute(name) || p.parseGreater() || - (!type && (p.parseColon() || p.parseType(type)))) - return Attribute(); - - return ParamDeclRefAttr::get(name, type); -} - -void ParamDeclRefAttr::print(AsmPrinter &p) const { - p << "<" << getName() << ">"; -} - -//===----------------------------------------------------------------------===// -// ParamVerbatimAttr -//===----------------------------------------------------------------------===// - -Attribute ParamVerbatimAttr::parse(AsmParser &p, Type type) { - StringAttr text; - if (p.parseLess() || p.parseAttribute(text) || p.parseGreater() || - (!type && (p.parseColon() || p.parseType(type)))) - return Attribute(); - - return ParamVerbatimAttr::get(p.getContext(), text, type); -} - -void ParamVerbatimAttr::print(AsmPrinter &p) const { - p << "<" << getValue() << ">"; -} - -//===----------------------------------------------------------------------===// -// ParamExprAttr -//===----------------------------------------------------------------------===// - -/// Given a binary function, if the two operands are known constant integers, -/// use the specified fold function to compute the result. -static TypedAttr foldBinaryOp( - ArrayRef operands, - llvm::function_ref calculate) { - assert(operands.size() == 2 && "binary operator always has two operands"); - if (auto lhs = operands[0].dyn_cast()) - if (auto rhs = operands[1].dyn_cast()) - return IntegerAttr::get(lhs.getType(), - calculate(lhs.getValue(), rhs.getValue())); - return {}; -} - -/// Given a unary function, if the operand is a known constant integer, -/// use the specified fold function to compute the result. -static TypedAttr foldUnaryOp( - ArrayRef operands, - llvm::function_ref calculate) { - assert(operands.size() == 1 && "unary operator always has one operand"); - if (auto intAttr = operands[0].dyn_cast()) - return IntegerAttr::get(intAttr.getType(), calculate(intAttr.getValue())); - return {}; -} - -/// If the specified attribute is a ParamExprAttr with the specified opcode, -/// return it. Otherwise return null. -static ParamExprAttr dyn_castPE(PEO opcode, Attribute value) { - if (auto expr = value.dyn_cast()) - if (expr.getOpcode() == opcode) return expr; - return {}; -} - -/// This implements a < comparison for two operands to an associative operation -/// imposing an ordering upon them. -/// -/// The ordering provided puts more complex things to the start of the list, -/// from left to right: -/// expressions :: verbatims :: decl.refs :: constant -/// -static bool paramExprOperandSortPredicate(Attribute lhs, Attribute rhs) { - // Simplify the code below - we never have to care about exactly equal values. - if (lhs == rhs) return false; - - // All expressions are "less than" a constant, since they appear on the right. - if (rhs.isa()) { - // We don't bother to order constants w.r.t. each other since they will be - // folded - they can all compare equal. - return !lhs.isa(); - } - if (lhs.isa()) return false; - - // Next up are named parameters. - if (auto rhsParam = rhs.dyn_cast()) { - // Parameters are sorted lexically w.r.t. each other. - if (auto lhsParam = lhs.dyn_cast()) - return lhsParam.getName().getValue() < rhsParam.getName().getValue(); - // They otherwise appear on the right of other things. - return true; - } - if (lhs.isa()) return false; - - // Next up are verbatim parameters. - if (auto rhsParam = rhs.dyn_cast()) { - // Verbatims are sorted lexically w.r.t. each other. - if (auto lhsParam = lhs.dyn_cast()) - return lhsParam.getValue().getValue() < rhsParam.getValue().getValue(); - // They otherwise appear on the right of other things. - return true; - } - if (lhs.isa()) return false; - - // The only thing left are nested expressions. - auto lhsExpr = lhs.cast(), rhsExpr = rhs.cast(); - // Sort by the string form of the opcode, e.g. add, .. mul,... then xor. - if (lhsExpr.getOpcode() != rhsExpr.getOpcode()) - return stringifyPEO(lhsExpr.getOpcode()) < - stringifyPEO(rhsExpr.getOpcode()); - - // If they are the same opcode, then sort by arity: more complex to the left. - ArrayRef lhsOperands = lhsExpr.getOperands(), - rhsOperands = rhsExpr.getOperands(); - if (lhsOperands.size() != rhsOperands.size()) - return lhsOperands.size() > rhsOperands.size(); - - // We know the two subexpressions are different (they'd otherwise be pointer - // equivalent) so just go compare all of the elements. - for (size_t i = 0, e = lhsOperands.size(); i != e; ++i) { - if (paramExprOperandSortPredicate(lhsOperands[i], rhsOperands[i])) - return true; - if (paramExprOperandSortPredicate(rhsOperands[i], lhsOperands[i])) - return false; - } - - llvm_unreachable("expressions should never be equivalent"); - return false; -} - -/// Given a fully associative variadic integer operation, constant fold any -/// constant operands and move them to the right. If the whole expression is -/// constant, then return that, otherwise update the operands list. -static TypedAttr simplifyAssocOp( - PEO opcode, SmallVector &operands, - llvm::function_ref calculateFn, - llvm::function_ref identityConstantFn, - llvm::function_ref destructiveConstantFn = {}) { - auto type = operands[0].getType(); - assert(isHWIntegerType(type)); - if (operands.size() == 1) return operands[0]; - - // Flatten any of the same operation into the operand list: - // `(add x, (add y, z))` => `(add x, y, z)`. - for (size_t i = 0, e = operands.size(); i != e; ++i) { - if (auto subexpr = dyn_castPE(opcode, operands[i])) { - std::swap(operands[i], operands.back()); - operands.pop_back(); - --e; - --i; - operands.append(subexpr.getOperands().begin(), - subexpr.getOperands().end()); - } - } - - // Impose an ordering on the operands, pushing subexpressions to the left and - // constants to the right, with verbatims and parameters in the middle - but - // predictably ordered w.r.t. each other. - llvm::stable_sort(operands, paramExprOperandSortPredicate); - - // Merge any constants, they will appear at the back of the operand list now. - if (operands.back().isa()) { - while (operands.size() >= 2 && - operands[operands.size() - 2].isa()) { - APInt c1 = operands.pop_back_val().cast().getValue(); - APInt c2 = operands.pop_back_val().cast().getValue(); - auto resultConstant = IntegerAttr::get(type, calculateFn(c1, c2)); - operands.push_back(resultConstant); - } - - auto resultCst = operands.back().cast(); - - // If the resulting constant is the destructive constant (e.g. `x*0`), then - // return it. - if (destructiveConstantFn && destructiveConstantFn(resultCst.getValue())) - return resultCst; - - // Remove the constant back to our operand list if it is the identity - // constant for this operator (e.g. `x*1`) and there are other operands. - if (identityConstantFn(resultCst.getValue()) && operands.size() != 1) - operands.pop_back(); - } - - return operands.size() == 1 ? operands[0] : TypedAttr(); -} - -/// Analyze an operand to an add. If it is a multiplication by a constant (e.g. -/// `(a*b*42)` then split it into the non-constant and the constant portions -/// (e.g. `a*b` and `42`). Otherwise return the operand as the first value and -/// null as the second (standin for "multiplication by 1"). -static std::pair decomposeAddend(TypedAttr operand) { - if (auto mul = dyn_castPE(PEO::Mul, operand)) - if (auto cst = mul.getOperands().back().dyn_cast()) { - auto nonCst = ParamExprAttr::get(PEO::Mul, mul.getOperands().drop_back()); - return {nonCst, cst}; - } - return {operand, TypedAttr()}; -} - -static TypedAttr getOneOfType(Type type) { - return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), 1)); -} - -static TypedAttr simplifyAdd(SmallVector &operands) { - if (auto result = simplifyAssocOp( - PEO::Add, operands, [](auto a, auto b) { return a + b; }, - /*identityCst*/ [](auto cst) { return cst.isZero(); })) - return result; - - // Canonicalize the add by splitting all addends into their variable and - // constant factors. - SmallVector> decomposedOperands; - llvm::SmallDenseSet nonConstantParts; - for (auto &op : operands) { - decomposedOperands.push_back(decomposeAddend(op)); - - // Keep track of non-constant parts we've already seen. If we see multiple - // uses of the same value, then we can fold them together with a multiply. - // This handles things like `(a+b+a)` => `(a*2 + b)` and `(a*2 + b + a)` => - // `(a*3 + b)`. - if (!nonConstantParts.insert(decomposedOperands.back().first).second) { - // The thing we multiply will be the common expression. - TypedAttr mulOperand = decomposedOperands.back().first; - - // Find the index of the first occurrence. - size_t i = 0; - while (decomposedOperands[i].first != mulOperand) ++i; - // Remove both occurrences from the operand list. - operands.erase(operands.begin() + (&op - &operands[0])); - operands.erase(operands.begin() + i); - - auto type = mulOperand.getType(); - auto c1 = decomposedOperands[i].second, - c2 = decomposedOperands.back().second; - // Fill in missing constant multiplicands with 1. - if (!c1) c1 = getOneOfType(type); - if (!c2) c2 = getOneOfType(type); - // Re-add the "a"*(c1+c2) expression to the operand list and - // re-canonicalize. - auto constant = ParamExprAttr::get(PEO::Add, c1, c2); - auto mulCst = ParamExprAttr::get(PEO::Mul, mulOperand, constant); - operands.push_back(mulCst); - return ParamExprAttr::get(PEO::Add, operands); - } - } - - return {}; -} - -static TypedAttr simplifyMul(SmallVector &operands) { - if (auto result = simplifyAssocOp( - PEO::Mul, operands, [](auto a, auto b) { return a * b; }, - /*identityCst*/ [](auto cst) { return cst.isOne(); }, - /*destructiveCst*/ [](auto cst) { return cst.isZero(); })) - return result; - - // We always build a sum-of-products representation, so if we see an addition - // as a subexpr, we need to pull it out: (a+b)*c*d ==> (a*c*d + b*c*d). - for (size_t i = 0, e = operands.size(); i != e; ++i) { - if (auto subexpr = dyn_castPE(PEO::Add, operands[i])) { - // Pull the `c*d` operands out - it is whatever operands remain after - // removing the `(a+b)` term. - SmallVector mulOperands(operands.begin(), operands.end()); - mulOperands.erase(mulOperands.begin() + i); - - // Build each add operand. - SmallVector addOperands; - for (auto addOperand : subexpr.getOperands()) { - mulOperands.push_back(addOperand); - addOperands.push_back(ParamExprAttr::get(PEO::Mul, mulOperands)); - mulOperands.pop_back(); - } - // Canonicalize and form the add expression. - return ParamExprAttr::get(PEO::Add, addOperands); - } - } - - return {}; -} -static TypedAttr simplifyAnd(SmallVector &operands) { - return simplifyAssocOp( - PEO::And, operands, [](auto a, auto b) { return a & b; }, - /*identityCst*/ [](auto cst) { return cst.isAllOnes(); }, - /*destructiveCst*/ [](auto cst) { return cst.isZero(); }); -} - -static TypedAttr simplifyOr(SmallVector &operands) { - return simplifyAssocOp( - PEO::Or, operands, [](auto a, auto b) { return a | b; }, - /*identityCst*/ [](auto cst) { return cst.isZero(); }, - /*destructiveCst*/ [](auto cst) { return cst.isAllOnes(); }); -} - -static TypedAttr simplifyXor(SmallVector &operands) { - return simplifyAssocOp( - PEO::Xor, operands, [](auto a, auto b) { return a ^ b; }, - /*identityCst*/ [](auto cst) { return cst.isZero(); }); -} - -static TypedAttr simplifyShl(SmallVector &operands) { - assert(isHWIntegerType(operands[0].getType())); - - if (auto rhs = operands[1].dyn_cast()) { - // Constant fold simple integers. - if (auto lhs = operands[0].dyn_cast()) - return IntegerAttr::get(lhs.getType(), - lhs.getValue().shl(rhs.getValue())); - - // Canonicalize `x << cst` => `x * (1< &operands) { - assert(isHWIntegerType(operands[0].getType())); - // Implement support for identities like `x >> 0`. - if (auto rhs = operands[1].dyn_cast()) - if (rhs.getValue().isZero()) return operands[0]; - - return foldBinaryOp(operands, [](auto a, auto b) { return a.lshr(b); }); -} - -static TypedAttr simplifyShrS(SmallVector &operands) { - assert(isHWIntegerType(operands[0].getType())); - // Implement support for identities like `x >> 0`. - if (auto rhs = operands[1].dyn_cast()) - if (rhs.getValue().isZero()) return operands[0]; - - return foldBinaryOp(operands, [](auto a, auto b) { return a.ashr(b); }); -} - -static TypedAttr simplifyDivU(SmallVector &operands) { - assert(isHWIntegerType(operands[0].getType())); - // Implement support for identities like `x/1`. - if (auto rhs = operands[1].dyn_cast()) - if (rhs.getValue().isOne()) return operands[0]; - - return foldBinaryOp(operands, [](auto a, auto b) { return a.udiv(b); }); -} - -static TypedAttr simplifyDivS(SmallVector &operands) { - assert(isHWIntegerType(operands[0].getType())); - // Implement support for identities like `x/1`. - if (auto rhs = operands[1].dyn_cast()) - if (rhs.getValue().isOne()) return operands[0]; - - return foldBinaryOp(operands, [](auto a, auto b) { return a.sdiv(b); }); -} - -static TypedAttr simplifyModU(SmallVector &operands) { - assert(isHWIntegerType(operands[0].getType())); - // Implement support for identities like `x%1`. - if (auto rhs = operands[1].dyn_cast()) - if (rhs.getValue().isOne()) return IntegerAttr::get(rhs.getType(), 0); - - return foldBinaryOp(operands, [](auto a, auto b) { return a.urem(b); }); -} - -static TypedAttr simplifyModS(SmallVector &operands) { - assert(isHWIntegerType(operands[0].getType())); - // Implement support for identities like `x%1`. - if (auto rhs = operands[1].dyn_cast()) - if (rhs.getValue().isOne()) return IntegerAttr::get(rhs.getType(), 0); - - return foldBinaryOp(operands, [](auto a, auto b) { return a.srem(b); }); -} - -static TypedAttr simplifyCLog2(SmallVector &operands) { - assert(isHWIntegerType(operands[0].getType())); - return foldUnaryOp(operands, [](auto a) { - // Following the Verilog spec, clog2(0) is 0 - return APInt(a.getBitWidth(), a == 0 ? 0 : a.ceilLogBase2()); - }); -} - -static TypedAttr simplifyStrConcat(SmallVector &operands) { - // Combine all adjacent strings. - SmallVector newOperands; - SmallVector stringsToCombine; - auto combineAndPush = [&]() { - if (stringsToCombine.empty()) return; - // Concatenate buffered strings, push to ops. - SmallString<32> newString; - for (auto part : stringsToCombine) newString.append(part.getValue()); - newOperands.push_back( - StringAttr::get(stringsToCombine[0].getContext(), newString)); - stringsToCombine.clear(); - }; - - for (TypedAttr op : operands) { - if (auto strOp = op.dyn_cast()) { - // Queue up adjacent strings. - stringsToCombine.push_back(strOp); - } else { - combineAndPush(); - newOperands.push_back(op); - } - } - combineAndPush(); - - assert(!newOperands.empty()); - if (newOperands.size() == 1) return newOperands[0]; - if (newOperands.size() < operands.size()) - return ParamExprAttr::get(PEO::StrConcat, newOperands); - return {}; -} - -/// Build a parameter expression. This automatically canonicalizes and -/// folds, so it may not necessarily return a ParamExprAttr. -TypedAttr ParamExprAttr::get(PEO opcode, ArrayRef operandsIn) { - assert(!operandsIn.empty() && "Cannot have expr with no operands"); - // All operands must have the same type, which is the type of the result. - auto type = operandsIn.front().getType(); - assert(llvm::all_of(operandsIn.drop_front(), - [&](auto op) { return op.getType() == type; })); - - SmallVector operands(operandsIn.begin(), operandsIn.end()); - - // Verify and canonicalize parameter expressions. - TypedAttr result; - switch (opcode) { - case PEO::Add: - result = simplifyAdd(operands); - break; - case PEO::Mul: - result = simplifyMul(operands); - break; - case PEO::And: - result = simplifyAnd(operands); - break; - case PEO::Or: - result = simplifyOr(operands); - break; - case PEO::Xor: - result = simplifyXor(operands); - break; - case PEO::Shl: - result = simplifyShl(operands); - break; - case PEO::ShrU: - result = simplifyShrU(operands); - break; - case PEO::ShrS: - result = simplifyShrS(operands); - break; - case PEO::DivU: - result = simplifyDivU(operands); - break; - case PEO::DivS: - result = simplifyDivS(operands); - break; - case PEO::ModU: - result = simplifyModU(operands); - break; - case PEO::ModS: - result = simplifyModS(operands); - break; - case PEO::CLog2: - result = simplifyCLog2(operands); - break; - case PEO::StrConcat: - result = simplifyStrConcat(operands); - break; - } - - // If we folded to an operand, return it. - if (result) return result; - - return Base::get(operands[0].getContext(), opcode, operands, type); -} - -Attribute ParamExprAttr::parse(AsmParser &p, Type type) { - // We require an opcode suffix like `#hw.param.expr.add`, we don't allow - // parsing a plain `#hw.param.expr` on its own. - p.emitError(p.getNameLoc(), "#hw.param.expr should have opcode suffix"); - return {}; -} - -/// Internal method used for .mlir file parsing when parsing the -/// "#hw.param.expr.mul" form of the attribute. -static Attribute parseParamExprWithOpcode(StringRef opcodeStr, - DialectAsmParser &p, Type type) { - SmallVector operands; - if (p.parseCommaSeparatedList( - mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { - operands.push_back({}); - return p.parseAttribute(operands.back(), type); - })) - return {}; - - std::optional opcode = symbolizePEO(opcodeStr); - if (!opcode.has_value()) { - p.emitError(p.getNameLoc(), "unknown parameter expr operator name"); - return {}; - } - - return ParamExprAttr::get(*opcode, operands); -} - -void ParamExprAttr::print(AsmPrinter &p) const { - p << "." << stringifyPEO(getOpcode()) << '<'; - llvm::interleaveComma(getOperands(), p.getStream(), - [&](Attribute op) { p.printAttributeWithoutType(op); }); - p << '>'; -} - -// Replaces any ParamDeclRefAttr within a parametric expression with its -// corresponding value from the map of provided parameters. -static FailureOr replaceDeclRefInExpr( - Location loc, const std::map ¶meters, - Attribute paramAttr) { - if (paramAttr.dyn_cast()) { - // Nothing to do, constant value. - return paramAttr; - } - if (auto paramRefAttr = paramAttr.dyn_cast()) { - // Get the value from the provided parameters. - auto it = parameters.find(paramRefAttr.getName().str()); - if (it == parameters.end()) - return emitError(loc) - << "Could not find parameter " << paramRefAttr.getName().str() - << " in the provided parameters for the expression!"; - return it->second; - } - if (auto paramExprAttr = paramAttr.dyn_cast()) { - // Recurse into all operands of the expression. - llvm::SmallVector replacedOperands; - for (auto operand : paramExprAttr.getOperands()) { - auto res = replaceDeclRefInExpr(loc, parameters, operand); - if (failed(res)) return {failure()}; - replacedOperands.push_back(res->cast()); - } - return { - hw::ParamExprAttr::get(paramExprAttr.getOpcode(), replacedOperands)}; - } - llvm_unreachable("Unhandled parametric attribute"); - return {}; -} - -FailureOr hw::evaluateParametricAttr(Location loc, - ArrayAttr parameters, - Attribute paramAttr) { - // Create a map of the provided parameters for faster lookup. - std::map parameterMap; - for (auto param : parameters) { - auto paramDecl = param.cast(); - parameterMap[paramDecl.getName().str()] = paramDecl.getValue(); - } - - // First, replace any ParamDeclRefAttr in the expression with its - // corresponding value in 'parameters'. - auto paramAttrRes = replaceDeclRefInExpr(loc, parameterMap, paramAttr); - if (failed(paramAttrRes)) return {failure()}; - paramAttr = *paramAttrRes; - - // Then, evaluate the parametric attribute. - if (paramAttr.isa()) - return paramAttr.cast(); - if (auto paramExprAttr = paramAttr.dyn_cast()) { - // Since any ParamDeclRefAttr was replaced within the expression, - // we re-evaluate the expression through the existing ParamExprAttr - // canonicalizer. - return ParamExprAttr::get(paramExprAttr.getOpcode(), - paramExprAttr.getOperands()); - } - - llvm_unreachable("Unhandled parametric attribute"); - return TypedAttr(); -} - -template -FailureOr evaluateParametricArrayType(Location loc, ArrayAttr parameters, - TArray arrayType) { - auto size = evaluateParametricAttr(loc, parameters, arrayType.getSizeAttr()); - if (failed(size)) return failure(); - auto elementType = - evaluateParametricType(loc, parameters, arrayType.getElementType()); - if (failed(elementType)) return failure(); - - // If the size was evaluated to a constant, use a 64-bit integer - // attribute version of it - if (auto intAttr = size->template dyn_cast()) - return TArray::get( - arrayType.getContext(), *elementType, - IntegerAttr::get(IntegerType::get(arrayType.getContext(), 64), - intAttr.getValue().getSExtValue())); - - // Otherwise parameter references are still involved - return TArray::get(arrayType.getContext(), *elementType, *size); -} - -FailureOr hw::evaluateParametricType(Location loc, ArrayAttr parameters, - Type type) { - return llvm::TypeSwitch>(type) - .Case([&](hw::IntType t) -> FailureOr { - auto evaluatedWidth = - evaluateParametricAttr(loc, parameters, t.getWidth()); - if (failed(evaluatedWidth)) return {failure()}; - - // If the width was evaluated to a constant, return an `IntegerType` - if (auto intAttr = evaluatedWidth->dyn_cast()) - return {IntegerType::get(type.getContext(), - intAttr.getValue().getSExtValue())}; - - // Otherwise parameter references are still involved - return hw::IntType::get(evaluatedWidth->cast()); - }) - .Case( - [&](auto arrayType) -> FailureOr { - return evaluateParametricArrayType(loc, parameters, arrayType); - }) - .Default([&](auto) { return type; }); -} - -// Returns true if any part of this parametric attribute contains a reference -// to a parameter declaration. -static bool isParamAttrWithParamRef(Attribute expr) { - return llvm::TypeSwitch(expr) - .Case([](ParamExprAttr attr) { - return llvm::any_of(attr.getOperands(), isParamAttrWithParamRef); - }) - .Case([](ParamDeclRefAttr) { return true; }) - .Default([](auto) { return false; }); -} - -bool hw::isParametricType(mlir::Type t) { - return llvm::TypeSwitch(t) - .Case( - [&](hw::IntType t) { return isParamAttrWithParamRef(t.getWidth()); }) - .Case([&](auto arrayType) { - return isParametricType(arrayType.getElementType()) || - isParamAttrWithParamRef(arrayType.getSizeAttr()); - }) - .Default([](auto) { return false; }); -} diff --git a/lib/circt/Dialect/HW/HWDialect.cpp b/lib/circt/Dialect/HW/HWDialect.cpp deleted file mode 100644 index 3b230c14ec..0000000000 --- a/lib/circt/Dialect/HW/HWDialect.cpp +++ /dev/null @@ -1,116 +0,0 @@ -//===- HWDialect.cpp - Implement the HW dialect ---------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the HW dialect. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWDialect.h" - -#include "include/circt/Dialect/HW/HWAttributes.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/HW/HWTypes.h" -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/InliningUtils.h" // from @llvm-project - -using namespace circt; -using namespace hw; - -//===----------------------------------------------------------------------===// -// Dialect specification. -//===----------------------------------------------------------------------===// - -// Pull in the dialect definition. -#include "include/circt/Dialect/HW/HWDialect.cpp.inc" - -namespace { - -// We implement the OpAsmDialectInterface so that HW dialect operations -// automatically interpret the name attribute on operations as their SSA name. -struct HWOpAsmDialectInterface : public OpAsmDialectInterface { - using OpAsmDialectInterface::OpAsmDialectInterface; - - /// Get a special name to use when printing the given operation. See - /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. - void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const {} -}; -} // end anonymous namespace - -namespace { -/// This class defines the interface for handling inlining with HW operations. -struct HWInlinerInterface : public mlir::DialectInlinerInterface { - using mlir::DialectInlinerInterface::DialectInlinerInterface; - - bool isLegalToInline(Operation *op, Region *, bool, - mlir::IRMapping &) const final { - return isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op) || isa(op); - } - - bool isLegalToInline(Region *, Region *, bool, - mlir::IRMapping &) const final { - return false; - } -}; -} // end anonymous namespace - -void HWDialect::initialize() { - // Register types and attributes. - registerTypes(); - registerAttributes(); - - // Register operations. - addOperations< -#define GET_OP_LIST -#include "include/circt/Dialect/HW/HW.cpp.inc" - >(); - - // Register interface implementations. - addInterfaces(); -} - -// Registered hook to materialize a single constant operation from a given -/// attribute value with the desired resultant type. This method should use -/// the provided builder to create the operation without changing the -/// insertion position. The generated operation is expected to be constant -/// like, i.e. single result, zero operands, non side-effecting, etc. On -/// success, this hook should return the value generated to represent the -/// constant value. Otherwise, it should return null on failure. -Operation *HWDialect::materializeConstant(OpBuilder &builder, Attribute value, - Type type, Location loc) { - // Integer constants can materialize into hw.constant - if (auto intType = type.dyn_cast()) - if (auto attrValue = value.dyn_cast()) - return builder.create(loc, type, attrValue); - - // Aggregate constants. - if (auto arrayAttr = value.dyn_cast()) { - if (type.isa()) - return builder.create(loc, type, arrayAttr); - } - - // Parameter expressions materialize into hw.param.value. - auto parentOp = builder.getBlock()->getParentOp(); - auto curModule = dyn_cast(parentOp); - if (!curModule) curModule = parentOp->getParentOfType(); - if (curModule && isValidParameterExpression(value, curModule)) - return builder.create(loc, type, value); - - return nullptr; -} - -// Provide implementations for the enums we use. -#include "include/circt/Dialect/HW/HWEnums.cpp.inc" diff --git a/lib/circt/Dialect/HW/HWInstanceGraph.cpp b/lib/circt/Dialect/HW/HWInstanceGraph.cpp deleted file mode 100644 index 8696b648b4..0000000000 --- a/lib/circt/Dialect/HW/HWInstanceGraph.cpp +++ /dev/null @@ -1,33 +0,0 @@ -//===- HWInstanceGraph.cpp - Instance Graph ---------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWInstanceGraph.h" - -using namespace circt; -using namespace hw; - -InstanceGraph::InstanceGraph(Operation *operation) - : igraph::InstanceGraph(operation) { - for (auto &node : nodes) - if (cast(node.getModule().getOperation()).isPublic()) - entry.addInstance({}, &node); -} - -igraph::InstanceGraphNode *InstanceGraph::addHWModule(HWModuleLike module) { - auto *node = igraph::InstanceGraph::addModule( - cast(module.getOperation())); - if (module.isPublic()) entry.addInstance({}, node); - return node; -} - -void InstanceGraph::erase(igraph::InstanceGraphNode *node) { - for (auto *instance : llvm::make_early_inc_range(entry)) { - if (instance->getTarget() == node) instance->erase(); - } - igraph::InstanceGraph::erase(node); -} diff --git a/lib/circt/Dialect/HW/HWModuleOpInterface.cpp b/lib/circt/Dialect/HW/HWModuleOpInterface.cpp deleted file mode 100644 index 98c5dab04f..0000000000 --- a/lib/circt/Dialect/HW/HWModuleOpInterface.cpp +++ /dev/null @@ -1,88 +0,0 @@ -//===- HWModuleOpInterface.cpp.h - Implement HWModuleLike ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements HWModuleLike related functionality. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWOpInterfaces.h" -#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - -using namespace circt; -using namespace hw; - -//===----------------------------------------------------------------------===// -// HWModuleLike Signature Conversion -//===----------------------------------------------------------------------===// - -static LogicalResult convertModuleOpTypes(HWModuleLike modOp, - const TypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { - ModuleType type = modOp.getHWModuleType(); - if (!type) return failure(); - - // Convert the original port types. - // Update the module signature in-place. - SmallVector newPorts; - TypeConverter::SignatureConversion result(type.getNumInputs()); - unsigned atInput = 0; - unsigned curInputs = 0; - for (auto &p : type.getPorts()) { - if (p.dir == ModulePort::Direction::Output) { - SmallVector newResults; - if (failed(typeConverter.convertType(p.type, newResults))) - return failure(); - for (auto np : newResults) newPorts.push_back({p.name, np, p.dir}); - } else { - if (failed(typeConverter.convertSignatureArg( - atInput++, - /* inout ports need to be wrapped in the appropriate type */ - p.dir == ModulePort::Direction::Input ? p.type - : InOutType::get(p.type), - result))) - return failure(); - for (auto np : result.getConvertedTypes().drop_front(curInputs)) - newPorts.push_back({p.name, np, p.dir}); - curInputs = result.getConvertedTypes().size(); - } - } - - if (failed(rewriter.convertRegionTypes(&modOp->getRegion(0), typeConverter, - &result))) - return failure(); - - auto newType = ModuleType::get(rewriter.getContext(), newPorts); - rewriter.updateRootInPlace(modOp, [&] { modOp.setHWModuleType(newType); }); - - return success(); -} - -/// Create a default conversion pattern that rewrites the type signature of a -/// FunctionOpInterface op. This only supports ops which use FunctionType to -/// represent their type. -namespace { -struct HWModuleLikeSignatureConversion : public ConversionPattern { - HWModuleLikeSignatureConversion(StringRef moduleLikeOpName, MLIRContext *ctx, - const TypeConverter &converter) - : ConversionPattern(converter, moduleLikeOpName, /*benefit=*/1, ctx) {} - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef /*operands*/, - ConversionPatternRewriter &rewriter) const override { - HWModuleLike modOp = cast(op); - return convertModuleOpTypes(modOp, *typeConverter, rewriter); - } -}; -} // namespace - -void circt::hw::populateHWModuleLikeTypeConversionPattern( - StringRef moduleLikeOpName, RewritePatternSet &patterns, - TypeConverter &converter) { - patterns.add( - moduleLikeOpName, patterns.getContext(), converter); -} diff --git a/lib/circt/Dialect/HW/HWOpInterfaces.cpp b/lib/circt/Dialect/HW/HWOpInterfaces.cpp deleted file mode 100644 index c7c1c6e731..0000000000 --- a/lib/circt/Dialect/HW/HWOpInterfaces.cpp +++ /dev/null @@ -1,99 +0,0 @@ -//===- HWOpInterfaces.cpp - Implement the HW op interfaces ----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implement the HW operation interfaces. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWOpInterfaces.h" - -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/HW/HWTypeInterfaces.h" -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/SmallBitVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallPtrSet.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project - -using namespace circt; - -LogicalResult hw::verifyInnerSymAttr(InnerSymbolOpInterface op) { - auto innerSym = op.getInnerSymAttr(); - // If does not have any inner sym then ignore. - if (!innerSym) return success(); - - if (innerSym.empty()) - return op->emitOpError("has empty list of inner symbols"); - - if (!op.supportsPerFieldSymbols()) { - // The inner sym can only be specified on fieldID=0. - if (innerSym.size() > 1 || !innerSym.getSymName()) { - op->emitOpError("does not support per-field inner symbols"); - return failure(); - } - return success(); - } - - auto result = op.getTargetResult(); - // If op supports per-field symbols, but does not have a target result, - // its up to the operation to verify itself. - // (there are no uses for this presently, but be open to this anyway.) - if (!result) return success(); - auto resultType = result.getType(); - auto maxFields = FieldIdImpl::getMaxFieldID(resultType); - llvm::SmallBitVector indices(maxFields + 1); - llvm::SmallPtrSet symNames; - // Ensure fieldID and symbol names are unique. - auto uniqSyms = [&](InnerSymPropertiesAttr p) { - if (maxFields < p.getFieldID()) { - op->emitOpError("field id:'" + Twine(p.getFieldID()) + - "' is greater than the maximum field id:'" + - Twine(maxFields) + "'"); - return false; - } - if (indices.test(p.getFieldID())) { - op->emitOpError("cannot assign multiple symbol names to the field id:'" + - Twine(p.getFieldID()) + "'"); - return false; - } - indices.set(p.getFieldID()); - auto it = symNames.insert(p.getName()); - if (!it.second) { - op->emitOpError("cannot reuse symbol name:'" + p.getName().getValue() + - "'"); - return false; - } - return true; - }; - - if (!llvm::all_of(innerSym.getProps(), uniqSyms)) return failure(); - - return success(); -} - -raw_ostream &circt::hw::operator<<(raw_ostream &printer, PortInfo port) { - StringRef dirstr; - switch (port.dir) { - case ModulePort::Direction::Input: - dirstr = "input"; - break; - case ModulePort::Direction::Output: - dirstr = "output"; - break; - case ModulePort::Direction::InOut: - dirstr = "inout"; - break; - } - printer << dirstr << " " << port.name << " : " << port.type << " (argnum " - << port.argNum << ", sym " << port.sym << ", loc " << port.loc - << ", args " << port.attrs << ")"; - return printer; -} - -#include "include/circt/Dialect/HW/HWOpInterfaces.cpp.inc" diff --git a/lib/circt/Dialect/HW/HWOps.cpp b/lib/circt/Dialect/HW/HWOps.cpp deleted file mode 100644 index ea9c6990bf..0000000000 --- a/lib/circt/Dialect/HW/HWOps.cpp +++ /dev/null @@ -1,3376 +0,0 @@ -//===- HWOps.cpp - Implement the HW operations ----------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implement the HW ops. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWOps.h" - -#include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/HW/CustomDirectiveImpl.h" -#include "include/circt/Dialect/HW/HWAttributes.h" -#include "include/circt/Dialect/HW/HWSymCache.h" -#include "include/circt/Dialect/HW/HWVisitors.h" -#include "include/circt/Dialect/HW/InstanceImplementation.h" -#include "include/circt/Dialect/HW/ModuleImplementation.h" -#include "include/circt/Support/CustomDirectiveImpl.h" -#include "include/circt/Support/Namespace.h" -#include "llvm/include/llvm/ADT/BitVector.h" // from @llvm-project -#include "llvm/include/llvm/ADT/SmallPtrSet.h" // from @llvm-project -#include "llvm/include/llvm/ADT/StringSet.h" // from @llvm-project -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/FunctionImplementation.h" // from @llvm-project - -using namespace circt; -using namespace hw; -using mlir::TypedAttr; - -/// Flip a port direction. -ModulePort::Direction hw::flip(ModulePort::Direction direction) { - switch (direction) { - case ModulePort::Direction::Input: - return ModulePort::Direction::Output; - case ModulePort::Direction::Output: - return ModulePort::Direction::Input; - case ModulePort::Direction::InOut: - return ModulePort::Direction::InOut; - } - llvm_unreachable("unknown PortDirection"); -} - -bool hw::isValidIndexBitWidth(Value index, Value array) { - hw::ArrayType arrayType = - hw::getCanonicalType(array.getType()).dyn_cast(); - assert(arrayType && "expected array type"); - unsigned indexWidth = index.getType().getIntOrFloatBitWidth(); - auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements()); - return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1) - : indexWidth == requiredWidth; -} - -/// Return true if the specified operation is a combinational logic op. -bool hw::isCombinational(Operation *op) { - struct IsCombClassifier : public TypeOpVisitor { - bool visitInvalidTypeOp(Operation *op) { return false; } - bool visitUnhandledTypeOp(Operation *op) { return true; } - }; - - return (op->getDialect() && op->getDialect()->getNamespace() == "comb") || - IsCombClassifier().dispatchTypeOpVisitor(op); -} - -static Value foldStructExtract(Operation *inputOp, StringRef field) { - // A struct extract of a struct create -> corresponding struct create operand. - if (auto structCreate = dyn_cast_or_null(inputOp)) { - auto ty = type_cast(structCreate.getResult().getType()); - if (auto idx = ty.getFieldIndex(field)) - return structCreate.getOperand(*idx); - return {}; - } - // Extracting injected field -> corresponding field - if (auto structInject = dyn_cast_or_null(inputOp)) { - if (structInject.getField() != field) return {}; - return structInject.getNewValue(); - } - return {}; -} - -/// Get a special name to use when printing the entry block arguments of the -/// region contained by an operation in this dialect. -static void getAsmBlockArgumentNamesImpl(mlir::Region ®ion, - OpAsmSetValueNameFn setNameFn) { - if (region.empty()) return; - // Assign port names to the bbargs. - // auto *module = region.getParentOp(); - auto module = cast(region.getParentOp()); - - auto *block = ®ion.front(); - for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) { - auto name = module.getInputName(i); - if (!name.empty()) setNameFn(block->getArgument(i), name); - } -} - -enum class Delimiter { - None, - Paren, // () enclosed list - OptionalLessGreater, // <> enclosed list or absent -}; - -/// Check parameter specified by `value` to see if it is valid according to the -/// module's parameters. If not, emit an error to the diagnostic provided as an -/// argument to the lambda 'instanceError' and return failure, otherwise return -/// success. -/// -/// If `disallowParamRefs` is true, then parameter references are not allowed. -LogicalResult hw::checkParameterInContext( - Attribute value, ArrayAttr moduleParameters, - const instance_like_impl::EmitErrorFn &instanceError, - bool disallowParamRefs) { - // Literals are always ok. Their types are already known to match - // expectations. - if (value.isa() || value.isa() || - value.isa() || value.isa()) - return success(); - - // Check both subexpressions of an expression. - if (auto expr = value.dyn_cast()) { - for (auto op : expr.getOperands()) - if (failed(checkParameterInContext(op, moduleParameters, instanceError, - disallowParamRefs))) - return failure(); - return success(); - } - - // Parameter references need more analysis to make sure they are valid within - // this module. - if (auto parameterRef = value.dyn_cast()) { - auto nameAttr = parameterRef.getName(); - - // Don't allow references to parameters from the default values of a - // parameter list. - if (disallowParamRefs) { - instanceError([&](auto &diag) { - diag << "parameter " << nameAttr - << " cannot be used as a default value for a parameter"; - return false; - }); - return failure(); - } - - // Find the corresponding attribute in the module. - for (auto param : moduleParameters) { - auto paramAttr = param.cast(); - if (paramAttr.getName() != nameAttr) continue; - - // If the types match then the reference is ok. - if (paramAttr.getType() == parameterRef.getType()) return success(); - - instanceError([&](auto &diag) { - diag << "parameter " << nameAttr << " used with type " - << parameterRef.getType() << "; should have type " - << paramAttr.getType(); - return true; - }); - return failure(); - } - - instanceError([&](auto &diag) { - diag << "use of unknown parameter " << nameAttr; - return true; - }); - return failure(); - } - - instanceError([&](auto &diag) { - diag << "invalid parameter value " << value; - return false; - }); - return failure(); -} - -/// Check parameter specified by `value` to see if it is valid within the scope -/// of the specified module `module`. If not, emit an error at the location of -/// `usingOp` and return failure, otherwise return success. If `usingOp` is -/// null, then no diagnostic is generated. -/// -/// If `disallowParamRefs` is true, then parameter references are not allowed. -LogicalResult hw::checkParameterInContext(Attribute value, Operation *module, - Operation *usingOp, - bool disallowParamRefs) { - instance_like_impl::EmitErrorFn emitError = - [&](const std::function &fn) { - if (usingOp) { - auto diag = usingOp->emitOpError(); - if (fn(diag)) - diag.attachNote(module->getLoc()) << "module declared here"; - } - }; - - return checkParameterInContext(value, - module->getAttrOfType("parameters"), - emitError, disallowParamRefs); -} - -/// Return true if the specified attribute tree is made up of nodes that are -/// valid in a parameter expression. -bool hw::isValidParameterExpression(Attribute attr, Operation *module) { - return succeeded(checkParameterInContext(attr, module, nullptr, false)); -} - -/// Return the name of the arg attributes list used for both modules and -/// instances. Normally we'd use the FunctionOpInterface for this, but both -/// modules and instances use the same attribute name, and instances don't -/// implement that interface. -StringAttr getArgAttrsName(MLIRContext *context) { - return HWModuleOp::getArgAttrsAttrName( - mlir::OperationName(HWModuleOp::getOperationName(), context)); -} - -/// Return the name of the result attributes list used for both modules and -/// instances. Normally we'd use the FunctionOpInterface for this, but both -/// modules and instances use the same attribute name, and instances don't -/// implement that interface. -StringAttr getResAttrsName(MLIRContext *context) { - return HWModuleOp::getResAttrsAttrName( - mlir::OperationName(HWModuleOp::getOperationName(), context)); -} - -HWModulePortAccessor::HWModulePortAccessor(Location loc, - const ModulePortInfo &info, - Region &bodyRegion) - : info(info) { - inputArgs.resize(info.sizeInputs()); - for (auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) { - inputIdx[info.at(i).name.str()] = i; - inputArgs[i] = barg; - } - - outputOperands.resize(info.sizeOutputs()); - for (auto [i, outputInfo] : llvm::enumerate(info.getOutputs())) { - outputIdx[outputInfo.name.str()] = i; - } -} - -void HWModulePortAccessor::setOutput(unsigned i, Value v) { - assert(outputOperands.size() > i && "invalid output index"); - assert(outputOperands[i] == Value() && "output already set"); - outputOperands[i] = v; -} - -Value HWModulePortAccessor::getInput(unsigned i) { - assert(inputArgs.size() > i && "invalid input index"); - return inputArgs[i]; -} -Value HWModulePortAccessor::getInput(StringRef name) { - return getInput(inputIdx.find(name.str())->second); -} -void HWModulePortAccessor::setOutput(StringRef name, Value v) { - setOutput(outputIdx.find(name.str())->second, v); -} - -//===----------------------------------------------------------------------===// -// ConstantOp -//===----------------------------------------------------------------------===// - -void ConstantOp::print(OpAsmPrinter &p) { - p << " "; - p.printAttribute(getValueAttr()); - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value"}); -} - -ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) { - IntegerAttr valueAttr; - - if (parser.parseAttribute(valueAttr, "value", result.attributes) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - result.addTypes(valueAttr.getType()); - return success(); -} - -LogicalResult ConstantOp::verify() { - // If the result type has a bitwidth, then the attribute must match its width. - if (getValue().getBitWidth() != getType().cast().getWidth()) - return emitError( - "hw.constant attribute bitwidth doesn't match return type"); - - return success(); -} - -/// Build a ConstantOp from an APInt, infering the result type from the -/// width of the APInt. -void ConstantOp::build(OpBuilder &builder, OperationState &result, - const APInt &value) { - auto type = IntegerType::get(builder.getContext(), value.getBitWidth()); - auto attr = builder.getIntegerAttr(type, value); - return build(builder, result, type, attr); -} - -/// Build a ConstantOp from an APInt, infering the result type from the -/// width of the APInt. -void ConstantOp::build(OpBuilder &builder, OperationState &result, - IntegerAttr value) { - return build(builder, result, value.getType(), value); -} - -/// This builder allows construction of small signed integers like 0, 1, -1 -/// matching a specified MLIR IntegerType. This shouldn't be used for general -/// constant folding because it only works with values that can be expressed in -/// an int64_t. Use APInt's instead. -void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type, - int64_t value) { - auto numBits = type.cast().getWidth(); - build(builder, result, APInt(numBits, (uint64_t)value, /*isSigned=*/true)); -} - -void ConstantOp::getAsmResultNames( - function_ref setNameFn) { - auto intTy = getType(); - auto intCst = getValue(); - - // Sugar i1 constants with 'true' and 'false'. - if (intTy.cast().getWidth() == 1) - return setNameFn(getResult(), intCst.isZero() ? "false" : "true"); - - // Otherwise, build a complex name with the value and type. - SmallVector specialNameBuffer; - llvm::raw_svector_ostream specialName(specialNameBuffer); - specialName << 'c' << intCst << '_' << intTy; - setNameFn(getResult(), specialName.str()); -} - -OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { - assert(adaptor.getOperands().empty() && "constant has no operands"); - return getValueAttr(); -} - -//===----------------------------------------------------------------------===// -// WireOp -//===----------------------------------------------------------------------===// - -/// Check whether an operation has any additional attributes set beyond its -/// standard list of attributes returned by `getAttributeNames`. -template -static bool hasAdditionalAttributes(Op op, - ArrayRef ignoredAttrs = {}) { - auto names = op.getAttributeNames(); - llvm::SmallDenseSet nameSet; - nameSet.reserve(names.size() + ignoredAttrs.size()); - nameSet.insert(names.begin(), names.end()); - nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end()); - return llvm::any_of(op->getAttrs(), [&](auto namedAttr) { - return !nameSet.contains(namedAttr.getName()); - }); -} - -void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { - // If the wire has an optional 'name' attribute, use it. - auto nameAttr = (*this)->getAttrOfType("name"); - if (!nameAttr.getValue().empty()) setNameFn(getResult(), nameAttr.getValue()); -} - -std::optional WireOp::getTargetResultIndex() { return 0; } - -OpFoldResult WireOp::fold(FoldAdaptor adaptor) { - // If the wire has no additional attributes, no name, and no symbol, just - // forward its input. - if (!hasAdditionalAttributes(*this, {"sv.namehint"}) && !getNameAttr() && - !getInnerSymAttr()) - return getInput(); - return {}; -} - -LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) { - // Block if the wire has any attributes. - if (hasAdditionalAttributes(wire, {"sv.namehint"})) return failure(); - - // If the wire has a symbol, then we can't delete it. - if (wire.getInnerSymAttr()) return failure(); - - // If the wire has a name or an `sv.namehint` attribute, propagate it as an - // `sv.namehint` to the expression. - if (auto *inputOp = wire.getInput().getDefiningOp()) { - auto name = wire.getNameAttr(); - if (!name || name.getValue().empty()) - name = wire->getAttrOfType("sv.namehint"); - if (name) - rewriter.updateRootInPlace( - inputOp, [&] { inputOp->setAttr("sv.namehint", name); }); - } - - rewriter.replaceOp(wire, wire.getInput()); - return success(); -} - -//===----------------------------------------------------------------------===// -// AggregateConstantOp -//===----------------------------------------------------------------------===// - -static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type) { - // If this is a type alias, get the underlying type. - if (auto typeAlias = type.dyn_cast()) - type = typeAlias.getCanonicalType(); - - if (auto structType = type.dyn_cast()) { - auto arrayAttr = attr.dyn_cast(); - if (!arrayAttr) - return op->emitOpError("expected array attribute for constant of type ") - << type; - for (auto [attr, fieldInfo] : - llvm::zip(arrayAttr.getValue(), structType.getElements())) { - if (failed(checkAttributes(op, attr, fieldInfo.type))) return failure(); - } - } else if (auto arrayType = type.dyn_cast()) { - auto arrayAttr = attr.dyn_cast(); - if (!arrayAttr) - return op->emitOpError("expected array attribute for constant of type ") - << type; - auto elementType = arrayType.getElementType(); - for (auto attr : arrayAttr.getValue()) { - if (failed(checkAttributes(op, attr, elementType))) return failure(); - } - } else if (auto arrayType = type.dyn_cast()) { - auto arrayAttr = attr.dyn_cast(); - if (!arrayAttr) - return op->emitOpError("expected array attribute for constant of type ") - << type; - auto elementType = arrayType.getElementType(); - for (auto attr : arrayAttr.getValue()) { - if (failed(checkAttributes(op, attr, elementType))) return failure(); - } - } else if (auto enumType = type.dyn_cast()) { - auto stringAttr = attr.dyn_cast(); - if (!stringAttr) - return op->emitOpError("expected string attribute for constant of type ") - << type; - } else if (auto intType = type.dyn_cast()) { - // Check the attribute kind is correct. - auto intAttr = attr.dyn_cast(); - if (!intAttr) - return op->emitOpError("expected integer attribute for constant of type ") - << type; - // Check the bitwidth is correct. - if (intAttr.getValue().getBitWidth() != intType.getWidth()) - return op->emitOpError( - "hw.constant attribute bitwidth " - "doesn't match return type"); - } else { - return op->emitOpError("unknown element type") << type; - } - return success(); -} - -LogicalResult AggregateConstantOp::verify() { - return checkAttributes(*this, getFieldsAttr(), getType()); -} - -OpFoldResult AggregateConstantOp::fold(FoldAdaptor) { return getFieldsAttr(); } - -//===----------------------------------------------------------------------===// -// ParamValueOp -//===----------------------------------------------------------------------===// - -static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, - Type &resultType) { - if (p.parseType(resultType) || p.parseEqual() || - p.parseAttribute(value, resultType)) - return failure(); - return success(); -} - -static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, - Type resultType) { - p << resultType << " = "; - p.printAttributeWithoutType(value); -} - -LogicalResult ParamValueOp::verify() { - // Check that the attribute expression is valid in this module. - return checkParameterInContext( - getValue(), (*this)->getParentOfType(), *this); -} - -OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) { - assert(adaptor.getOperands().empty() && "hw.param.value has no operands"); - return getValueAttr(); -} - -//===----------------------------------------------------------------------===// -// HWModuleOp -//===----------------------------------------------------------------------===/ - -/// Return true if isAnyModule or instance. -bool hw::isAnyModuleOrInstance(Operation *moduleOrInstance) { - return isa(moduleOrInstance); -} - -/// Return the signature for a module as a function type from the module itself -/// or from an hw::InstanceOp. -FunctionType hw::getModuleType(Operation *moduleOrInstance) { - if (auto instance = dyn_cast(moduleOrInstance)) { - SmallVector inputs(instance->getOperandTypes()); - SmallVector results(instance->getResultTypes()); - return FunctionType::get(instance->getContext(), inputs, results); - } - - if (auto mod = dyn_cast(moduleOrInstance)) - return mod.getModuleType().getFuncType(); - - if (auto mod = dyn_cast(moduleOrInstance)) - return mod.getHWModuleType().getFuncType(); - - return cast(moduleOrInstance) - .getFunctionType() - .cast(); -} - -/// Return the name to use for the Verilog module that we're referencing -/// here. This is typically the symbol, but can be overridden with the -/// verilogName attribute. -StringAttr hw::getVerilogModuleNameAttr(Operation *module) { - auto nameAttr = module->getAttrOfType("verilogName"); - if (nameAttr) return nameAttr; - - return module->getAttrOfType(SymbolTable::getSymbolAttrName()); -} - -// Flag for parsing different module types -enum ExternModKind { PlainMod, ExternMod, GenMod }; - -template -static void buildModule(OpBuilder &builder, OperationState &result, - StringAttr name, const ModulePortInfo &ports, - ArrayAttr parameters, - ArrayRef attributes, - StringAttr comment) { - using namespace mlir::function_interface_impl; - LocationAttr unknownLoc = builder.getUnknownLoc(); - - // Add an attribute for the name. - result.addAttribute(SymbolTable::getSymbolAttrName(), name); - - SmallVector argNames, resultNames; - SmallVector argTypes, resultTypes; - SmallVector argAttrs, resultAttrs; - SmallVector argLocs, resultLocs; - SmallVector portTypes; - auto exportPortIdent = StringAttr::get(builder.getContext(), "hw.exportPort"); - - for (auto elt : ports.getInputs()) { - portTypes.push_back(elt); - if (elt.dir == ModulePort::Direction::InOut && - !elt.type.isa()) - elt.type = hw::InOutType::get(elt.type); - argTypes.push_back(elt.type); - argNames.push_back(elt.name); - argLocs.push_back(elt.loc ? elt.loc : unknownLoc); - Attribute attr; - if (elt.sym && !elt.sym.empty()) - attr = builder.getDictionaryAttr({{exportPortIdent, elt.sym}}); - else - attr = builder.getDictionaryAttr({}); - argAttrs.push_back(attr); - } - - for (auto elt : ports.getOutputs()) { - portTypes.push_back(elt); - resultTypes.push_back(elt.type); - resultNames.push_back(elt.name); - resultLocs.push_back(elt.loc ? elt.loc : unknownLoc); - Attribute attr; - if (elt.sym && !elt.sym.empty()) - attr = builder.getDictionaryAttr({{exportPortIdent, elt.sym}}); - else - attr = builder.getDictionaryAttr({}); - resultAttrs.push_back(attr); - } - - // Allow clients to pass in null for the parameters list. - if (!parameters) parameters = builder.getArrayAttr({}); - - // Record the argument and result types as an attribute. - auto type = ModuleType::get(builder.getContext(), portTypes); - result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), - TypeAttr::get(type)); - result.addAttribute("argLocs", builder.getArrayAttr(argLocs)); - result.addAttribute("resultLocs", builder.getArrayAttr(resultLocs)); - result.addAttribute(ModuleTy::getArgAttrsAttrName(result.name), - builder.getArrayAttr(argAttrs)); - result.addAttribute(ModuleTy::getResAttrsAttrName(result.name), - builder.getArrayAttr(resultAttrs)); - result.addAttribute("parameters", parameters); - if (!comment) comment = builder.getStringAttr(""); - result.addAttribute("comment", comment); - result.addAttributes(attributes); - result.addRegion(); -} - -/// Internal implementation of argument/result insertion and removal on modules. -static void modifyModuleArgs( - MLIRContext *context, ArrayRef> insertArgs, - ArrayRef removeArgs, ArrayRef oldArgNames, - ArrayRef oldArgTypes, ArrayRef oldArgAttrs, - ArrayRef oldArgLocs, SmallVector &newArgNames, - SmallVector &newArgTypes, SmallVector &newArgAttrs, - SmallVector &newArgLocs, Block *body = nullptr) { -#ifndef NDEBUG - // Check that the `insertArgs` and `removeArgs` indices are in ascending - // order. - assert(llvm::is_sorted(insertArgs, - [](auto &a, auto &b) { return a.first < b.first; }) && - "insertArgs must be in ascending order"); - assert(llvm::is_sorted(removeArgs, [](auto &a, auto &b) { return a < b; }) && - "removeArgs must be in ascending order"); -#endif - - auto oldArgCount = oldArgTypes.size(); - auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size(); - assert((int)newArgCount >= 0); - - newArgNames.reserve(newArgCount); - newArgTypes.reserve(newArgCount); - newArgAttrs.reserve(newArgCount); - newArgLocs.reserve(newArgCount); - - auto exportPortAttrName = StringAttr::get(context, "hw.exportPort"); - auto emptyDictAttr = DictionaryAttr::get(context, {}); - auto unknownLoc = UnknownLoc::get(context); - - BitVector erasedIndices; - if (body) erasedIndices.resize(oldArgCount + insertArgs.size()); - - for (unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) { - // Insert new ports at this position. - while (!insertArgs.empty() && insertArgs[0].first == argIdx) { - auto port = insertArgs[0].second; - if (port.dir == ModulePort::Direction::InOut && - !port.type.isa()) - port.type = InOutType::get(port.type); - Attribute attr = - (port.sym && !port.sym.empty()) - ? DictionaryAttr::get(context, {{exportPortAttrName, port.sym}}) - : emptyDictAttr; - newArgNames.push_back(port.name); - newArgTypes.push_back(port.type); - newArgAttrs.push_back(attr); - insertArgs = insertArgs.drop_front(); - LocationAttr loc = port.loc ? port.loc : unknownLoc; - newArgLocs.push_back(loc); - if (body) body->insertArgument(idx++, port.type, loc); - } - if (argIdx == oldArgCount) break; - - // Migrate the old port at this position. - bool removed = false; - while (!removeArgs.empty() && removeArgs[0] == argIdx) { - removeArgs = removeArgs.drop_front(); - removed = true; - } - - if (removed) { - if (body) erasedIndices.set(idx); - } else { - newArgNames.push_back(oldArgNames[argIdx]); - newArgTypes.push_back(oldArgTypes[argIdx]); - newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr - : oldArgAttrs[argIdx]); - newArgLocs.push_back(oldArgLocs[argIdx]); - } - } - - if (body) body->eraseArguments(erasedIndices); - - assert(newArgNames.size() == newArgCount); - assert(newArgTypes.size() == newArgCount); - assert(newArgAttrs.size() == newArgCount); - assert(newArgLocs.size() == newArgCount); -} - -/// Insert and remove ports of a module. The insertion and removal indices must -/// be in ascending order. The indices refer to the port positions before any -/// insertion or removal occurs. Ports inserted at the same index will appear in -/// the module in the same order as they were listed in the `insert*` array. -/// -/// The operation must be any of the module-like operations. -void hw::modifyModulePorts( - Operation *op, ArrayRef> insertInputs, - ArrayRef> insertOutputs, - ArrayRef removeInputs, ArrayRef removeOutputs, - Block *body) { - auto moduleOp = cast(op); - auto *context = moduleOp.getContext(); - - // Dig up the old argument and result data. - auto oldArgNames = moduleOp.getInputNames(); - auto oldArgTypes = moduleOp.getInputTypes(); - auto oldArgAttrs = moduleOp.getAllInputAttrs(); - auto oldArgLocs = moduleOp.getInputLocs(); - - auto oldResultNames = moduleOp.getOutputNames(); - auto oldResultTypes = moduleOp.getOutputTypes(); - auto oldResultAttrs = moduleOp.getAllOutputAttrs(); - auto oldResultLocs = moduleOp.getOutputLocs(); - - // Modify the ports. - SmallVector newArgNames, newResultNames; - SmallVector newArgTypes, newResultTypes; - SmallVector newArgAttrs, newResultAttrs; - SmallVector newArgLocs, newResultLocs; - - modifyModuleArgs(context, insertInputs, removeInputs, oldArgNames, - oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames, - newArgTypes, newArgAttrs, newArgLocs, body); - - modifyModuleArgs(context, insertOutputs, removeOutputs, oldResultNames, - oldResultTypes, oldResultAttrs, oldResultLocs, - newResultNames, newResultTypes, newResultAttrs, - newResultLocs); - - // Update the module operation types and attributes. - auto fnty = FunctionType::get(context, newArgTypes, newResultTypes); - auto modty = detail::fnToMod(fnty, newArgNames, newResultNames); - moduleOp.setHWModuleType(modty); - moduleOp.setAllInputAttrs(newArgAttrs); - moduleOp.setInputLocs(newArgLocs); - moduleOp.setAllOutputAttrs(newResultAttrs); - moduleOp.setOutputLocs(newResultLocs); -} - -void HWModuleOp::build(OpBuilder &builder, OperationState &result, - StringAttr name, const ModulePortInfo &ports, - ArrayAttr parameters, - ArrayRef attributes, StringAttr comment, - bool shouldEnsureTerminator) { - buildModule(builder, result, name, ports, parameters, attributes, - comment); - - // Create a region and a block for the body. - auto *bodyRegion = result.regions[0].get(); - Block *body = new Block(); - bodyRegion->push_back(body); - - // Add arguments to the body block. - auto unknownLoc = builder.getUnknownLoc(); - for (auto port : ports.getInputs()) { - auto loc = port.loc ? Location(port.loc) : unknownLoc; - auto type = port.type; - if (port.isInOut() && !type.isa()) type = InOutType::get(type); - body->addArgument(type, loc); - } - - if (shouldEnsureTerminator) - HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location); -} - -void HWModuleOp::build(OpBuilder &builder, OperationState &result, - StringAttr name, ArrayRef ports, - ArrayAttr parameters, - ArrayRef attributes, - StringAttr comment) { - build(builder, result, name, ModulePortInfo(ports), parameters, attributes, - comment); -} - -void HWModuleOp::build(OpBuilder &builder, OperationState &odsState, - StringAttr name, const ModulePortInfo &ports, - HWModuleBuilder modBuilder, ArrayAttr parameters, - ArrayRef attributes, - StringAttr comment) { - build(builder, odsState, name, ports, parameters, attributes, comment, - /*shouldEnsureTerminator=*/false); - auto *bodyRegion = odsState.regions[0].get(); - OpBuilder::InsertionGuard guard(builder); - auto accessor = HWModulePortAccessor(odsState.location, ports, *bodyRegion); - builder.setInsertionPointToEnd(&bodyRegion->front()); - modBuilder(builder, accessor); - // Create output operands. - llvm::SmallVector outputOperands = accessor.getOutputOperands(); - builder.create(odsState.location, outputOperands); -} - -void HWModuleOp::modifyPorts( - ArrayRef> insertInputs, - ArrayRef> insertOutputs, - ArrayRef eraseInputs, ArrayRef eraseOutputs) { - hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs, - eraseOutputs); -} - -/// Return the name to use for the Verilog module that we're referencing -/// here. This is typically the symbol, but can be overridden with the -/// verilogName attribute. -StringAttr HWModuleExternOp::getVerilogModuleNameAttr() { - if (auto vName = getVerilogNameAttr()) return vName; - - return (*this)->getAttrOfType(SymbolTable::getSymbolAttrName()); -} - -StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() { - if (auto vName = getVerilogNameAttr()) { - return vName; - } - return (*this)->getAttrOfType( - ::mlir::SymbolTable::getSymbolAttrName()); -} - -void HWModuleExternOp::build(OpBuilder &builder, OperationState &result, - StringAttr name, const ModulePortInfo &ports, - StringRef verilogName, ArrayAttr parameters, - ArrayRef attributes) { - buildModule(builder, result, name, ports, parameters, - attributes, {}); - - if (!verilogName.empty()) - result.addAttribute("verilogName", builder.getStringAttr(verilogName)); -} - -void HWModuleExternOp::build(OpBuilder &builder, OperationState &result, - StringAttr name, ArrayRef ports, - StringRef verilogName, ArrayAttr parameters, - ArrayRef attributes) { - build(builder, result, name, ModulePortInfo(ports), verilogName, parameters, - attributes); -} - -void HWModuleExternOp::modifyPorts( - ArrayRef> insertInputs, - ArrayRef> insertOutputs, - ArrayRef eraseInputs, ArrayRef eraseOutputs) { - hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs, - eraseOutputs); -} - -void HWModuleExternOp::appendOutputs( - ArrayRef> outputs) {} - -void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result, - FlatSymbolRefAttr genKind, StringAttr name, - const ModulePortInfo &ports, - StringRef verilogName, ArrayAttr parameters, - ArrayRef attributes) { - buildModule(builder, result, name, ports, parameters, - attributes, {}); - result.addAttribute("generatorKind", genKind); - if (!verilogName.empty()) - result.addAttribute("verilogName", builder.getStringAttr(verilogName)); -} - -void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result, - FlatSymbolRefAttr genKind, StringAttr name, - ArrayRef ports, StringRef verilogName, - ArrayAttr parameters, - ArrayRef attributes) { - build(builder, result, genKind, name, ModulePortInfo(ports), verilogName, - parameters, attributes); -} - -void HWModuleGeneratedOp::modifyPorts( - ArrayRef> insertInputs, - ArrayRef> insertOutputs, - ArrayRef eraseInputs, ArrayRef eraseOutputs) { - hw::modifyModulePorts(*this, insertInputs, insertOutputs, eraseInputs, - eraseOutputs); -} - -void HWModuleGeneratedOp::appendOutputs( - ArrayRef> outputs) {} - -static InnerSymAttr extractSym(DictionaryAttr attrs) { - if (attrs) - if (auto symRef = attrs.get("hw.exportPort")) - return symRef.cast(); - return {}; -} - -static bool hasAttribute(StringRef name, ArrayRef attrs) { - for (auto &argAttr : attrs) - if (argAttr.getName() == name) return true; - return false; -} - -static Attribute getAttribute(StringRef name, ArrayRef attrs) { - for (auto &argAttr : attrs) - if (argAttr.getName() == name) return argAttr.getValue(); - return {}; -} - -template -static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result, - ExternModKind modKind = PlainMod) { - using namespace mlir::function_interface_impl; - auto loc = parser.getCurrentLocation(); - - // Parse the visibility attribute. - (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes); - - // Parse the name as a symbol. - StringAttr nameAttr; - if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), - result.attributes)) - return failure(); - - // Parse the generator information. - FlatSymbolRefAttr kindAttr; - if (modKind == GenMod) { - if (parser.parseComma() || - parser.parseAttribute(kindAttr, "generatorKind", result.attributes)) { - return failure(); - } - } - - // Parse the parameters. - ArrayAttr parameters; - if (parseOptionalParameterList(parser, parameters)) return failure(); - - // Parse the function signature. - bool isVariadic = false; - SmallVector entryArgs; - SmallVector argNames; - SmallVector argLocs; - SmallVector resultNames; - SmallVector resultAttrs; - SmallVector resultLocs; - TypeAttr functionType; - if (failed(module_like_impl::parseModuleFunctionSignature( - parser, isVariadic, entryArgs, argNames, argLocs, resultNames, - resultAttrs, resultLocs, functionType))) - return failure(); - - // Parse the attribute dict. - if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) - return failure(); - - if (hasAttribute("resultNames", result.attributes) || - hasAttribute("parameters", result.attributes)) { - parser.emitError( - loc, "explicit `resultNames` / `parameters` attributes not allowed"); - return failure(); - } - - auto *context = result.getContext(); - // prefer the attribute over the ssa values - auto attr = getAttribute("argNames", result.attributes); - auto modType = detail::fnToMod( - cast(functionType.getValue()), - attr ? cast(attr).getValue() : argNames, resultNames); - result.attributes.erase("argNames"); - result.addAttribute("argLocs", ArrayAttr::get(context, argLocs)); - result.addAttribute("resultLocs", ArrayAttr::get(context, resultLocs)); - result.addAttribute("parameters", parameters); - if (!hasAttribute("comment", result.attributes)) - result.addAttribute("comment", StringAttr::get(context, "")); - result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), - TypeAttr::get(modType)); - - // Add the attributes to the function arguments. - addArgAndResultAttrs(parser.getBuilder(), result, entryArgs, resultAttrs, - ModuleTy::getArgAttrsAttrName(result.name), - ModuleTy::getResAttrsAttrName(result.name)); - - // Parse the optional function body. - auto *body = result.addRegion(); - if (modKind == PlainMod) { - if (parser.parseRegion(*body, entryArgs)) return failure(); - - HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); - } - return success(); -} - -ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) { - return parseHWModuleOp(parser, result); -} - -ParseResult HWModuleExternOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseHWModuleOp(parser, result, ExternMod); -} - -ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseHWModuleOp(parser, result, GenMod); -} - -FunctionType getHWModuleOpType(Operation *op) { - if (auto mod = dyn_cast(op)) - return mod.getHWModuleType().getFuncType(); - return cast(op) - .getFunctionType() - .cast(); -} - -template -static void printModuleOp(OpAsmPrinter &p, ModuleTy mod, - ExternModKind modKind) { - using namespace mlir::function_interface_impl; - - FunctionType fnType = mod.getHWModuleType().getFuncType(); - auto argTypes = fnType.getInputs(); - auto resultTypes = fnType.getResults(); - - p << ' '; - - // Print the visibility of the module. - StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); - if (auto visibility = mod.getOperation()->template getAttrOfType( - visibilityAttrName)) - p << visibility.getValue() << ' '; - - // Print the operation and the function name. - p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue()); - if (modKind == GenMod) { - p << ", "; - p.printSymbolName( - cast(mod.getOperation()).getGeneratorKind()); - } - - // Print the parameter list if present. - printOptionalParameterList( - p, mod.getOperation(), - mod.getOperation()->template getAttrOfType("parameters")); - - bool needArgNamesAttr = false; - module_like_impl::printModuleSignature(p, mod.getOperation(), argTypes, - /*isVariadic=*/false, resultTypes, - needArgNamesAttr); - - SmallVector omittedAttrs; - if (modKind == GenMod) omittedAttrs.push_back("generatorKind"); - if (!needArgNamesAttr) omittedAttrs.push_back("argNames"); - omittedAttrs.push_back("argLocs"); - omittedAttrs.push_back( - ModuleTy::getModuleTypeAttrName(mod.getOperation()->getName())); - omittedAttrs.push_back( - ModuleTy::getArgAttrsAttrName(mod.getOperation()->getName())); - omittedAttrs.push_back( - ModuleTy::getResAttrsAttrName(mod.getOperation()->getName())); - omittedAttrs.push_back("resultNames"); - omittedAttrs.push_back("resultLocs"); - omittedAttrs.push_back("parameters"); - omittedAttrs.push_back(visibilityAttrName); - omittedAttrs.push_back(SymbolTable::getSymbolAttrName()); - if (mod.getOperation() - ->template getAttrOfType("comment") - .getValue() - .empty()) - omittedAttrs.push_back("comment"); - // inject argNames - auto attrs = mod->getAttrs(); - SmallVector realAttrs(attrs.begin(), attrs.end()); - realAttrs.push_back( - NamedAttribute(StringAttr::get(mod.getContext(), "argNames"), - ArrayAttr::get(mod.getContext(), mod.getInputNames()))); - p.printOptionalAttrDictWithKeyword(realAttrs, omittedAttrs); -} - -void HWModuleExternOp::print(OpAsmPrinter &p) { - printModuleOp(p, *this, ExternMod); -} -void HWModuleGeneratedOp::print(OpAsmPrinter &p) { - printModuleOp(p, *this, GenMod); -} - -void HWModuleOp::print(OpAsmPrinter &p) { - printModuleOp(p, *this, PlainMod); - - // Print the body if this is not an external function. - Region &body = getBody(); - if (!body.empty()) { - p << " "; - p.printRegion(body, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); - } -} - -static LogicalResult verifyModuleCommon(HWModuleLike module) { - assert(isa(module) && - "verifier hook should only be called on modules"); - - auto moduleType = module.getHWModuleType(); - - auto argLocs = module.getInputLocs(); - if (argLocs.size() != moduleType.getNumInputs()) - return module->emitOpError("incorrect number of argument locations"); - - auto resultLocs = module.getOutputLocs(); - if (resultLocs.size() != moduleType.getNumOutputs()) - return module->emitOpError("incorrect number of result locations"); - - SmallPtrSet paramNames; - - // Check parameter default values are sensible. - for (auto param : module->getAttrOfType("parameters")) { - auto paramAttr = param.cast(); - - // Check that we don't have any redundant parameter names. These are - // resolved by string name: reuse of the same name would cause ambiguities. - if (!paramNames.insert(paramAttr.getName()).second) - return module->emitOpError("parameter ") - << paramAttr << " has the same name as a previous parameter"; - - // Default values are allowed to be missing, check them if present. - auto value = paramAttr.getValue(); - if (!value) continue; - - auto typedValue = value.dyn_cast(); - if (!typedValue) - return module->emitOpError("parameter ") - << paramAttr << " should have a typed value; has value " << value; - - if (typedValue.getType() != paramAttr.getType()) - return module->emitOpError("parameter ") - << paramAttr << " should have type " << paramAttr.getType() - << "; has type " << typedValue.getType(); - - // Verify that this is a valid parameter value, disallowing parameter - // references. We could allow parameters to refer to each other in the - // future with lexical ordering if there is a need. - if (failed(checkParameterInContext(value, module, module, - /*disallowParamRefs=*/true))) - return failure(); - } - return success(); -} - -LogicalResult HWModuleOp::verify() { - if (failed(verifyModuleCommon(*this))) return failure(); - - auto type = getModuleType(); - auto *body = getBodyBlock(); - - // Verify the number of block arguments. - auto numInputs = type.getNumInputs(); - if (body->getNumArguments() != numInputs) - return emitOpError("entry block must have") - << numInputs << " arguments to match module signature"; - - // Verify that the block arguments match the op's attributes. - for (auto [arg, type, loc] : llvm::zip(getBodyBlock()->getArguments(), - getInputTypes(), getInputLocs())) { - if (arg.getType() != type) - return emitOpError("block argument types should match signature types"); - if (arg.getLoc() != loc.cast()) - return emitOpError( - "block argument locations should match signature locations"); - } - - return success(); -} - -LogicalResult HWModuleExternOp::verify() { return verifyModuleCommon(*this); } - -std::pair HWModuleOp::insertInput(unsigned index, - StringAttr name, - Type ty) { - // Find a unique name for the wire. - Namespace ns; - auto ports = getPortList(); - for (auto port : ports) ns.newName(port.name.getValue()); - auto nameAttr = StringAttr::get(getContext(), ns.newName(name.getValue())); - - Block *body = getBodyBlock(); - - // Create a new port for the host clock. - PortInfo port; - port.name = nameAttr; - port.dir = ModulePort::Direction::Input; - port.type = ty; - hw::modifyModulePorts(getOperation(), {std::make_pair(index, port)}, {}, {}, - {}, body); - - // Add a new argument. - return {nameAttr, body->getArgument(index)}; -} - -void HWModuleOp::insertOutputs(unsigned index, - ArrayRef> outputs) { - auto output = cast(getBodyBlock()->getTerminator()); - assert(index <= output->getNumOperands() && "invalid output index"); - - // Rewrite the port list of the module. - SmallVector> indexedNewPorts; - for (auto &[name, value] : outputs) { - PortInfo port; - port.name = name; - port.dir = ModulePort::Direction::Output; - port.type = value.getType(); - indexedNewPorts.emplace_back(index, port); - } - hw::modifyModulePorts(getOperation(), {}, indexedNewPorts, {}, {}, - getBodyBlock()); - - // Rewrite the output op. - for (auto &[name, value] : outputs) output->insertOperands(index++, value); -} - -void HWModuleOp::appendOutputs(ArrayRef> outputs) { - return insertOutputs(getNumOutputPorts(), outputs); -} - -void HWModuleOp::getAsmBlockArgumentNames(mlir::Region ®ion, - mlir::OpAsmSetValueNameFn setNameFn) { - getAsmBlockArgumentNamesImpl(region, setNameFn); -} - -void HWModuleExternOp::getAsmBlockArgumentNames( - mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) { - getAsmBlockArgumentNamesImpl(region, setNameFn); -} - -template -static SmallVector getAllPortLocs(ModTy module) { - SmallVector retval; - auto empty = UnknownLoc::get(module.getContext()); - auto locs = module.getArgLocs(); - if (locs) - for (auto l : locs) retval.push_back(cast(l)); - retval.resize(module.getNumInputPorts(), empty); - locs = module.getResultLocs(); - if (locs) - for (auto l : locs) retval.push_back(cast(l)); - retval.resize(module.getNumInputPorts() + module.getNumOutputPorts(), empty); - return retval; -} - -SmallVector HWModuleOp::getAllPortLocs() { - return ::getAllPortLocs(*this); -} - -SmallVector HWModuleExternOp::getAllPortLocs() { - return ::getAllPortLocs(*this); -} - -SmallVector HWModuleGeneratedOp::getAllPortLocs() { - return ::getAllPortLocs(*this); -} - -template -static void setAllPortLocs(ArrayRef locs, ModTy module) { - auto numInputs = module.getNumInputPorts(); - SmallVector argLocs(locs.begin(), locs.begin() + numInputs); - SmallVector resLocs(locs.begin() + numInputs, locs.end()); - module.setArgLocsAttr(ArrayAttr::get(module.getContext(), argLocs)); - module.setResultLocsAttr(ArrayAttr::get(module.getContext(), resLocs)); -} - -void HWModuleOp::setAllPortLocs(ArrayRef locs) { - ::setAllPortLocs(locs, *this); -} - -void HWModuleExternOp::setAllPortLocs(ArrayRef locs) { - ::setAllPortLocs(locs, *this); -} - -void HWModuleGeneratedOp::setAllPortLocs(ArrayRef locs) { - ::setAllPortLocs(locs, *this); -} - -template -static void setAllPortNames(ArrayRef names, ModTy module) { - auto numInputs = module.getNumInputPorts(); - SmallVector argNames(names.begin(), names.begin() + numInputs); - SmallVector resNames(names.begin() + numInputs, names.end()); - auto oldType = module.getModuleType(); - SmallVector newPorts(oldType.getPorts().begin(), - oldType.getPorts().end()); - for (size_t i = 0UL, e = newPorts.size(); i != e; ++i) - newPorts[i].name = cast(names[i]); - auto newType = ModuleType::get(module.getContext(), newPorts); - module.setModuleType(newType); -} - -void HWModuleOp::setAllPortNames(ArrayRef names) { - ::setAllPortNames(names, *this); -} - -void HWModuleExternOp::setAllPortNames(ArrayRef names) { - ::setAllPortNames(names, *this); -} - -void HWModuleGeneratedOp::setAllPortNames(ArrayRef names) { - ::setAllPortNames(names, *this); -} - -template -static SmallVector getAllPortAttrs(ModTy &mod) { - SmallVector retval; - auto empty = DictionaryAttr::get(mod.getContext()); - auto attrs = mod.getArgAttrs(); - if (attrs) - for (auto a : *attrs) retval.push_back(a); - retval.resize(mod.getNumInputPorts(), empty); - attrs = mod.getResAttrs(); - if (attrs) - for (auto a : *attrs) retval.push_back(a); - retval.resize(mod.getNumInputPorts() + mod.getNumOutputPorts(), empty); - return retval; -} - -SmallVector HWModuleOp::getAllPortAttrs() { - return ::getAllPortAttrs(*this); -} - -SmallVector HWModuleExternOp::getAllPortAttrs() { - return ::getAllPortAttrs(*this); -} - -SmallVector HWModuleGeneratedOp::getAllPortAttrs() { - return ::getAllPortAttrs(*this); -} - -template -static void setAllPortAttrs(ModTy &mod, ArrayRef attrs) { - auto numInputs = mod.getNumInputPorts(); - SmallVector argAttrs(attrs.begin(), attrs.begin() + numInputs); - SmallVector resAttrs(attrs.begin() + numInputs, attrs.end()); - - mod.setArgAttrsAttr(ArrayAttr::get(mod.getContext(), argAttrs)); - mod.setResAttrsAttr(ArrayAttr::get(mod.getContext(), resAttrs)); -} - -void HWModuleOp::setAllPortAttrs(ArrayRef attrs) { - return ::setAllPortAttrs(*this, attrs); -} - -void HWModuleExternOp::setAllPortAttrs(ArrayRef attrs) { - return ::setAllPortAttrs(*this, attrs); -} - -void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef attrs) { - return ::setAllPortAttrs(*this, attrs); -} - -template -static void removeAllPortAttrs(ModTy &mod) { - mod.setArgAttrsAttr(ArrayAttr::get(mod.getContext(), {})); - mod.setResAttrsAttr(ArrayAttr::get(mod.getContext(), {})); -} - -void HWModuleOp::removeAllPortAttrs() { return ::removeAllPortAttrs(*this); } - -void HWModuleExternOp::removeAllPortAttrs() { - return ::removeAllPortAttrs(*this); -} - -void HWModuleGeneratedOp::removeAllPortAttrs() { - return ::removeAllPortAttrs(*this); -} - -// This probably does really unexpected stuff when you change the number of - -template -static void setHWModuleType(ModTy &mod, ModuleType type) { - auto argAttrs = mod.getAllInputAttrs(); - auto resAttrs = mod.getAllOutputAttrs(); - mod.setModuleTypeAttr(TypeAttr::get(type)); - unsigned newNumArgs = type.getNumInputs(); - unsigned newNumResults = type.getNumOutputs(); - - auto emptyDict = DictionaryAttr::get(mod.getContext()); - argAttrs.resize(newNumArgs, emptyDict); - resAttrs.resize(newNumResults, emptyDict); - - SmallVector attrs; - attrs.append(argAttrs.begin(), argAttrs.end()); - attrs.append(resAttrs.begin(), resAttrs.end()); - - if (attrs.empty()) return mod.removeAllPortAttrs(); - mod.setAllPortAttrs(attrs); -} - -void HWModuleOp::setHWModuleType(ModuleType type) { - return ::setHWModuleType(*this, type); -} - -void HWModuleExternOp::setHWModuleType(ModuleType type) { - return ::setHWModuleType(*this, type); -} - -void HWModuleGeneratedOp::setHWModuleType(ModuleType type) { - return ::setHWModuleType(*this, type); -} - -/// Lookup the generator for the symbol. This returns null on -/// invalid IR. -Operation *HWModuleGeneratedOp::getGeneratorKindOp() { - auto topLevelModuleOp = (*this)->getParentOfType(); - return topLevelModuleOp.lookupSymbol(getGeneratorKind()); -} - -LogicalResult HWModuleGeneratedOp::verifySymbolUses( - SymbolTableCollection &symbolTable) { - auto *referencedKind = - symbolTable.lookupNearestSymbolFrom(*this, getGeneratorKindAttr()); - - if (referencedKind == nullptr) - return emitError("Cannot find generator definition '") - << getGeneratorKind() << "'"; - - if (!isa(referencedKind)) - return emitError("Symbol resolved to '") - << referencedKind->getName() - << "' which is not a HWGeneratorSchemaOp"; - - auto referencedKindOp = dyn_cast(referencedKind); - auto paramRef = referencedKindOp.getRequiredAttrs(); - auto dict = (*this)->getAttrDictionary(); - for (auto str : paramRef) { - auto strAttr = str.dyn_cast(); - if (!strAttr) return emitError("Unknown attribute type, expected a string"); - if (!dict.get(strAttr.getValue())) - return emitError("Missing attribute '") << strAttr.getValue() << "'"; - } - - return success(); -} - -LogicalResult HWModuleGeneratedOp::verify() { - return verifyModuleCommon(*this); -} - -void HWModuleGeneratedOp::getAsmBlockArgumentNames( - mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) { - getAsmBlockArgumentNamesImpl(region, setNameFn); -} - -LogicalResult HWModuleOp::verifyBody() { return success(); } - -//===----------------------------------------------------------------------===// -// InstanceOp -//===----------------------------------------------------------------------===// - -/// Create a instance that refers to a known module. -void InstanceOp::build(OpBuilder &builder, OperationState &result, - Operation *module, StringAttr name, - ArrayRef inputs, ArrayAttr parameters, - InnerSymAttr innerSym) { - if (!parameters) parameters = builder.getArrayAttr({}); - - auto mod = cast(module); - auto argNames = builder.getArrayAttr(mod.getInputNames()); - auto resultNames = builder.getArrayAttr(mod.getOutputNames()); - FunctionType modType = mod.getHWModuleType().getFuncType(); - build(builder, result, modType.getResults(), name, - FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs, - argNames, resultNames, parameters, innerSym); -} - -std::optional InstanceOp::getTargetResultIndex() { - // Inner symbols on instance operations target the op not any result. - return std::nullopt; -} - -/// Lookup the module or extmodule for the symbol. This returns null on -/// invalid IR. -Operation *InstanceOp::getReferencedModule(const HWSymbolCache *cache) { - return instance_like_impl::getReferencedModule(cache, *this, - getModuleNameAttr()); -} - -Operation *InstanceOp::getReferencedModule(SymbolTable &symtbl) { - return symtbl.lookup(getModuleNameAttr().getValue()); -} - -Operation *InstanceOp::getReferencedModuleSlow() { - return getReferencedModule(/*cache=*/nullptr); -} - -LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - return instance_like_impl::verifyInstanceOfHWModule( - *this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(), - getResultNames(), getParameters(), symbolTable); -} - -LogicalResult InstanceOp::verify() { - auto module = (*this)->getParentOfType(); - if (!module) return success(); - - auto moduleParameters = module->getAttrOfType("parameters"); - instance_like_impl::EmitErrorFn emitError = - [&](const std::function &fn) { - auto diag = emitOpError(); - if (fn(diag)) - diag.attachNote(module->getLoc()) << "module declared here"; - }; - return instance_like_impl::verifyParameterStructure( - getParameters(), moduleParameters, emitError); -} - -ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) { - StringAttr instanceNameAttr; - InnerSymAttr innerSym; - FlatSymbolRefAttr moduleNameAttr; - SmallVector inputsOperands; - SmallVector inputsTypes, allResultTypes; - ArrayAttr argNames, resultNames, parameters; - auto noneType = parser.getBuilder().getType(); - - if (parser.parseAttribute(instanceNameAttr, noneType, "instanceName", - result.attributes)) - return failure(); - - if (succeeded(parser.parseOptionalKeyword("sym"))) { - // Parsing an optional symbol name doesn't fail, so no need to check the - // result. - if (parser.parseCustomAttributeWithFallback(innerSym)) return failure(); - result.addAttribute(InnerSymbolTable::getInnerSymbolAttrName(), innerSym); - } - - llvm::SMLoc parametersLoc, inputsOperandsLoc; - if (parser.parseAttribute(moduleNameAttr, noneType, "moduleName", - result.attributes) || - parser.getCurrentLocation(¶metersLoc) || - parseOptionalParameterList(parser, parameters) || - parseInputPortList(parser, inputsOperands, inputsTypes, argNames) || - parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc, - result.operands) || - parser.parseArrow() || - parseOutputPortList(parser, allResultTypes, resultNames) || - parser.parseOptionalAttrDict(result.attributes)) { - return failure(); - } - - result.addAttribute("argNames", argNames); - result.addAttribute("resultNames", resultNames); - result.addAttribute("parameters", parameters); - result.addTypes(allResultTypes); - return success(); -} - -void InstanceOp::print(OpAsmPrinter &p) { - p << ' '; - p.printAttributeWithoutType(getInstanceNameAttr()); - if (auto attr = getInnerSymAttr()) { - p << " sym "; - attr.print(p); - } - p << ' '; - p.printAttributeWithoutType(getModuleNameAttr()); - printOptionalParameterList(p, *this, getParameters()); - printInputPortList(p, *this, getInputs(), getInputs().getTypes(), - getArgNames()); - p << " -> "; - printOutputPortList(p, *this, getResultTypes(), getResultNames()); - - p.printOptionalAttrDict( - (*this)->getAttrs(), - /*elidedAttrs=*/{"instanceName", - InnerSymbolTable::getInnerSymbolAttrName(), "moduleName", - "argNames", "resultNames", "parameters"}); -} - -/// Return the name of the specified input port or null if it cannot be -/// determined. -StringAttr InstanceOp::getArgumentName(size_t idx) { - return instance_like_impl::getName(getArgNames(), idx); -} - -/// Return the name of the specified result or null if it cannot be -/// determined. -StringAttr InstanceOp::getResultName(size_t idx) { - return instance_like_impl::getName(getResultNames(), idx); -} - -/// Change the name of the specified input port. -void InstanceOp::setArgumentName(size_t i, StringAttr name) { - setInputNames(instance_like_impl::updateName(getArgNames(), i, name)); -} - -/// Change the name of the specified output port. -void InstanceOp::setResultName(size_t i, StringAttr name) { - setOutputNames(instance_like_impl::updateName(getResultNames(), i, name)); -} - -/// Suggest a name for each result value based on the saved result names -/// attribute. -void InstanceOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { - instance_like_impl::getAsmResultNames(setNameFn, getInstanceName(), - getResultNames(), getResults()); -} - -ModulePortInfo InstanceOp::getPortList() { - SmallVector inputs, outputs; - auto emptyDict = DictionaryAttr::get(getContext()); - auto argNames = (*this)->getAttrOfType("argNames"); - auto argTypes = getModuleType(*this).getInputs(); - auto argLocs = (*this)->getAttrOfType("argLocs"); - for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { - auto type = argTypes[i]; - auto direction = ModulePort::Direction::Input; - - if (auto inout = type.dyn_cast()) { - type = inout.getElementType(); - direction = ModulePort::Direction::InOut; - } - - LocationAttr loc; - if (argLocs) loc = argLocs[i].cast(); - inputs.push_back({{argNames[i].cast(), type, direction}, - i, - {}, - emptyDict, - loc}); - } - - auto resultNames = (*this)->getAttrOfType("resultNames"); - auto resultTypes = getModuleType(*this).getResults(); - auto resultLocs = (*this)->getAttrOfType("resultLocs"); - for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) { - LocationAttr loc; - if (resultLocs) loc = resultLocs[i].cast(); - outputs.push_back({{resultNames[i].cast(), resultTypes[i], - ModulePort::Direction::Output}, - i, - {}, - emptyDict, - loc}); - } - return ModulePortInfo(inputs, outputs); -} - -size_t InstanceOp::getNumPorts() { - return getNumInputPorts() + getNumOutputPorts(); -} - -size_t InstanceOp::getNumInputPorts() { return getNumOperands(); } - -size_t InstanceOp::getNumOutputPorts() { return getNumResults(); } - -size_t InstanceOp::getPortIdForInputId(size_t idx) { return idx; } - -size_t InstanceOp::getPortIdForOutputId(size_t idx) { - return idx + getNumInputPorts(); -} - -void InstanceOp::getValues(SmallVectorImpl &values, - const ModulePortInfo &mpi) { - size_t inputPort = 0, resultPort = 0; - values.resize(mpi.size()); - auto results = getResults(); - auto inputs = getInputs(); - for (auto [idx, port] : llvm::enumerate(mpi)) - if (mpi.at(idx).isOutput()) - values[idx] = results[resultPort++]; - else - values[idx] = inputs[inputPort++]; -} - -//===----------------------------------------------------------------------===// -// HWOutputOp -//===----------------------------------------------------------------------===// - -/// Verify that the num of operands and types fit the declared results. -LogicalResult OutputOp::verify() { - // Check that the we (hw.output) have the same number of operands as our - // region has results. - ModuleType modType; - if (auto mod = dyn_cast((*this)->getParentOp())) - modType = mod.getHWModuleType(); - else if (auto mod = dyn_cast((*this)->getParentOp())) - modType = mod.getModuleType(); - else { - emitOpError("must have a module parent"); - return failure(); - } - auto modResults = modType.getOutputTypes(); - OperandRange outputValues = getOperands(); - if (modResults.size() != outputValues.size()) { - emitOpError("must have same number of operands as region results."); - return failure(); - } - - // Check that the types of our operands and the region's results match. - for (size_t i = 0, e = modResults.size(); i < e; ++i) { - if (modResults[i] != outputValues[i].getType()) { - emitOpError( - "output types must match module. In " - "operand ") - << i << ", expected " << modResults[i] << ", but got " - << outputValues[i].getType() << "."; - return failure(); - } - } - - return success(); -} - -//===----------------------------------------------------------------------===// -// Other Operations -//===----------------------------------------------------------------------===// - -LogicalResult GlobalRefOp::verifySymbolUses( - mlir::SymbolTableCollection &symTables) { - Operation *parent = (*this)->getParentOp(); - SymbolTable &symTable = symTables.getSymbolTable(parent); - StringAttr symNameAttr = (*this).getSymNameAttr(); - auto hasGlobalRef = [&](Attribute attr) -> bool { - if (!attr) return false; - for (auto ref : attr.cast().getAsRange()) - if (ref.getGlblSym().getAttr() == symNameAttr) return true; - return false; - }; - // For all inner refs in the namepath, ensure they have a corresponding - // GlobalRefAttr to this GlobalRefOp. - for (auto innerRef : getNamepath().getAsRange()) { - StringAttr modName = innerRef.getModule(); - StringAttr symName = innerRef.getName(); - Operation *mod = symTable.lookup(modName); - if (!mod) { - (*this)->emitOpError("module:'" + modName.str() + "' not found"); - return failure(); - } - bool glblSymNotFound = true; - bool innerSymOpNotFound = true; - mod->walk([&](InnerSymbolOpInterface op) -> WalkResult { - // If this is one of the ops in the instance path for the GlobalRefOp. - if (op.getInnerNameAttr() == symName) { - innerSymOpNotFound = false; - // Each op can have an array of GlobalRefAttr, check if this op is one - // of them. - if (hasGlobalRef(op->getAttr(GlobalRefAttr::DialectAttrName))) { - glblSymNotFound = false; - return WalkResult::interrupt(); - } - // If cannot find the ref, then its an error. - return failure(); - } - return WalkResult::advance(); - }); - if (glblSymNotFound) { - // TODO: Doesn't yet work for symbls on FIRRTL module ports. Need to - // implement an interface. - if (isa(mod)) { - auto hwmod = cast(mod); - auto inAttrs = hwmod.getAllInputAttrs(); - for (auto attr : inAttrs) - if (auto symRef = cast(attr).getAs( - "hw.exportPort")) - if (symRef.getSymName() == symName) - if (hasGlobalRef(cast(attr).get( - GlobalRefAttr::DialectAttrName))) - return success(); - - auto outAttrs = hwmod.getAllOutputAttrs(); - for (auto attr : outAttrs) - if (auto symRef = cast(attr).getAs( - "hw.exportPort")) - if (symRef.getSymName() == symName) - if (hasGlobalRef(cast(attr).get( - GlobalRefAttr::DialectAttrName))) - return success(); - } - } - if (innerSymOpNotFound) - return (*this)->emitOpError("operation:'" + symName.str() + - "' in module:'" + modName.str() + - "' could not be found"); - if (glblSymNotFound) - return (*this)->emitOpError( - "operation:'" + symName.str() + "' in module:'" + modName.str() + - "' does not contain a reference to '" + symNameAttr.str() + "'"); - } - return success(); -} - -static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, - Type &idxType) { - Type type; - if (p.parseType(type)) - return p.emitError(p.getCurrentLocation(), "Expected type"); - auto arrType = type_dyn_cast(type); - if (!arrType) - return p.emitError(p.getCurrentLocation(), "Expected !hw.array type"); - srcType = type; - unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements()); - idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth); - return success(); -} - -static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, - Type idxType) { - p.printType(srcType); -} - -ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) { - llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); - llvm::SmallVector operands; - Type elemType; - - if (parser.parseOperandList(operands) || - parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || - parser.parseType(elemType)) - return failure(); - - if (operands.size() == 0) - return parser.emitError(inputOperandsLoc, - "Cannot construct an array of length 0"); - result.addTypes({ArrayType::get(elemType, operands.size())}); - - for (auto operand : operands) - if (parser.resolveOperand(operand, elemType, result.operands)) - return failure(); - return success(); -} - -void ArrayCreateOp::print(OpAsmPrinter &p) { - p << " "; - p.printOperands(getInputs()); - p.printOptionalAttrDict((*this)->getAttrs()); - p << " : " << getInputs()[0].getType(); -} - -void ArrayCreateOp::build(OpBuilder &b, OperationState &state, - ValueRange values) { - assert(values.size() > 0 && "Cannot build array of zero elements"); - Type elemType = values[0].getType(); - assert(llvm::all_of( - values, - [elemType](Value v) -> bool { return v.getType() == elemType; }) && - "All values must have same type."); - build(b, state, ArrayType::get(elemType, values.size()), values); -} - -LogicalResult ArrayCreateOp::verify() { - unsigned returnSize = getType().cast().getNumElements(); - if (getInputs().size() != returnSize) return failure(); - return success(); -} - -OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) { - if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; })) - return {}; - return ArrayAttr::get(getContext(), adaptor.getInputs()); -} - -// Check whether an integer value is an offset from a base. -bool hw::isOffset(Value base, Value index, uint64_t offset) { - if (auto constBase = base.getDefiningOp()) { - if (auto constIndex = index.getDefiningOp()) { - // If both values are a constant, check if index == base + offset. - // To account for overflow, the addition is performed with an extra bit - // and the offset is asserted to fit in the bit width of the base. - auto baseValue = constBase.getValue(); - auto indexValue = constIndex.getValue(); - - unsigned bits = baseValue.getBitWidth(); - assert(bits == indexValue.getBitWidth() && "mismatched widths"); - - if (bits < 64 && offset >= (1ull << bits)) return false; - - APInt baseExt = baseValue.zextOrTrunc(bits + 1); - APInt indexExt = indexValue.zextOrTrunc(bits + 1); - return baseExt + offset == indexExt; - } - } - return false; -} - -// Canonicalize a create of consecutive elements to a slice. -static LogicalResult foldCreateToSlice(ArrayCreateOp op, - PatternRewriter &rewriter) { - // Do not canonicalize create of get into a slice. - auto arrayTy = hw::type_cast(op.getType()); - if (arrayTy.getNumElements() <= 1) return failure(); - auto elemTy = arrayTy.getElementType(); - - // Check if create arguments are consecutive elements of the same array. - // Attempt to break a create of gets into a sequence of consecutive intervals. - struct Chunk { - Value input; - Value index; - size_t size; - }; - SmallVector chunks; - for (Value value : llvm::reverse(op.getInputs())) { - auto get = value.getDefiningOp(); - if (!get) return failure(); - - Value input = get.getInput(); - Value index = get.getIndex(); - if (!chunks.empty()) { - auto &c = *chunks.rbegin(); - if (c.input == get.getInput() && isOffset(c.index, index, c.size)) { - c.size++; - continue; - } - } - - chunks.push_back(Chunk{input, index, 1}); - } - - // If there is a single slice, eliminate the create. - if (chunks.size() == 1) { - auto &chunk = chunks[0]; - rewriter.replaceOp(op, rewriter.createOrFold( - op.getLoc(), arrayTy, chunk.input, chunk.index)); - return success(); - } - - // If the number of chunks is significantly less than the number of - // elements, replace the create with a concat of the identified slices. - if (chunks.size() * 2 < arrayTy.getNumElements()) { - SmallVector slices; - for (auto &chunk : llvm::reverse(chunks)) { - auto sliceTy = ArrayType::get(elemTy, chunk.size); - slices.push_back(rewriter.createOrFold( - op.getLoc(), sliceTy, chunk.input, chunk.index)); - } - rewriter.replaceOpWithNewOp(op, arrayTy, slices); - return success(); - } - - return failure(); -} - -LogicalResult ArrayCreateOp::canonicalize(ArrayCreateOp op, - PatternRewriter &rewriter) { - if (succeeded(foldCreateToSlice(op, rewriter))) return success(); - return failure(); -} - -Value ArrayCreateOp::getUniformElement() { - if (!getInputs().empty() && llvm::all_equal(getInputs())) - return getInputs()[0]; - return {}; -} - -static std::optional getUIntFromValue(Value value) { - auto idxOp = dyn_cast_or_null(value.getDefiningOp()); - if (!idxOp) return std::nullopt; - APInt idxAttr = idxOp.getValue(); - if (idxAttr.getBitWidth() > 64) return std::nullopt; - return idxAttr.getLimitedValue(); -} - -LogicalResult ArraySliceOp::verify() { - unsigned inputSize = - type_cast(getInput().getType()).getNumElements(); - if (llvm::Log2_64_Ceil(inputSize) != - getLowIndex().getType().getIntOrFloatBitWidth()) - return emitOpError( - "ArraySlice: index width must match clog2 of array size"); - return success(); -} - -OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) { - // If we are slicing the entire input, then return it. - if (getType() == getInput().getType()) return getInput(); - return {}; -} - -LogicalResult ArraySliceOp::canonicalize(ArraySliceOp op, - PatternRewriter &rewriter) { - auto sliceTy = hw::type_cast(op.getType()); - auto elemTy = sliceTy.getElementType(); - uint64_t sliceSize = sliceTy.getNumElements(); - if (sliceSize == 0) return failure(); - - if (sliceSize == 1) { - // slice(a, n) -> create(a[n]) - auto get = rewriter.create(op.getLoc(), op.getInput(), - op.getLowIndex()); - rewriter.replaceOpWithNewOp(op, op.getType(), - get.getResult()); - return success(); - } - - auto offsetOpt = getUIntFromValue(op.getLowIndex()); - if (!offsetOpt) return failure(); - - auto inputOp = op.getInput().getDefiningOp(); - if (auto inputSlice = dyn_cast_or_null(inputOp)) { - // slice(slice(a, n), m) -> slice(a, n + m) - if (inputSlice == op) return failure(); - - auto inputIndex = inputSlice.getLowIndex(); - auto inputOffsetOpt = getUIntFromValue(inputIndex); - if (!inputOffsetOpt) return failure(); - - uint64_t offset = *offsetOpt + *inputOffsetOpt; - auto lowIndex = - rewriter.create(op.getLoc(), inputIndex.getType(), offset); - rewriter.replaceOpWithNewOp(op, op.getType(), - inputSlice.getInput(), lowIndex); - return success(); - } - - if (auto inputCreate = dyn_cast_or_null(inputOp)) { - // slice(create(a0, a1, ..., an), m) -> create(am, ...) - auto inputs = inputCreate.getInputs(); - - uint64_t begin = inputs.size() - *offsetOpt - sliceSize; - rewriter.replaceOpWithNewOp(op, op.getType(), - inputs.slice(begin, sliceSize)); - return success(); - } - - if (auto inputConcat = dyn_cast_or_null(inputOp)) { - // slice(concat(a1, a2, ...)) -> concat(a2, slice(a3, ..), ...) - SmallVector chunks; - uint64_t sliceStart = *offsetOpt; - for (auto input : llvm::reverse(inputConcat.getInputs())) { - // Check whether the input intersects with the slice. - uint64_t inputSize = - hw::type_cast(input.getType()).getNumElements(); - if (inputSize == 0 || inputSize <= sliceStart) { - sliceStart -= inputSize; - continue; - } - - // Find the indices to slice from this input by intersection. - uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize); - uint64_t cutSize = cutEnd - sliceStart; - assert(cutSize != 0 && "slice cannot be empty"); - - if (cutSize == inputSize) { - // The whole input fits in the slice, add it. - assert(sliceStart == 0 && "invalid cut size"); - chunks.push_back(input); - } else { - // Slice the required bits from the input. - unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize); - auto lowIndex = rewriter.create( - op.getLoc(), rewriter.getIntegerType(width), sliceStart); - chunks.push_back(rewriter.create( - op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input, lowIndex)); - } - - sliceStart = 0; - sliceSize -= cutSize; - if (sliceSize == 0) break; - } - - assert(chunks.size() > 0 && "missing sliced items"); - if (chunks.size() == 1) - rewriter.replaceOp(op, chunks[0]); - else - rewriter.replaceOpWithNewOp( - op, llvm::to_vector(llvm::reverse(chunks))); - return success(); - } - return failure(); -} - -//===----------------------------------------------------------------------===// -// ArrayConcatOp -//===----------------------------------------------------------------------===// - -static ParseResult parseArrayConcatTypes(OpAsmParser &p, - SmallVectorImpl &inputTypes, - Type &resultType) { - Type elemType; - uint64_t resultSize = 0; - - auto parseElement = [&]() -> ParseResult { - Type ty; - if (p.parseType(ty)) return failure(); - auto arrTy = type_dyn_cast(ty); - if (!arrTy) - return p.emitError(p.getCurrentLocation(), "Expected !hw.array type"); - if (elemType && elemType != arrTy.getElementType()) - return p.emitError(p.getCurrentLocation(), "Expected array element type ") - << elemType; - - elemType = arrTy.getElementType(); - inputTypes.push_back(ty); - resultSize += arrTy.getNumElements(); - return success(); - }; - - if (p.parseCommaSeparatedList(parseElement)) return failure(); - - resultType = ArrayType::get(elemType, resultSize); - return success(); -} - -static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, - TypeRange inputTypes, Type resultType) { - llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; }); -} - -void ArrayConcatOp::build(OpBuilder &b, OperationState &state, - ValueRange values) { - assert(!values.empty() && "Cannot build array of zero elements"); - ArrayType arrayTy = values[0].getType().cast(); - Type elemTy = arrayTy.getElementType(); - assert(llvm::all_of(values, - [elemTy](Value v) -> bool { - return v.getType().isa() && - v.getType().cast().getElementType() == - elemTy; - }) && - "All values must be of ArrayType with the same element type."); - - uint64_t resultSize = 0; - for (Value val : values) - resultSize += val.getType().cast().getNumElements(); - build(b, state, ArrayType::get(elemTy, resultSize), values); -} - -OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) { - auto inputs = adaptor.getInputs(); - SmallVector array; - for (size_t i = 0, e = getNumOperands(); i < e; ++i) { - if (!inputs[i]) return {}; - llvm::copy(inputs[i].cast(), std::back_inserter(array)); - } - return ArrayAttr::get(getContext(), array); -} - -// Flatten a concatenation of array creates into a single create. -static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter) { - for (auto input : op.getInputs()) - if (!input.getDefiningOp()) return false; - - SmallVector items; - for (auto input : op.getInputs()) { - auto create = cast(input.getDefiningOp()); - for (auto item : create.getInputs()) items.push_back(item); - } - - rewriter.replaceOpWithNewOp(op, items); - return true; -} - -// Merge consecutive slice expressions in a concatenation. -static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter) { - struct Slice { - Value input; - Value index; - size_t size; - Value op; - SmallVector locs; - }; - - SmallVector items; - std::optional last; - bool changed = false; - - auto concatenate = [&] { - // If there is only one op in the slice, place it to the items list. - if (!last) return; - if (last->op) { - items.push_back(last->op); - last.reset(); - return; - } - - // Otherwise, create a new slice of with the given size and place it. - // In this case, the concat op is replaced, using the new argument. - changed = true; - auto loc = FusedLoc::get(op.getContext(), last->locs); - auto origTy = hw::type_cast(last->input.getType()); - auto arrayTy = ArrayType::get(origTy.getElementType(), last->size); - items.push_back(rewriter.createOrFold( - loc, arrayTy, last->input, last->index)); - - last.reset(); - }; - - auto append = [&](Value op, Value input, Value index, size_t size) { - // If this slice is an extension of the previous one, extend the size - // saved. In this case, a new slice of is created and the concatenation - // operator is rewritten. Otherwise, flush the last slice. - if (last) { - if (last->input == input && isOffset(last->index, index, last->size)) { - last->size += size; - last->op = {}; - last->locs.push_back(op.getLoc()); - return; - } - concatenate(); - } - last.emplace(Slice{input, index, size, op, {op.getLoc()}}); - }; - - for (auto item : llvm::reverse(op.getInputs())) { - if (auto slice = item.getDefiningOp()) { - auto size = hw::type_cast(slice.getType()).getNumElements(); - append(item, slice.getInput(), slice.getLowIndex(), size); - continue; - } - - if (auto create = item.getDefiningOp()) { - if (create.getInputs().size() == 1) { - if (auto get = create.getInputs()[0].getDefiningOp()) { - append(item, get.getInput(), get.getIndex(), 1); - continue; - } - } - } - - concatenate(); - items.push_back(item); - } - concatenate(); - - if (!changed) return false; - - if (items.size() == 1) { - rewriter.replaceOp(op, items[0]); - } else { - std::reverse(items.begin(), items.end()); - rewriter.replaceOpWithNewOp(op, items); - } - return true; -} - -LogicalResult ArrayConcatOp::canonicalize(ArrayConcatOp op, - PatternRewriter &rewriter) { - // concat(create(a1, ...), create(a3, ...), ...) -> create(a1, ..., a3, ...) - if (flattenConcatOp(op, rewriter)) return success(); - - // concat(slice(a, n, m), slice(a, n + m, p)) -> concat(slice(a, n, m + p)) - if (mergeConcatSlices(op, rewriter)) return success(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// EnumConstantOp -//===----------------------------------------------------------------------===// - -ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse a Type instead of an EnumType since the type might be a type alias. - // The validity of the canonical type is checked during construction of the - // EnumFieldAttr. - Type type; - StringRef field; - - auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); - if (parser.parseKeyword(&field) || parser.parseColonType(type)) - return failure(); - - auto fieldAttr = EnumFieldAttr::get( - loc, StringAttr::get(parser.getContext(), field), type); - - if (!fieldAttr) return failure(); - - result.addAttribute("field", fieldAttr); - result.addTypes(type); - - return success(); -} - -void EnumConstantOp::print(OpAsmPrinter &p) { - p << " " << getField().getField().getValue() << " : " - << getField().getType().getValue(); -} - -void EnumConstantOp::getAsmResultNames( - function_ref setNameFn) { - setNameFn(getResult(), getField().getField().str()); -} - -void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState, - EnumFieldAttr field) { - return build(builder, odsState, field.getType().getValue(), field); -} - -OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) { - assert(adaptor.getOperands().empty() && "constant has no operands"); - return getFieldAttr(); -} - -LogicalResult EnumConstantOp::verify() { - auto fieldAttr = getFieldAttr(); - auto fieldType = fieldAttr.getType().getValue(); - // This check ensures that we are using the exact same type, without looking - // through type aliases. - if (fieldType != getType()) - emitOpError("return type ") - << getType() << " does not match attribute type " << fieldAttr; - return success(); -} - -//===----------------------------------------------------------------------===// -// EnumCmpOp -//===----------------------------------------------------------------------===// - -LogicalResult EnumCmpOp::verify() { - // Compare the canonical types. - auto lhsType = type_cast(getLhs().getType()); - auto rhsType = type_cast(getRhs().getType()); - if (rhsType != lhsType) emitOpError("types do not match"); - return success(); -} - -//===----------------------------------------------------------------------===// -// StructCreateOp -//===----------------------------------------------------------------------===// - -ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) { - llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); - llvm::SmallVector operands; - Type declOrAliasType; - - if (parser.parseLParen() || parser.parseOperandList(operands) || - parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(declOrAliasType)) - return failure(); - - auto declType = type_dyn_cast(declOrAliasType); - if (!declType) - return parser.emitError(parser.getNameLoc(), - "expected !hw.struct type or alias"); - - llvm::SmallVector structInnerTypes; - declType.getInnerTypes(structInnerTypes); - result.addTypes(declOrAliasType); - - if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc, - result.operands)) - return failure(); - return success(); -} - -void StructCreateOp::print(OpAsmPrinter &printer) { - printer << " ("; - printer.printOperands(getInput()); - printer << ")"; - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : " << getType(); -} - -LogicalResult StructCreateOp::verify() { - auto elements = hw::type_cast(getType()).getElements(); - - if (elements.size() != getInput().size()) - return emitOpError("structure field count mismatch"); - - for (const auto &[field, value] : llvm::zip(elements, getInput())) - if (field.type != value.getType()) - return emitOpError("structure field `") - << field.name << "` type does not match"; - - return success(); -} - -OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) { - // struct_create(struct_explode(x)) => x - if (!getInput().empty()) - if (auto explodeOp = getInput()[0].getDefiningOp(); - explodeOp && getInput() == explodeOp.getResults() && - getResult().getType() == explodeOp.getInput().getType()) - return explodeOp.getInput(); - - auto inputs = adaptor.getInput(); - if (llvm::any_of(inputs, [](Attribute attr) { return !attr; })) return {}; - return ArrayAttr::get(getContext(), inputs); -} - -//===----------------------------------------------------------------------===// -// StructExplodeOp -//===----------------------------------------------------------------------===// - -ParseResult StructExplodeOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand operand; - Type declType; - - if (parser.parseOperand(operand) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(declType)) - return failure(); - auto structType = type_dyn_cast(declType); - if (!structType) - return parser.emitError(parser.getNameLoc(), - "invalid kind of type specified"); - - llvm::SmallVector structInnerTypes; - structType.getInnerTypes(structInnerTypes); - result.addTypes(structInnerTypes); - - if (parser.resolveOperand(operand, declType, result.operands)) - return failure(); - return success(); -} - -void StructExplodeOp::print(OpAsmPrinter &printer) { - printer << " "; - printer.printOperand(getInput()); - printer.printOptionalAttrDict((*this)->getAttrs()); - printer << " : " << getInput().getType(); -} - -LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor, - SmallVectorImpl &results) { - auto input = adaptor.getInput(); - if (!input) return failure(); - llvm::copy(input.cast(), std::back_inserter(results)); - return success(); -} - -LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op, - PatternRewriter &rewriter) { - auto *inputOp = op.getInput().getDefiningOp(); - auto elements = type_cast(op.getInput().getType()).getElements(); - auto result = failure(); - for (auto [element, res] : llvm::zip(elements, op.getResults())) { - if (auto foldResult = foldStructExtract(inputOp, element.name.str())) { - rewriter.replaceAllUsesWith(res, foldResult); - result = success(); - } - } - return result; -} - -void StructExplodeOp::getAsmResultNames( - function_ref setNameFn) { - auto structType = type_cast(getInput().getType()); - for (auto [res, field] : llvm::zip(getResults(), structType.getElements())) - setNameFn(res, field.name.str()); -} - -void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - Value input) { - StructType inputType = input.getType().dyn_cast(); - assert(inputType); - SmallVector fieldTypes; - for (auto field : inputType.getElements()) fieldTypes.push_back(field.type); - build(odsBuilder, odsState, fieldTypes, input); -} - -//===----------------------------------------------------------------------===// -// StructExtractOp -//===----------------------------------------------------------------------===// - -/// Use the same parser for both struct_extract and union_extract since the -/// syntax is identical. -template -static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand operand; - StringAttr fieldName; - Type declType; - - if (parser.parseOperand(operand) || parser.parseLSquare() || - parser.parseAttribute(fieldName, "field", result.attributes) || - parser.parseRSquare() || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(declType)) - return failure(); - auto aggType = type_dyn_cast(declType); - if (!aggType) - return parser.emitError(parser.getNameLoc(), - "invalid kind of type specified"); - - Type resultType = aggType.getFieldType(fieldName.getValue()); - if (!resultType) { - parser.emitError(parser.getNameLoc(), "invalid field name specified"); - return failure(); - } - result.addTypes(resultType); - - if (parser.resolveOperand(operand, declType, result.operands)) - return failure(); - return success(); -} - -/// Use the same printer for both struct_extract and union_extract since the -/// syntax is identical. -template -static void printExtractOp(OpAsmPrinter &printer, AggType op) { - printer << " "; - printer.printOperand(op.getInput()); - printer << "[\"" << op.getField() << "\"]"; - printer.printOptionalAttrDict(op->getAttrs(), {"field"}); - printer << " : " << op.getInput().getType(); -} - -ParseResult StructExtractOp::parse(OpAsmParser &parser, - OperationState &result) { - return parseExtractOp(parser, result); -} - -void StructExtractOp::print(OpAsmPrinter &printer) { - printExtractOp(printer, *this); -} - -void StructExtractOp::build(OpBuilder &builder, OperationState &odsState, - Value input, StructType::FieldInfo field) { - build(builder, odsState, field.type, input, field.name); -} - -void StructExtractOp::build(OpBuilder &builder, OperationState &odsState, - Value input, StringAttr fieldAttr) { - auto structType = type_cast(input.getType()); - auto resultType = structType.getFieldType(fieldAttr); - build(builder, odsState, resultType, input, fieldAttr); -} - -OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) { - if (auto constOperand = adaptor.getInput()) { - // Fold extract from aggregate constant - auto operandType = type_cast(getOperand().getType()); - auto fieldIdx = operandType.getFieldIndex(getField()); - auto operandAttr = llvm::cast(constOperand); - return operandAttr.getValue()[*fieldIdx]; - } - - if (auto foldResult = - foldStructExtract(getInput().getDefiningOp(), getField())) - return foldResult; - return {}; -} - -LogicalResult StructExtractOp::canonicalize(StructExtractOp op, - PatternRewriter &rewriter) { - auto inputOp = op.getInput().getDefiningOp(); - - // b = extract(inject(x["a"], v0)["b"]) => extract(x, "b") - if (auto structInject = dyn_cast_or_null(inputOp)) { - if (structInject.getField() != op.getField()) { - rewriter.replaceOpWithNewOp( - op, op.getType(), structInject.getInput(), op.getField()); - return success(); - } - } - - return failure(); -} - -void StructExtractOp::getAsmResultNames( - function_ref setNameFn) { - auto structType = type_cast(getInput().getType()); - for (auto field : structType.getElements()) { - if (field.name == getField()) { - setNameFn(getResult(), field.name.str()); - return; - } - } -} - -//===----------------------------------------------------------------------===// -// StructInjectOp -//===----------------------------------------------------------------------===// - -ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) { - llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation(); - OpAsmParser::UnresolvedOperand operand, val; - StringAttr fieldName; - Type declType; - - if (parser.parseOperand(operand) || parser.parseLSquare() || - parser.parseAttribute(fieldName, "field", result.attributes) || - parser.parseRSquare() || parser.parseComma() || - parser.parseOperand(val) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(declType)) - return failure(); - auto structType = type_dyn_cast(declType); - if (!structType) - return parser.emitError(inputOperandsLoc, "invalid kind of type specified"); - - Type resultType = structType.getFieldType(fieldName.getValue()); - if (!resultType) { - parser.emitError(inputOperandsLoc, "invalid field name specified"); - return failure(); - } - result.addTypes(declType); - - if (parser.resolveOperands({operand, val}, {declType, resultType}, - inputOperandsLoc, result.operands)) - return failure(); - return success(); -} - -void StructInjectOp::print(OpAsmPrinter &printer) { - printer << " "; - printer.printOperand(getInput()); - printer << "[\"" << getField() << "\"], "; - printer.printOperand(getNewValue()); - printer.printOptionalAttrDict((*this)->getAttrs(), {"field"}); - printer << " : " << getInput().getType(); -} - -OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) { - auto input = adaptor.getInput(); - auto newValue = adaptor.getNewValue(); - if (!input || !newValue) return {}; - SmallVector array; - llvm::copy(input.cast(), std::back_inserter(array)); - StructType structType = getInput().getType(); - auto index = *structType.getFieldIndex(getField()); - array[index] = newValue; - return ArrayAttr::get(getContext(), array); -} - -LogicalResult StructInjectOp::canonicalize(StructInjectOp op, - PatternRewriter &rewriter) { - // Canonicalize multiple injects into a create op and eliminate overwrites. - SmallPtrSet injects; - DenseMap fields; - - // Chase a chain of injects. Bail out if cycles are present. - StructInjectOp inject = op; - Value input; - do { - if (!injects.insert(inject).second) return failure(); - - fields.try_emplace(inject.getFieldAttr(), inject.getNewValue()); - input = inject.getInput(); - inject = dyn_cast_or_null(input.getDefiningOp()); - } while (inject); - assert(input && "missing input to inject chain"); - - auto ty = hw::type_cast(op.getType()); - auto elements = ty.getElements(); - - // If the inject chain sets all fields, canonicalize to create. - if (fields.size() == elements.size()) { - SmallVector createFields; - for (const auto &field : elements) { - auto it = fields.find(field.name); - assert(it != fields.end() && "missing field"); - createFields.push_back(it->second); - } - rewriter.replaceOpWithNewOp(op, ty, createFields); - return success(); - } - - // Nothing to canonicalize, only the original inject in the chain. - if (injects.size() == fields.size()) return failure(); - - // Eliminate overwrites. The hash map contains the last write to each field. - for (const auto &field : elements) { - auto it = fields.find(field.name); - if (it == fields.end()) continue; - input = rewriter.create(op.getLoc(), ty, input, field.name, - it->second); - } - - rewriter.replaceOp(op, input); - return success(); -} - -//===----------------------------------------------------------------------===// -// UnionCreateOp -//===----------------------------------------------------------------------===// - -ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) { - Type declOrAliasType; - StringAttr field; - OpAsmParser::UnresolvedOperand input; - llvm::SMLoc fieldLoc = parser.getCurrentLocation(); - - if (parser.parseAttribute(field, "field", result.attributes) || - parser.parseComma() || parser.parseOperand(input) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(declOrAliasType)) - return failure(); - - auto declType = type_dyn_cast(declOrAliasType); - if (!declType) - return parser.emitError(parser.getNameLoc(), - "expected !hw.union type or alias"); - - Type inputType = declType.getFieldType(field.getValue()); - if (!inputType) { - parser.emitError(fieldLoc, "cannot find union field '") - << field.getValue() << '\''; - return failure(); - } - - if (parser.resolveOperand(input, inputType, result.operands)) - return failure(); - result.addTypes({declOrAliasType}); - return success(); -} - -void UnionCreateOp::print(OpAsmPrinter &printer) { - printer << " \"" << getField() << "\", "; - printer.printOperand(getInput()); - printer.printOptionalAttrDict((*this)->getAttrs(), {"field"}); - printer << " : " << getType(); -} - -//===----------------------------------------------------------------------===// -// UnionExtractOp -//===----------------------------------------------------------------------===// - -ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) { - return parseExtractOp(parser, result); -} - -void UnionExtractOp::print(OpAsmPrinter &printer) { - printExtractOp(printer, *this); -} - -LogicalResult UnionExtractOp::inferReturnTypes( - MLIRContext *context, std::optional loc, ValueRange operands, - DictionaryAttr attrs, mlir::OpaqueProperties properties, - mlir::RegionRange regions, SmallVectorImpl &results) { - results.push_back(cast(getCanonicalType(operands[0].getType())) - .getFieldType(attrs.getAs("field"))); - return success(); -} - -//===----------------------------------------------------------------------===// -// ArrayGetOp -//===----------------------------------------------------------------------===// - -// An array_get of an array_create with a constant index can just be the -// array_create operand at the constant index. If the array_create has a -// single uniform value for each element, just return that value regardless of -// the index. If the array is constructed from a constant by a bitcast -// operation, we can fold into a constant. -OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) { - auto inputCst = adaptor.getInput().dyn_cast_or_null(); - auto indexCst = adaptor.getIndex().dyn_cast_or_null(); - - if (inputCst) { - // Constant array index. - if (indexCst) { - auto indexVal = indexCst.getValue(); - if (indexVal.getBitWidth() < 64) { - auto index = indexVal.getZExtValue(); - return inputCst[inputCst.size() - 1 - index]; - } - } - // If all elements of the array are the same, we can return any element of - // array. - if (!inputCst.empty() && llvm::all_equal(inputCst)) return inputCst[0]; - } - - // array_get(bitcast(c), i) -> c[i*w+w-1:i*w] - if (auto bitcast = getInput().getDefiningOp()) { - auto intTy = getType().dyn_cast(); - if (!intTy) return {}; - auto bitcastInputOp = bitcast.getInput().getDefiningOp(); - if (!bitcastInputOp) return {}; - if (!indexCst) return {}; - auto bitcastInputCst = bitcastInputOp.getValue(); - // Calculate the index. Make sure to zero-extend the index value before - // multiplying the element width. - auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) * - getType().getIntOrFloatBitWidth(); - // Extract [startIdx + width - 1: startIdx]. - return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc( - intTy.getIntOrFloatBitWidth())); - } - - auto inputCreate = getInput().getDefiningOp(); - if (!inputCreate) return {}; - - if (auto uniformValue = inputCreate.getUniformElement()) return uniformValue; - - if (!indexCst || indexCst.getValue().getBitWidth() > 64) return {}; - - uint64_t index = indexCst.getValue().getLimitedValue(); - auto createInputs = inputCreate.getInputs(); - if (index >= createInputs.size()) return {}; - return createInputs[createInputs.size() - index - 1]; -} - -LogicalResult ArrayGetOp::canonicalize(ArrayGetOp op, - PatternRewriter &rewriter) { - auto idxOpt = getUIntFromValue(op.getIndex()); - if (!idxOpt) return failure(); - - auto *inputOp = op.getInput().getDefiningOp(); - if (auto inputSlice = dyn_cast_or_null(inputOp)) { - // get(slice(a, n), m) -> get(a, n + m) - auto offsetOp = inputSlice.getLowIndex(); - auto offsetOpt = getUIntFromValue(offsetOp); - if (!offsetOpt) return failure(); - - uint64_t offset = *offsetOpt + *idxOpt; - auto newOffset = - rewriter.create(op.getLoc(), offsetOp.getType(), offset); - rewriter.replaceOpWithNewOp(op, inputSlice.getInput(), - newOffset); - return success(); - } - - if (auto inputConcat = dyn_cast_or_null(inputOp)) { - // get(concat(a0, a1, ...), m) -> get(an, m - s0 - s1 - ...) - uint64_t elemIndex = *idxOpt; - for (auto input : llvm::reverse(inputConcat.getInputs())) { - size_t size = hw::type_cast(input.getType()).getNumElements(); - if (elemIndex >= size) { - elemIndex -= size; - continue; - } - - unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size); - auto newIdxOp = rewriter.create( - op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex); - - rewriter.replaceOpWithNewOp(op, input, newIdxOp); - return success(); - } - return failure(); - } - - // array_get const, (array_get sel, (array_create a, b, c, d)) --> - // array_get sel, (array_create (array_get const a), (array_get const b), - // (array_get const, c), (array_get const, d)) - if (auto innerGet = dyn_cast_or_null(inputOp)) { - if (!innerGet.getIndex().getDefiningOp()) { - if (auto create = - innerGet.getInput().getDefiningOp()) { - SmallVector newValues; - for (auto operand : create.getOperands()) - newValues.push_back(rewriter.createOrFold( - op.getLoc(), operand, op.getIndex())); - - rewriter.replaceOpWithNewOp( - op, - rewriter.createOrFold(op.getLoc(), newValues), - innerGet.getIndex()); - return success(); - } - } - } - - return failure(); -} - -//===----------------------------------------------------------------------===// -// TypedeclOp -//===----------------------------------------------------------------------===// - -StringRef TypedeclOp::getPreferredName() { - return getVerilogName().value_or(getName()); -} - -Type TypedeclOp::getAliasType() { - auto parentScope = cast(getOperation()->getParentOp()); - return hw::TypeAliasType::get( - SymbolRefAttr::get(parentScope.getSymNameAttr(), - {FlatSymbolRefAttr::get(*this)}), - getType()); -} - -//===----------------------------------------------------------------------===// -// BitcastOp -//===----------------------------------------------------------------------===// - -OpFoldResult BitcastOp::fold(FoldAdaptor) { - // Identity. - // bitcast(%a) : A -> A ==> %a - if (getOperand().getType() == getType()) return getOperand(); - - return {}; -} - -LogicalResult BitcastOp::canonicalize(BitcastOp op, PatternRewriter &rewriter) { - // Composition. - // %b = bitcast(%a) : A -> B - // bitcast(%b) : B -> C - // ===> bitcast(%a) : A -> C - auto inputBitcast = - dyn_cast_or_null(op.getInput().getDefiningOp()); - if (!inputBitcast) return failure(); - auto bitcast = rewriter.createOrFold(op.getLoc(), op.getType(), - inputBitcast.getInput()); - rewriter.replaceOp(op, bitcast); - return success(); -} - -LogicalResult BitcastOp::verify() { - if (getBitWidth(getInput().getType()) != getBitWidth(getResult().getType())) - return this->emitOpError("Bitwidth of input must match result"); - return success(); -} - -//===----------------------------------------------------------------------===// -// HierPathOp helpers. -//===----------------------------------------------------------------------===// - -bool HierPathOp::dropModule(StringAttr moduleToDrop) { - SmallVector newPath; - bool updateMade = false; - for (auto nameRef : getNamepath()) { - // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. - if (auto ref = nameRef.dyn_cast()) { - if (ref.getModule() == moduleToDrop) - updateMade = true; - else - newPath.push_back(ref); - } else { - if (nameRef.cast().getAttr() == moduleToDrop) - updateMade = true; - else - newPath.push_back(nameRef); - } - } - if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); - return updateMade; -} - -bool HierPathOp::inlineModule(StringAttr moduleToDrop) { - SmallVector newPath; - bool updateMade = false; - StringRef inlinedInstanceName = ""; - for (auto nameRef : getNamepath()) { - // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. - if (auto ref = nameRef.dyn_cast()) { - if (ref.getModule() == moduleToDrop) { - inlinedInstanceName = ref.getName().getValue(); - updateMade = true; - } else if (!inlinedInstanceName.empty()) { - newPath.push_back(hw::InnerRefAttr::get( - ref.getModule(), - StringAttr::get(getContext(), inlinedInstanceName + "_" + - ref.getName().getValue()))); - inlinedInstanceName = ""; - } else - newPath.push_back(ref); - } else { - if (nameRef.cast().getAttr() == moduleToDrop) - updateMade = true; - else - newPath.push_back(nameRef); - } - } - if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); - return updateMade; -} - -bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) { - SmallVector newPath; - bool updateMade = false; - for (auto nameRef : getNamepath()) { - // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. - if (auto ref = nameRef.dyn_cast()) { - if (ref.getModule() == oldMod) { - newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName())); - updateMade = true; - } else - newPath.push_back(ref); - } else { - if (nameRef.cast().getAttr() == oldMod) { - newPath.push_back(FlatSymbolRefAttr::get(newMod)); - updateMade = true; - } else - newPath.push_back(nameRef); - } - } - if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); - return updateMade; -} - -bool HierPathOp::updateModuleAndInnerRef( - StringAttr oldMod, StringAttr newMod, - const llvm::DenseMap &innerSymRenameMap) { - auto fromRef = FlatSymbolRefAttr::get(oldMod); - if (oldMod == newMod) return false; - - auto namepathNew = getNamepath().getValue().vec(); - bool updateMade = false; - // Break from the loop if the module is found, since it can occur only once. - for (auto &element : namepathNew) { - if (auto innerRef = element.dyn_cast()) { - if (innerRef.getModule() != oldMod) continue; - auto symName = innerRef.getName(); - // Since the module got updated, the old innerRef symbol inside oldMod - // should also be updated to the new symbol inside the newMod. - auto to = innerSymRenameMap.find(symName); - if (to != innerSymRenameMap.end()) symName = to->second; - updateMade = true; - element = hw::InnerRefAttr::get(newMod, symName); - break; - } - if (element != fromRef) continue; - - updateMade = true; - element = FlatSymbolRefAttr::get(newMod); - break; - } - if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), namepathNew)); - return updateMade; -} - -bool HierPathOp::truncateAtModule(StringAttr atMod, bool includeMod) { - SmallVector newPath; - bool updateMade = false; - for (auto nameRef : getNamepath()) { - // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. - if (auto ref = nameRef.dyn_cast()) { - if (ref.getModule() == atMod) { - updateMade = true; - if (includeMod) newPath.push_back(ref); - } else - newPath.push_back(ref); - } else { - if (nameRef.cast().getAttr() == atMod && !includeMod) - updateMade = true; - else - newPath.push_back(nameRef); - } - if (updateMade) break; - } - if (updateMade) setNamepathAttr(ArrayAttr::get(getContext(), newPath)); - return updateMade; -} - -/// Return just the module part of the namepath at a specific index. -StringAttr HierPathOp::modPart(unsigned i) { - return TypeSwitch(getNamepath()[i]) - .Case([](auto a) { return a.getAttr(); }) - .Case([](auto a) { return a.getModule(); }); -} - -/// Return the root module. -StringAttr HierPathOp::root() { - assert(!getNamepath().empty()); - return modPart(0); -} - -/// Return true if the NLA has the module in its path. -bool HierPathOp::hasModule(StringAttr modName) { - for (auto nameRef : getNamepath()) { - // nameRef is either an InnerRefAttr or a FlatSymbolRefAttr. - if (auto ref = nameRef.dyn_cast()) { - if (ref.getModule() == modName) return true; - } else { - if (nameRef.cast().getAttr() == modName) return true; - } - } - return false; -} - -/// Return true if the NLA has the InnerSym . -bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName) const { - for (auto nameRef : const_cast(this)->getNamepath()) - if (auto ref = nameRef.dyn_cast()) - if (ref.getName() == symName && ref.getModule() == modName) return true; - - return false; -} - -/// Return just the reference part of the namepath at a specific index. This -/// will return an empty attribute if this is the leaf and the leaf is a module. -StringAttr HierPathOp::refPart(unsigned i) { - return TypeSwitch(getNamepath()[i]) - .Case([](auto a) { return StringAttr({}); }) - .Case([](auto a) { return a.getName(); }); -} - -/// Return the leaf reference. This returns an empty attribute if the leaf -/// reference is a module. -StringAttr HierPathOp::ref() { - assert(!getNamepath().empty()); - return refPart(getNamepath().size() - 1); -} - -/// Return the leaf module. -StringAttr HierPathOp::leafMod() { - assert(!getNamepath().empty()); - return modPart(getNamepath().size() - 1); -} - -/// Returns true if this NLA targets an instance of a module (as opposed to -/// an instance's port or something inside an instance). -bool HierPathOp::isModule() { return !ref(); } - -/// Returns true if this NLA targets something inside a module (as opposed -/// to a module or an instance of a module); -bool HierPathOp::isComponent() { return (bool)ref(); } - -// Verify the HierPathOp. -// 1. Iterate over the namepath. -// 2. The namepath should be a valid instance path, specified either on a -// module or a declaration inside a module. -// 3. Each element in the namepath is an InnerRefAttr except possibly the -// last element. -// 4. Make sure that the InnerRefAttr is legal, by verifying the module name -// and the corresponding inner_sym on the instance. -// 5. Make sure that the instance path is legal, by verifying the sequence of -// instance and the expected module occurs as the next element in the path. -// 6. The last element of the namepath, can be an InnerRefAttr on either a -// module port or a declaration inside the module. -// 7. The last element of the namepath can also be a module symbol. -LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) { - StringAttr expectedModuleName = {}; - if (!getNamepath() || getNamepath().empty()) - return emitOpError() << "the instance path cannot be empty"; - for (unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) { - hw::InnerRefAttr innerRef = getNamepath()[i].dyn_cast(); - if (!innerRef) - return emitOpError() - << "the instance path can only contain inner sym reference" - << ", only the leaf can refer to a module symbol"; - - if (expectedModuleName && expectedModuleName != innerRef.getModule()) - return emitOpError() << "instance path is incorrect. Expected module: " - << expectedModuleName - << " instead found: " << innerRef.getModule(); - auto instOp = ns.lookupOp(innerRef); - if (!instOp) - return emitOpError() << " module: " << innerRef.getModule() - << " does not contain any instance with symbol: " - << innerRef.getName(); - expectedModuleName = instOp.getReferencedModuleNameAttr(); - } - // The instance path has been verified. Now verify the last element. - auto leafRef = getNamepath()[getNamepath().size() - 1]; - if (auto innerRef = leafRef.dyn_cast()) { - if (!ns.lookup(innerRef)) { - return emitOpError() << " operation with symbol: " << innerRef - << " was not found "; - } - if (expectedModuleName && expectedModuleName != innerRef.getModule()) - return emitOpError() << "instance path is incorrect. Expected module: " - << expectedModuleName - << " instead found: " << innerRef.getModule(); - } else if (expectedModuleName && - expectedModuleName != - leafRef.cast().getAttr()) { - // This is the case when the nla is applied to a module. - return emitOpError() << "instance path is incorrect. Expected module: " - << expectedModuleName << " instead found: " - << leafRef.cast().getAttr(); - } - return success(); -} - -void HierPathOp::print(OpAsmPrinter &p) { - p << " "; - - // Print visibility if present. - StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); - if (auto visibility = - getOperation()->getAttrOfType(visibilityAttrName)) - p << visibility.getValue() << ' '; - - p.printSymbolName(getSymName()); - p << " ["; - llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) { - if (auto ref = attr.dyn_cast()) { - p.printSymbolName(ref.getModule().getValue()); - p << "::"; - p.printSymbolName(ref.getName().getValue()); - } else { - p.printSymbolName(attr.cast().getValue()); - } - }); - p << "]"; - p.printOptionalAttrDict( - (*this)->getAttrs(), - {SymbolTable::getSymbolAttrName(), "namepath", visibilityAttrName}); -} - -ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse the visibility attribute. - (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes); - - // Parse the symbol name. - StringAttr symName; - if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(), - result.attributes)) - return failure(); - - // Parse the namepath. - SmallVector namepath; - if (parser.parseCommaSeparatedList( - OpAsmParser::Delimiter::Square, [&]() -> ParseResult { - auto loc = parser.getCurrentLocation(); - SymbolRefAttr ref; - if (parser.parseAttribute(ref)) return failure(); - - // "A" is a Ref, "A::b" is a InnerRef, "A::B::c" is an error. - auto pathLength = ref.getNestedReferences().size(); - if (pathLength == 0) - namepath.push_back( - FlatSymbolRefAttr::get(ref.getRootReference())); - else if (pathLength == 1) - namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(), - ref.getLeafReference())); - else - return parser.emitError(loc, - "only one nested reference is allowed"); - return success(); - })) - return failure(); - result.addAttribute("namepath", - ArrayAttr::get(parser.getContext(), namepath)); - - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// TriggeredOp -//===----------------------------------------------------------------------===// - -void TriggeredOp::build(OpBuilder &builder, OperationState &odsState, - EventControlAttr event, Value trigger, - ValueRange inputs) { - odsState.addOperands(trigger); - odsState.addOperands(inputs); - odsState.addAttribute(getEventAttrName(odsState.name), event); - auto *r = odsState.addRegion(); - Block *b = new Block(); - r->push_back(b); - - llvm::SmallVector argLocs; - llvm::transform(inputs, std::back_inserter(argLocs), - [&](Value v) { return v.getLoc(); }); - b->addArguments(inputs.getTypes(), argLocs); -} - -//===----------------------------------------------------------------------===// -// Temporary test module -//===----------------------------------------------------------------------===// - -static void addPortAttrsAndLocs( - Builder &builder, OperationState &result, - SmallVectorImpl &ports, - StringAttr portAttrsName, StringAttr portLocsName) { - auto unknownLoc = builder.getUnknownLoc(); - auto nonEmptyAttrsFn = [](Attribute attr) { - return attr && !cast(attr).empty(); - }; - auto nonEmptyLocsFn = [unknownLoc](Attribute attr) { - return attr && cast(attr) != unknownLoc; - }; - - // Convert the specified array of dictionary attrs (which may have null - // entries) to an ArrayAttr of dictionaries. - SmallVector attrs; - SmallVector locs; - for (auto &port : ports) { - attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({})); - locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc); - } - - // Add the attributes to the ports. - if (llvm::any_of(attrs, nonEmptyAttrsFn)) - result.addAttribute(portAttrsName, builder.getArrayAttr(attrs)); - - if (llvm::any_of(locs, nonEmptyLocsFn)) - result.addAttribute(portLocsName, builder.getArrayAttr(locs)); -} - -void HWTestModuleOp::print(OpAsmPrinter &p) { - p << ' '; - // Print the visibility of the module. - StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); - if (auto visibility = (*this)->getAttrOfType(visibilityAttrName)) - p << visibility.getValue() << ' '; - - // Print the operation and the function name. - p.printSymbolName(SymbolTable::getSymbolName(*this).getValue()); - - // Print the parameter list if present. - printOptionalParameterList(p, *this, getParameters()); - - module_like_impl::printModuleSignatureNew(p, *this); - SmallVector omittedAttrs; - omittedAttrs.push_back(getPortLocsAttrName()); - omittedAttrs.push_back(getModuleTypeAttrName()); - omittedAttrs.push_back(getPortAttrsAttrName()); - omittedAttrs.push_back(getParametersAttrName()); - omittedAttrs.push_back(visibilityAttrName); - if (auto cmt = (*this)->getAttrOfType("comment")) - if (cmt.getValue().empty()) omittedAttrs.push_back("comment"); - - mlir::function_interface_impl::printFunctionAttributes(p, *this, - omittedAttrs); - - // Print the body if this is not an external function. - Region &body = getBody(); - if (!body.empty()) { - p << " "; - p.printRegion(body, /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/true); - } -} - -ParseResult HWTestModuleOp::parse(OpAsmParser &parser, OperationState &result) { - auto loc = parser.getCurrentLocation(); - - // Parse the visibility attribute. - (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes); - - // Parse the name as a symbol. - StringAttr nameAttr; - if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), - result.attributes)) - return failure(); - - // Parse the parameters. - ArrayAttr parameters; - if (parseOptionalParameterList(parser, parameters)) return failure(); - - SmallVector ports; - TypeAttr modType; - if (failed(module_like_impl::parseModuleSignature(parser, ports, modType))) - return failure(); - - // Parse the attribute dict. - if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) - return failure(); - - if (hasAttribute("parameters", result.attributes)) { - parser.emitError(loc, "explicit `parameters` attributes not allowed"); - return failure(); - } - - result.addAttribute("parameters", parameters); - result.addAttribute(getModuleTypeAttrName(result.name), modType); - addPortAttrsAndLocs(parser.getBuilder(), result, ports, - getPortAttrsAttrName(result.name), - getPortLocsAttrName(result.name)); - - SmallVector entryArgs; - for (auto &port : ports) - if (port.direction != ModulePort::Direction::Output) - entryArgs.push_back(port); - - // Parse the optional function body. - auto *body = result.addRegion(); - if (parser.parseRegion(*body, entryArgs)) return failure(); - - HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location); - - return success(); -} - -void HWTestModuleOp::getAsmBlockArgumentNames( - mlir::Region ®ion, mlir::OpAsmSetValueNameFn setNameFn) { - if (region.empty()) return; - // Assign port names to the bbargs. - auto *block = ®ion.front(); - auto mt = getModuleType(); - for (size_t i = 0, e = block->getNumArguments(); i != e; ++i) { - auto name = mt.getInputName(i); - if (!name.empty()) setNameFn(block->getArgument(i), name); - } -} - -ModulePortInfo HWTestModuleOp::getPortList() { - SmallVector ports; - auto refPorts = getModuleType().getPorts(); - for (auto [i, port] : enumerate(refPorts)) { - auto loc = getPortLocs() ? cast((*getPortLocs())[i]) - : LocationAttr(); - auto attr = getPortAttrs() ? cast((*getPortAttrs())[i]) - : DictionaryAttr(); - InnerSymAttr sym = {}; - ports.push_back({{port}, i, sym, attr, loc}); - } - return ModulePortInfo(ports); -} - -size_t HWTestModuleOp::getNumPorts() { return getModuleType().getNumPorts(); } -size_t HWTestModuleOp::getNumInputPorts() { - return getModuleType().getNumInputs(); -} -size_t HWTestModuleOp::getNumOutputPorts() { - return getModuleType().getNumOutputs(); -} - -size_t HWTestModuleOp::getPortIdForInputId(size_t idx) { - return getModuleType().getPortIdForInputId(idx); -} - -size_t HWTestModuleOp::getPortIdForOutputId(size_t idx) { - return getModuleType().getPortIdForOutputId(idx); -} - -hw::InnerSymAttr HWTestModuleOp::getPortSymbolAttr(size_t portIndex) { - auto pa = getPortAttrs(); - if (pa) return cast((*pa)[portIndex]); - return nullptr; -} - -//===----------------------------------------------------------------------===// -// TableGen generated logic. -//===----------------------------------------------------------------------===// - -// Provide the autogenerated implementation guts for the Op classes. -#define GET_OP_CLASSES -#include "include/circt/Dialect/HW/HW.cpp.inc" diff --git a/lib/circt/Dialect/HW/HWReductions.cpp b/lib/circt/Dialect/HW/HWReductions.cpp deleted file mode 100644 index 4361c30651..0000000000 --- a/lib/circt/Dialect/HW/HWReductions.cpp +++ /dev/null @@ -1,157 +0,0 @@ -//===- HWReductions.cpp - Reduction patterns for the HW dialect -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWReductions.h" - -#include "include/circt/Dialect/HW/HWInstanceGraph.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Reduce/ReductionUtils.h" -#include "llvm/include/llvm/ADT/SmallSet.h" // from @llvm-project -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project - -#define DEBUG_TYPE "hw-reductions" - -using namespace mlir; -using namespace circt; -using namespace hw; - -//===----------------------------------------------------------------------===// -// Utilities -//===----------------------------------------------------------------------===// - -/// Utility to track the transitive size of modules. -struct ModuleSizeCache { - void clear() { moduleSizes.clear(); } - - uint64_t getModuleSize(HWModuleLike module, - hw::InstanceGraph &instanceGraph) { - if (auto it = moduleSizes.find(module); it != moduleSizes.end()) - return it->second; - uint64_t size = 1; - module->walk([&](Operation *op) { - size += 1; - if (auto instOp = dyn_cast(op)) - if (auto instModule = - instanceGraph.getReferencedModule(instOp)) - size += getModuleSize(instModule, instanceGraph); - }); - moduleSizes.insert({module, size}); - return size; - } - - private: - llvm::DenseMap moduleSizes; -}; - -//===----------------------------------------------------------------------===// -// Reduction patterns -//===----------------------------------------------------------------------===// - -/// A sample reduction pattern that maps `hw.module` to `hw.module.extern`. -struct ModuleExternalizer : public OpReduction { - void beforeReduction(mlir::ModuleOp op) override { - instanceGraph = std::make_unique(op); - moduleSizes.clear(); - } - - uint64_t match(HWModuleOp op) override { - return moduleSizes.getModuleSize(op, *instanceGraph); - } - - LogicalResult rewrite(HWModuleOp op) override { - OpBuilder builder(op); - builder.create(op->getLoc(), op.getModuleNameAttr(), - op.getPortList(), StringRef(), - op.getParameters()); - op->erase(); - return success(); - } - - std::string getName() const override { return "hw-module-externalizer"; } - - std::unique_ptr instanceGraph; - ModuleSizeCache moduleSizes; -}; - -/// A sample reduction pattern that replaces all uses of an operation with one -/// of its operands. This can help pruning large parts of the expression tree -/// rapidly. -template -struct HWOperandForwarder : public Reduction { - uint64_t match(Operation *op) override { - if (op->getNumResults() != 1 || op->getNumOperands() < 2 || - OpNum >= op->getNumOperands()) - return 0; - auto resultTy = op->getResult(0).getType().dyn_cast(); - auto opTy = op->getOperand(OpNum).getType().dyn_cast(); - return resultTy && opTy && resultTy == opTy && - op->getResult(0) != op->getOperand(OpNum); - } - LogicalResult rewrite(Operation *op) override { - assert(match(op)); - ImplicitLocOpBuilder builder(op->getLoc(), op); - auto result = op->getResult(0); - auto operand = op->getOperand(OpNum); - LLVM_DEBUG(llvm::dbgs() - << "Forwarding " << operand << " in " << *op << "\n"); - result.replaceAllUsesWith(operand); - reduce::pruneUnusedOps(op, *this); - return success(); - } - std::string getName() const override { - return ("hw-operand" + Twine(OpNum) + "-forwarder").str(); - } -}; - -/// A sample reduction pattern that replaces integer operations with a constant -/// zero of their type. -struct HWConstantifier : public Reduction { - uint64_t match(Operation *op) override { - if (op->getNumResults() == 0 || op->getNumOperands() == 0) return 0; - return llvm::all_of(op->getResults(), [](Value result) { - return result.getType().isa(); - }); - } - LogicalResult rewrite(Operation *op) override { - assert(match(op)); - OpBuilder builder(op); - for (auto result : op->getResults()) { - auto type = result.getType().cast(); - auto newOp = builder.create(op->getLoc(), type, 0); - result.replaceAllUsesWith(newOp); - } - reduce::pruneUnusedOps(op, *this); - return success(); - } - std::string getName() const override { return "hw-constantifier"; } -}; - -//===----------------------------------------------------------------------===// -// Reduction Registration -//===----------------------------------------------------------------------===// - -void HWReducePatternDialectInterface::populateReducePatterns( - circt::ReducePatternSet &patterns) const { - // Gather a list of reduction patterns that we should try. Ideally these are - // assigned reasonable benefit indicators (higher benefit patterns are - // prioritized). For example, things that can knock out entire modules while - // being cheap should be tried first (and thus have higher benefit), before - // trying to tweak operands of individual arithmetic ops. - patterns.add(); - patterns.add(); - patterns.add, 4>(); - patterns.add, 3>(); - patterns.add, 2>(); -} - -void hw::registerReducePatternDialectInterface( - mlir::DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) { - dialect->addInterfaces(); - }); -} diff --git a/lib/circt/Dialect/HW/HWTypeInterfaces.cpp b/lib/circt/Dialect/HW/HWTypeInterfaces.cpp deleted file mode 100644 index 3d38ef1078..0000000000 --- a/lib/circt/Dialect/HW/HWTypeInterfaces.cpp +++ /dev/null @@ -1,74 +0,0 @@ -//===- HWTypeInterfaces.cpp - Implement HW type interfaces ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements type interfaces of the HW Dialect. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWTypeInterfaces.h" - -using namespace mlir; -using namespace circt; -using namespace hw; -using namespace FieldIdImpl; - -Type circt::hw::FieldIdImpl::getFinalTypeByFieldID(Type type, - uint64_t fieldID) { - std::pair pair(type, fieldID); - while (pair.second) { - if (auto ftype = dyn_cast(pair.first)) - pair = ftype.getSubTypeByFieldID(pair.second); - else - llvm::report_fatal_error("fieldID indexing into a non-aggregate type"); - } - return pair.first; -} - -std::pair circt::hw::FieldIdImpl::getSubTypeByFieldID( - Type type, uint64_t fieldID) { - if (!fieldID) return {type, 0}; - if (auto ftype = dyn_cast(type)) - return ftype.getSubTypeByFieldID(fieldID); - - llvm::report_fatal_error("fieldID indexing into a non-aggregate type"); -} - -uint64_t circt::hw::FieldIdImpl::getMaxFieldID(Type type) { - if (auto ftype = dyn_cast(type)) - return ftype.getMaxFieldID(); - return 0; -} - -std::pair circt::hw::FieldIdImpl::projectToChildFieldID( - Type type, uint64_t fieldID, uint64_t index) { - if (auto ftype = dyn_cast(type)) - return ftype.projectToChildFieldID(fieldID, index); - return {0, fieldID == 0}; -} - -uint64_t circt::hw::FieldIdImpl::getIndexForFieldID(Type type, - uint64_t fieldID) { - if (auto ftype = dyn_cast(type)) - return ftype.getIndexForFieldID(fieldID); - return 0; -} - -uint64_t circt::hw::FieldIdImpl::getFieldID(Type type, uint64_t fieldID) { - if (auto ftype = dyn_cast(type)) - return ftype.getFieldID(fieldID); - return 0; -} - -std::pair circt::hw::FieldIdImpl::getIndexAndSubfieldID( - Type type, uint64_t fieldID) { - if (auto ftype = dyn_cast(type)) - return ftype.getIndexAndSubfieldID(fieldID); - return {0, fieldID == 0}; -} - -#include "include/circt/Dialect/HW/HWTypeInterfaces.cpp.inc" diff --git a/lib/circt/Dialect/HW/HWTypes.cpp b/lib/circt/Dialect/HW/HWTypes.cpp deleted file mode 100644 index efbe531f50..0000000000 --- a/lib/circt/Dialect/HW/HWTypes.cpp +++ /dev/null @@ -1,974 +0,0 @@ -//===- HWTypes.cpp - HW types code defs -----------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Implementation logic for HW data types. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/HWTypes.h" - -#include "include/circt/Dialect/HW/HWAttributes.h" -#include "include/circt/Dialect/HW/HWDialect.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/HW/HWSymCache.h" -#include "include/circt/Support/LLVM.h" -#include "llvm/include/llvm/ADT/StringExtras.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project -#include "mlir/include/mlir/IR/StorageUniquerSupport.h" // from @llvm-project -#include "mlir/include/mlir/IR/Types.h" // from @llvm-project - -using namespace circt; -using namespace circt::hw; -using namespace circt::hw::detail; - -#define GET_TYPEDEF_CLASSES -#include "include/circt/Dialect/HW/HWTypes.cpp.inc" - -//===----------------------------------------------------------------------===// -// Type Helpers -//===----------------------------------------------------------------------===/ - -mlir::Type circt::hw::getCanonicalType(mlir::Type type) { - Type canonicalType; - if (auto typeAlias = type.dyn_cast()) - canonicalType = typeAlias.getCanonicalType(); - else - canonicalType = type; - return canonicalType; -} - -/// Return true if the specified type is a value HW Integer type. This checks -/// that it is a signless standard dialect type or a hw::IntType. -bool circt::hw::isHWIntegerType(mlir::Type type) { - Type canonicalType = getCanonicalType(type); - - if (canonicalType.isa()) return true; - - auto intType = canonicalType.dyn_cast(); - if (!intType || !intType.isSignless()) return false; - - return true; -} - -bool circt::hw::isHWEnumType(mlir::Type type) { - return getCanonicalType(type).isa(); -} - -/// Return true if the specified type can be used as an HW value type, that is -/// the set of types that can be composed together to represent synthesized, -/// hardware but not marker types like InOutType. -bool circt::hw::isHWValueType(Type type) { - // Signless and signed integer types are both valid. - if (type.isa()) return true; - - if (auto array = type.dyn_cast()) - return isHWValueType(array.getElementType()); - - if (auto array = type.dyn_cast()) - return isHWValueType(array.getElementType()); - - if (auto t = type.dyn_cast()) - return llvm::all_of(t.getElements(), - [](auto f) { return isHWValueType(f.type); }); - - if (auto t = type.dyn_cast()) - return llvm::all_of(t.getElements(), - [](auto f) { return isHWValueType(f.type); }); - - if (auto t = type.dyn_cast()) - return isHWValueType(t.getCanonicalType()); - - return false; -} - -/// Return the hardware bit width of a type. Does not reflect any encoding, -/// padding, or storage scheme, just the bit (and wire width) of a -/// statically-size type. Reflects the number of wires needed to transmit a -/// value of this type. Returns -1 if the type is not known or cannot be -/// statically computed. -int64_t circt::hw::getBitWidth(mlir::Type type) { - return llvm::TypeSwitch<::mlir::Type, size_t>(type) - .Case( - [](IntegerType t) { return t.getIntOrFloatBitWidth(); }) - .Case([](auto a) { - int64_t elementBitWidth = getBitWidth(a.getElementType()); - if (elementBitWidth < 0) return elementBitWidth; - int64_t dimBitWidth = a.getNumElements(); - if (dimBitWidth < 0) return static_cast(-1L); - return (int64_t)a.getNumElements() * elementBitWidth; - }) - .Case([](StructType s) { - int64_t total = 0; - for (auto field : s.getElements()) { - int64_t fieldSize = getBitWidth(field.type); - if (fieldSize < 0) return fieldSize; - total += fieldSize; - } - return total; - }) - .Case([](UnionType u) { - int64_t maxSize = 0; - for (auto field : u.getElements()) { - int64_t fieldSize = getBitWidth(field.type) + field.offset; - if (fieldSize > maxSize) maxSize = fieldSize; - } - return maxSize; - }) - .Case([](EnumType e) { return e.getBitWidth(); }) - .Case( - [](TypeAliasType t) { return getBitWidth(t.getCanonicalType()); }) - .Default([](Type) { return -1; }); -} - -/// Return true if the specified type contains known marker types like -/// InOutType. Unlike isHWValueType, this is not conservative, it only returns -/// false on known InOut types, rather than any unknown types. -bool circt::hw::hasHWInOutType(Type type) { - if (auto array = type.dyn_cast()) - return hasHWInOutType(array.getElementType()); - - if (auto array = type.dyn_cast()) - return hasHWInOutType(array.getElementType()); - - if (auto t = type.dyn_cast()) { - return std::any_of(t.getElements().begin(), t.getElements().end(), - [](const auto &f) { return hasHWInOutType(f.type); }); - } - - if (auto t = type.dyn_cast()) - return hasHWInOutType(t.getCanonicalType()); - - return type.isa(); -} - -/// Parse and print nested HW types nicely. These helper methods allow eliding -/// the "hw." prefix on array, inout, and other types when in a context that -/// expects HW subelement types. -static ParseResult parseHWElementType(Type &result, AsmParser &p) { - // If this is an HW dialect type, then we don't need/want the !hw. prefix - // redundantly specified. - auto fullString = static_cast(p).getFullSymbolSpec(); - auto *curPtr = p.getCurrentLocation().getPointer(); - auto typeString = - StringRef(curPtr, fullString.size() - (curPtr - fullString.data())); - - if (typeString.startswith("array<") || typeString.startswith("inout<") || - typeString.startswith("uarray<") || typeString.startswith("struct<") || - typeString.startswith("typealias<") || typeString.startswith("int<") || - typeString.startswith("enum<")) { - llvm::StringRef mnemonic; - auto parseResult = generatedTypeParser(p, &mnemonic, result); - return parseResult.has_value() ? success() : failure(); - } - - return p.parseType(result); -} - -static void printHWElementType(Type element, AsmPrinter &p) { - if (succeeded(generatedTypePrinter(element, p))) return; - p.printType(element); -} - -//===----------------------------------------------------------------------===// -// Int Type -//===----------------------------------------------------------------------===// - -Type IntType::get(mlir::TypedAttr width) { - // The width expression must always be a 32-bit wide integer type itself. - auto widthWidth = width.getType().dyn_cast(); - assert(widthWidth && widthWidth.getWidth() == 32 && - "!hw.int width must be 32-bits"); - (void)widthWidth; - - if (auto cstWidth = width.dyn_cast()) - return IntegerType::get(width.getContext(), - cstWidth.getValue().getZExtValue()); - - return Base::get(width.getContext(), width); -} - -Type IntType::parse(AsmParser &p) { - // The bitwidth of the parameter size is always 32 bits. - auto int32Type = p.getBuilder().getIntegerType(32); - - mlir::TypedAttr width; - if (p.parseLess() || p.parseAttribute(width, int32Type) || p.parseGreater()) - return Type(); - return get(width); -} - -void IntType::print(AsmPrinter &p) const { - p << "<"; - p.printAttributeWithoutType(getWidth()); - p << '>'; -} - -//===----------------------------------------------------------------------===// -// Struct Type -//===----------------------------------------------------------------------===// - -namespace circt { -namespace hw { -namespace detail { -bool operator==(const FieldInfo &a, const FieldInfo &b) { - return a.name == b.name && a.type == b.type; -} -llvm::hash_code hash_value(const FieldInfo &fi) { - return llvm::hash_combine(fi.name, fi.type); -} -} // namespace detail -} // namespace hw -} // namespace circt - -/// Parse a list of field names and types within <>. E.g.: -/// -static ParseResult parseFields(AsmParser &p, - SmallVectorImpl ¶meters) { - return p.parseCommaSeparatedList( - mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { - StringRef name; - Type type; - if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type)) - return failure(); - parameters.push_back( - FieldInfo{StringAttr::get(p.getContext(), name), type}); - return success(); - }); -} - -/// Print out a list of named fields surrounded by <>. -static void printFields(AsmPrinter &p, ArrayRef fields) { - p << '<'; - llvm::interleaveComma(fields, p, [&](const FieldInfo &field) { - p << field.name.getValue() << ": " << field.type; - }); - p << ">"; -} - -Type StructType::parse(AsmParser &p) { - llvm::SmallVector parameters; - if (parseFields(p, parameters)) return Type(); - return get(p.getContext(), parameters); -} - -void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); } - -Type StructType::getFieldType(mlir::StringRef fieldName) { - for (const auto &field : getElements()) - if (field.name == fieldName) return field.type; - return Type(); -} - -std::optional StructType::getFieldIndex(mlir::StringRef fieldName) { - ArrayRef elems = getElements(); - for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx) - if (elems[idx].name == fieldName) return idx; - return {}; -} - -std::optional StructType::getFieldIndex(mlir::StringAttr fieldName) { - ArrayRef elems = getElements(); - for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx) - if (elems[idx].name == fieldName) return idx; - return {}; -} - -static std::pair> getFieldIDsStruct( - const StructType &st) { - uint64_t fieldID = 0; - auto elements = st.getElements(); - SmallVector fieldIDs; - fieldIDs.reserve(elements.size()); - for (auto &element : elements) { - auto type = element.type; - fieldID += 1; - fieldIDs.push_back(fieldID); - // Increment the field ID for the next field by the number of subfields. - fieldID += hw::FieldIdImpl::getMaxFieldID(type); - } - return {fieldID, fieldIDs}; -} - -void StructType::getInnerTypes(SmallVectorImpl &types) { - for (const auto &field : getElements()) types.push_back(field.type); -} - -uint64_t StructType::getMaxFieldID() const { - uint64_t fieldID = 0; - for (const auto &field : getElements()) - fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type); - return fieldID; -} - -std::pair StructType::getSubTypeByFieldID( - uint64_t fieldID) const { - if (fieldID == 0) return {*this, 0}; - auto [maxId, fieldIDs] = getFieldIDsStruct(*this); - auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID)); - auto subfieldIndex = std::distance(fieldIDs.begin(), it); - auto subfieldType = getElements()[subfieldIndex].type; - auto subfieldID = fieldID - fieldIDs[subfieldIndex]; - return {subfieldType, subfieldID}; -} - -std::pair StructType::projectToChildFieldID( - uint64_t fieldID, uint64_t index) const { - auto [maxId, fieldIDs] = getFieldIDsStruct(*this); - auto childRoot = fieldIDs[index]; - auto rangeEnd = - index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1); - return std::make_pair(fieldID - childRoot, - fieldID >= childRoot && fieldID <= rangeEnd); -} - -uint64_t StructType::getFieldID(uint64_t index) const { - auto [maxId, fieldIDs] = getFieldIDsStruct(*this); - return fieldIDs[index]; -} - -uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const { - assert(!getElements().empty() && "Bundle must have >0 fields"); - auto [maxId, fieldIDs] = getFieldIDsStruct(*this); - auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID)); - return std::distance(fieldIDs.begin(), it); -} - -std::pair StructType::getIndexAndSubfieldID( - uint64_t fieldID) const { - auto index = getIndexForFieldID(fieldID); - auto elementFieldID = getFieldID(index); - return {index, fieldID - elementFieldID}; -} - -//===----------------------------------------------------------------------===// -// Union Type -//===----------------------------------------------------------------------===// - -namespace circt { -namespace hw { -namespace detail { -bool operator==(const OffsetFieldInfo &a, const OffsetFieldInfo &b) { - return a.name == b.name && a.type == b.type && a.offset == b.offset; -} -// NOLINTNEXTLINE -llvm::hash_code hash_value(const OffsetFieldInfo &fi) { - return llvm::hash_combine(fi.name, fi.type, fi.offset); -} -} // namespace detail -} // namespace hw -} // namespace circt - -Type UnionType::parse(AsmParser &p) { - llvm::SmallVector parameters; - if (p.parseCommaSeparatedList( - mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { - StringRef name; - Type type; - if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type)) - return failure(); - size_t offset = 0; - if (succeeded(p.parseOptionalKeyword("offset"))) - if (p.parseInteger(offset)) return failure(); - parameters.push_back(UnionType::FieldInfo{ - StringAttr::get(p.getContext(), name), type, offset}); - return success(); - })) - return Type(); - return get(p.getContext(), parameters); -} - -void UnionType::print(AsmPrinter &odsPrinter) const { - odsPrinter << '<'; - llvm::interleaveComma( - getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) { - odsPrinter << field.name.getValue() << ": " << field.type; - if (field.offset) odsPrinter << " offset " << field.offset; - }); - odsPrinter << ">"; -} - -UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) { - for (const auto &field : getElements()) - if (field.name == fieldName) return field; - return FieldInfo(); -} - -Type UnionType::getFieldType(mlir::StringRef fieldName) { - return getFieldInfo(fieldName).type; -} - -//===----------------------------------------------------------------------===// -// Enum Type -//===----------------------------------------------------------------------===// - -Type EnumType::parse(AsmParser &p) { - llvm::SmallVector fields; - - if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() { - StringRef name; - if (p.parseKeyword(&name)) return failure(); - fields.push_back(StringAttr::get(p.getContext(), name)); - return success(); - })) - return Type(); - - return get(p.getContext(), ArrayAttr::get(p.getContext(), fields)); -} - -void EnumType::print(AsmPrinter &p) const { - p << '<'; - llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) { - p << enumerator.cast().getValue(); - }); - p << ">"; -} - -bool EnumType::contains(mlir::StringRef field) { - return indexOf(field).has_value(); -} - -std::optional EnumType::indexOf(mlir::StringRef field) { - for (auto it : llvm::enumerate(getFields())) - if (it.value().cast().getValue() == field) return it.index(); - return {}; -} - -size_t EnumType::getBitWidth() { - auto w = getFields().size(); - if (w > 1) return llvm::Log2_64_Ceil(getFields().size()); - return 1; -} - -//===----------------------------------------------------------------------===// -// ArrayType -//===----------------------------------------------------------------------===// - -static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner) { - if (p.parseLess()) return failure(); - - uint64_t dimLiteral; - auto int64Type = p.getBuilder().getIntegerType(64); - - if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value()) - dim = p.getBuilder().getI64IntegerAttr(dimLiteral); - else if (!p.parseOptionalAttribute(dim, int64Type).has_value()) - return failure(); - - if (!dim.isa()) { - p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array"); - return failure(); - } - - if (p.parseXInDimensionList() || parseHWElementType(inner, p) || - p.parseGreater()) - return failure(); - - return success(); -} - -Type ArrayType::parse(AsmParser &p) { - Attribute dim; - Type inner; - - if (failed(parseArray(p, dim, inner))) return Type(); - - auto loc = p.getEncodedSourceLoc(p.getCurrentLocation()); - if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim))) - return Type(); - - return get(inner.getContext(), inner, dim); -} - -void ArrayType::print(AsmPrinter &p) const { - p << "<"; - p.printAttributeWithoutType(getSizeAttr()); - p << "x"; - printHWElementType(getElementType(), p); - p << '>'; -} - -size_t ArrayType::getNumElements() const { - if (auto intAttr = getSizeAttr().dyn_cast()) - return intAttr.getInt(); - return -1; -} - -LogicalResult ArrayType::verify(function_ref emitError, - Type innerType, Attribute size) { - if (hasHWInOutType(innerType)) - return emitError() << "hw.array cannot contain InOut types"; - return success(); -} - -uint64_t ArrayType::getMaxFieldID() const { - return getNumElements() * - (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); -} - -std::pair ArrayType::getSubTypeByFieldID( - uint64_t fieldID) const { - if (fieldID == 0) return {*this, 0}; - return {getElementType(), getIndexAndSubfieldID(fieldID).second}; -} - -std::pair ArrayType::projectToChildFieldID( - uint64_t fieldID, uint64_t index) const { - auto childRoot = getFieldID(index); - auto rangeEnd = - index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1); - return std::make_pair(fieldID - childRoot, - fieldID >= childRoot && fieldID <= rangeEnd); -} - -uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const { - assert(fieldID && "fieldID must be at least 1"); - // Divide the field ID by the number of fieldID's per element. - return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); -} - -std::pair ArrayType::getIndexAndSubfieldID( - uint64_t fieldID) const { - auto index = getIndexForFieldID(fieldID); - auto elementFieldID = getFieldID(index); - return {index, fieldID - elementFieldID}; -} - -uint64_t ArrayType::getFieldID(uint64_t index) const { - return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); -} - -//===----------------------------------------------------------------------===// -// UnpackedArrayType -//===----------------------------------------------------------------------===// - -Type UnpackedArrayType::parse(AsmParser &p) { - Attribute dim; - Type inner; - - if (failed(parseArray(p, dim, inner))) return Type(); - - auto loc = p.getEncodedSourceLoc(p.getCurrentLocation()); - if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim))) - return Type(); - - return get(inner.getContext(), inner, dim); -} - -void UnpackedArrayType::print(AsmPrinter &p) const { - p << "<"; - p.printAttributeWithoutType(getSizeAttr()); - p << "x"; - printHWElementType(getElementType(), p); - p << '>'; -} - -LogicalResult UnpackedArrayType::verify( - function_ref emitError, Type innerType, - Attribute size) { - if (!isHWValueType(innerType)) - return emitError() << "invalid element for uarray type"; - return success(); -} - -size_t UnpackedArrayType::getNumElements() const { - return getSizeAttr().cast().getInt(); -} - -uint64_t UnpackedArrayType::getMaxFieldID() const { - return getNumElements() * - (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); -} - -std::pair UnpackedArrayType::getSubTypeByFieldID( - uint64_t fieldID) const { - if (fieldID == 0) return {*this, 0}; - return {getElementType(), getIndexAndSubfieldID(fieldID).second}; -} - -std::pair UnpackedArrayType::projectToChildFieldID( - uint64_t fieldID, uint64_t index) const { - auto childRoot = getFieldID(index); - auto rangeEnd = - index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1); - return std::make_pair(fieldID - childRoot, - fieldID >= childRoot && fieldID <= rangeEnd); -} - -uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const { - assert(fieldID && "fieldID must be at least 1"); - // Divide the field ID by the number of fieldID's per element. - return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); -} - -std::pair UnpackedArrayType::getIndexAndSubfieldID( - uint64_t fieldID) const { - auto index = getIndexForFieldID(fieldID); - auto elementFieldID = getFieldID(index); - return {index, fieldID - elementFieldID}; -} - -uint64_t UnpackedArrayType::getFieldID(uint64_t index) const { - return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1); -} - -//===----------------------------------------------------------------------===// -// InOutType -//===----------------------------------------------------------------------===// - -Type InOutType::parse(AsmParser &p) { - Type inner; - if (p.parseLess() || parseHWElementType(inner, p) || p.parseGreater()) - return Type(); - - auto loc = p.getEncodedSourceLoc(p.getCurrentLocation()); - if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner))) - return Type(); - - return get(p.getContext(), inner); -} - -void InOutType::print(AsmPrinter &p) const { - p << "<"; - printHWElementType(getElementType(), p); - p << '>'; -} - -LogicalResult InOutType::verify(function_ref emitError, - Type innerType) { - if (!isHWValueType(innerType)) - return emitError() << "invalid element for hw.inout type " << innerType; - return success(); -} - -//===----------------------------------------------------------------------===// -// TypeAliasType -//===----------------------------------------------------------------------===// - -static Type computeCanonicalType(Type type) { - return llvm::TypeSwitch(type) - .Case([](TypeAliasType t) { - return computeCanonicalType(t.getCanonicalType()); - }) - .Case([](ArrayType t) { - return ArrayType::get(computeCanonicalType(t.getElementType()), - t.getNumElements()); - }) - .Case([](UnpackedArrayType t) { - return UnpackedArrayType::get(computeCanonicalType(t.getElementType()), - t.getNumElements()); - }) - .Case([](StructType t) { - SmallVector fieldInfo; - for (auto field : t.getElements()) - fieldInfo.push_back(StructType::FieldInfo{ - field.name, computeCanonicalType(field.type)}); - return StructType::get(t.getContext(), fieldInfo); - }) - .Default([](Type t) { return t; }); -} - -TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) { - return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType)); -} - -Type TypeAliasType::parse(AsmParser &p) { - SymbolRefAttr ref; - Type type; - if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() || - p.parseType(type) || p.parseGreater()) - return Type(); - - return get(ref, type); -} - -void TypeAliasType::print(AsmPrinter &p) const { - p << "<" << getRef() << ", " << getInnerType() << ">"; -} - -/// Return the Typedecl referenced by this TypeAlias, given the module to look -/// in. This returns null when the IR is malformed. -TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) { - SymbolRefAttr ref = getRef(); - auto typeScope = ::dyn_cast_or_null( - cache.getDefinition(ref.getRootReference())); - if (!typeScope) return {}; - - return typeScope.lookupSymbol(ref.getLeafReference()); -} - -//////////////////////////////////////////////////////////////////////////////// -// ModuleType -//////////////////////////////////////////////////////////////////////////////// - -LogicalResult ModuleType::verify(function_ref emitError, - ArrayRef ports) { - if (llvm::any_of(ports, [](const ModulePort &port) { - return hasHWInOutType(port.type); - })) - return emitError() << "Ports cannot be inout types"; - return success(); -} - -size_t ModuleType::getPortIdForInputId(size_t idx) { - for (auto [i, p] : llvm::enumerate(getPorts())) { - if (p.dir != ModulePort::Direction::Output) { - if (!idx) return i; - --idx; - } - } - assert(0 && "Out of bounds input port id"); - return ~0UL; -} - -size_t ModuleType::getPortIdForOutputId(size_t idx) { - for (auto [i, p] : llvm::enumerate(getPorts())) { - if (p.dir == ModulePort::Direction::Output) { - if (!idx) return i; - --idx; - } - } - assert(0 && "Out of bounds output port id"); - return ~0UL; -} - -size_t ModuleType::getNumInputs() { - return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) { - return p.dir != ModulePort::Direction::Output; - }); -} - -size_t ModuleType::getNumOutputs() { - return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) { - return p.dir == ModulePort::Direction::Output; - }); -} - -size_t ModuleType::getNumPorts() { return getPorts().size(); } - -SmallVector ModuleType::getInputTypes() { - SmallVector retval; - for (auto &p : getPorts()) { - if (p.dir == ModulePort::Direction::Input) - retval.push_back(p.type); - else if (p.dir == ModulePort::Direction::InOut) { - retval.push_back(hw::InOutType::get(p.type)); - } - } - return retval; -} - -SmallVector ModuleType::getOutputTypes() { - SmallVector retval; - for (auto &p : getPorts()) - if (p.dir == ModulePort::Direction::Output) retval.push_back(p.type); - return retval; -} - -SmallVector ModuleType::getPortTypes() { - SmallVector retval; - for (auto &p : getPorts()) retval.push_back(p.type); - return retval; -} - -Type ModuleType::getInputType(size_t idx) { - return getPorts()[getPortIdForInputId(idx)].type; -} - -Type ModuleType::getOutputType(size_t idx) { - return getPorts()[getPortIdForOutputId(idx)].type; -} - -SmallVector ModuleType::getInputNamesStr() { - SmallVector retval; - for (auto &p : getPorts()) - if (p.dir != ModulePort::Direction::Output) retval.push_back(p.name); - return retval; -} - -SmallVector ModuleType::getOutputNamesStr() { - SmallVector retval; - for (auto &p : getPorts()) - if (p.dir == ModulePort::Direction::Output) retval.push_back(p.name); - return retval; -} - -SmallVector ModuleType::getInputNames() { - SmallVector retval; - for (auto &p : getPorts()) - if (p.dir != ModulePort::Direction::Output) retval.push_back(p.name); - return retval; -} - -SmallVector ModuleType::getOutputNames() { - SmallVector retval; - for (auto &p : getPorts()) - if (p.dir == ModulePort::Direction::Output) retval.push_back(p.name); - return retval; -} - -StringAttr ModuleType::getPortNameAttr(size_t idx) { - return getPorts()[idx].name; -} - -StringRef ModuleType::getPortName(size_t idx) { - auto sa = getPortNameAttr(idx); - if (sa) return sa.getValue(); - return {}; -} - -StringAttr ModuleType::getInputNameAttr(size_t idx) { - return getPorts()[getPortIdForInputId(idx)].name; -} - -StringRef ModuleType::getInputName(size_t idx) { - auto sa = getInputNameAttr(idx); - if (sa) return sa.getValue(); - return {}; -} - -StringAttr ModuleType::getOutputNameAttr(size_t idx) { - return getPorts()[getPortIdForOutputId(idx)].name; -} - -StringRef ModuleType::getOutputName(size_t idx) { - auto sa = getOutputNameAttr(idx); - if (sa) return sa.getValue(); - return {}; -} - -FunctionType ModuleType::getFuncType() { - SmallVector inputs, outputs; - for (auto p : getPorts()) - if (p.dir == ModulePort::Input) - inputs.push_back(p.type); - else if (p.dir == ModulePort::InOut) - inputs.push_back(InOutType::get(p.type)); - else - outputs.push_back(p.type); - return FunctionType::get(getContext(), inputs, outputs); -} - -static StringRef dirToStr(ModulePort::Direction dir) { - switch (dir) { - case ModulePort::Direction::Input: - return "input"; - case ModulePort::Direction::Output: - return "output"; - case ModulePort::Direction::InOut: - return "inout"; - } -} - -static ModulePort::Direction strToDir(StringRef str) { - if (str == "input") return ModulePort::Direction::Input; - if (str == "output") return ModulePort::Direction::Output; - if (str == "inout") return ModulePort::Direction::InOut; - llvm::report_fatal_error("invalid direction"); -} - -/// Parse a list of field names and types within <>. E.g.: -/// -static ParseResult parsePorts(AsmParser &p, - SmallVectorImpl &ports) { - return p.parseCommaSeparatedList( - mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult { - StringRef dir; - StringRef name; - Type type; - if (p.parseKeyword(&dir) || p.parseKeyword(&name) || p.parseColon() || - p.parseType(type)) - return failure(); - ports.push_back( - {StringAttr::get(p.getContext(), name), type, strToDir(dir)}); - return success(); - }); -} - -/// Print out a list of named fields surrounded by <>. -static void printPorts(AsmPrinter &p, ArrayRef ports) { - p << '<'; - llvm::interleaveComma(ports, p, [&](const ModulePort &port) { - p << dirToStr(port.dir) << " " << port.name.getValue() << " : " - << port.type; - }); - p << ">"; -} - -Type ModuleType::parse(AsmParser &odsParser) { - llvm::SmallVector ports; - if (parsePorts(odsParser, ports)) return Type(); - return get(odsParser.getContext(), ports); -} - -void ModuleType::print(AsmPrinter &odsPrinter) const { - printPorts(odsPrinter, getPorts()); -} - -namespace circt { -namespace hw { - -static bool operator==(const ModulePort &a, const ModulePort &b) { - return a.dir == b.dir && a.name == b.name && a.type == b.type; -} -static llvm::hash_code hash_value(const ModulePort &port) { - return llvm::hash_combine(port.dir, port.name, port.type); -} -} // namespace hw -} // namespace circt - -ModuleType circt::hw::detail::fnToMod(Operation *op, - ArrayRef inputNames, - ArrayRef outputNames) { - return fnToMod( - cast(cast(op).getFunctionType()), - inputNames, outputNames); -} - -ModuleType circt::hw::detail::fnToMod(FunctionType fnty, - ArrayRef inputNames, - ArrayRef outputNames) { - SmallVector ports; - if (!inputNames.empty()) { - for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames)) - if (auto iot = dyn_cast(t)) - ports.push_back({cast(n), iot.getElementType(), - ModulePort::Direction::InOut}); - else - ports.push_back({cast(n), t, ModulePort::Direction::Input}); - } else { - for (auto t : fnty.getInputs()) - if (auto iot = dyn_cast(t)) - ports.push_back( - {{}, iot.getElementType(), ModulePort::Direction::InOut}); - else - ports.push_back({{}, t, ModulePort::Direction::Input}); - } - if (!outputNames.empty()) { - for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames)) - ports.push_back({cast(n), t, ModulePort::Direction::Output}); - } else { - for (auto t : fnty.getResults()) - ports.push_back({{}, t, ModulePort::Direction::Output}); - } - return ModuleType::get(fnty.getContext(), ports); -} - -//////////////////////////////////////////////////////////////////////////////// -// BoilerPlate -//////////////////////////////////////////////////////////////////////////////// - -void HWDialect::registerTypes() { - addTypes< -#define GET_TYPEDEF_LIST -#include "include/circt/Dialect/HW/HWTypes.cpp.inc" - >(); -} diff --git a/lib/circt/Dialect/HW/InnerSymbolTable.cpp b/lib/circt/Dialect/HW/InnerSymbolTable.cpp deleted file mode 100644 index 6856376727..0000000000 --- a/lib/circt/Dialect/HW/InnerSymbolTable.cpp +++ /dev/null @@ -1,251 +0,0 @@ -//===- InnerSymbolTable.cpp - InnerSymbolTable and InnerRef verification --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements InnerSymbolTable and verification for InnerRef's. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/InnerSymbolTable.h" - -#include "include/circt/Dialect/HW/HWOpInterfaces.h" -#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project -#include "mlir/include/mlir/IR/Threading.h" // from @llvm-project - -using namespace circt; -using namespace hw; - -namespace circt { -namespace hw { - -//===----------------------------------------------------------------------===// -// InnerSymbolTable -//===----------------------------------------------------------------------===// -InnerSymbolTable::InnerSymbolTable(Operation *op) { - assert(op->hasTrait()); - // Save the operation this table is for. - this->innerSymTblOp = op; - - walkSymbols(op, [&](StringAttr name, const InnerSymTarget &target) { - auto it = symbolTable.try_emplace(name, target); - (void)it; - assert(it.second && "repeated symbol found"); - }); -} - -FailureOr InnerSymbolTable::get(Operation *op) { - assert(op); - if (!op->hasTrait()) - return op->emitError("expected operation to have InnerSymbolTable trait"); - - TableTy table; - auto result = walkSymbols( - op, [&](StringAttr name, const InnerSymTarget &target) -> LogicalResult { - auto it = table.try_emplace(name, target); - if (it.second) return success(); - auto existing = it.first->second; - return target.getOp() - ->emitError() - .append("redefinition of inner symbol named '", name.strref(), "'") - .attachNote(existing.getOp()->getLoc()) - .append("see existing inner symbol definition here"); - }); - if (failed(result)) return failure(); - return InnerSymbolTable(op, std::move(table)); -} - -LogicalResult InnerSymbolTable::walkSymbols(Operation *op, - InnerSymCallbackFn callback) { - auto walkSym = [&](StringAttr name, const InnerSymTarget &target) { - assert(name && !name.getValue().empty()); - return callback(name, target); - }; - - auto walkSyms = [&](hw::InnerSymAttr symAttr, - const InnerSymTarget &baseTarget) -> LogicalResult { - assert(baseTarget.getField() == 0); - for (auto symProp : symAttr) { - if (failed(walkSym(symProp.getName(), - InnerSymTarget::getTargetForSubfield( - baseTarget, symProp.getFieldID())))) - return failure(); - } - return success(); - }; - - // Walk the operation and add InnerSymbolTarget's to the table. - return success( - !op->walk([&](Operation *curOp) -> WalkResult { - if (auto symOp = dyn_cast(curOp)) - if (auto symAttr = symOp.getInnerSymAttr()) - if (failed(walkSyms(symAttr, InnerSymTarget(symOp)))) - return WalkResult::interrupt(); - - // Check for ports - if (auto mod = dyn_cast(curOp)) { - for (size_t i = 0, e = mod.getNumPorts(); i < e; ++i) { - if (auto symAttr = mod.getPortSymbolAttr(i)) - if (failed(walkSyms(symAttr, InnerSymTarget(i, curOp)))) - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }).wasInterrupted()); -} - -/// Look up a symbol with the specified name, returning empty InnerSymTarget if -/// no such name exists. Names never include the @ on them. -InnerSymTarget InnerSymbolTable::lookup(StringRef name) const { - return lookup(StringAttr::get(innerSymTblOp->getContext(), name)); -} -InnerSymTarget InnerSymbolTable::lookup(StringAttr name) const { - return symbolTable.lookup(name); -} - -/// Look up a symbol with the specified name, returning null if no such -/// name exists or doesn't target just an operation. -Operation *InnerSymbolTable::lookupOp(StringRef name) const { - return lookupOp(StringAttr::get(innerSymTblOp->getContext(), name)); -} -Operation *InnerSymbolTable::lookupOp(StringAttr name) const { - auto result = lookup(name); - if (result.isOpOnly()) return result.getOp(); - return nullptr; -} - -/// Get InnerSymbol for an operation. -StringAttr InnerSymbolTable::getInnerSymbol(Operation *op) { - if (auto innerSymOp = dyn_cast(op)) - return innerSymOp.getInnerNameAttr(); - return {}; -} - -/// Get InnerSymbol for a target. Be robust to queries on unexpected -/// operations to avoid users needing to know the details. -StringAttr InnerSymbolTable::getInnerSymbol(const InnerSymTarget &target) { - // Assert on misuse, but try to handle queries otherwise. - assert(target); - - // Obtain the base InnerSymAttr for the specified target. - auto getBase = [](auto &target) -> hw::InnerSymAttr { - if (target.isPort()) { - if (auto mod = dyn_cast(target.getOp())) { - assert(target.getPort() < mod.getNumPorts()); - return mod.getPortSymbolAttr(target.getPort()); - } - } else { - // InnerSymbols only supported if op implements the interface. - if (auto symOp = dyn_cast(target.getOp())) - return symOp.getInnerSymAttr(); - } - return {}; - }; - - if (auto base = getBase(target)) - return base.getSymIfExists(target.getField()); - return {}; -} - -//===----------------------------------------------------------------------===// -// InnerSymbolTableCollection -//===----------------------------------------------------------------------===// - -InnerSymbolTable &InnerSymbolTableCollection::getInnerSymbolTable( - Operation *op) { - auto it = symbolTables.try_emplace(op, nullptr); - if (it.second) it.first->second = ::std::make_unique(op); - return *it.first->second; -} - -LogicalResult InnerSymbolTableCollection::populateAndVerifyTables( - Operation *innerRefNSOp) { - // Gather top-level operations that have the InnerSymbolTable trait. - SmallVector innerSymTableOps(llvm::make_filter_range( - llvm::make_pointer_range(innerRefNSOp->getRegion(0).front()), - [&](Operation *op) { - return op->hasTrait(); - })); - - // Ensure entries exist for each operation. - llvm::for_each(innerSymTableOps, - [&](auto *op) { symbolTables.try_emplace(op, nullptr); }); - - // Construct the tables in parallel (if context allows it). - return mlir::failableParallelForEach( - innerRefNSOp->getContext(), innerSymTableOps, [&](auto *op) { - auto it = symbolTables.find(op); - assert(it != symbolTables.end()); - if (!it->second) { - auto result = InnerSymbolTable::get(op); - if (failed(result)) return failure(); - it->second = std::make_unique(std::move(*result)); - return success(); - } - return failure(); - }); -} - -//===----------------------------------------------------------------------===// -// InnerRefNamespace -//===----------------------------------------------------------------------===// - -InnerSymTarget InnerRefNamespace::lookup(hw::InnerRefAttr inner) { - auto *mod = symTable.lookup(inner.getModule()); - if (!mod) return {}; - assert(mod->hasTrait()); - return innerSymTables.getInnerSymbolTable(mod).lookup(inner.getName()); -} - -Operation *InnerRefNamespace::lookupOp(hw::InnerRefAttr inner) { - auto *mod = symTable.lookup(inner.getModule()); - if (!mod) return nullptr; - assert(mod->hasTrait()); - return innerSymTables.getInnerSymbolTable(mod).lookupOp(inner.getName()); -} - -//===----------------------------------------------------------------------===// -// InnerRefNamespace verification -//===----------------------------------------------------------------------===// - -namespace detail { - -LogicalResult verifyInnerRefNamespace(Operation *op) { - // Construct the symbol tables. - InnerSymbolTableCollection innerSymTables; - if (failed(innerSymTables.populateAndVerifyTables(op))) return failure(); - - SymbolTable symbolTable(op); - InnerRefNamespace ns{symbolTable, innerSymTables}; - - // Conduct parallel walks of the top-level children of this - // InnerRefNamespace, verifying all InnerRefUserOp's discovered within. - auto verifySymbolUserFn = [&](Operation *op) -> WalkResult { - if (auto user = dyn_cast(op)) - return WalkResult(user.verifyInnerRefs(ns)); - return WalkResult::advance(); - }; - return mlir::failableParallelForEach( - op->getContext(), op->getRegion(0).front(), [&](auto &op) { - return success(!op.walk(verifySymbolUserFn).wasInterrupted()); - }); -} - -} // namespace detail - -bool InnerRefNamespaceLike::classof(mlir::Operation *op) { - return op->hasTrait() || - op->hasTrait(); -} - -bool InnerRefNamespaceLike::classof( - const mlir::RegisteredOperationName *opInfo) { - return opInfo->hasTrait() || - opInfo->hasTrait(); -} - -} // namespace hw -} // namespace circt diff --git a/lib/circt/Dialect/HW/InstanceImplementation.cpp b/lib/circt/Dialect/HW/InstanceImplementation.cpp deleted file mode 100644 index 6294ca1a4a..0000000000 --- a/lib/circt/Dialect/HW/InstanceImplementation.cpp +++ /dev/null @@ -1,347 +0,0 @@ -//===- InstanceImplementation.cpp - Utilities for instance-like ops -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/InstanceImplementation.h" - -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/HW/HWSymCache.h" - -using namespace circt; -using namespace circt::hw; - -Operation *instance_like_impl::getReferencedModule( - const HWSymbolCache *cache, Operation *instanceOp, - mlir::FlatSymbolRefAttr moduleName) { - if (cache) - if (auto *result = cache->getDefinition(moduleName)) return result; - - auto topLevelModuleOp = instanceOp->getParentOfType(); - return topLevelModuleOp.lookupSymbol(moduleName.getValue()); -} - -LogicalResult instance_like_impl::verifyReferencedModule( - Operation *instanceOp, SymbolTableCollection &symbolTable, - mlir::FlatSymbolRefAttr moduleName, Operation *&module) { - module = symbolTable.lookupNearestSymbolFrom(instanceOp, moduleName); - if (module == nullptr) - return instanceOp->emitError("Cannot find module definition '") - << moduleName.getValue() << "'"; - - // It must be some sort of module. - if (!isa(module)) - return instanceOp->emitError("symbol reference '") - << moduleName.getValue() << "' isn't a module"; - - return success(); -} - -LogicalResult instance_like_impl::resolveParametricTypes( - Location loc, ArrayAttr parameters, ArrayRef types, - SmallVectorImpl &resolvedTypes, const EmitErrorFn &emitError) { - for (auto type : types) { - auto expectedType = evaluateParametricType(loc, parameters, type); - if (failed(expectedType)) { - emitError([&](auto &diag) { - diag << "failed to resolve parametric input of instantiated module"; - return true; - }); - return failure(); - } - - resolvedTypes.push_back(*expectedType); - } - - return success(); -} - -LogicalResult instance_like_impl::verifyInputs(ArrayAttr argNames, - ArrayAttr moduleArgNames, - TypeRange inputTypes, - ArrayRef moduleInputTypes, - const EmitErrorFn &emitError) { - // Check operand types first. - if (moduleInputTypes.size() != inputTypes.size()) { - emitError([&](auto &diag) { - diag << "has a wrong number of operands; expected " - << moduleInputTypes.size() << " but got " << inputTypes.size(); - return true; - }); - return failure(); - } - - if (argNames.size() != inputTypes.size()) { - emitError([&](auto &diag) { - diag << "has a wrong number of input port names; expected " - << inputTypes.size() << " but got " << argNames.size(); - return true; - }); - return failure(); - } - - for (size_t i = 0; i != inputTypes.size(); ++i) { - auto expectedType = moduleInputTypes[i]; - auto operandType = inputTypes[i]; - - if (operandType != expectedType) { - emitError([&](auto &diag) { - diag << "operand type #" << i << " must be " << expectedType - << ", but got " << operandType; - return true; - }); - return failure(); - } - - if (argNames[i] != moduleArgNames[i]) { - emitError([&](auto &diag) { - diag << "input label #" << i << " must be " << moduleArgNames[i] - << ", but got " << argNames[i]; - return true; - }); - return failure(); - } - } - - return success(); -} - -LogicalResult instance_like_impl::verifyOutputs( - ArrayAttr resultNames, ArrayAttr moduleResultNames, TypeRange resultTypes, - ArrayRef moduleResultTypes, const EmitErrorFn &emitError) { - // Check result types and labels. - if (moduleResultTypes.size() != resultTypes.size()) { - emitError([&](auto &diag) { - diag << "has a wrong number of results; expected " - << moduleResultTypes.size() << " but got " << resultTypes.size(); - return true; - }); - return failure(); - } - - if (resultNames.size() != resultTypes.size()) { - emitError([&](auto &diag) { - diag << "has a wrong number of results port labels; expected " - << resultTypes.size() << " but got " << resultNames.size(); - return true; - }); - return failure(); - } - - for (size_t i = 0; i != resultTypes.size(); ++i) { - auto expectedType = moduleResultTypes[i]; - auto resultType = resultTypes[i]; - - if (resultType != expectedType) { - emitError([&](auto &diag) { - diag << "result type #" << i << " must be " << expectedType - << ", but got " << resultType; - return true; - }); - return failure(); - } - - if (resultNames[i] != moduleResultNames[i]) { - emitError([&](auto &diag) { - diag << "result label #" << i << " must be " << moduleResultNames[i] - << ", but got " << resultNames[i]; - return true; - }); - return failure(); - } - } - - return success(); -} - -LogicalResult instance_like_impl::verifyParameters( - ArrayAttr parameters, ArrayAttr moduleParameters, - const EmitErrorFn &emitError) { - // Check parameters match up. - auto numParameters = parameters.size(); - if (numParameters != moduleParameters.size()) { - emitError([&](auto &diag) { - diag << "expected " << moduleParameters.size() << " parameters but had " - << numParameters; - return true; - }); - return failure(); - } - - for (size_t i = 0; i != numParameters; ++i) { - auto param = parameters[i].cast(); - auto modParam = moduleParameters[i].cast(); - - auto paramName = param.getName(); - if (paramName != modParam.getName()) { - emitError([&](auto &diag) { - diag << "parameter #" << i << " should have name " << modParam.getName() - << " but has name " << paramName; - return true; - }); - return failure(); - } - - if (param.getType() != modParam.getType()) { - emitError([&](auto &diag) { - diag << "parameter " << paramName << " should have type " - << modParam.getType() << " but has type " << param.getType(); - return true; - }); - return failure(); - } - - // All instance parameters must have a value. Specify the same value as - // a module's default value if you want the default. - if (!param.getValue()) { - emitError([&](auto &diag) { - diag << "parameter " << paramName << " must have a value"; - return false; - }); - return failure(); - } - } - - return success(); -} - -LogicalResult instance_like_impl::verifyInstanceOfHWModule( - Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, - TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, - ArrayAttr parameters, SymbolTableCollection &symbolTable) { - // Verify that we reference some kind of HW module and get the module on - // success. - Operation *module; - if (failed(instance_like_impl::verifyReferencedModule(instance, symbolTable, - moduleRef, module))) - return failure(); - - // Emit an error message on the instance, with a note indicating which module - // is being referenced. The error message on the instance is added by the - // verification function this lambda is passed to. - EmitErrorFn emitError = - [&](const std::function &fn) { - auto diag = instance->emitOpError(); - if (fn(diag)) - diag.attachNote(module->getLoc()) << "module declared here"; - }; - - // Check that input types are consistent with the referenced module. - auto mod = cast(module); - auto modArgNames = - ArrayAttr::get(instance->getContext(), mod.getInputNames()); - auto modResultNames = - ArrayAttr::get(instance->getContext(), mod.getOutputNames()); - - ArrayRef resolvedModInputTypesRef = getModuleType(module).getInputs(); - SmallVector resolvedModInputTypes; - if (parameters) { - if (failed(instance_like_impl::resolveParametricTypes( - instance->getLoc(), parameters, getModuleType(module).getInputs(), - resolvedModInputTypes, emitError))) - return failure(); - resolvedModInputTypesRef = resolvedModInputTypes; - } - if (failed(instance_like_impl::verifyInputs( - argNames, modArgNames, inputs.getTypes(), resolvedModInputTypesRef, - emitError))) - return failure(); - - // Check that result types are consistent with the referenced module. - ArrayRef resolvedModResultTypesRef = getModuleType(module).getResults(); - SmallVector resolvedModResultTypes; - if (parameters) { - if (failed(instance_like_impl::resolveParametricTypes( - instance->getLoc(), parameters, getModuleType(module).getResults(), - resolvedModResultTypes, emitError))) - return failure(); - resolvedModResultTypesRef = resolvedModResultTypes; - } - if (failed(instance_like_impl::verifyOutputs( - resultNames, modResultNames, results, resolvedModResultTypesRef, - emitError))) - return failure(); - - if (parameters) { - // Check that the parameters are consistent with the referenced module. - ArrayAttr modParameters = module->getAttrOfType("parameters"); - if (failed(instance_like_impl::verifyParameters(parameters, modParameters, - emitError))) - return failure(); - } - - return success(); -} - -LogicalResult instance_like_impl::verifyParameterStructure( - ArrayAttr parameters, ArrayAttr moduleParameters, - const EmitErrorFn &emitError) { - // Check that all the parameter values specified to the instance are - // structurally valid. - for (auto param : parameters) { - auto paramAttr = param.cast(); - auto value = paramAttr.getValue(); - // The SymbolUses verifier which checks that this exists may not have been - // run yet. Let it issue the error. - if (!value) continue; - - auto typedValue = value.dyn_cast(); - if (!typedValue) { - emitError([&](auto &diag) { - diag << "parameter " << paramAttr - << " should have a typed value; has value " << value; - return false; - }); - return failure(); - } - - if (typedValue.getType() != paramAttr.getType()) { - emitError([&](auto &diag) { - diag << "parameter " << paramAttr << " should have type " - << paramAttr.getType() << "; has type " << typedValue.getType(); - return false; - }); - return failure(); - } - - if (failed(checkParameterInContext(value, moduleParameters, emitError))) - return failure(); - } - return success(); -} - -StringAttr instance_like_impl::getName(ArrayAttr names, size_t idx) { - // Tolerate malformed IR here to enable debug printing etc. - if (names && idx < names.size()) return names[idx].cast(); - return StringAttr(); -} - -ArrayAttr instance_like_impl::updateName(ArrayAttr oldNames, size_t i, - StringAttr name) { - SmallVector newNames(oldNames.begin(), oldNames.end()); - if (newNames[i] == name) return oldNames; - newNames[i] = name; - return ArrayAttr::get(oldNames.getContext(), oldNames); -} - -void instance_like_impl::getAsmResultNames(OpAsmSetValueNameFn setNameFn, - StringRef instanceName, - ArrayAttr resultNames, - ValueRange results) { - // Provide default names for instance results. - std::string name = instanceName.str() + "."; - size_t baseNameLen = name.size(); - - for (size_t i = 0, e = resultNames.size(); i != e; ++i) { - auto resName = getName(resultNames, i); - name.resize(baseNameLen); - if (resName && !resName.getValue().empty()) - name += resName.getValue().str(); - else - name += std::to_string(i); - setNameFn(results[i], name); - } -} diff --git a/lib/circt/Dialect/HW/ModuleImplementation.cpp b/lib/circt/Dialect/HW/ModuleImplementation.cpp deleted file mode 100644 index 02ad5501ca..0000000000 --- a/lib/circt/Dialect/HW/ModuleImplementation.cpp +++ /dev/null @@ -1,328 +0,0 @@ -//===- ModuleImplementation.cpp - Utilities for module-like ops -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/ModuleImplementation.h" - -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Support/LLVM.h" -#include "include/circt/Support/ParsingUtils.h" -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project -#include "mlir/include/mlir/Interfaces/FunctionImplementation.h" // from @llvm-project - -using namespace circt; -using namespace circt::hw; - -/// Parse a function result list. -/// -/// function-result-list ::= function-result-list-parens -/// function-result-list-parens ::= `(` `)` -/// | `(` function-result-list-no-parens `)` -/// function-result-list-no-parens ::= function-result (`,` function-result)* -/// function-result ::= (percent-identifier `:`) type attribute-dict? -/// -static ParseResult parseFunctionResultList( - OpAsmParser &parser, SmallVectorImpl &resultNames, - SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs, - SmallVectorImpl &resultLocs) { - auto parseElt = [&]() -> ParseResult { - // Stash the current location parser location. - auto irLoc = parser.getCurrentLocation(); - - // Parse the result name. - std::string portName; - if (parser.parseKeywordOrString(&portName)) return failure(); - resultNames.push_back(StringAttr::get(parser.getContext(), portName)); - - // Parse the results type. - resultTypes.emplace_back(); - if (parser.parseColonType(resultTypes.back())) return failure(); - - // Parse the result attributes. - NamedAttrList attrs; - if (failed(parser.parseOptionalAttrDict(attrs))) return failure(); - resultAttrs.push_back(attrs.getDictionary(parser.getContext())); - - // Parse the result location. - std::optional maybeLoc; - if (failed(parser.parseOptionalLocationSpecifier(maybeLoc))) - return failure(); - Location loc = maybeLoc ? *maybeLoc : parser.getEncodedSourceLoc(irLoc); - resultLocs.push_back(loc); - - return success(); - }; - - return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, - parseElt); -} - -/// Return the port name for the specified argument or result. -static StringRef getModuleArgumentName(Operation *module, size_t argNo) { - if (auto mod = dyn_cast(module)) { - if (argNo < mod.getNumInputPorts()) return mod.getInputName(argNo); - return StringRef(); - } - auto argNames = module->getAttrOfType("argNames"); - // Tolerate malformed IR here to enable debug printing etc. - if (argNames && argNo < argNames.size()) - return argNames[argNo].cast().getValue(); - return StringRef(); -} - -static StringRef getModuleResultName(Operation *module, size_t resultNo) { - if (auto mod = dyn_cast(module)) { - if (resultNo < mod.getNumOutputPorts()) return mod.getOutputName(resultNo); - return StringRef(); - } - auto resultNames = module->getAttrOfType("resultNames"); - // Tolerate malformed IR here to enable debug printing etc. - if (resultNames && resultNo < resultNames.size()) - return resultNames[resultNo].cast().getValue(); - return StringRef(); -} - -void module_like_impl::printModuleSignature(OpAsmPrinter &p, Operation *op, - ArrayRef argTypes, - bool isVariadic, - ArrayRef resultTypes, - bool &needArgNamesAttr) { - using namespace mlir::function_interface_impl; - - Region &body = op->getRegion(0); - bool isExternal = body.empty(); - SmallString<32> resultNameStr; - mlir::OpPrintingFlags flags; - - // Handle either old FunctionOpInterface modules or new-style hwmodulelike - // This whole thing should be split up into two functions, but the delta is - // so small, we are leaving this for now. - auto modOp = dyn_cast(op); - auto funcOp = dyn_cast(op); - SmallVector inputAttrs, outputAttrs; - if (funcOp) { - if (auto args = funcOp.getAllArgAttrs()) - for (auto a : args.getValue()) inputAttrs.push_back(a); - inputAttrs.resize(funcOp.getNumArguments()); - if (auto results = funcOp.getAllResultAttrs()) - for (auto a : results.getValue()) outputAttrs.push_back(a); - outputAttrs.resize(funcOp.getNumResults()); - } else { - inputAttrs = modOp.getAllInputAttrs(); - outputAttrs = modOp.getAllOutputAttrs(); - } - - p << '('; - for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { - if (i > 0) p << ", "; - - auto argName = modOp ? modOp.getInputName(i) : getModuleArgumentName(op, i); - if (!isExternal) { - // Get the printed format for the argument name. - resultNameStr.clear(); - llvm::raw_svector_ostream tmpStream(resultNameStr); - p.printOperand(body.front().getArgument(i), tmpStream); - // If the name wasn't printable in a way that agreed with argName, make - // sure to print out an explicit argNames attribute. - if (tmpStream.str().drop_front() != argName) needArgNamesAttr = true; - - p << tmpStream.str() << ": "; - } else if (!argName.empty()) { - p << '%' << argName << ": "; - } - - p.printType(argTypes[i]); - auto inputAttr = inputAttrs[i]; - p.printOptionalAttrDict(inputAttr - ? cast(inputAttr).getValue() - : ArrayRef()); - - // TODO: `printOptionalLocationSpecifier` will emit aliases for locations, - // even if they are not printed. This will have to be fixed upstream. For - // now, use what was specified on the command line. - if (flags.shouldPrintDebugInfo()) { - auto loc = modOp.getInputLoc(i); - if (!isa(loc)) p.printOptionalLocationSpecifier(loc); - } - } - - if (isVariadic) { - if (!argTypes.empty()) p << ", "; - p << "..."; - } - - p << ')'; - - // We print result types specially since we support named arguments. - if (!resultTypes.empty()) { - p << " -> ("; - for (size_t i = 0, e = resultTypes.size(); i < e; ++i) { - if (i != 0) p << ", "; - p.printKeywordOrString(getModuleResultName(op, i)); - p << ": "; - p.printType(resultTypes[i]); - auto outputAttr = outputAttrs[i]; - p.printOptionalAttrDict(outputAttr - ? cast(outputAttr).getValue() - : ArrayRef()); - - // TODO: `printOptionalLocationSpecifier` will emit aliases for locations, - // even if they are not printed. This will have to be fixed upstream. For - // now, use what was specified on the command line. - if (flags.shouldPrintDebugInfo()) { - auto loc = modOp.getOutputLoc(i); - if (!isa(loc)) p.printOptionalLocationSpecifier(loc); - } - } - p << ')'; - } -} - -ParseResult module_like_impl::parseModuleFunctionSignature( - OpAsmParser &parser, bool &isVariadic, - SmallVectorImpl &args, - SmallVectorImpl &argNames, SmallVectorImpl &argLocs, - SmallVectorImpl &resultNames, - SmallVectorImpl &resultAttrs, - SmallVectorImpl &resultLocs, TypeAttr &type) { - using namespace mlir::function_interface_impl; - auto *context = parser.getContext(); - - // Parse the argument list. - if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, - /*allowType=*/true, /*allowAttrs=*/true)) - return failure(); - - // Parse the result list. - SmallVector resultTypes; - if (succeeded(parser.parseOptionalArrow())) - if (failed(parseFunctionResultList(parser, resultNames, resultTypes, - resultAttrs, resultLocs))) - return failure(); - - // Process the ssa args for the information we're looking for. - SmallVector argTypes; - for (auto &arg : args) { - argNames.push_back(parsing_util::getNameFromSSA(context, arg.ssaName.name)); - argTypes.push_back(arg.type); - if (!arg.sourceLoc) - arg.sourceLoc = parser.getEncodedSourceLoc(arg.ssaName.location); - argLocs.push_back(*arg.sourceLoc); - } - - type = TypeAttr::get(FunctionType::get(context, argTypes, resultTypes)); - - return success(); -} - -//////////////////////////////////////////////////////////////////////////////// -// New Style -//////////////////////////////////////////////////////////////////////////////// - -static ParseResult parseDirection(OpAsmParser &p, ModulePort::Direction &dir) { - StringRef key; - if (failed(p.parseKeyword(&key))) - return p.emitError(p.getCurrentLocation(), "expected port direction"); - if (key == "input") - dir = ModulePort::Direction::Input; - else if (key == "output") - dir = ModulePort::Direction::Output; - else if (key == "inout") - dir = ModulePort::Direction::InOut; - else - return p.emitError(p.getCurrentLocation(), "unknown port direction '") - << key << "'"; - return success(); -} - -/// Parse a single argument with the following syntax: -/// -/// direction `%ssaname : !type { optionalAttrDict} loc(optionalSourceLoc)` -/// -/// If `allowType` is false or `allowAttrs` are false then the respective -/// parts of the grammar are not parsed. -static ParseResult parsePort(OpAsmParser &p, - module_like_impl::PortParse &result) { - NamedAttrList attrs; - if (parseDirection(p, result.direction) || - p.parseOperand(result.ssaName, /*allowResultNumber=*/false) || - p.parseColonType(result.type) || p.parseOptionalAttrDict(attrs) || - p.parseOptionalLocationSpecifier(result.sourceLoc)) - return failure(); - result.attrs = attrs.getDictionary(p.getContext()); - return success(); -} - -static ParseResult parsePortList( - OpAsmParser &p, SmallVectorImpl &result) { - auto parseOnePort = [&]() -> ParseResult { - return parsePort(p, result.emplace_back()); - }; - return p.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOnePort, - " in port list"); -} - -ParseResult module_like_impl::parseModuleSignature( - OpAsmParser &parser, SmallVectorImpl &args, TypeAttr &modType) { - auto *context = parser.getContext(); - - // Parse the port list. - if (parsePortList(parser, args)) return failure(); - - // Process the ssa args for the information we're looking for. - SmallVector ports; - for (auto &arg : args) { - ports.push_back({parsing_util::getNameFromSSA(context, arg.ssaName.name), - arg.type, arg.direction}); - if (!arg.sourceLoc) - arg.sourceLoc = parser.getEncodedSourceLoc(arg.ssaName.location); - } - modType = TypeAttr::get(ModuleType::get(context, ports)); - - return success(); -} - -static const char *directionAsString(ModulePort::Direction dir) { - if (dir == ModulePort::Direction::Input) return "input"; - if (dir == ModulePort::Direction::Output) return "output"; - if (dir == ModulePort::Direction::InOut) return "inout"; - assert(0 && "Unknown port direction"); - abort(); - return "unknown"; -} - -void module_like_impl::printModuleSignatureNew(OpAsmPrinter &p, Operation *op) { - mlir::OpPrintingFlags flags; - - auto typeAttr = op->getAttrOfType("module_type"); - auto modType = cast(typeAttr.getValue()); - auto portAttrs = op->getAttrOfType("port_attrs"); - auto locAttrs = op->getAttrOfType("port_locs"); - - p << '('; - for (auto [i, port] : llvm::enumerate(modType.getPorts())) { - if (i > 0) p << ", "; - p.printKeywordOrString(directionAsString(port.dir)); - p << " %"; - p.printKeywordOrString(port.name); - p << " : "; - p.printType(port.type); - if (auto attr = dyn_cast(portAttrs[i])) - p.printOptionalAttrDict(attr.getValue()); - - // TODO: `printOptionalLocationSpecifier` will emit aliases for locations, - // even if they are not printed. This will have to be fixed upstream. For - // now, use what was specified on the command line. - if (flags.shouldPrintDebugInfo()) - if (auto loc = locAttrs[i]) - p.printOptionalLocationSpecifier(cast(loc)); - } - - p << ')'; -} diff --git a/lib/circt/Dialect/HW/PortConverter.cpp b/lib/circt/Dialect/HW/PortConverter.cpp deleted file mode 100644 index 6118d86efa..0000000000 --- a/lib/circt/Dialect/HW/PortConverter.cpp +++ /dev/null @@ -1,233 +0,0 @@ -//===- PortConverter.cpp - Module I/O rewriting utility ---------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Dialect/HW/PortConverter.h" - -#include - -using namespace circt; -using namespace hw; - -/// Return a attribute with the specified suffix appended. -static StringAttr append(StringAttr base, const Twine &suffix) { - if (suffix.isTriviallyEmpty()) return base; - auto *context = base.getContext(); - return StringAttr::get(context, base.getValue() + suffix); -} - -namespace { - -/// We consider non-caught ports to be ad-hoc signaling or 'untouched'. (Which -/// counts as a signaling protocol if one squints pretty hard). We mostly do -/// this since it allows us a more consistent internal API. -class UntouchedPortConversion : public PortConversion { - public: - UntouchedPortConversion(PortConverterImpl &converter, hw::PortInfo origPort) - : PortConversion(converter, origPort) { - // Set the "RTTI flag" to true (see comment in header for this variable). - isUntouchedFlag = true; - } - - void mapInputSignals(OpBuilder &b, Operation *inst, Value instValue, - SmallVectorImpl &newOperands, - ArrayRef newResults) override { - newOperands[portInfo.argNum] = instValue; - } - void mapOutputSignals(OpBuilder &b, Operation *inst, Value instValue, - SmallVectorImpl &newOperands, - ArrayRef newResults) override { - instValue.replaceAllUsesWith(newResults[portInfo.argNum]); - } - - private: - void buildInputSignals() override { - Value newValue = - converter.createNewInput(origPort, "", origPort.type, portInfo); - if (body) body->getArgument(origPort.argNum).replaceAllUsesWith(newValue); - } - - void buildOutputSignals() override { - Value output; - if (body) output = body->getTerminator()->getOperand(origPort.argNum); - converter.createNewOutput(origPort, "", origPort.type, output, portInfo); - } - - hw::PortInfo portInfo; -}; - -} // namespace - -FailureOr> PortConversionBuilder::build( - hw::PortInfo port) { - // Default builder is the 'untouched' port conversion which will simply - // pass ports through unmodified. - return {std::make_unique(converter, port)}; -} - -PortConverterImpl::PortConverterImpl(igraph::InstanceGraphNode *moduleNode) - : moduleNode(moduleNode), b(moduleNode->getModule()->getContext()) { - mod = dyn_cast(*moduleNode->getModule()); - assert(mod && "PortConverter only works on HWMutableModuleLike"); - - if (mod->getNumRegions() == 1 && mod->getRegion(0).hasOneBlock()) { - body = &mod->getRegion(0).front(); - terminator = body->getTerminator(); - } -} - -Value PortConverterImpl::createNewInput(PortInfo origPort, const Twine &suffix, - Type type, PortInfo &newPort) { - newPort = PortInfo{ - {append(origPort.name, suffix), type, ModulePort::Direction::Input}, - newInputs.size(), - {}, - {}, - origPort.loc}; - newInputs.emplace_back(0, newPort); - - if (!body) return {}; - return body->addArgument(type, origPort.loc); -} - -void PortConverterImpl::createNewOutput(PortInfo origPort, const Twine &suffix, - Type type, Value output, - PortInfo &newPort) { - newPort = PortInfo{ - {append(origPort.name, suffix), type, ModulePort::Direction::Output}, - newOutputs.size(), - {}, - {}, - origPort.loc}; - newOutputs.emplace_back(0, newPort); - - if (!body) return; - - OpBuilder::InsertionGuard g(b); - b.setInsertionPointToStart(body); - terminator->insertOperands(terminator->getNumOperands(), output); -} - -LogicalResult PortConverterImpl::run() { - ModulePortInfo ports = mod.getPortList(); - - bool foundLoweredPorts = false; - - auto createPortLowering = [&](PortInfo port) { - auto &loweredPorts = port.dir == ModulePort::Direction::Output - ? loweredOutputs - : loweredInputs; - - auto loweredPort = ssb->build(port); - if (failed(loweredPort)) return failure(); - - foundLoweredPorts |= !(*loweredPort)->isUntouched(); - loweredPorts.emplace_back(std::move(*loweredPort)); - - if (failed(loweredPorts.back()->init())) return failure(); - - return success(); - }; - - // Dispatch the port conversion builder on the I/O of the module. - for (PortInfo port : ports) - if (failed(createPortLowering(port))) return failure(); - - // Bail early if we didn't find anything to convert. - if (!foundLoweredPorts) { - // Memory optimization. - loweredInputs.clear(); - loweredOutputs.clear(); - return success(); - } - - // Lower the ports -- this mutates the body directly and builds the port - // lists. - for (auto &lowering : loweredInputs) lowering->lowerPort(); - for (auto &lowering : loweredOutputs) lowering->lowerPort(); - - // Set up vectors to erase _all_ the ports. It's easier to rebuild everything - // than reason about interleaving the newly lowered ports with the non lowered - // ports. Also, the 'modifyPorts' method ends up rebuilding the port lists - // anyway, so this isn't nearly as expensive as it may seem. - SmallVector inputsToErase(mod.getNumInputPorts()); - std::iota(inputsToErase.begin(), inputsToErase.end(), 0); - SmallVector outputsToErase(mod.getNumOutputPorts()); - std::iota(outputsToErase.begin(), outputsToErase.end(), 0); - - mod.modifyPorts(newInputs, newOutputs, inputsToErase, outputsToErase); - - if (body) { - // We should only erase the original arguments. New ones were appended - // with the `createInput` method call. - body->eraseArguments([&ports](BlockArgument arg) { - return arg.getArgNumber() < ports.sizeInputs(); - }); - - // And erase the first ports.sizeOutputs operands from the terminator. - terminator->eraseOperands(0, ports.sizeOutputs()); - } - - // Rewrite instances pointing to this module. - for (auto *instance : moduleNode->uses()) { - auto instanceLike = instance->getInstance(); - if (!instanceLike) continue; - hw::InstanceOp hwInstance = dyn_cast_or_null(*instanceLike); - if (!hwInstance) { - return instanceLike->emitOpError( - "This code only converts hw.instance instances - ask your friendly " - "neighborhood compiler engineers to implement support for something " - "like an hw::HWMutableInstanceLike interface"); - } - updateInstance(hwInstance); - } - - // Memory optimization -- we don't need these anymore. - newInputs.clear(); - newOutputs.clear(); - return success(); -} - -void PortConverterImpl::updateInstance(hw::InstanceOp inst) { - ImplicitLocOpBuilder b(inst.getLoc(), inst); - BackedgeBuilder beb(b, inst.getLoc()); - ModulePortInfo ports = mod.getPortList(); - - // Create backedges for the future instance results so the signal mappers can - // use the future results as values. - SmallVector newResults; - for (PortInfo outputPort : ports.getOutputs()) - newResults.push_back(beb.get(outputPort.type)); - - // Map the operands. - SmallVector newOperands(ports.sizeInputs(), {}); - for (size_t oldOpIdx = 0, e = inst.getNumOperands(); oldOpIdx < e; ++oldOpIdx) - loweredInputs[oldOpIdx]->mapInputSignals( - b, inst, inst->getOperand(oldOpIdx), newOperands, newResults); - - // Map the results. - for (size_t oldResIdx = 0, e = inst.getNumResults(); oldResIdx < e; - ++oldResIdx) - loweredOutputs[oldResIdx]->mapOutputSignals( - b, inst, inst->getResult(oldResIdx), newOperands, newResults); - - // Clone the instance. We cannot just modifiy the existing one since the - // result types might have changed types and number of them. - assert(llvm::none_of(newOperands, [](Value v) { return !v; })); - b.setInsertionPointAfter(inst); - auto newInst = - b.create(mod, inst.getInstanceNameAttr(), newOperands, - inst.getParameters(), inst.getInnerSymAttr()); - newInst->setDialectAttrs(inst->getDialectAttrs()); - - // Assign the backedges to the new results. - for (auto [idx, be] : llvm::enumerate(newResults)) - be.setValue(newInst.getResult(idx)); - - // Erase the old instance. - inst.erase(); -} diff --git a/lib/circt/Dialect/HW/Transforms/CMakeLists.txt b/lib/circt/Dialect/HW/Transforms/CMakeLists.txt deleted file mode 100644 index 86fbb1b64a..0000000000 --- a/lib/circt/Dialect/HW/Transforms/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -add_circt_dialect_library(CIRCTHWTransforms - HWPrintInstanceGraph.cpp - HWSpecialize.cpp - PrintHWModuleGraph.cpp - FlattenIO.cpp - VerifyInnerRefNamespace.cpp - - DEPENDS - CIRCTHWTransformsIncGen - - LINK_LIBS PUBLIC - CIRCTHW - CIRCTSV - CIRCTSeq - CIRCTComb - CIRCTSupport - MLIRIR - MLIRPass - MLIRTransformUtils -) diff --git a/lib/circt/Dialect/HW/Transforms/FlattenIO.cpp b/lib/circt/Dialect/HW/Transforms/FlattenIO.cpp deleted file mode 100644 index eb66cf4378..0000000000 --- a/lib/circt/Dialect/HW/Transforms/FlattenIO.cpp +++ /dev/null @@ -1,429 +0,0 @@ -//===- FlattenIO.cpp - HW I/O flattening pass -------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PassDetails.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/HW/HWPasses.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - -using namespace mlir; -using namespace circt; - -static bool isStructType(Type type) { - return hw::getCanonicalType(type).isa(); -} - -static hw::StructType getStructType(Type type) { - return hw::getCanonicalType(type).dyn_cast(); -} - -// Legal if no in- or output type is a struct. -static bool isLegalModLikeOp(hw::HWModuleLike moduleLikeOp) { - return llvm::none_of(moduleLikeOp.getHWModuleType().getPortTypes(), - isStructType); -} - -static llvm::SmallVector getInnerTypes(hw::StructType t) { - llvm::SmallVector inner; - t.getInnerTypes(inner); - for (auto [index, innerType] : llvm::enumerate(inner)) - inner[index] = hw::getCanonicalType(innerType); - return inner; -} - -namespace { - -// Replaces an output op with a new output with flattened (exploded) structs. -struct OutputOpConversion : public OpConversionPattern { - OutputOpConversion(TypeConverter &typeConverter, MLIRContext *context, - DenseSet *opVisited) - : OpConversionPattern(typeConverter, context), opVisited(opVisited) {} - - LogicalResult matchAndRewrite( - hw::OutputOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - llvm::SmallVector convOperands; - - // Flatten the operands. - for (auto operand : adaptor.getOperands()) { - if (auto structType = getStructType(operand.getType())) { - auto explodedStruct = rewriter.create( - op.getLoc(), getInnerTypes(structType), operand); - llvm::copy(explodedStruct.getResults(), - std::back_inserter(convOperands)); - } else { - convOperands.push_back(operand); - } - } - - // And replace. - rewriter.replaceOpWithNewOp(op, convOperands); - opVisited->insert(op->getParentOp()); - return success(); - } - DenseSet *opVisited; -}; - -struct InstanceOpConversion : public OpConversionPattern { - InstanceOpConversion(TypeConverter &typeConverter, MLIRContext *context, - DenseSet *convertedOps) - : OpConversionPattern(typeConverter, context), - convertedOps(convertedOps) {} - - LogicalResult matchAndRewrite( - hw::InstanceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - // Flatten the operands. - llvm::SmallVector convOperands; - for (auto operand : adaptor.getOperands()) { - if (auto structType = getStructType(operand.getType())) { - auto explodedStruct = rewriter.create( - loc, getInnerTypes(structType), operand); - llvm::copy(explodedStruct.getResults(), - std::back_inserter(convOperands)); - } else { - convOperands.push_back(operand); - } - } - - // Create the new instance... - auto newInstance = rewriter.create( - loc, op.getReferencedModuleSlow(), op.getInstanceName(), convOperands); - - // re-create any structs in the result. - llvm::SmallVector convResults; - size_t oldResultCntr = 0; - for (size_t resIndex = 0; resIndex < newInstance.getNumResults(); - ++resIndex) { - Type oldResultType = op.getResultTypes()[oldResultCntr]; - if (auto structType = getStructType(oldResultType)) { - size_t nElements = structType.getElements().size(); - auto implodedStruct = rewriter.create( - loc, structType, - newInstance.getResults().slice(resIndex, nElements)); - convResults.push_back(implodedStruct.getResult()); - resIndex += nElements - 1; - } else - convResults.push_back(newInstance.getResult(resIndex)); - - ++oldResultCntr; - } - rewriter.replaceOp(op, convResults); - convertedOps->insert(newInstance); - return success(); - } - - DenseSet *convertedOps; -}; - -using IOTypes = std::pair; - -struct IOInfo { - // A mapping between an arg/res index and the struct type of the given field. - DenseMap argStructs, resStructs; - - // Records of the original arg/res types. - SmallVector argTypes, resTypes; -}; - -class FlattenIOTypeConverter : public TypeConverter { - public: - FlattenIOTypeConverter() { - addConversion([](Type type, SmallVectorImpl &results) { - auto structType = getStructType(type); - if (!structType) - results.push_back(type); - else { - for (auto field : structType.getElements()) - results.push_back(field.type); - } - return success(); - }); - - addTargetMaterialization([](OpBuilder &builder, hw::StructType type, - ValueRange inputs, Location loc) { - auto result = builder.create(loc, type, inputs); - return result.getResult(); - }); - - addTargetMaterialization([](OpBuilder &builder, hw::TypeAliasType type, - ValueRange inputs, Location loc) { - auto structType = getStructType(type); - assert(structType && "expected struct type"); - auto result = builder.create(loc, structType, inputs); - return result.getResult(); - }); - } -}; - -} // namespace - -template -static void addSignatureConversion(DenseMap &ioMap, - ConversionTarget &target, - RewritePatternSet &patterns, - FlattenIOTypeConverter &typeConverter) { - (hw::populateHWModuleLikeTypeConversionPattern(TOp::getOperationName(), - patterns, typeConverter), - ...); - - // Legality is defined by a module having been processed once. This is due to - // that a pattern cannot be applied multiple times (a 'pattern was already - // applied' error - a case that would occur for nested structs). Additionally, - // if a pattern could be applied multiple times, this would complicate - // updating arg/res names. - - // Instead, we define legality as when a module has had a modification to its - // top-level i/o. This ensures that only a single level of structs are - // processed during signature conversion, which then allows us to use the - // signature conversion in a recursive manner. - target.addDynamicallyLegalOp([&](hw::HWModuleLike moduleLikeOp) { - if (isLegalModLikeOp(moduleLikeOp)) return true; - - // This op is involved in conversion. Check if the signature has changed. - auto ioInfoIt = ioMap.find(moduleLikeOp); - if (ioInfoIt == ioMap.end()) { - // Op wasn't primed in the map. Do the safe thing, assume - // that it's not considered in this pass, and mark it as legal - return true; - } - auto ioInfo = ioInfoIt->second; - - auto compareTypes = [&](TypeRange oldTypes, TypeRange newTypes) { - return llvm::any_of(llvm::zip(oldTypes, newTypes), [&](auto typePair) { - auto oldType = std::get<0>(typePair); - auto newType = std::get<1>(typePair); - return oldType != newType; - }); - }; - auto mtype = moduleLikeOp.getHWModuleType(); - if (compareTypes(mtype.getOutputTypes(), ioInfo.resTypes) || - compareTypes(mtype.getInputTypes(), ioInfo.argTypes)) - return true; - - // We're pre-conversion for an op that was primed in the map - it will - // always be illegal since it has to-be-converted struct types at its I/O. - return false; - }); -} - -template -static bool hasUnconvertedOps(mlir::ModuleOp module) { - return llvm::any_of(module.getBody()->getOps(), - [](T op) { return !isLegalModLikeOp(op); }); -} - -template -static DenseMap populateIOMap(mlir::ModuleOp module) { - DenseMap ioMap; - for (auto op : module.getOps()) - ioMap[op] = {op.getArgumentTypes(), op.getResultTypes()}; - return ioMap; -} - -template -static llvm::SmallVector updateNameAttribute( - ModTy op, StringRef attrName, DenseMap &structMap, - T oldNames) { - llvm::SmallVector newNames; - for (auto [i, oldName] : llvm::enumerate(oldNames)) { - // Was this arg/res index a struct? - auto it = structMap.find(i); - if (it == structMap.end()) { - // No, keep old name. - newNames.push_back(StringAttr::get(op->getContext(), oldName)); - continue; - } - - // Yes - create new names from the struct fields and the old name at the - // index. - auto structType = it->second; - for (auto field : structType.getElements()) - newNames.push_back( - StringAttr::get(op->getContext(), oldName + "." + field.name.str())); - } - return newNames; -} - -static llvm::SmallVector updateLocAttribute( - DenseMap &structMap, ArrayAttr oldLocs) { - llvm::SmallVector newLocs; - if (!oldLocs) return newLocs; - for (auto [i, oldLoc] : llvm::enumerate(oldLocs.getAsRange())) { - // Was this arg/res index a struct? - auto it = structMap.find(i); - if (it == structMap.end()) { - // No, keep old name. - newLocs.push_back(oldLoc); - continue; - } - - auto structType = it->second; - for (size_t i = 0, e = structType.getElements().size(); i < e; ++i) - newLocs.push_back(oldLoc); - } - return newLocs; -} - -/// The conversion framework seems to throw away block argument locations. We -/// use this function to copy the location from the original argument to the -/// set of flattened arguments. -static void updateBlockLocations( - hw::HWModuleLike op, StringRef attrName, - DenseMap &structMap) { - auto locs = op.getOperation()->getAttrOfType(attrName); - if (!locs || op.getModuleBody().empty()) return; - for (auto [arg, loc] : llvm::zip(op.getBodyBlock()->getArguments(), - locs.getAsRange())) - arg.setLoc(loc); -} - -template -static DenseMap populateIOInfoMap(mlir::ModuleOp module) { - DenseMap ioInfoMap; - for (auto op : module.getOps()) { - IOInfo ioInfo; - ioInfo.argTypes = op.getInputTypes(); - ioInfo.resTypes = op.getOutputTypes(); - for (auto [i, arg] : llvm::enumerate(ioInfo.argTypes)) { - if (auto structType = getStructType(arg)) - ioInfo.argStructs[i] = structType; - } - for (auto [i, res] : llvm::enumerate(ioInfo.resTypes)) { - if (auto structType = getStructType(res)) - ioInfo.resStructs[i] = structType; - } - ioInfoMap[op] = ioInfo; - } - return ioInfoMap; -} - -template -static LogicalResult flattenOpsOfType(ModuleOp module, bool recursive) { - auto *ctx = module.getContext(); - FlattenIOTypeConverter typeConverter; - - // Recursively (in case of nested structs) lower the module. We do this one - // conversion at a time to allow for updating the arg/res names of the - // module in between flattening each level of structs. - while (hasUnconvertedOps(module)) { - ConversionTarget target(*ctx); - RewritePatternSet patterns(ctx); - target.addLegalDialect(); - - // Record any struct types at the module signature. This will be used - // post-conversion to update the argument and result names. - auto ioInfoMap = populateIOInfoMap(module); - - // Record the instances that were converted. We keep these around since we - // need to update their arg/res attribute names after the modules themselves - // have been updated. - llvm::DenseSet convertedInstances; - - // Argument conversion for output ops. Similarly to the signature - // conversion, legality is based on the op having been visited once, due to - // the possibility of nested structs. - DenseSet opVisited; - patterns.add(typeConverter, ctx, &opVisited); - - patterns.add(typeConverter, ctx, &convertedInstances); - target.addDynamicallyLegalOp( - [&](auto op) { return opVisited.contains(op->getParentOp()); }); - target.addDynamicallyLegalOp([&](auto op) { - return llvm::none_of(op->getOperands(), [](auto operand) { - return isStructType(operand.getType()); - }); - }); - - DenseMap oldArgNames, oldResNames, oldArgLocs, - oldResLocs; - for (auto op : module.getOps()) { - oldArgNames[op] = ArrayAttr::get(module.getContext(), op.getInputNames()); - oldResNames[op] = - ArrayAttr::get(module.getContext(), op.getOutputNames()); - oldArgLocs[op] = op.getInputLocsAttr(); - oldResLocs[op] = op.getOutputLocsAttr(); - } - - // Signature conversion and legalization patterns. - addSignatureConversion(ioInfoMap, target, patterns, typeConverter); - - if (failed(applyPartialConversion(module, target, std::move(patterns)))) - return failure(); - - // Update the arg/res names of the module. - for (auto op : module.getOps()) { - auto ioInfo = ioInfoMap[op]; - auto newArgNames = updateNameAttribute( - op, "argNames", ioInfo.argStructs, - oldArgNames[op].template getAsValueRange()); - auto newResNames = updateNameAttribute( - op, "resultNames", ioInfo.resStructs, - oldResNames[op].template getAsValueRange()); - newArgNames.append(newResNames.begin(), newResNames.end()); - op.setAllPortNames(newArgNames); - op.setInputLocs(updateLocAttribute(ioInfo.argStructs, oldArgLocs[op])); - op.setOutputLocs(updateLocAttribute(ioInfo.resStructs, oldResLocs[op])); - updateBlockLocations(op, "argLocs", ioInfo.argStructs); - } - - // And likewise with the converted instance ops. - for (auto instanceOp : convertedInstances) { - Operation *targetModule = instanceOp.getReferencedModuleSlow(); - auto ioInfo = ioInfoMap[targetModule]; - instanceOp.setInputNames(ArrayAttr::get( - instanceOp.getContext(), - updateNameAttribute(instanceOp, "argNames", ioInfo.argStructs, - oldArgNames[targetModule] - .template getAsValueRange()))); - instanceOp.setOutputNames(ArrayAttr::get( - instanceOp.getContext(), - updateNameAttribute(instanceOp, "resultNames", ioInfo.resStructs, - oldResNames[targetModule] - .template getAsValueRange()))); - instanceOp.dump(); - } - - // Break if we've only lowering a single level of structs. - if (!recursive) break; - } - return success(); -} - -//===----------------------------------------------------------------------===// -// Pass driver -//===----------------------------------------------------------------------===// - -template -static bool flattenIO(ModuleOp module, bool recursive) { - return (failed(flattenOpsOfType(module, recursive)) || ...); -} - -namespace { - -class FlattenIOPass : public circt::hw::FlattenIOBase { - public: - void runOnOperation() override { - ModuleOp module = getOperation(); - if (flattenIO(module, recursive)) - signalPassFailure(); - }; -}; - -} // namespace - -//===----------------------------------------------------------------------===// -// Pass initialization -//===----------------------------------------------------------------------===// - -std::unique_ptr circt::hw::createFlattenIOPass() { - return std::make_unique(); -} diff --git a/lib/circt/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp b/lib/circt/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp deleted file mode 100644 index 455c46ff0c..0000000000 --- a/lib/circt/Dialect/HW/Transforms/HWPrintInstanceGraph.cpp +++ /dev/null @@ -1,36 +0,0 @@ -//===- HWPrintInstanceGraph.cpp - Print the instance graph ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===----------------------------------------------------------------------===// -// -// Print the module hierarchy. -// -//===----------------------------------------------------------------------===// - -#include "PassDetails.h" -#include "include/circt/Dialect/HW/HWInstanceGraph.h" -#include "include/circt/Dialect/HW/HWPasses.h" -#include "llvm/include/llvm/Support/GraphWriter.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project - -using namespace circt; -using namespace hw; - -namespace { -struct PrintInstanceGraphPass - : public PrintInstanceGraphBase { - PrintInstanceGraphPass(raw_ostream &os) : os(os) {} - void runOnOperation() override { - InstanceGraph &instanceGraph = getAnalysis(); - llvm::WriteGraph(os, &instanceGraph, /*ShortNames=*/false); - markAllAnalysesPreserved(); - } - raw_ostream &os; -}; -} // end anonymous namespace - -std::unique_ptr circt::hw::createPrintInstanceGraphPass() { - return std::make_unique(llvm::errs()); -} diff --git a/lib/circt/Dialect/HW/Transforms/HWSpecialize.cpp b/lib/circt/Dialect/HW/Transforms/HWSpecialize.cpp deleted file mode 100644 index a99a9c81fa..0000000000 --- a/lib/circt/Dialect/HW/Transforms/HWSpecialize.cpp +++ /dev/null @@ -1,422 +0,0 @@ -//===- HWSpecialize.cpp - hw module specialization ------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This transform performs specialization of parametric hw.module's. -// -//===----------------------------------------------------------------------===// - -#include "PassDetails.h" -#include "include/circt/Dialect/Comb/CombOps.h" -#include "include/circt/Dialect/HW/HWAttributes.h" -#include "include/circt/Dialect/HW/HWOps.h" -#include "include/circt/Dialect/HW/HWPasses.h" -#include "include/circt/Dialect/HW/HWSymCache.h" -#include "include/circt/Support/Namespace.h" -#include "include/circt/Support/ValueMapper.h" -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project -#include "mlir/include/mlir/IR/IRMapping.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project - -using namespace llvm; -using namespace mlir; -using namespace circt; -using namespace hw; - -namespace { - -// Generates a module name by composing the name of 'moduleOp' and the set of -// provided 'parameters'. -static std::string generateModuleName(Namespace &ns, hw::HWModuleOp moduleOp, - ArrayAttr parameters) { - assert(parameters.size() != 0); - std::string name = moduleOp.getName().str(); - for (auto param : parameters) { - auto paramAttr = param.cast(); - int64_t paramValue = paramAttr.getValue().cast().getInt(); - name += "_" + paramAttr.getName().str() + "_" + std::to_string(paramValue); - } - - // Query the namespace to generate a unique name. - return ns.newName(name).str(); -} - -// Returns true if any operand or result of 'op' is parametric. -static bool isParametricOp(Operation *op) { - return llvm::any_of(op->getOperandTypes(), isParametricType) || - llvm::any_of(op->getResultTypes(), isParametricType); -} - -// Narrows 'value' using a comb.extract operation to the width of the -// hw.array-typed 'array'. -static FailureOr narrowValueToArrayWidth(OpBuilder &builder, Value array, - Value value) { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPointAfterValue(value); - auto arrayType = array.getType().cast(); - unsigned hiBit = llvm::Log2_64_Ceil(arrayType.getNumElements()); - - return hiBit == 0 - ? builder - .create(value.getLoc(), - APInt(arrayType.getNumElements(), 0)) - .getResult() - : builder - .create(value.getLoc(), value, - /*lowBit=*/0, hiBit) - .getResult(); -} - -static hw::HWModuleOp targetModuleOp(hw::InstanceOp instanceOp, - const SymbolCache &sc) { - auto *targetOp = sc.getDefinition(instanceOp.getModuleNameAttr()); - auto targetHWModule = dyn_cast(targetOp); - if (!targetHWModule) return {}; // Won't specialize external modules. - - if (targetHWModule.getParameters().size() == 0) - return {}; // nothing to record or specialize - - return targetHWModule; -} - -// Stores unique module parameters and references to them -struct ParameterSpecializationRegistry { - llvm::MapVector> - uniqueModuleParameters; - - bool isRegistered(hw::HWModuleOp moduleOp, ArrayAttr parameters) const { - auto it = uniqueModuleParameters.find(moduleOp); - return it != uniqueModuleParameters.end() && - it->second.contains(parameters); - } - - void registerModuleOp(hw::HWModuleOp moduleOp, ArrayAttr parameters) { - uniqueModuleParameters[moduleOp].insert(parameters); - } -}; - -struct EliminateParamValueOpPattern : public OpRewritePattern { - EliminateParamValueOpPattern(MLIRContext *context, ArrayAttr parameters) - : OpRewritePattern(context), parameters(parameters) {} - - LogicalResult matchAndRewrite(ParamValueOp op, - PatternRewriter &rewriter) const override { - // Substitute the param value op with an evaluated constant operation. - FailureOr evaluated = - evaluateParametricAttr(op.getLoc(), parameters, op.getValue()); - if (failed(evaluated)) return failure(); - rewriter.replaceOpWithNewOp( - op, op.getType(), - evaluated->cast().getValue().getSExtValue()); - return success(); - } - - ArrayAttr parameters; -}; - -// hw.array_get operations require indexes to be of equal width of the -// array itself. Since indexes may originate from constants or parameters, -// emit comb.extract operations to fulfill this invariant. -struct NarrowArrayGetIndexPattern : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult matchAndRewrite( - ArrayGetOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto inputType = type_cast(op.getInput().getType()); - Type targetIndexType = IntegerType::get( - getContext(), inputType.getNumElements() == 1 - ? 1 - : llvm::Log2_64_Ceil(inputType.getNumElements())); - - if (op.getIndex().getType().getIntOrFloatBitWidth() == - targetIndexType.getIntOrFloatBitWidth()) - return failure(); // nothing to do - - // Narrow the index value. - FailureOr narrowedIndex = - narrowValueToArrayWidth(rewriter, op.getInput(), op.getIndex()); - if (failed(narrowedIndex)) return failure(); - rewriter.replaceOpWithNewOp(op, op.getInput(), *narrowedIndex); - return success(); - } -}; - -// Generic pattern to convert parametric result types. -struct ParametricTypeConversionPattern : public ConversionPattern { - ParametricTypeConversionPattern(MLIRContext *ctx, - TypeConverter &typeConverter, - ArrayAttr parameters) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, - ctx), - parameters(parameters) {} - - LogicalResult matchAndRewrite( - Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - llvm::SmallVector convertedOperands; - // Update the result types of the operation - bool ok = true; - rewriter.updateRootInPlace(op, [&]() { - // Mutate result types - for (auto it : llvm::enumerate(op->getResultTypes())) { - FailureOr res = - evaluateParametricType(op->getLoc(), parameters, it.value()); - ok &= succeeded(res); - if (!ok) return; - op->getResult(it.index()).setType(*res); - } - - // Note: 'operands' have already been converted with the supplied type - // converter to this pattern. Make sure that we materialize this - // conversion by updating the operands to op. - op->setOperands(operands); - }); - - return success(ok); - }; - ArrayAttr parameters; -}; - -struct HWSpecializePass : public hw::HWSpecializeBase { - void runOnOperation() override; -}; - -static void populateTypeConversion(Location loc, TypeConverter &typeConverter, - ArrayAttr parameters) { - // Possibly parametric types - typeConverter.addConversion([=](hw::IntType type) { - return evaluateParametricType(loc, parameters, type); - }); - typeConverter.addConversion([=](hw::ArrayType type) { - return evaluateParametricType(loc, parameters, type); - }); - - // Valid target types. - typeConverter.addConversion([](mlir::IntegerType type) { return type; }); -} - -// Registers any nested parametric instance ops of `target` for the next -// specialization loop -static LogicalResult registerNestedParametricInstanceOps( - HWModuleOp target, ArrayAttr parameters, SymbolCache &sc, - const ParameterSpecializationRegistry ¤tRegistry, - ParameterSpecializationRegistry &nextRegistry, - llvm::DenseMap>> - ¶metersUsers) { - // Register any nested parametric instance ops for the next loop - auto walkResult = target->walk([&](InstanceOp instanceOp) -> WalkResult { - auto instanceParameters = instanceOp.getParameters(); - // We can ignore non-parametric instances - if (instanceParameters.empty()) return WalkResult::advance(); - - // Replace instance parameters with evaluated versions - llvm::SmallVector evaluatedInstanceParameters; - evaluatedInstanceParameters.reserve(instanceParameters.size()); - for (auto instanceParameter : instanceParameters) { - auto instanceParameterDecl = instanceParameter.cast(); - auto instanceParameterValue = instanceParameterDecl.getValue(); - auto evaluated = evaluateParametricAttr(target.getLoc(), parameters, - instanceParameterValue); - if (failed(evaluated)) return WalkResult::interrupt(); - evaluatedInstanceParameters.push_back( - hw::ParamDeclAttr::get(instanceParameterDecl.getName(), *evaluated)); - } - - auto evaluatedInstanceParametersAttr = - ArrayAttr::get(target.getContext(), evaluatedInstanceParameters); - - if (auto targetHWModule = targetModuleOp(instanceOp, sc)) { - if (!currentRegistry.isRegistered(targetHWModule, - evaluatedInstanceParametersAttr)) - nextRegistry.registerModuleOp(targetHWModule, - evaluatedInstanceParametersAttr); - parametersUsers[targetHWModule][evaluatedInstanceParametersAttr] - .push_back(instanceOp); - } - - return WalkResult::advance(); - }); - - return failure(walkResult.wasInterrupted()); -} - -// Specializes the provided 'base' module into the 'target' module. By doing -// so, we create a new module which -// 1. has no parameters -// 2. has a name composing the name of 'base' as well as the 'parameters' -// parameters. -// 3. Has a top-level interface with any parametric types resolved. -// 4. Any references to module parameters have been replaced with the -// parameter value. -static LogicalResult specializeModule( - OpBuilder builder, ArrayAttr parameters, SymbolCache &sc, Namespace &ns, - HWModuleOp source, HWModuleOp &target, - const ParameterSpecializationRegistry ¤tRegistry, - ParameterSpecializationRegistry &nextRegistry, - llvm::DenseMap>> - ¶metersUsers) { - auto *ctx = builder.getContext(); - // Update the types of the source module ports based on evaluating any - // parametric in/output ports. - auto ports = source.getPortList(); - for (auto in : llvm::enumerate(source.getInputTypes())) { - FailureOr resType = - evaluateParametricType(source.getLoc(), parameters, in.value()); - if (failed(resType)) return failure(); - ports.atInput(in.index()).type = *resType; - } - for (auto out : llvm::enumerate(source.getOutputTypes())) { - FailureOr resolvedType = - evaluateParametricType(source.getLoc(), parameters, out.value()); - if (failed(resolvedType)) return failure(); - ports.atOutput(out.index()).type = *resolvedType; - } - - // Create the specialized module using the evaluated port info. - target = builder.create( - source.getLoc(), - StringAttr::get(ctx, generateModuleName(ns, source, parameters)), ports); - - // Erase the default created hw.output op - we'll copy the correct operation - // during body elaboration. - (*target.getOps().begin()).erase(); - - // Clone body of the source into the target. Use ValueMapper to ensure safe - // cloning in the presence of backedges. - BackedgeBuilder bb(builder, source.getLoc()); - ValueMapper mapper(&bb); - for (auto &&[src, dst] : llvm::zip(source.getBodyBlock()->getArguments(), - target.getBodyBlock()->getArguments())) - mapper.set(src, dst); - builder.setInsertionPointToStart(target.getBodyBlock()); - - for (auto &op : source.getOps()) { - IRMapping bvMapper; - for (auto operand : op.getOperands()) - bvMapper.map(operand, mapper.get(operand)); - auto *newOp = builder.clone(op, bvMapper); - for (auto &&[oldRes, newRes] : - llvm::zip(op.getResults(), newOp->getResults())) - mapper.set(oldRes, newRes); - } - - // Register any nested parametric instance ops for the next loop - auto nestedRegistrationResult = registerNestedParametricInstanceOps( - target, parameters, sc, currentRegistry, nextRegistry, parametersUsers); - if (failed(nestedRegistrationResult)) return failure(); - - // We've now created a separate copy of the source module with a rewritten - // top-level interface. Next, we enter the module to convert parametric - // types within operations. - RewritePatternSet patterns(ctx); - TypeConverter t; - populateTypeConversion(target.getLoc(), t, parameters); - patterns.add(ctx, parameters); - patterns.add(ctx); - patterns.add(ctx, t, parameters); - ConversionTarget convTarget(*ctx); - convTarget.addLegalOp(); - convTarget.addIllegalOp(); - - // Generic legalization of converted operations. - convTarget.markUnknownOpDynamicallyLegal( - [](Operation *op) { return !isParametricOp(op); }); - - return applyPartialConversion(target, convTarget, std::move(patterns)); -} - -void HWSpecializePass::runOnOperation() { - ModuleOp module = getOperation(); - - // Record unique module parameters and references to these. - llvm::DenseMap>> - parametersUsers; - ParameterSpecializationRegistry registry; - - // Maintain a symbol cache for fast lookup during module specialization. - SymbolCache sc; - sc.addDefinitions(module); - Namespace ns; - ns.add(sc); - - for (auto hwModule : module.getOps()) { - // If this module is parametric, defer registering its parametric - // instantiations until this module is specialized - if (!hwModule.getParameters().empty()) continue; - for (auto instanceOp : hwModule.getOps()) { - if (auto targetHWModule = targetModuleOp(instanceOp, sc)) { - auto parameters = instanceOp.getParameters(); - registry.registerModuleOp(targetHWModule, parameters); - - parametersUsers[targetHWModule][parameters].push_back(instanceOp); - } - } - } - - // Create specialized modules. - OpBuilder builder = OpBuilder(&getContext()); - builder.setInsertionPointToStart(module.getBody()); - llvm::DenseMap> - specializations; - - // For every module specialization, any nested parametric modules will be - // registered for the next loop. We loop until no new nested modules have been - // registered. - while (!registry.uniqueModuleParameters.empty()) { - // The registry for the next specialization loop - ParameterSpecializationRegistry nextRegistry; - for (auto it : registry.uniqueModuleParameters) { - for (auto parameters : it.second) { - HWModuleOp specializedModule; - if (failed(specializeModule(builder, parameters, sc, ns, it.first, - specializedModule, registry, nextRegistry, - parametersUsers))) { - signalPassFailure(); - return; - } - - // Extend the symbol cache with the newly created module. - sc.addDefinition(specializedModule.getNameAttr(), specializedModule); - - // Add the specialization - specializations[it.first][parameters] = specializedModule; - } - } - - // Transfer newly registered specializations to iterate over - registry.uniqueModuleParameters = - std::move(nextRegistry.uniqueModuleParameters); - } - - // Rewrite instances of specialized modules to the specialized module. - for (auto it : specializations) { - auto unspecialized = it.getFirst(); - auto &users = parametersUsers[unspecialized]; - for (auto specialization : it.getSecond()) { - auto parameters = specialization.getFirst(); - auto specializedModule = specialization.getSecond(); - for (auto instanceOp : users[parameters]) { - instanceOp->setAttr("moduleName", - FlatSymbolRefAttr::get(specializedModule)); - instanceOp->setAttr("parameters", ArrayAttr::get(&getContext(), {})); - } - } - } -} - -} // namespace - -std::unique_ptr circt::hw::createHWSpecializePass() { - return std::make_unique(); -} diff --git a/lib/circt/Dialect/HW/Transforms/PassDetails.h b/lib/circt/Dialect/HW/Transforms/PassDetails.h deleted file mode 100644 index ca998bd1f4..0000000000 --- a/lib/circt/Dialect/HW/Transforms/PassDetails.h +++ /dev/null @@ -1,31 +0,0 @@ -//===- PassDetails.h - HW pass class details ----------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Stuff shared between the different HW passes. -// -//===----------------------------------------------------------------------===// - -// clang-tidy seems to expect the absolute path in the header guard on some -// systems, so just disable it. -// NOLINTNEXTLINE(llvm-header-guard) -#ifndef DIALECT_HW_TRANSFORMS_PASSDETAILS_H -#define DIALECT_HW_TRANSFORMS_PASSDETAILS_H - -#include "include/circt/Dialect/HW/HWOps.h" -#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project - -namespace circt { -namespace hw { - -#define GEN_PASS_CLASSES -#include "include/circt/Dialect/HW/Passes.h.inc" - -} // namespace hw -} // namespace circt - -#endif // DIALECT_HW_TRANSFORMS_PASSDETAILS_H diff --git a/lib/circt/Dialect/HW/Transforms/PrintHWModuleGraph.cpp b/lib/circt/Dialect/HW/Transforms/PrintHWModuleGraph.cpp deleted file mode 100644 index d10ef855cd..0000000000 --- a/lib/circt/Dialect/HW/Transforms/PrintHWModuleGraph.cpp +++ /dev/null @@ -1,42 +0,0 @@ -//===- PrintHWModuleGraph.cpp - Print the instance graph --------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===----------------------------------------------------------------------===// -// -// Prints an HW module as a .dot graph. -// -//===----------------------------------------------------------------------===// - -#include "PassDetails.h" -#include "include/circt/Dialect/HW/HWModuleGraph.h" -#include "include/circt/Dialect/HW/HWPasses.h" -#include "llvm/include/llvm/Support/GraphWriter.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project - -using namespace circt; -using namespace hw; - -namespace { -struct PrintHWModuleGraphPass - : public PrintHWModuleGraphBase { - PrintHWModuleGraphPass(raw_ostream &os) : os(os) {} - void runOnOperation() override { - getOperation().walk([&](hw::HWModuleOp module) { - // We don't really have any other way of forwarding draw arguments to the - // DOTGraphTraits for HWModule except through the module itself - as an - // attribute. - module->setAttr("dot_verboseEdges", - BoolAttr::get(module.getContext(), verboseEdges)); - - llvm::WriteGraph(os, module, /*ShortNames=*/false); - }); - } - raw_ostream &os; -}; -} // end anonymous namespace - -std::unique_ptr circt::hw::createPrintHWModuleGraphPass() { - return std::make_unique(llvm::errs()); -} diff --git a/lib/circt/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp b/lib/circt/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp deleted file mode 100644 index 0a0d56d47f..0000000000 --- a/lib/circt/Dialect/HW/Transforms/VerifyInnerRefNamespace.cpp +++ /dev/null @@ -1,44 +0,0 @@ -//===- VerifyInnerRefNamespace.cpp - InnerRefNamespace verification Pass --===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a simple pass to drive verification of operations -// that want to be InnerRefNamespace's but don't have the trait to verify -// themselves. -// -//===----------------------------------------------------------------------===// - -#include "PassDetails.h" -#include "include/circt/Dialect/HW/HWOpInterfaces.h" -#include "include/circt/Dialect/HW/HWPasses.h" - -/// VerifyInnerRefNamespace pass until have container operation. - -namespace { - -class VerifyInnerRefNamespacePass - : public circt::hw::VerifyInnerRefNamespaceBase< - VerifyInnerRefNamespacePass> { - public: - void runOnOperation() override { - auto *irnLike = getOperation(); - if (!irnLike->hasTrait()) - if (failed(circt::hw::detail::verifyInnerRefNamespace(irnLike))) - return signalPassFailure(); - - return markAllAnalysesPreserved(); - }; - bool canScheduleOn(mlir::RegisteredOperationName opInfo) const override { - return llvm::isa(opInfo); - } -}; - -} // namespace - -std::unique_ptr circt::hw::createVerifyInnerRefNamespacePass() { - return std::make_unique(); -} diff --git a/lib/circt/Support/APInt.cpp b/lib/circt/Support/APInt.cpp deleted file mode 100644 index 67d6343756..0000000000 --- a/lib/circt/Support/APInt.cpp +++ /dev/null @@ -1,27 +0,0 @@ -//===- APInt.h - CIRCT Lowering Options -------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for working around limitations of upstream LLVM APInts. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/APInt.h" - -#include "llvm/include/llvm/ADT/APSInt.h" // from @llvm-project - -using namespace circt; - -APInt circt::sextZeroWidth(APInt value, unsigned width) { - return value.getBitWidth() ? value.sext(width) : value.zext(width); -} - -APSInt circt::extOrTruncZeroWidth(APSInt value, unsigned width) { - return value.getBitWidth() - ? value.extOrTrunc(width) - : APSInt(value.zextOrTrunc(width), value.isUnsigned()); -} diff --git a/lib/circt/Support/BUILD b/lib/circt/Support/BUILD index f9dd1ac2d9..1f178fb3ef 100644 --- a/lib/circt/Support/BUILD +++ b/lib/circt/Support/BUILD @@ -5,30 +5,11 @@ package( cc_library( name = "Support", - srcs = [ - "CustomDirectiveImpl.cpp", - "InstanceGraph.cpp", - "ValueMapper.cpp", - ], hdrs = [ - "@heir//include/circt/Support:BackedgeBuilder.h", - "@heir//include/circt/Support:BuilderUtils.h", - "@heir//include/circt/Support:CustomDirectiveImpl.h", - "@heir//include/circt/Support:InstanceGraph.h", - "@heir//include/circt/Support:InstanceGraphInterface.h", "@heir//include/circt/Support:LLVM.h", - "@heir//include/circt/Support:Namespace.h", - "@heir//include/circt/Support:ParsingUtils.h", - "@heir//include/circt/Support:SymCache.h", - "@heir//include/circt/Support:ValueMapper.h", ], deps = [ - "@heir//include/circt/Support:interfaces_inc_gen", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", ], ) diff --git a/lib/circt/Support/BackedgeBuilder.cpp b/lib/circt/Support/BackedgeBuilder.cpp deleted file mode 100644 index 6d41d5b1d4..0000000000 --- a/lib/circt/Support/BackedgeBuilder.cpp +++ /dev/null @@ -1,71 +0,0 @@ -//===- BackedgeBuilder.cpp - Support for building backedges ---------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provide support for building backedges. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/BackedgeBuilder.h" - -#include "include/circt/Support/LLVM.h" -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project - -using namespace circt; - -Backedge::Backedge(mlir::Operation *op) : value(op->getResult(0)) {} - -void Backedge::setValue(mlir::Value newValue) { - assert(value.getType() == newValue.getType()); - assert(!set && "backedge already set to a value!"); - value.replaceAllUsesWith(newValue); - value = newValue; // In case the backedge is still referred to after setting. - set = true; - - // If the backedge is referenced again, it should now point to the updated - // value. - value = newValue; -} - -BackedgeBuilder::~BackedgeBuilder() { (void)clearOrEmitError(); } - -LogicalResult BackedgeBuilder::clearOrEmitError() { - unsigned numInUse = 0; - for (Operation *op : edges) { - if (!op->use_empty()) { - auto diag = op->emitError("backedge of type `") - << op->getResult(0).getType() << "`still in use"; - for (auto user : op->getUsers()) - diag.attachNote(user->getLoc()) << "used by " << *user; - ++numInUse; - continue; - } - if (rewriter) - rewriter->eraseOp(op); - else - op->erase(); - } - edges.clear(); - if (numInUse > 0) - mlir::emitRemark(loc, "abandoned ") << numInUse << " backedges"; - return success(numInUse == 0); -} - -void BackedgeBuilder::abandon() { edges.clear(); } - -BackedgeBuilder::BackedgeBuilder(OpBuilder &builder, Location loc) - : builder(builder), rewriter(nullptr), loc(loc) {} -BackedgeBuilder::BackedgeBuilder(PatternRewriter &rewriter, Location loc) - : builder(rewriter), rewriter(&rewriter), loc(loc) {} -Backedge BackedgeBuilder::get(Type t, mlir::LocationAttr optionalLoc) { - if (!optionalLoc) optionalLoc = loc; - Operation *op = builder.create( - optionalLoc, t, ValueRange{}); - edges.push_back(op); - return Backedge(op); -} diff --git a/lib/circt/Support/CMakeLists.txt b/lib/circt/Support/CMakeLists.txt deleted file mode 100644 index 27d6fd639a..0000000000 --- a/lib/circt/Support/CMakeLists.txt +++ /dev/null @@ -1,52 +0,0 @@ -##===- CMakeLists.txt - Define a support library --------------*- cmake -*-===// -## -##===----------------------------------------------------------------------===// - -set(VERSION_CPP "${CMAKE_CURRENT_BINARY_DIR}/Version.cpp") -set_source_files_properties("${VERSION_CPP}" PROPERTIES GENERATED TRUE) - -add_circt_library(CIRCTSupport - APInt.cpp - BackedgeBuilder.cpp - CustomDirectiveImpl.cpp - FieldRef.cpp - JSON.cpp - LoweringOptions.cpp - Passes.cpp - Path.cpp - PrettyPrinter.cpp - PrettyPrinterHelpers.cpp - ParsingUtils.cpp - SymCache.cpp - ValueMapper.cpp - InstanceGraph.cpp - "${VERSION_CPP}" - - LINK_LIBS PUBLIC - MLIRIR - MLIRTransforms - MLIRTransformUtils - ) - -#------------------------------------------------------------------------------- -# Generate Version.cpp -#------------------------------------------------------------------------------- -find_first_existing_vc_file("${CIRCT_SOURCE_DIR}" CIRCT_GIT_LOGS_HEAD) -set(GEN_VERSION_SCRIPT "${CIRCT_SOURCE_DIR}/cmake/modules/GenVersionFile.cmake") - -if (CIRCT_RELEASE_TAG_ENABLED) - add_custom_command(OUTPUT "${VERSION_CPP}" - DEPENDS "${CIRCT_GIT_LOGS_HEAD}" "${GEN_VERSION_SCRIPT}" - COMMAND ${CMAKE_COMMAND} -DIN_FILE="${CMAKE_CURRENT_SOURCE_DIR}/Version.cpp.in" - -DOUT_FILE="${VERSION_CPP}" -DRELEASE_PATTERN=${CIRCT_RELEASE_TAG}* - -DDRY_RUN=OFF -DSOURCE_ROOT="${CIRCT_SOURCE_DIR}" - -P "${GEN_VERSION_SCRIPT}") -else () - # If the release tag generation is disabled, run the script only at the first - # cmake configuration. - add_custom_command(OUTPUT "${VERSION_CPP}" - DEPENDS "${GEN_VERSION_SCRIPT}" - COMMAND ${CMAKE_COMMAND} -DIN_FILE="${CMAKE_CURRENT_SOURCE_DIR}/Version.cpp.in" - -DOUT_FILE="${VERSION_CPP}" -DDRY_RUN=ON -DSOURCE_ROOT="${CIRCT_SOURCE_DIR}" - -P "${GEN_VERSION_SCRIPT}") -endif() diff --git a/lib/circt/Support/CustomDirectiveImpl.cpp b/lib/circt/Support/CustomDirectiveImpl.cpp deleted file mode 100644 index a3b1a35043..0000000000 --- a/lib/circt/Support/CustomDirectiveImpl.cpp +++ /dev/null @@ -1,130 +0,0 @@ -//===- CustomDirectiveImpl.cpp - Custom TableGen directives ---------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/CustomDirectiveImpl.h" - -#include "llvm/include/llvm/ADT/SmallString.h" // from @llvm-project - -using namespace circt; - -ParseResult circt::parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr) { - // Use the explicit name if one is provided as `name "xyz"`. - if (!parser.parseOptionalKeyword("name")) { - std::string str; - if (parser.parseString(&str)) return failure(); - attr = parser.getBuilder().getStringAttr(str); - return success(); - } - - // Infer the name from the SSA name of the operation's first result. - auto resultName = parser.getResultName(0).first; - if (!resultName.empty() && isdigit(resultName[0])) resultName = ""; - attr = parser.getBuilder().getStringAttr(resultName); - return success(); -} - -ParseResult circt::parseImplicitSSAName(OpAsmParser &parser, - NamedAttrList &attrs) { - if (parser.parseOptionalAttrDict(attrs)) return failure(); - inferImplicitSSAName(parser, attrs); - return success(); -} - -bool circt::inferImplicitSSAName(OpAsmParser &parser, NamedAttrList &attrs) { - // Don't do anything if a `name` attribute is explicitly provided. - if (attrs.get("name")) return false; - - // Infer the name from the SSA name of the operation's first result. - auto resultName = parser.getResultName(0).first; - if (!resultName.empty() && isdigit(resultName[0])) resultName = ""; - auto nameAttr = parser.getBuilder().getStringAttr(resultName); - auto *context = parser.getBuilder().getContext(); - attrs.push_back({StringAttr::get(context, "name"), nameAttr}); - return true; -} - -void circt::printImplicitSSAName(OpAsmPrinter &printer, Operation *op, - StringAttr attr) { - SmallString<32> resultNameStr; - llvm::raw_svector_ostream tmpStream(resultNameStr); - printer.printOperand(op->getResult(0), tmpStream); - auto actualName = tmpStream.str().drop_front(); - auto expectedName = attr.getValue(); - // Anonymous names are printed as digits, which is fine. - if (actualName == expectedName || - (expectedName.empty() && isdigit(actualName[0]))) - return; - - printer << " name " << attr; -} - -void circt::printImplicitSSAName(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs, - ArrayRef extraElides) { - SmallVector elides(extraElides.begin(), extraElides.end()); - elideImplicitSSAName(printer, op, attrs, elides); - printer.printOptionalAttrDict(attrs.getValue(), elides); -} - -void circt::elideImplicitSSAName(OpAsmPrinter &printer, Operation *op, - DictionaryAttr attrs, - SmallVectorImpl &elides) { - SmallString<32> resultNameStr; - llvm::raw_svector_ostream tmpStream(resultNameStr); - printer.printOperand(op->getResult(0), tmpStream); - auto actualName = tmpStream.str().drop_front(); - auto expectedName = attrs.getAs("name").getValue(); - // Anonymous names are printed as digits, which is fine. - if (actualName == expectedName || - (expectedName.empty() && isdigit(actualName[0]))) - elides.push_back("name"); -} - -ParseResult circt::parseOptionalBinaryOpTypes(OpAsmParser &parser, Type &lhs, - Type &rhs) { - if (parser.parseType(lhs)) return failure(); - - // Parse an optional rhs type. - if (parser.parseOptionalComma()) { - rhs = lhs; - } else { - if (parser.parseType(rhs)) return failure(); - } - return success(); -} - -void circt::printOptionalBinaryOpTypes(OpAsmPrinter &p, Operation *op, Type lhs, - Type rhs) { - p << lhs; - // If operand types are not same, print a rhs type. - if (lhs != rhs) p << ", " << rhs; -} - -ParseResult circt::parseKeywordBool(OpAsmParser &parser, BoolAttr &attr, - StringRef trueKeyword, - StringRef falseKeyword) { - if (succeeded(parser.parseOptionalKeyword(trueKeyword))) { - attr = BoolAttr::get(parser.getContext(), true); - } else if (succeeded(parser.parseOptionalKeyword(falseKeyword))) { - attr = BoolAttr::get(parser.getContext(), false); - } else { - return parser.emitError(parser.getCurrentLocation()) - << "expected keyword \"" << trueKeyword << "\" or \"" << falseKeyword - << "\""; - } - return success(); -} - -void circt::printKeywordBool(OpAsmPrinter &printer, Operation *op, - BoolAttr attr, StringRef trueKeyword, - StringRef falseKeyword) { - if (attr.getValue()) - printer << trueKeyword; - else - printer << falseKeyword; -} diff --git a/lib/circt/Support/FieldRef.cpp b/lib/circt/Support/FieldRef.cpp deleted file mode 100644 index 409b2e7fa3..0000000000 --- a/lib/circt/Support/FieldRef.cpp +++ /dev/null @@ -1,23 +0,0 @@ -//===- FieldRef.cpp - Field Refs -------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header file defines FieldRef and helpers for them. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/FieldRef.h" - -#include "mlir/include/mlir/IR/Block.h" // from @llvm-project -#include "mlir/include/mlir/IR/Value.h" // from @llvm-project - -using namespace circt; - -Operation *FieldRef::getDefiningOp() const { - if (auto *op = value.getDefiningOp()) return op; - return value.cast().getOwner()->getParentOp(); -} diff --git a/lib/circt/Support/InstanceGraph.cpp b/lib/circt/Support/InstanceGraph.cpp deleted file mode 100644 index 67a9f45c4a..0000000000 --- a/lib/circt/Support/InstanceGraph.cpp +++ /dev/null @@ -1,314 +0,0 @@ -//===- InstanceGraph.cpp - Instance Graph -----------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/InstanceGraph.h" - -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/include/mlir/IR/Threading.h" // from @llvm-project - -using namespace circt; -using namespace igraph; - -void InstanceRecord::erase() { - // Update the prev node to point to the next node. - if (prevUse) - prevUse->nextUse = nextUse; - else - target->firstUse = nextUse; - // Update the next node to point to the prev node. - if (nextUse) nextUse->prevUse = prevUse; - getParent()->instances.erase(this); -} - -InstanceRecord *InstanceGraphNode::addInstance(InstanceOpInterface instance, - InstanceGraphNode *target) { - auto *instanceRecord = new InstanceRecord(this, instance, target); - target->recordUse(instanceRecord); - instances.push_back(instanceRecord); - return instanceRecord; -} - -void InstanceGraphNode::recordUse(InstanceRecord *record) { - record->nextUse = firstUse; - if (firstUse) firstUse->prevUse = record; - firstUse = record; -} - -InstanceGraphNode *InstanceGraph::getOrAddNode(StringAttr name) { - // Try to insert an InstanceGraphNode. If its not inserted, it returns - // an iterator pointing to the node. - auto *&node = nodeMap[name]; - if (!node) { - node = new InstanceGraphNode(); - nodes.push_back(node); - } - return node; -} - -InstanceGraph::InstanceGraph(Operation *parent) : parent(parent) { - assert(parent->hasTrait() && - "top-level operation must have a single block"); - SmallVector>> - moduleToInstances; - // First accumulate modules inside the parent op. - for (auto module : - parent->getRegion(0).front().getOps()) - moduleToInstances.push_back({module, {}}); - - // Populate instances in the module parallelly. - mlir::parallelFor(parent->getContext(), 0, moduleToInstances.size(), - [&](size_t idx) { - auto module = moduleToInstances[idx].first; - auto &instances = moduleToInstances[idx].second; - // Find all instance operations in the module body. - module.walk([&](InstanceOpInterface instanceOp) { - instances.push_back(instanceOp); - }); - }); - - // Construct an instance graph sequentially. - for (auto &[module, instances] : moduleToInstances) { - auto name = module.getModuleNameAttr(); - auto *currentNode = getOrAddNode(name); - currentNode->module = module; - for (auto instanceOp : instances) { - // Add an edge to indicate that this module instantiates the target. - auto *targetNode = getOrAddNode(instanceOp.getReferencedModuleNameAttr()); - currentNode->addInstance(instanceOp, targetNode); - } - } -} - -InstanceGraphNode *InstanceGraph::addModule(ModuleOpInterface module) { - assert(!nodeMap.count(module.getModuleNameAttr()) && "module already added"); - auto *node = new InstanceGraphNode(); - node->module = module; - nodeMap[module.getModuleNameAttr()] = node; - nodes.push_back(node); - return node; -} - -void InstanceGraph::erase(InstanceGraphNode *node) { - assert(node->noUses() && - "all instances of this module must have been erased."); - // Erase all instances inside this module. - for (auto *instance : llvm::make_early_inc_range(*node)) instance->erase(); - nodeMap.erase(node->getModule().getModuleNameAttr()); - nodes.erase(node); -} - -InstanceGraphNode *InstanceGraph::lookup(StringAttr name) { - auto it = nodeMap.find(name); - assert(it != nodeMap.end() && "Module not in InstanceGraph!"); - return it->second; -} - -InstanceGraphNode *InstanceGraph::lookup(ModuleOpInterface op) { - return lookup(cast(op).getModuleNameAttr()); -} - -ModuleOpInterface InstanceGraph::getReferencedModuleImpl( - InstanceOpInterface op) { - return lookup(op.getReferencedModuleNameAttr())->getModule(); -} - -void InstanceGraph::replaceInstance(InstanceOpInterface inst, - InstanceOpInterface newInst) { - assert(inst.getReferencedModuleName() == newInst.getReferencedModuleName() && - "Both instances must be targeting the same module"); - - // Find the instance record of this instance. - auto *node = lookup(inst.getReferencedModuleNameAttr()); - auto it = llvm::find_if(node->uses(), [&](InstanceRecord *record) { - return record->getInstance() == inst; - }); - assert(it != node->usesEnd() && "Instance of module not recorded in graph"); - - // We can just replace the instance op in the InstanceRecord without updating - // any instance lists. - (*it)->instance = newInst; -} - -bool InstanceGraph::isAncestor(ModuleOpInterface child, - ModuleOpInterface parent) { - DenseSet seen; - SmallVector worklist; - auto *cn = lookup(child); - worklist.push_back(cn); - seen.insert(cn); - while (!worklist.empty()) { - auto *node = worklist.back(); - worklist.pop_back(); - if (node->getModule() == parent) return true; - for (auto *use : node->uses()) { - auto *mod = use->getParent(); - if (!seen.count(mod)) { - seen.insert(mod); - worklist.push_back(mod); - } - } - } - return false; -} - -FailureOr> -InstanceGraph::getInferredTopLevelNodes() { - if (!inferredTopLevelNodes.empty()) return {inferredTopLevelNodes}; - - /// Topologically sort the instance graph. - llvm::SetVector visited, marked; - llvm::SetVector candidateTopLevels(this->begin(), - this->end()); - SmallVector cycleTrace; - - // Recursion function; returns true if a cycle was detected. - std::function)> - cycleUtil = - [&](InstanceGraphNode *node, SmallVector trace) { - if (visited.contains(node)) return false; - trace.push_back(node); - if (marked.contains(node)) { - // Cycle detected. - cycleTrace = trace; - return true; - } - marked.insert(node); - for (auto use : *node) { - InstanceGraphNode *targetModule = use->getTarget(); - candidateTopLevels.remove(targetModule); - if (cycleUtil(targetModule, trace)) - return true; // Cycle detected. - } - marked.remove(node); - visited.insert(node); - return false; - }; - - bool cyclic = false; - for (auto moduleIt : *this) { - if (visited.contains(moduleIt)) continue; - - cyclic |= cycleUtil(moduleIt, {}); - if (cyclic) break; - } - - if (cyclic) { - auto err = getParent()->emitOpError(); - err << "cannot deduce top level module - cycle " - "detected in instance graph ("; - llvm::interleave( - cycleTrace, err, - [&](auto node) { err << node->getModule().getModuleName(); }, "->"); - err << ")."; - return err; - } - assert(!candidateTopLevels.empty() && - "if non-cyclic, there should be at least 1 candidate top level"); - - inferredTopLevelNodes = llvm::SmallVector( - candidateTopLevels.begin(), candidateTopLevels.end()); - return {inferredTopLevelNodes}; -} - -static InstancePath empty{}; - -// NOLINTBEGIN(misc-no-recursion) -ArrayRef InstancePathCache::getAbsolutePaths( - ModuleOpInterface op) { - InstanceGraphNode *node = instanceGraph[op]; - - if (node == instanceGraph.getTopLevelNode()) { - return empty; - } - - // Fast path: hit the cache. - auto cached = absolutePathsCache.find(op); - if (cached != absolutePathsCache.end()) return cached->second; - - // For each instance, collect the instance paths to its parent and append the - // instance itself to each. - SmallVector extendedPaths; - for (auto *inst : node->uses()) { - if (auto module = inst->getParent()->getModule()) { - auto instPaths = getAbsolutePaths(module); - extendedPaths.reserve(instPaths.size()); - for (auto path : instPaths) { - extendedPaths.push_back(appendInstance( - path, cast(*inst->getInstance()))); - } - } else { - extendedPaths.emplace_back(empty); - } - } - - // Move the list of paths into the bump allocator for later quick retrieval. - ArrayRef pathList; - if (!extendedPaths.empty()) { - auto *paths = allocator.Allocate(extendedPaths.size()); - std::copy(extendedPaths.begin(), extendedPaths.end(), paths); - pathList = ArrayRef(paths, extendedPaths.size()); - } - absolutePathsCache.insert({op, pathList}); - return pathList; -} -// NOLINTEND(misc-no-recursion) - -InstancePath InstancePathCache::appendInstance(InstancePath path, - InstanceOpInterface inst) { - size_t n = path.size() + 1; - auto *newPath = allocator.Allocate(n); - std::copy(path.begin(), path.end(), newPath); - newPath[path.size()] = inst; - return InstancePath(ArrayRef(newPath, n)); -} - -InstancePath InstancePathCache::prependInstance(InstanceOpInterface inst, - InstancePath path) { - size_t n = path.size() + 1; - auto *newPath = allocator.Allocate(n); - std::copy(path.begin(), path.end(), newPath + 1); - newPath[0] = inst; - return InstancePath(ArrayRef(newPath, n)); -} - -void InstancePathCache::replaceInstance(InstanceOpInterface oldOp, - InstanceOpInterface newOp) { - instanceGraph.replaceInstance(oldOp, newOp); - - // Iterate over all the paths, and search for the old InstanceOpInterface. If - // found, then replace it with the new InstanceOpInterface, and create a new - // copy of the paths and update the cache. - auto instanceExists = [&](const ArrayRef &paths) -> bool { - return llvm::any_of( - paths, [&](InstancePath p) { return llvm::is_contained(p, oldOp); }); - }; - - for (auto &iter : absolutePathsCache) { - if (!instanceExists(iter.getSecond())) continue; - SmallVector updatedPaths; - for (auto path : iter.getSecond()) { - const auto *iter = llvm::find(path, oldOp); - if (iter == path.end()) { - // path does not contain the oldOp, just copy it as is. - updatedPaths.push_back(path); - continue; - } - auto *newPath = allocator.Allocate(path.size()); - llvm::copy(path, newPath); - newPath[iter - path.begin()] = newOp; - updatedPaths.push_back(InstancePath(ArrayRef(newPath, path.size()))); - } - // Move the list of paths into the bump allocator for later quick - // retrieval. - auto *paths = allocator.Allocate(updatedPaths.size()); - llvm::copy(updatedPaths, paths); - iter.getSecond() = ArrayRef(paths, updatedPaths.size()); - } -} - -#include "include/circt/Support/InstanceGraphInterface.cpp.inc" diff --git a/lib/circt/Support/JSON.cpp b/lib/circt/Support/JSON.cpp deleted file mode 100644 index b30171cbd3..0000000000 --- a/lib/circt/Support/JSON.cpp +++ /dev/null @@ -1,123 +0,0 @@ -//===- Json.cpp - Json Utilities --------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/JSON.h" - -#include "llvm/include/llvm/ADT/StringSwitch.h" // from @llvm-project -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/include/mlir/IR/OperationSupport.h" // from @llvm-project - -namespace json = llvm::json; - -using namespace circt; -using mlir::UnitAttr; - -// NOLINTBEGIN(misc-no-recursion) -LogicalResult circt::convertAttributeToJSON(llvm::json::OStream &json, - Attribute attr) { - return TypeSwitch(attr) - .Case([&](auto attr) { - json.objectBegin(); - for (auto subAttr : attr) { - json.attributeBegin(subAttr.getName()); - if (failed(convertAttributeToJSON(json, subAttr.getValue()))) - return failure(); - json.attributeEnd(); - } - json.objectEnd(); - return success(); - }) - .Case([&](auto attr) { - json.arrayBegin(); - for (auto subAttr : attr) - if (failed(convertAttributeToJSON(json, subAttr))) return failure(); - json.arrayEnd(); - return success(); - }) - .Case([&](auto attr) { - json.value(attr.getValue()); - return success(); - }) - .Case([&](auto attr) -> LogicalResult { - // If the integer can be accurately represented by a double, print - // it as an integer. Otherwise, convert it to an exact decimal string. - const auto &apint = attr.getValue(); - if (!apint.isSignedIntN(64)) return failure(); - json.value(apint.getSExtValue()); - return success(); - }) - .Case([&](auto attr) -> LogicalResult { - const auto &apfloat = attr.getValue(); - json.value(apfloat.convertToDouble()); - return success(); - }) - .Default([&](auto) -> LogicalResult { return failure(); }); -} -// NOLINTEND(misc-no-recursion) - -// NOLINTBEGIN(misc-no-recursion) -Attribute circt::convertJSONToAttribute(MLIRContext *context, - json::Value &value, json::Path p) { - // String or quoted JSON - if (auto a = value.getAsString()) { - // Test to see if this might be quoted JSON (a string that is actually - // JSON). Sometimes FIRRTL developers will do this to serialize objects - // that the Scala FIRRTL Compiler doesn't know about. - auto unquotedValue = json::parse(*a); - auto err = unquotedValue.takeError(); - // If this parsed without an error and we didn't just unquote a number, then - // it's more JSON and recurse on that. - // - // We intentionally do not want to unquote a number as, in JSON, the string - // "0" is different from the number 0. If we conflate these, then later - // expectations about annotation structure may be broken. I.e., an - // annotation expecting a string may see a number. - if (!err && !unquotedValue.get().getAsNumber()) - return convertJSONToAttribute(context, unquotedValue.get(), p); - // If there was an error, then swallow it and handle this as a string. - handleAllErrors(std::move(err), [&](const json::ParseError &a) {}); - return StringAttr::get(context, *a); - } - - // Integer - if (auto a = value.getAsInteger()) - return IntegerAttr::get(IntegerType::get(context, 64), *a); - - // Float - if (auto a = value.getAsNumber()) - return FloatAttr::get(mlir::FloatType::getF64(context), *a); - - // Boolean - if (auto a = value.getAsBoolean()) return BoolAttr::get(context, *a); - - // Null - if (auto a = value.getAsNull()) return mlir::UnitAttr::get(context); - - // Object - if (auto *a = value.getAsObject()) { - NamedAttrList metadata; - for (auto b : *a) - metadata.append( - b.first, convertJSONToAttribute(context, b.second, p.field(b.first))); - return DictionaryAttr::get(context, metadata); - } - - // Array - if (auto *a = value.getAsArray()) { - SmallVector metadata; - for (size_t i = 0, e = (*a).size(); i != e; ++i) - metadata.push_back(convertJSONToAttribute(context, (*a)[i], p.index(i))); - return ArrayAttr::get(context, metadata); - } - - llvm_unreachable("Impossible unhandled JSON type"); -} -// NOLINTEND(misc-no-recursion) diff --git a/lib/circt/Support/LoweringOptions.cpp b/lib/circt/Support/LoweringOptions.cpp deleted file mode 100644 index 7f60ce7e46..0000000000 --- a/lib/circt/Support/LoweringOptions.cpp +++ /dev/null @@ -1,187 +0,0 @@ -//===- LoweringOptions.cpp - CIRCT Lowering Options -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Options for controlling the lowering process. Contains command line -// option definitions and support. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/LoweringOptions.h" - -#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project - -using namespace circt; -using namespace mlir; - -//===----------------------------------------------------------------------===// -// LoweringOptions -//===----------------------------------------------------------------------===// - -LoweringOptions::LoweringOptions(StringRef options, ErrorHandlerT errorHandler) - : LoweringOptions() { - parse(options, errorHandler); -} - -LoweringOptions::LoweringOptions(mlir::ModuleOp module) : LoweringOptions() { - parseFromAttribute(module); -} - -static std::optional parseLocationInfoStyle( - StringRef option) { - return llvm::StringSwitch>( - option) - .Case("plain", LoweringOptions::Plain) - .Case("wrapInAtSquareBracket", LoweringOptions::WrapInAtSquareBracket) - .Case("none", LoweringOptions::None) - .Default(std::nullopt); -} - -static std::optional -parseWireSpillingHeuristic(StringRef option) { - return llvm::StringSwitch< - std::optional>(option) - .Case("spillLargeTermsWithNamehints", - LoweringOptions::SpillLargeTermsWithNamehints) - .Default(std::nullopt); -} - -void LoweringOptions::parse(StringRef text, ErrorHandlerT errorHandler) { - while (!text.empty()) { - // Remove the first option from the text. - auto split = text.split(","); - auto option = split.first.trim(); - text = split.second; - if (option == "") { - // Empty options are fine. - } else if (option == "noAlwaysComb") { - noAlwaysComb = true; - } else if (option == "exprInEventControl") { - allowExprInEventControl = true; - } else if (option == "disallowPackedArrays") { - disallowPackedArrays = true; - } else if (option == "disallowPackedStructAssignments") { - disallowPackedStructAssignments = true; - } else if (option == "disallowLocalVariables") { - disallowLocalVariables = true; - } else if (option == "verifLabels") { - enforceVerifLabels = true; - } else if (option.consume_front("emittedLineLength=")) { - if (option.getAsInteger(10, emittedLineLength)) { - errorHandler("expected integer source width"); - emittedLineLength = DEFAULT_LINE_LENGTH; - } - } else if (option == "explicitBitcast") { - explicitBitcast = true; - } else if (option == "emitReplicatedOpsToHeader") { - emitReplicatedOpsToHeader = true; - } else if (option.consume_front("maximumNumberOfTermsPerExpression=")) { - if (option.getAsInteger(10, maximumNumberOfTermsPerExpression)) { - errorHandler("expected integer source width"); - maximumNumberOfTermsPerExpression = DEFAULT_TERM_LIMIT; - } - } else if (option.consume_front("locationInfoStyle=")) { - if (auto style = parseLocationInfoStyle(option)) { - locationInfoStyle = *style; - } else { - errorHandler("expected 'plain', 'wrapInAtSquareBracket', or 'none'"); - } - } else if (option == "disallowPortDeclSharing") { - disallowPortDeclSharing = true; - } else if (option == "printDebugInfo") { - printDebugInfo = true; - } else if (option == "disallowExpressionInliningInPorts") { - disallowExpressionInliningInPorts = true; - } else if (option == "disallowMuxInlining") { - disallowMuxInlining = true; - } else if (option == "mitigateVivadoArrayIndexConstPropBug") { - mitigateVivadoArrayIndexConstPropBug = true; - } else if (option.consume_front("wireSpillingHeuristic=")) { - if (auto heuristic = parseWireSpillingHeuristic(option)) { - wireSpillingHeuristicSet |= *heuristic; - } else { - errorHandler("expected ''spillLargeTermsWithNamehints'"); - } - } else if (option.consume_front("wireSpillingNamehintTermLimit=")) { - if (option.getAsInteger(10, wireSpillingNamehintTermLimit)) { - errorHandler( - "expected integer for number of namehint heurstic term limit"); - wireSpillingNamehintTermLimit = DEFAULT_NAMEHINT_TERM_LIMIT; - } - } else if (option == "emitWireInPorts") { - emitWireInPorts = true; - } else if (option == "emitBindComments") { - emitBindComments = true; - } else if (option == "omitVersionComment") { - omitVersionComment = true; - } else if (option == "caseInsensitiveKeywords") { - caseInsensitiveKeywords = true; - } else { - errorHandler(llvm::Twine("unknown style option \'") + option + "\'"); - // We continue parsing options after a failure. - } - } -} - -std::string LoweringOptions::toString() const { - std::string options = ""; - // All options should add a trailing comma to simplify the code. - if (noAlwaysComb) options += "noAlwaysComb,"; - if (allowExprInEventControl) options += "exprInEventControl,"; - if (disallowPackedArrays) options += "disallowPackedArrays,"; - if (disallowPackedStructAssignments) - options += "disallowPackedStructAssignments,"; - if (disallowLocalVariables) options += "disallowLocalVariables,"; - if (enforceVerifLabels) options += "verifLabels,"; - if (explicitBitcast) options += "explicitBitcast,"; - if (emitReplicatedOpsToHeader) options += "emitReplicatedOpsToHeader,"; - if (locationInfoStyle == LocationInfoStyle::WrapInAtSquareBracket) - options += "locationInfoStyle=wrapInAtSquareBracket,"; - if (locationInfoStyle == LocationInfoStyle::None) - options += "locationInfoStyle=none,"; - if (disallowPortDeclSharing) options += "disallowPortDeclSharing,"; - if (printDebugInfo) options += "printDebugInfo,"; - if (isWireSpillingHeuristicEnabled( - WireSpillingHeuristic::SpillLargeTermsWithNamehints)) - options += "wireSpillingHeuristic=spillLargeTermsWithNamehints,"; - if (disallowExpressionInliningInPorts) - options += "disallowExpressionInliningInPorts,"; - if (disallowMuxInlining) options += "disallowMuxInlining,"; - if (mitigateVivadoArrayIndexConstPropBug) - options += "mitigateVivadoArrayIndexConstPropBug,"; - - if (emittedLineLength != DEFAULT_LINE_LENGTH) - options += "emittedLineLength=" + std::to_string(emittedLineLength) + ','; - if (maximumNumberOfTermsPerExpression != DEFAULT_TERM_LIMIT) - options += "maximumNumberOfTermsPerExpression=" + - std::to_string(maximumNumberOfTermsPerExpression) + ','; - if (emitWireInPorts) options += "emitWireInPorts,"; - if (emitBindComments) options += "emitBindComments,"; - if (omitVersionComment) options += "omitVersionComment,"; - if (caseInsensitiveKeywords) options += "caseInsensitiveKeywords,"; - - // Remove a trailing comma if present. - if (!options.empty()) { - assert(options.back() == ',' && "all options should add a trailing comma"); - options.pop_back(); - } - return options; -} - -StringAttr LoweringOptions::getAttributeFrom(ModuleOp module) { - return module->getAttrOfType("circt.loweringOptions"); -} - -void LoweringOptions::setAsAttribute(ModuleOp module) { - module->setAttr("circt.loweringOptions", - StringAttr::get(module.getContext(), toString())); -} - -void LoweringOptions::parseFromAttribute(ModuleOp module) { - if (auto styleAttr = getAttributeFrom(module)) - parse(styleAttr.getValue(), [&](Twine error) { module.emitError(error); }); -} diff --git a/lib/circt/Support/ParsingUtils.cpp b/lib/circt/Support/ParsingUtils.cpp deleted file mode 100644 index 4cee0240de..0000000000 --- a/lib/circt/Support/ParsingUtils.cpp +++ /dev/null @@ -1,50 +0,0 @@ -//===- ParsingUtils.cpp - CIRCT parsing common functions ------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/ParsingUtils.h" - -using namespace circt; - -ParseResult circt::parsing_util::parseInitializerList( - OpAsmParser &parser, - llvm::SmallVector &inputArguments, - llvm::SmallVector &inputOperands, - llvm::SmallVector &inputTypes, ArrayAttr &inputNames) { - llvm::SmallVector names; - if (failed(parser.parseCommaSeparatedList( - OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { - OpAsmParser::UnresolvedOperand inputOperand; - Type type; - auto &arg = inputArguments.emplace_back(); - if (parser.parseArgument(arg) || parser.parseColonType(type) || - parser.parseEqual() || parser.parseOperand(inputOperand)) - return failure(); - - inputOperands.push_back(inputOperand); - inputTypes.push_back(type); - arg.type = type; - names.push_back(StringAttr::get( - parser.getContext(), - /*drop leading %*/ arg.ssaName.name.drop_front())); - return success(); - }))) - return failure(); - - inputNames = ArrayAttr::get(parser.getContext(), names); - return success(); -} - -void circt::parsing_util::printInitializerList(OpAsmPrinter &p, ValueRange ins, - ArrayRef args) { - p << "("; - llvm::interleaveComma(llvm::zip(ins, args), p, [&](auto it) { - auto [in, arg] = it; - p << arg << " : " << in.getType() << " = " << in; - }); - p << ")"; -} diff --git a/lib/circt/Support/Passes.cpp b/lib/circt/Support/Passes.cpp deleted file mode 100644 index d594564484..0000000000 --- a/lib/circt/Support/Passes.cpp +++ /dev/null @@ -1,21 +0,0 @@ -//===- Passes.cpp - Pass Utilities ------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/Passes.h" - -#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project - -using namespace circt; - -std::unique_ptr circt::createSimpleCanonicalizerPass() { - mlir::GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = false; - return mlir::createCanonicalizerPass(config); -} diff --git a/lib/circt/Support/Path.cpp b/lib/circt/Support/Path.cpp deleted file mode 100644 index 3cd30babdc..0000000000 --- a/lib/circt/Support/Path.cpp +++ /dev/null @@ -1,32 +0,0 @@ -//===- Path.cpp - Path Utilities --------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Utilities for file system path handling, supplementing the ones from -// llvm::sys::path. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/Path.h" - -#include "llvm/include/llvm/Support/Path.h" // from @llvm-project - -using namespace circt; - -/// Append a path to an existing path, replacing it if the other path is -/// absolute. This mimicks the behaviour of `foo/bar` and `/foo/bar` being used -/// in a working directory `/home`, resulting in `/home/foo/bar` and `/foo/bar`, -/// respectively. -void circt::appendPossiblyAbsolutePath(llvm::SmallVectorImpl &base, - const llvm::Twine &suffix) { - if (llvm::sys::path::is_absolute(suffix)) { - base.clear(); - suffix.toVector(base); - } else { - llvm::sys::path::append(base, suffix); - } -} diff --git a/lib/circt/Support/PrettyPrinter.cpp b/lib/circt/Support/PrettyPrinter.cpp deleted file mode 100644 index 21f94137ba..0000000000 --- a/lib/circt/Support/PrettyPrinter.cpp +++ /dev/null @@ -1,305 +0,0 @@ -//===- PrettyPrinter.cpp - Pretty printing --------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This implements a pretty-printer. -// "PrettyPrinting", Derek C. Oppen, 1980. -// https://dx.doi.org/10.1145/357114.357115 -// -// This was selected as it is linear in number of tokens O(n) and requires -// memory O(linewidth). -// -// This has been adjusted from the paper: -// * Deque for tokens instead of ringbuffer + left/right cursors. -// This is simpler to reason about and allows us to easily grow the buffer -// to accommodate longer widths when needed (and not reserve 3*linewidth). -// Since scanStack references buffered tokens by index, we track an offset -// that we increase when dropping off the front. -// When the scan stack is cleared the buffer is reset, including this offset. -// * Indentation tracked from left not relative to margin (linewidth). -// * Indentation emitted lazily, avoid trailing whitespace. -// * Group indentation styles: Visual and Block, set on 'begin' tokens. -// "Visual" is the style in the paper, offset relative to current column. -// "Block" is relative to current base indentation. -// * Break: Add "Neverbreak": acts like a break re:sizing previous range, -// but will never be broken. Useful for adding content to end of line -// that may go over margin but should not influence layout. -// * Begin: Add "Never" breaking style, for forcing no breaks including -// within nested groups. Use sparingly. It is an error to insert -// a newline (Break with spaces==kInfinity) within such a group. -// * If leftTotal grows too large, "rebase" our datastructures by -// walking the tokens with pending sizes (scanStack) and adjusting -// them by `leftTotal - 1`. Also reset tokenOffset while visiting. -// This is mostly needed due to use of tokens/groups that 'never' break -// which can greatly increase times between `clear()`. -// * Optionally, minimum amount of space is granted regardless of indentation. -// To avoid forcing expressions against the line limit, never try to print -// an expression in, say, 2 columns, as this is unlikely to produce good -// output. -// (TODO) -// -// There are many pretty-printing implementations based on this paper, -// and research literature is rich with functional formulations based originally -// on this algorithm. -// -// Implementations of note that have interesting modifications for their -// languages and modernization of the paper's algorithm: -// * prettyplease / rustc_ast_pretty -// Pretty-printers for rust, the first being useful for rustfmt-like output. -// These have largely the same code and were based on one another. -// https://github.com/dtolnay/prettyplease -// https://github.com/rust-lang/rust/tree/master/compiler/rustc_ast_pretty -// This is closest to the paper's algorithm with modernizations, -// and most of the initial tweaks have also been implemented here (thanks!). -// * swift-format: https://github.com/apple/swift-format/ -// -// If we want fancier output or need to handle more complicated constructs, -// both are good references for lessons and ideas. -// -// FWIW, at time of writing these have compatible licensing (Apache 2.0). -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/PrettyPrinter.h" - -#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project -#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project - -namespace circt { -namespace pretty { - -/// Destructor, anchor. -PrettyPrinter::Listener::~Listener() = default; - -/// Add token for printing. In Oppen, this is "scan". -void PrettyPrinter::add(Token t) { - // Add token to tokens, and add its index to scanStack. - auto addScanToken = [&](auto offset) { - auto right = tokenOffset + tokens.size(); - scanStack.push_back(right); - tokens.push_back({t, offset}); - }; - llvm::TypeSwitch(&t) - .Case([&](StringToken *s) { - // If nothing on stack, directly print - FormattedToken f{t, (int32_t)s->text().size()}; - // Empty string token isn't /wrong/ but can have unintended effect. - assert(!s->text().empty() && "empty string token"); - if (scanStack.empty()) return print(f); - tokens.push_back(f); - rightTotal += f.size; - assert(rightTotal > 0); - checkStream(); - }) - .Case([&](BreakToken *b) { - if (scanStack.empty()) - clear(); - else - checkStack(); - addScanToken(-rightTotal); - rightTotal += b->spaces(); - assert(rightTotal > 0); - }) - .Case([&](BeginToken *b) { - if (scanStack.empty()) clear(); - addScanToken(-rightTotal); - }) - .Case([&](EndToken *end) { - if (scanStack.empty()) return print({t, 0}); - addScanToken(-1); - }) - .Case([&](CallbackToken *c) { - // Callbacktoken must be associated with a listener, it doesn't have any - // meaning without it. - assert(listener); - if (scanStack.empty()) return print({t, 0}); - tokens.push_back({t, 0}); - }); - rebaseIfNeeded(); -} - -void PrettyPrinter::rebaseIfNeeded() { - // Check for too-large totals, reset. - // This can happen if we have an open group and emit - // many tokens, especially newlines which have artificial size. - if (tokens.empty()) return; - assert(leftTotal >= 0); - assert(rightTotal >= 0); - if (uint32_t(leftTotal) > rebaseThreshold) { - // Plan: reset leftTotal to '1', adjust all accordingly. - auto adjust = leftTotal - 1; - for (auto &scanIndex : scanStack) { - assert(scanIndex >= tokenOffset); - auto &t = tokens[scanIndex - tokenOffset]; - if (isa(&t.token)) { - if (t.size < 0) { - assert(t.size + adjust < 0); - t.size += adjust; - } - } - // While walking, reset tokenOffset too. - scanIndex -= tokenOffset; - } - leftTotal -= adjust; - rightTotal -= adjust; - tokenOffset = 0; - } -} - -void PrettyPrinter::eof() { - if (!scanStack.empty()) { - checkStack(); - advanceLeft(); - } - assert(scanStack.empty() && "unclosed groups at EOF"); - if (scanStack.empty()) clear(); -} - -void PrettyPrinter::clear() { - assert(scanStack.empty() && "clearing tokens while still on scan stack"); - assert(tokens.empty()); - leftTotal = rightTotal = 1; - tokens.clear(); - tokenOffset = 0; - if (listener && !donotClear) listener->clear(); -} - -/// Break encountered, set sizes of begin/breaks in scanStack that we now know. -void PrettyPrinter::checkStack() { - unsigned depth = 0; - while (!scanStack.empty()) { - auto x = scanStack.back(); - assert(x >= tokenOffset && tokens.size() + tokenOffset > x); - auto &t = tokens[x - tokenOffset]; - if (llvm::isa(&t.token)) { - if (depth == 0) break; - scanStack.pop_back(); - t.size += rightTotal; - --depth; - } else if (llvm::isa(&t.token)) { - scanStack.pop_back(); - t.size = 1; - ++depth; - } else { - scanStack.pop_back(); - t.size += rightTotal; - if (depth == 0) break; - } - } -} - -/// Check if there are enough tokens to hit width, if so print. -/// If scan size is wider than line, it's infinity. -void PrettyPrinter::checkStream() { - // While buffer needs more than 1 line to print, print and consume. - assert(!tokens.empty()); - assert(leftTotal >= 0); - assert(rightTotal >= 0); - while (rightTotal - leftTotal > space && !tokens.empty()) { - // Ran out of space, set size to infinity and take off scan stack. - // No need to keep track as we know enough to know this won't fit. - if (!scanStack.empty() && tokenOffset == scanStack.front()) { - tokens.front().size = kInfinity; - scanStack.pop_front(); - } - advanceLeft(); - } -} - -/// Print out tokens we know sizes for, and drop from token buffer. -void PrettyPrinter::advanceLeft() { - assert(!tokens.empty()); - - while (!tokens.empty() && tokens.front().size >= 0) { - const auto &f = tokens.front(); - print(f); - leftTotal += - llvm::TypeSwitch(&f.token) - .Case([&](const BreakToken *b) { return b->spaces(); }) - .Case([&](const StringToken *s) { return s->text().size(); }) - .Default([](const auto *) { return 0; }); - tokens.pop_front(); - ++tokenOffset; - } -} - -/// Compute indentation w/o overflow, clamp to [0,maxStartingIndent]. -/// Output looks better if we don't stop indenting entirely at target width, -/// but don't do this indefinitely. -static uint32_t computeNewIndent(ssize_t newIndent, int32_t offset, - uint32_t maxStartingIndent) { - return std::clamp(newIndent + offset, 0, maxStartingIndent); -} - -/// Print a token, maintaining printStack for context. -void PrettyPrinter::print(const FormattedToken &f) { - llvm::TypeSwitch(&f.token) - .Case([&](const StringToken *s) { - space -= f.size; - os.indent(pendingIndentation); - pendingIndentation = 0; - os << s->text(); - }) - .Case([&](const BreakToken *b) { - auto &frame = getPrintFrame(); - assert((b->spaces() != kInfinity || alwaysFits == 0) && - "newline inside never group"); - bool fits = - (alwaysFits > 0) || b->neverbreak() || - frame.breaks == PrintBreaks::Fits || - (frame.breaks == PrintBreaks::Inconsistent && f.size <= space); - if (fits) { - space -= b->spaces(); - pendingIndentation += b->spaces(); - } else { - os << "\n"; - pendingIndentation = - computeNewIndent(indent, b->offset(), maxStartingIndent); - space = margin - pendingIndentation; - } - }) - .Case([&](const BeginToken *b) { - if (b->breaks() == Breaks::Never) { - printStack.push_back({0, PrintBreaks::AlwaysFits}); - ++alwaysFits; - } else if (f.size > space && alwaysFits == 0) { - auto breaks = b->breaks() == Breaks::Consistent - ? PrintBreaks::Consistent - : PrintBreaks::Inconsistent; - ssize_t newIndent = indent; - if (b->style() == IndentStyle::Visual) - newIndent = ssize_t{margin} - space; - indent = computeNewIndent(newIndent, b->offset(), maxStartingIndent); - printStack.push_back({indent, breaks}); - } else { - printStack.push_back({0, PrintBreaks::Fits}); - } - }) - .Case([&](const EndToken *) { - assert(!printStack.empty() && "more ends than begins?"); - // Try to tolerate this when assertions are disabled. - if (printStack.empty()) return; - if (getPrintFrame().breaks == PrintBreaks::AlwaysFits) --alwaysFits; - printStack.pop_back(); - auto &frame = getPrintFrame(); - if (frame.breaks != PrintBreaks::Fits && - frame.breaks != PrintBreaks::AlwaysFits) - indent = frame.offset; - }) - .Case([&](const CallbackToken *c) { - if (pendingIndentation) { - // This is necessary to get the correct location on the stream for the - // callback invocation. - os.indent(pendingIndentation); - pendingIndentation = 0; - } - listener->print(); - }); -} -} // end namespace pretty -} // end namespace circt diff --git a/lib/circt/Support/PrettyPrinterHelpers.cpp b/lib/circt/Support/PrettyPrinterHelpers.cpp deleted file mode 100644 index 7c820c796f..0000000000 --- a/lib/circt/Support/PrettyPrinterHelpers.cpp +++ /dev/null @@ -1,50 +0,0 @@ -//===- PrettyPrinterHelpers.cpp - Pretty printing helpers -----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Helper classes for using PrettyPrinter. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/PrettyPrinterHelpers.h" - -#include - -#include "llvm/include/llvm/ADT/SmallString.h" // from @llvm-project -#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project - -namespace circt { -namespace pretty { - -//===----------------------------------------------------------------------===// -// Convenience builders. -//===----------------------------------------------------------------------===// - -void TokenStringSaver::clear() { alloc.Reset(); } - -/// Add multiple non-breaking spaces as a single token. -void detail::emitNBSP(unsigned n, llvm::function_ref add) { - static const std::array spaces = ([]() constexpr { - std::array s = {}; - for (auto &c : s) c = ' '; - return s; - })(); - - const auto size = spaces.size(); - if (n <= size) { - if (n != 0) add(StringToken({spaces.data(), n})); - return; - } - while (n) { - auto chunk = std::min(n, size); - add(StringToken({spaces.data(), chunk})); - n -= chunk; - } -} - -} // end namespace pretty -} // end namespace circt diff --git a/lib/circt/Support/SymCache.cpp b/lib/circt/Support/SymCache.cpp deleted file mode 100644 index 1bd6546790..0000000000 --- a/lib/circt/Support/SymCache.cpp +++ /dev/null @@ -1,29 +0,0 @@ -//===- SymCache.cpp - Declare Symbol Cache ----------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines a Symbol Cache. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/SymCache.h" - -using namespace mlir; -using namespace circt; - -namespace circt { - -/// Virtual method anchor. -SymbolCacheBase::~SymbolCacheBase() {} - -void SymbolCacheBase::addDefinitions(mlir::Operation *top) { - for (auto ®ion : top->getRegions()) - for (auto &block : region.getBlocks()) - for (auto symOp : block.getOps()) - addSymbol(symOp); -} -} // namespace circt diff --git a/lib/circt/Support/ValueMapper.cpp b/lib/circt/Support/ValueMapper.cpp deleted file mode 100644 index d923e5079a..0000000000 --- a/lib/circt/Support/ValueMapper.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===- ValueMapper.cpp - Support for mapping SSA values ---------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides support for mapping SSA values between two domains. -// Provided a BackedgeBuilder, the ValueMapper supports mappings between -// GraphRegions, creating Backedges in cases of 'get'ing mapped values which are -// yet to be 'set'. -// -//===----------------------------------------------------------------------===// - -#include "include/circt/Support/ValueMapper.h" - -using namespace mlir; -using namespace circt; -mlir::Value ValueMapper::get(Value from, TypeTransformer typeTransformer) { - if (mapping.count(from) == 0) { - assert(bb && - "Trying to 'get' a mapped value without any value set. No " - "BackedgeBuilder was provided, so cannot provide any mapped " - "SSA value!"); - // Create a backedge which will be resolved at a later time once all - // operands are created. - mapping[from] = bb->get(typeTransformer(from.getType())); - } - auto operandMapping = mapping[from]; - Value mappedOperand; - if (auto *v = std::get_if(&operandMapping)) - mappedOperand = *v; - else - mappedOperand = std::get(operandMapping); - return mappedOperand; -} - -llvm::SmallVector ValueMapper::get(ValueRange from, - TypeTransformer typeTransformer) { - llvm::SmallVector to; - for (auto f : from) to.push_back(get(f, typeTransformer)); - return to; -} - -void ValueMapper::set(Value from, Value to, bool replace) { - auto it = mapping.find(from); - if (it != mapping.end()) { - if (auto *backedge = std::get_if(&it->second)) - backedge->setValue(to); - else if (!replace) - assert(false && "'from' was already mapped to a final value!"); - } - // Register the new mapping - mapping[from] = to; -} - -void ValueMapper::set(ValueRange from, ValueRange to, bool replace) { - assert(from.size() == to.size() && - "Expected # of 'from' values and # of 'to' values to be identical."); - for (auto [f, t] : llvm::zip(from, to)) set(f, t, replace); -} diff --git a/lib/circt/Support/Version.cpp.in b/lib/circt/Support/Version.cpp.in deleted file mode 100644 index f813fd102a..0000000000 --- a/lib/circt/Support/Version.cpp.in +++ /dev/null @@ -1,6 +0,0 @@ -#include "include/circt/Support/Version.h" - -const char *circt::getCirctVersion() { return "CIRCT @GIT_DESCRIBE_OUTPUT@"; } -const char *circt::getCirctVersionComment() { - return "// Generated by CIRCT @GIT_DESCRIBE_OUTPUT@\n"; -}