Skip to content

Commit

Permalink
Merge pull request google#222 from j2kun:reproduce-apint-failure
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579184490
  • Loading branch information
copybara-github committed Nov 3, 2023
2 parents 880d682 + 38eee4a commit da3125a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 25 deletions.
3 changes: 2 additions & 1 deletion include/Dialect/PolyExt/IR/PolyExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
include "PolyExtDialect.td"
include "include/Dialect/Polynomial/IR/PolynomialDialect.td"
include "include/Dialect/Polynomial/IR/PolynomialTypes.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand All @@ -25,7 +26,7 @@ def PolyExt_CModSwitchOp : PolyExt_Op<"cmod_switch", traits = [Pure, Elementwise
let arguments = (ins
PolynomialLike:$x,
// TODO: make congruence_modulus optional with default value 1
APIntAttr:$congruence_modulus
Builtin_IntegerAttr:$congruence_modulus
);

let results = (outs
Expand Down
10 changes: 8 additions & 2 deletions include/Dialect/Polynomial/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,17 @@ gentbl_cc_library(
name = "attributes_inc_gen",
tbl_outs = [
(
["-gen-attrdef-decls"],
[
"-gen-attrdef-decls",
"-attrdefs-dialect=polynomial",
],
"PolynomialAttributes.h.inc",
),
(
["-gen-attrdef-defs"],
[
"-gen-attrdef-defs",
"-attrdefs-dialect=polynomial",
],
"PolynomialAttributes.cpp.inc",
),
(
Expand Down
26 changes: 11 additions & 15 deletions include/Dialect/Polynomial/IR/PolynomialAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

include "include/Dialect/Polynomial/IR/PolynomialDialect.td"

include "mlir/IR/DialectBase.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/DialectBase.td"

class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<Polynomial_Dialect, name, traits> {
Expand Down Expand Up @@ -46,26 +47,21 @@ def Ring_Attr : Polynomial_Attr<"Ring", "ring"> {
#ring = #polynomial.ring<cmod=1234, ideal=#polynomial.polynomial<x**1024 + 1>>
}];

// The extra cwidth parameter is required because, when the MLIR framework's
// StorageUniquer is determining whether to build a new instance of this
// attribute or reuse an existing one, it does an == comparison on the
// parameters, and comparing two APInts with different underlying bit widths
// leads to an assertion failure. By adding the container bit width as a
// parameter, that bit width is compared first and the rest of the check is
// skipped if the bit widths disagree. The custom builder and custom
// parser/printer ensures that cwidth is automatically inferred from cmod, and
// not printed or parsed, so that the cwidth parameter need not be known to the
// user.
let parameters = (ins "int64_t": $cwidth, "APInt": $cmod, "Polynomial":$ideal);
let parameters = (ins "IntegerAttr": $cmod, "Polynomial":$ideal);

let builders = [
AttrBuilderWithInferredContext<(ins "APInt": $cmod, "Polynomial":$ideal), [{
return $_get(ideal.getContext(), cmod.getBitWidth(), cmod, ideal);
AttrBuilderWithInferredContext<
(ins "const APInt &": $cmod, "Polynomial":$ideal), [{
return $_get(
ideal.getContext(),
IntegerAttr::get(IntegerType::get(ideal.getContext(), cmod.getBitWidth()), cmod),
ideal
);
}]>
];
let extraClassDeclaration = [{
Polynomial ideal() const { return getIdeal(); }
APInt coefficientModulus() const { return getCmod(); }
APInt coefficientModulus() const { return getCmod().getValue(); }
}];

let skipDefaultBuilders = 1;
Expand Down
7 changes: 4 additions & 3 deletions lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ func::FuncOp PolynomialToStandard::buildPolynomialModFunc(FunctionType funcType,
// division. The lowering must fail:
auto divisor = ring.getIdeal();
auto leadingCoef = divisor.getTerms().back().coefficient;
auto leadingCoefInverse = leadingCoef.multiplicativeInverse(ring.getCmod());
auto leadingCoefInverse =
leadingCoef.multiplicativeInverse(ring.coefficientModulus());
// APInt signals no inverse by returning zero.
if (leadingCoefInverse.isZero()) {
signalPassFailure();
Expand Down Expand Up @@ -686,8 +687,8 @@ func::FuncOp PolynomialToStandard::buildPolynomialModFunc(FunctionType funcType,
auto toTensorOp = builder.create<ToTensorOp>(inputType, whileOp.getResult(0));
auto remainderModArg = builder.create<arith::ConstantOp>(
inputType, DenseElementsAttr::get(
inputType, APInt(inputType.getElementTypeBitWidth(),
ring.getCmod().getSExtValue())));
inputType, ring.coefficientModulus().sextOrTrunc(
inputType.getElementTypeBitWidth())));
auto remainderCoeffsRemOp = builder.create<arith::RemSIOp>(
toTensorOp.getResult(), remainderModArg.getResult());

Expand Down
16 changes: 13 additions & 3 deletions lib/Dialect/PolyExt/IR/PolyExtDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,30 @@ polynomial::PolynomialType getPolynomialType(Type t) {
LogicalResult CModSwitchOp::verify() {
auto xRing = getPolynomialType(getX().getType()).getRing();
auto outRing = getPolynomialType(getOutput().getType()).getRing();
auto outRingCmod = outRing.getCmod().getValue();
auto xRingCmod = xRing.getCmod().getValue();

if (xRing.getIdeal() != outRing.getIdeal()) {
return emitOpError("input and output rings ideals must be the same");
}

if (xRing.getCmod().ule(outRing.getCmod())) {
if (xRingCmod.getBitWidth() != outRingCmod.getBitWidth()) {
return emitOpError(
"input ring cmod and output ring cmod's have different bit widths; "
"consider annotating the types with `: i64` or similar, or using "
"the relevant builder on Ring_Attr");
}

if (xRingCmod.ule(outRingCmod)) {
return emitOpError("input ring cmod must be larger than output ring cmod");
}

if (getCongruenceModulus().ule(0)) {
APInt congMod = getCongruenceModulus().getValue();
if (congMod.ule(APInt::getZero(congMod.getBitWidth()))) {
return emitOpError("congruence modulus must be positive");
}

if (outRing.getCmod().ule(getCongruenceModulus())) {
if (outRingCmod.ule(congMod.zextOrTrunc(outRingCmod.getBitWidth()))) {
return emitOpError(
"output ring cmod must be larger than congruence modulus");
}
Expand Down
2 changes: 1 addition & 1 deletion tests/poly_ext/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

module {
func.func @test_ops(%p0 : !polynomial.polynomial<#ring1>) {
%cmod_switch = poly_ext.cmod_switch %p0 {congruence_modulus=117} : !polynomial.polynomial<#ring1> -> !polynomial.polynomial<#ring2>
%cmod_switch = poly_ext.cmod_switch %p0 {congruence_modulus=117 : i16} : !polynomial.polynomial<#ring1> -> !polynomial.polynomial<#ring2>
return
}

Expand Down

0 comments on commit da3125a

Please sign in to comment.