From c1ab041702d6d99562e045980d716dbea6c5ac7f Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Thu, 3 Oct 2024 21:09:20 +0200 Subject: [PATCH] [OM] Add integer shift left op --- include/circt/Dialect/OM/OMOps.td | 14 +++ lib/Dialect/OM/OMOps.cpp | 16 +++ test/Dialect/OM/round-trip.mlir | 3 + .../Dialect/OM/Evaluator/EvaluatorTests.cpp | 102 ++++++++++++++++++ 4 files changed, 135 insertions(+) diff --git a/include/circt/Dialect/OM/OMOps.td b/include/circt/Dialect/OM/OMOps.td index 6338218c2ce6..9155d821b2cc 100644 --- a/include/circt/Dialect/OM/OMOps.td +++ b/include/circt/Dialect/OM/OMOps.td @@ -516,4 +516,18 @@ def IntegerShrOp : IntegerBinaryArithmeticOp<"integer.shr"> { }]; } +def IntegerShlOp : IntegerBinaryArithmeticOp<"integer.shl"> { + let summary = "Shift an OMIntegerType value left by an OMIntegerType value"; + let description = [{ + Perform arbitrary precision signed integer arithmetic shift left of the lhs + OMIntegerType value by the rhs OMIntegerType value. The rhs value must be + non-negative. + + Example: + ```mlir + %2 = om.integer.shl %0, %1 : !om.integer + ``` + }]; +} + #endif // CIRCT_DIALECT_OM_OMOPS_TD diff --git a/lib/Dialect/OM/OMOps.cpp b/lib/Dialect/OM/OMOps.cpp index 7c1861906daa..e635a37a9f42 100644 --- a/lib/Dialect/OM/OMOps.cpp +++ b/lib/Dialect/OM/OMOps.cpp @@ -508,6 +508,22 @@ IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs, return success(lhs >> rhs.getExtValue()); } +//===----------------------------------------------------------------------===// +// IntegerShlOp +//===----------------------------------------------------------------------===// + +FailureOr +IntegerShlOp::evaluateIntegerOperation(const llvm::APSInt &lhs, + const llvm::APSInt &rhs) { + // Check non-negative constraint from operation semantics. + if (!rhs.isNonNegative()) + return emitOpError("shift amount must be non-negative"); + // Check size constraint from implementation detail of using getExtValue. + if (!rhs.isRepresentableByInt64()) + return emitOpError("shift amount must be representable in 64 bits"); + return success(lhs << rhs.getExtValue()); +} + //===----------------------------------------------------------------------===// // TableGen generated logic. //===----------------------------------------------------------------------===// diff --git a/test/Dialect/OM/round-trip.mlir b/test/Dialect/OM/round-trip.mlir index 738e895e434c..575774c93f5c 100644 --- a/test/Dialect/OM/round-trip.mlir +++ b/test/Dialect/OM/round-trip.mlir @@ -310,4 +310,7 @@ om.class @IntegerArithmetic() { // CHECK: om.integer.shr %0, %1 : !om.integer %4 = om.integer.shr %0, %1 : !om.integer + + // CHECK: om.integer.shl %0, %1 : !om.integer + %5 = om.integer.shl %0, %1 : !om.integer } diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index abd9547c4632..25c2b920a84c 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -783,6 +783,108 @@ TEST(EvaluatorTests, IntegerBinaryArithmeticShrTooLarge) { ASSERT_TRUE(failed(result)); } +TEST(EvaluatorTests, IntegerBinaryArithmeticShl) { + StringRef mod = "om.class @IntegerBinaryArithmeticShl() {" + " %0 = om.constant #om.integer<8 : si7> : !om.integer" + " %1 = om.constant #om.integer<2 : si3> : !om.integer" + " %2 = om.integer.shl %0, %1 : !om.integer" + " om.class.field @result, %2 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticShl"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ(32, llvm::cast(fieldValue.get()) + ->getAs() + .getValue() + .getValue()); +} + +TEST(EvaluatorTests, IntegerBinaryArithmeticShlNegative) { + StringRef mod = "om.class @IntegerBinaryArithmeticShlNegative() {" + " %0 = om.constant #om.integer<8 : si5> : !om.integer" + " %1 = om.constant #om.integer<-2 : si3> : !om.integer" + " %2 = om.integer.shl %0, %1 : !om.integer" + " om.class.field @result, %2 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + if (StringRef(diag.str()).starts_with("'om.integer.shl'")) + ASSERT_EQ(diag.str(), + "'om.integer.shl' op shift amount must be non-negative"); + if (StringRef(diag.str()).starts_with("failed")) + ASSERT_EQ(diag.str(), "failed to evaluate integer operation"); + }); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticShlNegative"), {}); + + ASSERT_TRUE(failed(result)); +} + +TEST(EvaluatorTests, IntegerBinaryArithmeticShlTooLarge) { + StringRef mod = "om.class @IntegerBinaryArithmeticShlTooLarge() {" + " %0 = om.constant #om.integer<8 : si5> : !om.integer" + " %1 = om.constant #om.integer<36893488147419100000 : si66> " + ": !om.integer" + " %2 = om.integer.shl %0, %1 : !om.integer" + " om.class.field @result, %2 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + if (StringRef(diag.str()).starts_with("'om.integer.shl'")) + ASSERT_EQ( + diag.str(), + "'om.integer.shl' op shift amount must be representable in 64 bits"); + if (StringRef(diag.str()).starts_with("failed")) + ASSERT_EQ(diag.str(), "failed to evaluate integer operation"); + }); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticShlTooLarge"), {}); + + ASSERT_TRUE(failed(result)); +} + TEST(EvaluatorTests, IntegerBinaryArithmeticObjects) { StringRef mod = "om.class @Class1() {" " %0 = om.constant #om.integer<1 : si3> : !om.integer"