diff --git a/include/Dialect/Polynomial/IR/BUILD b/include/Dialect/Polynomial/IR/BUILD index d7acbe6b67..3d43d74d65 100644 --- a/include/Dialect/Polynomial/IR/BUILD +++ b/include/Dialect/Polynomial/IR/BUILD @@ -64,11 +64,17 @@ gentbl_cc_library( name = "attributes_inc_gen", tbl_outs = [ ( - ["-gen-attrdef-decls"], + [ + "-gen-attrdef-decls", + "-attrdefs-dialect=polynomial", + ], "PolynomialAttributes.h.inc", ), ( - ["-gen-attrdef-defs"], + [ + "-gen-attrdef-defs", + "-attrdefs-dialect=polynomial", + ], "PolynomialAttributes.cpp.inc", ), ( diff --git a/include/Dialect/Polynomial/IR/PolynomialAttributes.td b/include/Dialect/Polynomial/IR/PolynomialAttributes.td index 06a7f45e43..22b26356bd 100644 --- a/include/Dialect/Polynomial/IR/PolynomialAttributes.td +++ b/include/Dialect/Polynomial/IR/PolynomialAttributes.td @@ -3,8 +3,9 @@ include "include/Dialect/Polynomial/IR/PolynomialDialect.td" -include "mlir/IR/DialectBase.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/DialectBase.td" class Polynomial_Attr traits = []> : AttrDef { @@ -46,26 +47,29 @@ def Ring_Attr : Polynomial_Attr<"Ring", "ring"> { #ring = #polynomial.ring> }]; - // The extra cwidth parameter is required because, when the MLIR framework's - // StorageUniquer is determining whether to build a new instance of this - // attribute or reuse an existing one, it does an == comparison on the - // parameters, and comparing two APInts with different underlying bit widths - // leads to an assertion failure. By adding the container bit width as a - // parameter, that bit width is compared first and the rest of the check is - // skipped if the bit widths disagree. The custom builder and custom - // parser/printer ensures that cwidth is automatically inferred from cmod, and - // not printed or parsed, so that the cwidth parameter need not be known to the - // user. - let parameters = (ins "int64_t": $cwidth, "APInt": $cmod, "Polynomial":$ideal); + let parameters = (ins "IntegerAttr": $cmod, "Polynomial":$ideal); let builders = [ - AttrBuilderWithInferredContext<(ins "APInt": $cmod, "Polynomial":$ideal), [{ - return $_get(ideal.getContext(), cmod.getBitWidth(), cmod, ideal); + AttrBuilderWithInferredContext< + (ins "IntegerType":$type, "const APInt &": $cmod, "Polynomial":$ideal), [{ + return $_get( + type.getContext(), + IntegerAttr::get(type, cmod.zextOrTrunc(type.getWidth())), + ideal + ); + }]>, + AttrBuilderWithInferredContext< + (ins "const APInt &": $cmod, "Polynomial":$ideal), [{ + return $_get( + ideal.getContext(), + IntegerAttr::get(IntegerType::get(ideal.getContext(), cmod.getBitWidth()), cmod), + ideal + ); }]> ]; let extraClassDeclaration = [{ Polynomial ideal() const { return getIdeal(); } - APInt coefficientModulus() const { return getCmod(); } + APInt coefficientModulus() const { return getCmod().getValue(); } }]; let skipDefaultBuilders = 1; diff --git a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp index cf85cc0d53..49b22a4aaa 100644 --- a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp +++ b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp @@ -621,7 +621,8 @@ func::FuncOp PolynomialToStandard::buildPolynomialModFunc(FunctionType funcType, // division. The lowering must fail: auto divisor = ring.getIdeal(); auto leadingCoef = divisor.getTerms().back().coefficient; - auto leadingCoefInverse = leadingCoef.multiplicativeInverse(ring.getCmod()); + auto leadingCoefInverse = + leadingCoef.multiplicativeInverse(ring.coefficientModulus()); // APInt signals no inverse by returning zero. if (leadingCoefInverse.isZero()) { signalPassFailure(); @@ -686,8 +687,8 @@ func::FuncOp PolynomialToStandard::buildPolynomialModFunc(FunctionType funcType, auto toTensorOp = builder.create(inputType, whileOp.getResult(0)); auto remainderModArg = builder.create( inputType, DenseElementsAttr::get( - inputType, APInt(inputType.getElementTypeBitWidth(), - ring.getCmod().getSExtValue()))); + inputType, ring.coefficientModulus().sextOrTrunc( + inputType.getElementTypeBitWidth()))); auto remainderCoeffsRemOp = builder.create( toTensorOp.getResult(), remainderModArg.getResult()); diff --git a/lib/Dialect/PolyExt/IR/PolyExtDialect.cpp b/lib/Dialect/PolyExt/IR/PolyExtDialect.cpp index 9a072bb8cb..71b7d7aed2 100644 --- a/lib/Dialect/PolyExt/IR/PolyExtDialect.cpp +++ b/lib/Dialect/PolyExt/IR/PolyExtDialect.cpp @@ -33,13 +33,21 @@ polynomial::PolynomialType getPolynomialType(Type t) { LogicalResult CModSwitchOp::verify() { auto xRing = getPolynomialType(getX().getType()).getRing(); auto outRing = getPolynomialType(getOutput().getType()).getRing(); - auto outRingCmod = outRing.getCmod(); + auto outRingCmod = outRing.getCmod().getValue(); + auto xRingCmod = xRing.getCmod().getValue(); if (xRing.getIdeal() != outRing.getIdeal()) { return emitOpError("input and output rings ideals must be the same"); } - if (xRing.getCmod().ule(outRingCmod)) { + if (xRingCmod.getBitWidth() != outRingCmod.getBitWidth()) { + return emitOpError( + "input ring cmod and output ring cmod's have different bit widths; " + "consider annotating the types with `: i64` or similar, or using " + "the relevant builder on Ring_Attr"); + } + + if (xRingCmod.ule(outRingCmod)) { return emitOpError("input ring cmod must be larger than output ring cmod"); }