diff --git a/.pre-commit-search-and-replace.yaml b/.pre-commit-search-and-replace.yaml index dde262705..9399ee901 100644 --- a/.pre-commit-search-and-replace.yaml +++ b/.pre-commit-search-and-replace.yaml @@ -3,3 +3,13 @@ # filesystem. - search: '/^#include "mlir\/(?!include\/mlir\/)/' replacement: '#include "mlir/include/mlir/' +# Same for llvm paths. +- search: '/^#include "llvm\/(?!include\/llvm\/)/' + replacement: '#include "llvm/include/llvm/' +# Ensure that all C++ mlir include paths include a "// from @llvm-project" +# comment import into Google's internal filesystem. +- search: '/^#include ("mlir\/.*")$/' + replacement: '#include \1 // from @llvm-project' +# Same for llvm paths. +- search: '/^#include ("llvm\/.*")$/' + replacement: '#include \1 // from @llvm-project' diff --git a/include/Dialect/Comb/IR/BUILD b/include/Dialect/Comb/IR/BUILD new file mode 100644 index 000000000..a671418cb --- /dev/null +++ b/include/Dialect/Comb/IR/BUILD @@ -0,0 +1,140 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +exports_files( + glob(["*.h"]), +) + +td_library( + name = "td_files", + srcs = [ + "Comb.td", + "Combinational.td", + ], + includes = ["include"], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", + "@llvm-project//mlir:FunctionInterfacesTdFiles", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "dialect_inc_gen", + includes = ["include"], + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + "-dialect=comb", + ], + "CombDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=comb", + ], + "CombDialect.cpp.inc", + ), + ( + ["-gen-dialect-doc"], + "CombDialect.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":td_files", + ":type_inc_gen", + ], +) + +gentbl_cc_library( + name = "ops_inc_gen", + includes = ["include"], + tbl_outs = [ + ( + [ + "-gen-op-decls", + ], + "Comb.h.inc", + ), + ( + [ + "-gen-op-defs", + ], + "Comb.cpp.inc", + ), + ( + ["-gen-op-doc"], + "CombOps.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) + +gentbl_cc_library( + name = "type_inc_gen", + includes = ["include"], + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + ], + "CombTypes.h.inc", + ), + ( + [ + "-gen-typedef-defs", + ], + "CombTypes.cpp.inc", + ), + ( + ["-gen-typedef-doc"], + "CombTypes.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":td_files", + ], +) + +gentbl_cc_library( + name = "enum_inc_gen", + includes = ["include"], + tbl_outs = [ + ( + [ + "-gen-enum-decls", + ], + "CombEnums.h.inc", + ), + ( + [ + "-gen-enum-defs", + ], + "CombEnums.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Comb.td", + deps = [ + ":dialect_inc_gen", + ":td_files", + ], +) diff --git a/include/Dialect/Comb/IR/Comb.td b/include/Dialect/Comb/IR/Comb.td new file mode 100644 index 000000000..2917e954b --- /dev/null +++ b/include/Dialect/Comb/IR/Comb.td @@ -0,0 +1,42 @@ +//===- Comb.td - Comb dialect definition --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the top level file for the Comb dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef HEIR_INCLUDE_DIALECT_COMB_COMB_TD +#define HEIR_INCLUDE_DIALECT_COMB_COMB_TD + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" + +def CombDialect : Dialect { + let name = "comb"; + + let summary = "Types and operations for comb dialect"; + let description = [{ + This dialect defines the `comb` dialect, which is intended to be a generic + representation of combinational logic outside of a particular use-case. + }]; + let cppNamespace = "::mlir::heir::comb"; + + // This will be the default after next LLVM bump. + let usePropertiesForAttributes = 1; + +} + +// Base class for the operation in this dialect. +class CombOp traits = []> : + Op; + +include "Combinational.td" + +#endif // HEIR_INCLUDE_DIALECT_COMB_COMB_TD diff --git a/include/Dialect/Comb/IR/CombDialect.h b/include/Dialect/Comb/IR/CombDialect.h new file mode 100644 index 000000000..c6fde96c5 --- /dev/null +++ b/include/Dialect/Comb/IR/CombDialect.h @@ -0,0 +1,25 @@ +//===- CombDialect.h - Comb dialect declaration -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the Combinational MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef HEIR_INCLUDE_DIALECT_COMB_COMBDIALECT_H +#define HEIR_INCLUDE_DIALECT_COMB_COMBDIALECT_H + +#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project + +// Pull in the Dialect definition. +#include "include/Dialect/Comb/IR/CombDialect.h.inc" + +// Pull in all enum type definitions and utility function declarations. +#include "include/Dialect/Comb/IR/CombEnums.h.inc" + +#endif // HEIR_INCLUDE_DIALECT_COMB_COMBDIALECT_H diff --git a/include/Dialect/Comb/IR/CombOps.h b/include/Dialect/Comb/IR/CombOps.h new file mode 100644 index 000000000..631211bd5 --- /dev/null +++ b/include/Dialect/Comb/IR/CombOps.h @@ -0,0 +1,36 @@ +//===- CombOps.h - Declare Comb dialect operations --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the operation classes for the Comb dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef HEIR_INCLUDE_DIALECT_COMB_COMBOPS_H +#define HEIR_INCLUDE_DIALECT_COMB_COMBOPS_H + +#include "include/Dialect/Comb/IR/CombDialect.h" +#include "mlir/include/mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/include/mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/include/mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project + +namespace llvm { +struct KnownBits; +} + +namespace mlir { +class PatternRewriter; +} + +#define GET_OP_CLASSES +#include "include/Dialect/Comb/IR/Comb.h.inc" + +#endif // HEIR_INCLUDE_DIALECT_COMB_COMBOPS_H diff --git a/include/Dialect/Comb/IR/Combinational.td b/include/Dialect/Comb/IR/Combinational.td new file mode 100644 index 000000000..9ca038140 --- /dev/null +++ b/include/Dialect/Comb/IR/Combinational.td @@ -0,0 +1,297 @@ +//===- Combinational.td - combinational logic ops ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This describes the MLIR ops for combinational logic. +// +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Arithmetic and Logical Operations +//===----------------------------------------------------------------------===// + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/IR/EnumAttr.td" + +def HWIntegerType : Type< + CPred<"$_self.isSignlessInteger()">, "signless integer", + "::mlir::IntegerType">; + +// Base class for binary operators. +class BinOp traits = []> : + CombOp { + let arguments = (ins HWIntegerType:$lhs, HWIntegerType:$rhs, UnitAttr:$twoState); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = + "$lhs `,` $rhs (`bin` $twoState^)? attr-dict `:` functional-type($args, $results)"; +} + +// Binary operator with uniform input/result types. +class UTBinOp traits = []> : + BinOp { + let assemblyFormat = "(`bin` $twoState^)? $lhs `,` $rhs attr-dict `:` qualified(type($result))"; +} + +// Base class for variadic operators. +class VariadicOp traits = []> : + CombOp { + let arguments = (ins Variadic:$inputs, UnitAttr:$twoState); + let results = (outs HWIntegerType:$result); +} + +class UTVariadicOp traits = []> : + VariadicOp { + + let hasVerifier = 1; + + let assemblyFormat = "(`bin` $twoState^)? $inputs attr-dict `:` qualified(type($result))"; + + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs, CArg<"bool", "false">:$twoState), [{ + return build($_builder, $_state, lhs.getType(), + ValueRange{lhs, rhs}, twoState); + }]> + ]; +} + +// Arithmetic and Logical Operations. +def AddOp : UTVariadicOp<"add", [Commutative]>; +def MulOp : UTVariadicOp<"mul", [Commutative]>; + +def AndOp : UTVariadicOp<"and", [Commutative]>; +def OrOp : UTVariadicOp<"or", [Commutative]>; +def XorOp : UTVariadicOp<"xor", [Commutative]> { + let extraClassDeclaration = [{ + /// Return true if this is a two operand xor with an all ones constant as + /// its RHS operand. + bool isBinaryNot(); + }]; +} + +//===----------------------------------------------------------------------===// +// Comparisons +//===----------------------------------------------------------------------===// + +def ICmpPredicateEQ : I64EnumAttrCase<"eq", 0>; +def ICmpPredicateNE : I64EnumAttrCase<"ne", 1>; +def ICmpPredicateSLT : I64EnumAttrCase<"slt", 2>; +def ICmpPredicateSLE : I64EnumAttrCase<"sle", 3>; +def ICmpPredicateSGT : I64EnumAttrCase<"sgt", 4>; +def ICmpPredicateSGE : I64EnumAttrCase<"sge", 5>; +def ICmpPredicateULT : I64EnumAttrCase<"ult", 6>; +def ICmpPredicateULE : I64EnumAttrCase<"ule", 7>; +def ICmpPredicateUGT : I64EnumAttrCase<"ugt", 8>; +def ICmpPredicateUGE : I64EnumAttrCase<"uge", 9>; +// SV case equality +def ICmpPredicateCEQ : I64EnumAttrCase<"ceq", 10>; +def ICmpPredicateCNE : I64EnumAttrCase<"cne", 11>; +// SV wild card equality +def ICmpPredicateWEQ : I64EnumAttrCase<"weq", 12>; +def ICmpPredicateWNE : I64EnumAttrCase<"wne", 13>; +let cppNamespace = "::mlir::heir::comb" in +def ICmpPredicate : I64EnumAttr< + "ICmpPredicate", + "hw.icmp comparison predicate", + [ICmpPredicateEQ, ICmpPredicateNE, ICmpPredicateSLT, ICmpPredicateSLE, + ICmpPredicateSGT, ICmpPredicateSGE, ICmpPredicateULT, ICmpPredicateULE, + ICmpPredicateUGT, ICmpPredicateUGE, ICmpPredicateCEQ, ICmpPredicateCNE, + ICmpPredicateWEQ, ICmpPredicateWNE]>; + +def ICmpOp : CombOp<"icmp", [Pure, SameTypeOperands]> { + let summary = "Compare two integer values"; + let description = [{ + This operation compares two integers using a predicate. If the predicate is + true, returns 1, otherwise returns 0. This operation always returns a one + bit wide result. + + ``` + %r = comb.icmp eq %a, %b : i4 + ``` + }]; + + let arguments = (ins ICmpPredicate:$predicate, + HWIntegerType:$lhs, HWIntegerType:$rhs, UnitAttr:$twoState); + let results = (outs I1:$result); + + let assemblyFormat = "(`bin` $twoState^)? $predicate $lhs `,` $rhs attr-dict `:` qualified(type($lhs))"; + + let extraClassDeclaration = [{ + /// Returns the flipped predicate, reversing the LHS and RHS operands. The + /// lhs and rhs operands should be flipped to match the new predicate. + static ICmpPredicate getFlippedPredicate(ICmpPredicate predicate); + + /// Returns true if the predicate is signed. + static bool isPredicateSigned(ICmpPredicate predicate); + + /// Returns the predicate for a logically negated comparison, e.g. mapping + /// EQ => NE and SLE => SGT. + static ICmpPredicate getNegatedPredicate(ICmpPredicate predicate); + + /// Return true if this is an equality test with -1, which is a "reduction + /// and" operation in Verilog. + bool isEqualAllOnes(); + + /// Return true if this is a not equal test with 0, which is a "reduction + /// or" operation in Verilog. + bool isNotEqualZero(); + }]; +} + +//===----------------------------------------------------------------------===// +// Unary Operations +//===----------------------------------------------------------------------===// + +// Base class for unary reduction operations that produce an i1. +class UnaryI1ReductionOp traits = []> : + CombOp { + let arguments = (ins HWIntegerType:$input, UnitAttr:$twoState); + let results = (outs I1:$result); + + let assemblyFormat = "(`bin` $twoState^)? $input attr-dict `:` qualified(type($input))"; +} + +def ParityOp : UnaryI1ReductionOp<"parity">; + +//===----------------------------------------------------------------------===// +// Integer width modifying operations. +//===----------------------------------------------------------------------===// + +// Extract a range of bits from the specified input. +def ExtractOp : CombOp<"extract", [Pure]> { + let summary = "Extract a range of bits into a smaller value, lowBit " + "specifies the lowest bit included."; + + let arguments = (ins HWIntegerType:$input, I32Attr:$lowBit); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = + "$input `from` $lowBit attr-dict `:` functional-type($input, $result)"; + + let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "Value":$lhs, "int32_t":$lowBit, "int32_t":$bitWidth), [{ + auto resultType = $_builder.getIntegerType(bitWidth); + return build($_builder, $_state, resultType, lhs, lowBit); + }]> + ]; +} + +//===----------------------------------------------------------------------===// +// Other Operations +//===----------------------------------------------------------------------===// +def ConcatOp : CombOp<"concat", [InferTypeOpInterface, Pure]> { + let summary = "Concatenate a variadic list of operands together."; + let description = [{ + See the comb rationale document for details on operand ordering. + }]; + + let arguments = (ins Variadic:$inputs); + let results = (outs HWIntegerType:$result); + + let hasVerifier = 1; + + let assemblyFormat = "$inputs attr-dict `:` qualified(type($inputs))"; + + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ + return build($_builder, $_state, ValueRange{lhs, rhs}); + }]>, + OpBuilder<(ins "Value":$hd, "ValueRange":$tl)>, + ]; + + let extraClassDeclaration = [{ + /// Infer the return types of this operation. + static LogicalResult inferReturnTypes(MLIRContext *context, + std::optional loc, + ValueRange operands, + DictionaryAttr attrs, + mlir::OpaqueProperties properties, + mlir::RegionRange regions, + SmallVectorImpl &results); + }]; +} + +def ReplicateOp : CombOp<"replicate", [Pure]> { + let summary = "Concatenate the operand a constant number of times"; + + let arguments = (ins HWIntegerType:$input); + let results = (outs HWIntegerType:$result); + + let assemblyFormat = + "$input attr-dict `:` functional-type($input, $result)"; + + let hasVerifier = 1; + + let builders = [ + OpBuilder<(ins "Value":$operand, "int32_t":$multiple), [{ + auto bitWidth = operand.getType().cast().getWidth(); + auto resultType = $_builder.getIntegerType(bitWidth*multiple); + return build($_builder, $_state, resultType, operand); + }]> + ]; + + let extraClassDeclaration = [{ + /// Returns the number of times the operand is replicated. + size_t getMultiple() { + auto opWidth = getInput().getType().cast().getWidth(); + return getType().cast().getWidth()/opWidth; + } + }]; +} + +// Select one of two values based on a condition. +def MuxOp : CombOp<"mux", + [Pure, AllTypesMatch<["trueValue", "falseValue", "result"]>]> { + let summary = "Return one or the other operand depending on a selector bit"; + let description = [{ + ``` + %0 = mux %pred, %tvalue, %fvalue : i4 + ``` + }]; + + let arguments = (ins I1:$cond, AnyType:$trueValue, + AnyType:$falseValue, UnitAttr:$twoState); + let results = (outs AnyType:$result); + + let assemblyFormat = + "(`bin` $twoState^)? $cond `,` $trueValue `,` $falseValue attr-dict `:` qualified(type($result))"; + +} + +def TruthTableOp : CombOp<"truth_table", [Pure]> { + let summary = "Return a true/false based on a lookup table"; + let description = [{ + ``` + %a = ... : i1 + %b = ... : i1 + %0 = comb.truth_table %a, %b -> [false, true, true, false] + ``` + + This operation assumes a fully elaborated table -- 2^n entries. Inputs are + sorted MSB -> LSB from left to right and the offset into `lookupTable` is + computed from them. The table is sorted from 0 -> (2^n - 1) from left to + right. + + No difference from array_get into an array of constants except for xprop + behavior. If one of the inputs is unknown, but said input doesn't make a + difference in the output (based on the lookup table) the result should not + be 'x' -- it should be the well-known result. + }]; + + let arguments = (ins Variadic:$inputs, BoolArrayAttr:$lookupTable); + let results = (outs I1:$result); + + let assemblyFormat = [{ + $inputs `->` $lookupTable attr-dict + }]; + + let hasVerifier = 1; +} diff --git a/lib/Dialect/Comb/IR/BUILD b/lib/Dialect/Comb/IR/BUILD new file mode 100644 index 000000000..7a40c72ac --- /dev/null +++ b/lib/Dialect/Comb/IR/BUILD @@ -0,0 +1,28 @@ +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "Dialect", + srcs = [ + "CombDialect.cpp", + "CombOps.cpp", + ], + hdrs = [ + "@heir//include/Dialect/Comb/IR:CombDialect.h", + "@heir//include/Dialect/Comb/IR:CombOps.h", + ], + deps = [ + "@heir//include/Dialect/Comb/IR:dialect_inc_gen", + "@heir//include/Dialect/Comb/IR:enum_inc_gen", + "@heir//include/Dialect/Comb/IR:ops_inc_gen", + "@heir//include/Dialect/Comb/IR:type_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Dialect/Comb/IR/CombDialect.cpp b/lib/Dialect/Comb/IR/CombDialect.cpp new file mode 100644 index 000000000..d59230d47 --- /dev/null +++ b/lib/Dialect/Comb/IR/CombDialect.cpp @@ -0,0 +1,42 @@ +//===- CombDialect.cpp - Implement the Comb dialect -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Comb dialect. +// +//===----------------------------------------------------------------------===// + +#include "include/Dialect/Comb/IR/CombDialect.h" + +#include "include/Dialect/Comb/IR/CombOps.h" +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace comb { + +//===----------------------------------------------------------------------===// +// Dialect specification. +//===----------------------------------------------------------------------===// + +void CombDialect::initialize() { + // Register operations. + addOperations< +#define GET_OP_LIST +#include "include/Dialect/Comb/IR/Comb.cpp.inc" + >(); +} + +} // namespace comb +} // namespace heir +} // namespace mlir + +// Provide implementations for the enums we use. +#include "include/Dialect/Comb/IR/CombDialect.cpp.inc" +#include "include/Dialect/Comb/IR/CombEnums.cpp.inc" diff --git a/lib/Dialect/Comb/IR/CombOps.cpp b/lib/Dialect/Comb/IR/CombOps.cpp new file mode 100644 index 000000000..8a25fbee6 --- /dev/null +++ b/lib/Dialect/Comb/IR/CombOps.cpp @@ -0,0 +1,240 @@ +//===- CombOps.cpp - Implement the Comb operations ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements combinational ops. +// +//===----------------------------------------------------------------------===// + +#include "include/Dialect/Comb/IR/CombOps.h" + +#include "llvm/include/llvm/Support/FormatVariadic.h" // from @llvm-project +#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace comb { + +//===----------------------------------------------------------------------===// +// ICmpOp +//===----------------------------------------------------------------------===// + +ICmpPredicate ICmpOp::getFlippedPredicate(ICmpPredicate predicate) { + switch (predicate) { + case ICmpPredicate::eq: + return ICmpPredicate::eq; + case ICmpPredicate::ne: + return ICmpPredicate::ne; + case ICmpPredicate::slt: + return ICmpPredicate::sgt; + case ICmpPredicate::sle: + return ICmpPredicate::sge; + case ICmpPredicate::sgt: + return ICmpPredicate::slt; + case ICmpPredicate::sge: + return ICmpPredicate::sle; + case ICmpPredicate::ult: + return ICmpPredicate::ugt; + case ICmpPredicate::ule: + return ICmpPredicate::uge; + case ICmpPredicate::ugt: + return ICmpPredicate::ult; + case ICmpPredicate::uge: + return ICmpPredicate::ule; + case ICmpPredicate::ceq: + return ICmpPredicate::ceq; + case ICmpPredicate::cne: + return ICmpPredicate::cne; + case ICmpPredicate::weq: + return ICmpPredicate::weq; + case ICmpPredicate::wne: + return ICmpPredicate::wne; + } + llvm_unreachable("unknown comparison predicate"); +} + +bool ICmpOp::isPredicateSigned(ICmpPredicate predicate) { + switch (predicate) { + case ICmpPredicate::ult: + case ICmpPredicate::ugt: + case ICmpPredicate::ule: + case ICmpPredicate::uge: + case ICmpPredicate::ne: + case ICmpPredicate::eq: + case ICmpPredicate::cne: + case ICmpPredicate::ceq: + case ICmpPredicate::wne: + case ICmpPredicate::weq: + return false; + case ICmpPredicate::slt: + case ICmpPredicate::sgt: + case ICmpPredicate::sle: + case ICmpPredicate::sge: + return true; + } + llvm_unreachable("unknown comparison predicate"); +} + +/// Returns the predicate for a logically negated comparison, e.g. mapping +/// EQ => NE and SLE => SGT. +ICmpPredicate ICmpOp::getNegatedPredicate(ICmpPredicate predicate) { + switch (predicate) { + case ICmpPredicate::eq: + return ICmpPredicate::ne; + case ICmpPredicate::ne: + return ICmpPredicate::eq; + case ICmpPredicate::slt: + return ICmpPredicate::sge; + case ICmpPredicate::sle: + return ICmpPredicate::sgt; + case ICmpPredicate::sgt: + return ICmpPredicate::sle; + case ICmpPredicate::sge: + return ICmpPredicate::slt; + case ICmpPredicate::ult: + return ICmpPredicate::uge; + case ICmpPredicate::ule: + return ICmpPredicate::ugt; + case ICmpPredicate::ugt: + return ICmpPredicate::ule; + case ICmpPredicate::uge: + return ICmpPredicate::ult; + case ICmpPredicate::ceq: + return ICmpPredicate::cne; + case ICmpPredicate::cne: + return ICmpPredicate::ceq; + case ICmpPredicate::weq: + return ICmpPredicate::wne; + case ICmpPredicate::wne: + return ICmpPredicate::weq; + } + llvm_unreachable("unknown comparison predicate"); +} + +//===----------------------------------------------------------------------===// +// Unary Operations +//===----------------------------------------------------------------------===// + +LogicalResult ReplicateOp::verify() { + // The source must be equal or smaller than the dest type, and an even + // multiple of it. Both are already known to be signless integers. + auto srcWidth = getOperand().getType().cast().getWidth(); + auto dstWidth = getType().cast().getWidth(); + if (srcWidth == 0) + return emitOpError("replicate does not take zero bit integer"); + + if (srcWidth > dstWidth) + return emitOpError("replicate cannot shrink bitwidth of operand"), + failure(); + + if (dstWidth % srcWidth) + return emitOpError("replicate must produce integer multiple of operand"), + failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Variadic operations +//===----------------------------------------------------------------------===// + +static LogicalResult verifyUTBinOp(Operation *op) { + if (op->getOperands().empty()) + return op->emitOpError("requires 1 or more args"); + return success(); +} + +LogicalResult AddOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult MulOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult AndOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult OrOp::verify() { return verifyUTBinOp(*this); } + +LogicalResult XorOp::verify() { return verifyUTBinOp(*this); } + +//===----------------------------------------------------------------------===// +// ConcatOp +//===----------------------------------------------------------------------===// + +static unsigned getTotalWidth(ValueRange inputs) { + unsigned resultWidth = 0; + for (auto input : inputs) { + resultWidth += input.getType().cast().getWidth(); + } + return resultWidth; +} + +LogicalResult ConcatOp::verify() { + unsigned tyWidth = getType().cast().getWidth(); + unsigned operandsTotalWidth = getTotalWidth(getInputs()); + if (tyWidth != operandsTotalWidth) + return emitOpError( + "ConcatOp requires operands total width to " + "match type width. operands " + "totalWidth is") + << operandsTotalWidth << ", but concatOp type width is " << tyWidth; + + return success(); +} + +void ConcatOp::build(OpBuilder &builder, OperationState &result, Value hd, + ValueRange tl) { + result.addOperands(ValueRange{hd}); + result.addOperands(tl); + unsigned hdWidth = hd.getType().cast().getWidth(); + result.addTypes(builder.getIntegerType(getTotalWidth(tl) + hdWidth)); +} + +LogicalResult ConcatOp::inferReturnTypes( + MLIRContext *context, std::optional loc, ValueRange operands, + DictionaryAttr attrs, mlir::OpaqueProperties properties, + mlir::RegionRange regions, SmallVectorImpl &results) { + unsigned resultWidth = getTotalWidth(operands); + results.push_back(IntegerType::get(context, resultWidth)); + return success(); +} + +//===----------------------------------------------------------------------===// +// Other Operations +//===----------------------------------------------------------------------===// + +LogicalResult ExtractOp::verify() { + unsigned srcWidth = getInput().getType().cast().getWidth(); + unsigned dstWidth = getType().cast().getWidth(); + if (getLowBit() >= srcWidth || srcWidth - getLowBit() < dstWidth) + return emitOpError("from bit too large for input"), failure(); + + return success(); +} + +LogicalResult TruthTableOp::verify() { + size_t numInputs = getInputs().size(); + if (numInputs >= sizeof(size_t) * 8) + return emitOpError("Truth tables support a maximum of ") + << sizeof(size_t) * 8 - 1 << " inputs on your platform"; + + ArrayAttr table = getLookupTable(); + if (table.size() != (1ull << numInputs)) + return emitOpError("Expected lookup table of 2^n length"); + return success(); +} + +} // namespace comb +} // namespace heir +} // namespace mlir + +//===----------------------------------------------------------------------===// +// TableGen generated logic. +//===----------------------------------------------------------------------===// + +// Provide the autogenerated implementation guts for the Op classes. +#define GET_OP_CLASSES +#include "include/Dialect/Comb/IR/Comb.cpp.inc" diff --git a/tests/comb.mlir b/tests/comb.mlir new file mode 100644 index 000000000..d8b9684db --- /dev/null +++ b/tests/comb.mlir @@ -0,0 +1,8 @@ +// RUN: heir-opt %s -verify-diagnostics + +module { + func.func @comb(%a: i1, %b: i1) -> () { + %0 = comb.truth_table %a, %b -> [true, false, true, false] + return + } +} diff --git a/tools/BUILD b/tools/BUILD index ef40621cd..0fd299ff1 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -16,6 +16,7 @@ cc_binary( "@heir//lib/Conversion/MemrefToArith:MemrefToArithRegistration", "@heir//lib/Conversion/PolyToStandard", "@heir//lib/Dialect/BGV/IR:Dialect", + "@heir//lib/Dialect/Comb/IR:Dialect", "@heir//lib/Dialect/EncryptedArith/IR:Dialect", "@heir//lib/Dialect/LWE/IR:Dialect", "@heir//lib/Dialect/Poly/IR:Dialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index 70b692f78..21269efa4 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -2,6 +2,7 @@ #include "include/Conversion/MemrefToArith/MemrefToArith.h" #include "include/Conversion/PolyToStandard/PolyToStandard.h" #include "include/Dialect/BGV/IR/BGVDialect.h" +#include "include/Dialect/Comb/IR/CombDialect.h" #include "include/Dialect/EncryptedArith/IR/EncryptedArithDialect.h" #include "include/Dialect/LWE/IR/LWEDialect.h" #include "include/Dialect/Poly/IR/PolyDialect.h" @@ -80,6 +81,7 @@ int main(int argc, char **argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); // Add expected MLIR dialects to the registry. registry.insert();