Skip to content

Commit

Permalink
migrate RingAttr to use Builtin_IntegerAttr
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Nov 2, 2023
1 parent 64e4d61 commit d60b8a6
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 22 deletions.
10 changes: 8 additions & 2 deletions include/Dialect/Polynomial/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,17 @@ gentbl_cc_library(
name = "attributes_inc_gen",
tbl_outs = [
(
["-gen-attrdef-decls"],
[
"-gen-attrdef-decls",
"-attrdefs-dialect=polynomial",
],
"PolynomialAttributes.h.inc",
),
(
["-gen-attrdef-defs"],
[
"-gen-attrdef-defs",
"-attrdefs-dialect=polynomial",
],
"PolynomialAttributes.cpp.inc",
),
(
Expand Down
34 changes: 19 additions & 15 deletions include/Dialect/Polynomial/IR/PolynomialAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

include "include/Dialect/Polynomial/IR/PolynomialDialect.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/DialectBase.td"

class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Polynomial_Dialect, name, traits> {
Expand Down Expand Up @@ -46,26 +47,29 @@ def Ring_Attr : Polynomial_Attr<"Ring", "ring"> {
#ring = #polynomial.ring<cmod=1234, ideal=#polynomial.polynomial<x**1024 + 1>>
}];

// The extra cwidth parameter is required because, when the MLIR framework's
// StorageUniquer is determining whether to build a new instance of this
// attribute or reuse an existing one, it does an == comparison on the
// parameters, and comparing two APInts with different underlying bit widths
// leads to an assertion failure. By adding the container bit width as a
// parameter, that bit width is compared first and the rest of the check is
// skipped if the bit widths disagree. The custom builder and custom
// parser/printer ensures that cwidth is automatically inferred from cmod, and
// not printed or parsed, so that the cwidth parameter need not be known to the
// user.
let parameters = (ins "int64_t": $cwidth, "APInt": $cmod, "Polynomial":$ideal);
let parameters = (ins "IntegerAttr": $cmod, "Polynomial":$ideal);

let builders = [
AttrBuilderWithInferredContext<(ins "APInt": $cmod, "Polynomial":$ideal), [{
return $_get(ideal.getContext(), cmod.getBitWidth(), cmod, ideal);
AttrBuilderWithInferredContext<
(ins "IntegerType":$type, "const APInt &": $cmod, "Polynomial":$ideal), [{
return $_get(
type.getContext(),
IntegerAttr::get(type, cmod.zextOrTrunc(type.getWidth())),
ideal
);
}]>,
AttrBuilderWithInferredContext<
(ins "const APInt &": $cmod, "Polynomial":$ideal), [{
return $_get(
ideal.getContext(),
IntegerAttr::get(IntegerType::get(ideal.getContext(), cmod.getBitWidth()), cmod),
ideal
);
}]>
];
let extraClassDeclaration = [{
Polynomial ideal() const { return getIdeal(); }
APInt coefficientModulus() const { return getCmod(); }
APInt coefficientModulus() const { return getCmod().getValue(); }
}];

let skipDefaultBuilders = 1;
Expand Down
7 changes: 4 additions & 3 deletions lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ func::FuncOp PolynomialToStandard::buildPolynomialModFunc(FunctionType funcType,
// division. The lowering must fail:
auto divisor = ring.getIdeal();
auto leadingCoef = divisor.getTerms().back().coefficient;
auto leadingCoefInverse = leadingCoef.multiplicativeInverse(ring.getCmod());
auto leadingCoefInverse =
leadingCoef.multiplicativeInverse(ring.coefficientModulus());
// APInt signals no inverse by returning zero.
if (leadingCoefInverse.isZero()) {
signalPassFailure();
Expand Down Expand Up @@ -686,8 +687,8 @@ func::FuncOp PolynomialToStandard::buildPolynomialModFunc(FunctionType funcType,
auto toTensorOp = builder.create<ToTensorOp>(inputType, whileOp.getResult(0));
auto remainderModArg = builder.create<arith::ConstantOp>(
inputType, DenseElementsAttr::get(
inputType, APInt(inputType.getElementTypeBitWidth(),
ring.getCmod().getSExtValue())));
inputType, ring.coefficientModulus().sextOrTrunc(
inputType.getElementTypeBitWidth())));
auto remainderCoeffsRemOp = builder.create<arith::RemSIOp>(
toTensorOp.getResult(), remainderModArg.getResult());

Expand Down
12 changes: 10 additions & 2 deletions lib/Dialect/PolyExt/IR/PolyExtDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,21 @@ polynomial::PolynomialType getPolynomialType(Type t) {
LogicalResult CModSwitchOp::verify() {
auto xRing = getPolynomialType(getX().getType()).getRing();
auto outRing = getPolynomialType(getOutput().getType()).getRing();
auto outRingCmod = outRing.getCmod();
auto outRingCmod = outRing.getCmod().getValue();
auto xRingCmod = xRing.getCmod().getValue();

if (xRing.getIdeal() != outRing.getIdeal()) {
return emitOpError("input and output rings ideals must be the same");
}

if (xRing.getCmod().ule(outRingCmod)) {
if (xRingCmod.getBitWidth() != outRingCmod.getBitWidth()) {
return emitOpError(
"input ring cmod and output ring cmod's have different bit widths; "
"consider annotating the types with `: i64` or similar, or using "
"the relevant builder on Ring_Attr");
}

if (xRingCmod.ule(outRingCmod)) {
return emitOpError("input ring cmod must be larger than output ring cmod");
}

Expand Down

0 comments on commit d60b8a6

Please sign in to comment.