Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OM] Add integer shift left op #7658

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/circt/Dialect/OM/OMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -516,4 +516,18 @@ def IntegerShrOp : IntegerBinaryArithmeticOp<"integer.shr"> {
}];
}

def IntegerShlOp : IntegerBinaryArithmeticOp<"integer.shl"> {
let summary = "Shift an OMIntegerType value left by an OMIntegerType value";
let description = [{
Perform arbitrary precision signed integer arithmetic shift left of the lhs
OMIntegerType value by the rhs OMIntegerType value. The rhs value must be
non-negative.

Example:
```mlir
%2 = om.integer.shl %0, %1 : !om.integer
```
}];
}

#endif // CIRCT_DIALECT_OM_OMOPS_TD
16 changes: 16 additions & 0 deletions lib/Dialect/OM/OMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,22 @@ IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
return success(lhs >> rhs.getExtValue());
}

//===----------------------------------------------------------------------===//
// IntegerShlOp
//===----------------------------------------------------------------------===//

FailureOr<llvm::APSInt>
IntegerShlOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
const llvm::APSInt &rhs) {
// Check non-negative constraint from operation semantics.
if (!rhs.isNonNegative())
return emitOpError("shift amount must be non-negative");
// Check size constraint from implementation detail of using getExtValue.
if (!rhs.isRepresentableByInt64())
return emitOpError("shift amount must be representable in 64 bits");
return success(lhs << rhs.getExtValue());
}

//===----------------------------------------------------------------------===//
// TableGen generated logic.
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions test/Dialect/OM/round-trip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,7 @@ om.class @IntegerArithmetic() {

// CHECK: om.integer.shr %0, %1 : !om.integer
%4 = om.integer.shr %0, %1 : !om.integer

// CHECK: om.integer.shl %0, %1 : !om.integer
%5 = om.integer.shl %0, %1 : !om.integer
}
102 changes: 102 additions & 0 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,108 @@ TEST(EvaluatorTests, IntegerBinaryArithmeticShrTooLarge) {
ASSERT_TRUE(failed(result));
}

TEST(EvaluatorTests, IntegerBinaryArithmeticShl) {
StringRef mod = "om.class @IntegerBinaryArithmeticShl() {"
" %0 = om.constant #om.integer<8 : si7> : !om.integer"
" %1 = om.constant #om.integer<2 : si3> : !om.integer"
" %2 = om.integer.shl %0, %1 : !om.integer"
" om.class.field @result, %2 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticShl"), {});

ASSERT_TRUE(succeeded(result));

auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();

ASSERT_EQ(32, llvm::cast<evaluator::AttributeValue>(fieldValue.get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());
}

TEST(EvaluatorTests, IntegerBinaryArithmeticShlNegative) {
StringRef mod = "om.class @IntegerBinaryArithmeticShlNegative() {"
" %0 = om.constant #om.integer<8 : si5> : !om.integer"
" %1 = om.constant #om.integer<-2 : si3> : !om.integer"
" %2 = om.integer.shl %0, %1 : !om.integer"
" om.class.field @result, %2 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
if (StringRef(diag.str()).starts_with("'om.integer.shl'"))
ASSERT_EQ(diag.str(),
"'om.integer.shl' op shift amount must be non-negative");
if (StringRef(diag.str()).starts_with("failed"))
ASSERT_EQ(diag.str(), "failed to evaluate integer operation");
});

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticShlNegative"), {});

ASSERT_TRUE(failed(result));
}

TEST(EvaluatorTests, IntegerBinaryArithmeticShlTooLarge) {
StringRef mod = "om.class @IntegerBinaryArithmeticShlTooLarge() {"
" %0 = om.constant #om.integer<8 : si5> : !om.integer"
" %1 = om.constant #om.integer<36893488147419100000 : si66> "
": !om.integer"
" %2 = om.integer.shl %0, %1 : !om.integer"
" om.class.field @result, %2 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
if (StringRef(diag.str()).starts_with("'om.integer.shl'"))
ASSERT_EQ(
diag.str(),
"'om.integer.shl' op shift amount must be representable in 64 bits");
if (StringRef(diag.str()).starts_with("failed"))
ASSERT_EQ(diag.str(), "failed to evaluate integer operation");
});

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticShlTooLarge"), {});

ASSERT_TRUE(failed(result));
}

TEST(EvaluatorTests, IntegerBinaryArithmeticObjects) {
StringRef mod = "om.class @Class1() {"
" %0 = om.constant #om.integer<1 : si3> : !om.integer"
Expand Down
Loading