Skip to content

Commit

Permalink
[OM] Add integer shift left op (#7658)
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart authored Oct 4, 2024
1 parent 001d806 commit f287ba2
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 0 deletions.
14 changes: 14 additions & 0 deletions include/circt/Dialect/OM/OMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -516,4 +516,18 @@ def IntegerShrOp : IntegerBinaryArithmeticOp<"integer.shr"> {
}];
}

def IntegerShlOp : IntegerBinaryArithmeticOp<"integer.shl"> {
let summary = "Shift an OMIntegerType value left by an OMIntegerType value";
let description = [{
Perform arbitrary precision signed integer arithmetic shift left of the lhs
OMIntegerType value by the rhs OMIntegerType value. The rhs value must be
non-negative.

Example:
```mlir
%2 = om.integer.shl %0, %1 : !om.integer
```
}];
}

#endif // CIRCT_DIALECT_OM_OMOPS_TD
16 changes: 16 additions & 0 deletions lib/Dialect/OM/OMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,22 @@ IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
return success(lhs >> rhs.getExtValue());
}

//===----------------------------------------------------------------------===//
// IntegerShlOp
//===----------------------------------------------------------------------===//

FailureOr<llvm::APSInt>
IntegerShlOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
const llvm::APSInt &rhs) {
// Check non-negative constraint from operation semantics.
if (!rhs.isNonNegative())
return emitOpError("shift amount must be non-negative");
// Check size constraint from implementation detail of using getExtValue.
if (!rhs.isRepresentableByInt64())
return emitOpError("shift amount must be representable in 64 bits");
return success(lhs << rhs.getExtValue());
}

//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions test/Dialect/OM/round-trip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,7 @@ om.class @IntegerArithmetic() {

// CHECK: om.integer.shr %0, %1 : !om.integer
%4 = om.integer.shr %0, %1 : !om.integer

// CHECK: om.integer.shl %0, %1 : !om.integer
%5 = om.integer.shl %0, %1 : !om.integer
}
102 changes: 102 additions & 0 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,108 @@ TEST(EvaluatorTests, IntegerBinaryArithmeticShrTooLarge) {
ASSERT_TRUE(failed(result));
}

TEST(EvaluatorTests, IntegerBinaryArithmeticShl) {
StringRef mod = "om.class @IntegerBinaryArithmeticShl() {"
" %0 = om.constant #om.integer<8 : si7> : !om.integer"
" %1 = om.constant #om.integer<2 : si3> : !om.integer"
" %2 = om.integer.shl %0, %1 : !om.integer"
" om.class.field @result, %2 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticShl"), {});

ASSERT_TRUE(succeeded(result));

auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();

ASSERT_EQ(32, llvm::cast<evaluator::AttributeValue>(fieldValue.get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());
}

TEST(EvaluatorTests, IntegerBinaryArithmeticShlNegative) {
StringRef mod = "om.class @IntegerBinaryArithmeticShlNegative() {"
" %0 = om.constant #om.integer<8 : si5> : !om.integer"
" %1 = om.constant #om.integer<-2 : si3> : !om.integer"
" %2 = om.integer.shl %0, %1 : !om.integer"
" om.class.field @result, %2 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
if (StringRef(diag.str()).starts_with("'om.integer.shl'"))
ASSERT_EQ(diag.str(),
"'om.integer.shl' op shift amount must be non-negative");
if (StringRef(diag.str()).starts_with("failed"))
ASSERT_EQ(diag.str(), "failed to evaluate integer operation");
});

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticShlNegative"), {});

ASSERT_TRUE(failed(result));
}

TEST(EvaluatorTests, IntegerBinaryArithmeticShlTooLarge) {
StringRef mod = "om.class @IntegerBinaryArithmeticShlTooLarge() {"
" %0 = om.constant #om.integer<8 : si5> : !om.integer"
" %1 = om.constant #om.integer<36893488147419100000 : si66> "
": !om.integer"
" %2 = om.integer.shl %0, %1 : !om.integer"
" om.class.field @result, %2 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
if (StringRef(diag.str()).starts_with("'om.integer.shl'"))
ASSERT_EQ(
diag.str(),
"'om.integer.shl' op shift amount must be representable in 64 bits");
if (StringRef(diag.str()).starts_with("failed"))
ASSERT_EQ(diag.str(), "failed to evaluate integer operation");
});

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticShlTooLarge"), {});

ASSERT_TRUE(failed(result));
}

TEST(EvaluatorTests, IntegerBinaryArithmeticObjects) {
StringRef mod = "om.class @Class1() {"
" %0 = om.constant #om.integer<1 : si3> : !om.integer"
Expand Down

0 comments on commit f287ba2

Please sign in to comment.