Skip to content

Commit

Permalink
Merge pull request #134 from asraa:lower-poly-add-mul
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565391673
  • Loading branch information
copybara-github committed Sep 14, 2023
2 parents 94a5d2e + 5d8cc8f commit 183fc81
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 7 deletions.
74 changes: 68 additions & 6 deletions lib/Conversion/PolyToStandard/PolyToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
#include "include/Dialect/Poly/IR/PolyOps.h"
#include "include/Dialect/Poly/IR/PolyTypes.h"
#include "lib/Conversion/Utils.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand All @@ -28,6 +29,9 @@ class PolyToStandardTypeConverter : public TypeConverter {
IntegerType elementTy =
IntegerType::get(ctx, attr.coefficientModulus().getBitWidth(),
IntegerType::SignednessSemantics::Signless);
// We must remove the ring attribute on the tensor, since the
// unrealized_conversion_casts cannot carry the poly.ring attribute
// through.
return RankedTensorType::get({idealDegree}, elementTy);
});

Expand Down Expand Up @@ -109,11 +113,66 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {

using OpConversionPattern::OpConversionPattern;

// Convert add lowers a poly.add operation to arith operations. A poly.add
// operation is defined within the polynomial ring. Coefficients are added
// element-wise as elements of the ring, so they are performed modulo the
// coefficient modulus.
//
// To perform modular addition, assume that `cmod` is the coefficient modulus
// of the ring, and that `N` is the bitwidth used to store the ring elements.
// This may be much larger than `log_2(cmod)`.
//
// Let `x` and `y` be the inputs to modular addition, then:
// c1, n1 = addui_extended(x, y)
// If the coefficient modulus divides `2^N`, then return
// c0 = c1 % cmod
// Otherwise, compute the adjusted result:
// c0 = ((c1 % cmod) + (n1 * 2^N % cmod)) % cmod
//
// Note that `(c1 % cmod) + (n1 * 2^N % cmod)` will not overflow mod `2^N`.
// If it did, then it would require that `cmod > (2^N) / 2`.
// This would imply that `2^N % cmod = 2^N - cmod`.
// If the sum overflowed, then we would have
// ((c1 % cmod) + (2^N % cmod)) > 2^N
// ((c1 % cmod) + (2^N - cmod)) > 2^N
// ((c1 % cmod) > cmod
// Which is a contradiction.
LogicalResult matchAndRewrite(
AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// TODO(https://github.com/google/heir/issues/104): implement
return failure();
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto type = adaptor.getLhs().getType();

APInt mod =
cast<PolyType>(op.getResult().getType()).getRing().coefficientModulus();
auto cmod = b.create<arith::ConstantOp>(
DenseElementsAttr::get(cast<ShapedType>(type), {mod}));

auto addExtendedOp =
b.create<arith::AddUIExtendedOp>(adaptor.getLhs(), adaptor.getRhs());
auto c1ModOp = b.create<arith::RemUIOp>(addExtendedOp->getResult(0), cmod);
// If mod divides 2^N, c1modOp is our result.
if (mod.isPowerOf2()) {
rewriter.replaceOp(op, c1ModOp.getResult());
return success();
}
// Otherwise, add (n1 * 2^N % cmod)
APInt quotient, remainder;
APInt bigMod = APInt(mod.getBitWidth() + 1, 2) << (mod.getBitWidth() - 1);
APInt::udivrem(bigMod, mod.zext(bigMod.getBitWidth()), quotient, remainder);
remainder = remainder.trunc(mod.getBitWidth());

auto bitwidth = b.create<arith::ConstantOp>(
DenseElementsAttr::get(cast<ShapedType>(type), {remainder}));
auto adjustOp = b.create<arith::AddIOp>(c1ModOp, bitwidth);

auto selectOp = b.create<arith::SelectOp>(addExtendedOp.getResult(1),
c1ModOp, adjustOp);
// Mod the final result.
rewriter.replaceOp(op, b.create<arith::RemUIOp>(selectOp, cmod));

return success();
}
};

Expand All @@ -140,16 +199,19 @@ struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
ConversionTarget target(*context);
PolyToStandardTypeConverter typeConverter(context);

target.addLegalDialect<arith::ArithDialect>();

// target.addIllegalDialect<PolyDialect>();
target.addIllegalOp<FromTensorOp, ToTensorOp>();
target.addIllegalOp<FromTensorOp, ToTensorOp, AddOp>();
// target.addIllegalOp<AddOp>();
// target.addIllegalOp<MulOp>();

RewritePatternSet patterns(context);
patterns.add<ConvertFromTensor, ConvertToTensor>(typeConverter, context);

patterns.add<ConvertFromTensor, ConvertToTensor, ConvertAdd>(typeConverter,
context);
addStructuralConversionPatterns(typeConverter, patterns, target);

// TODO(https://github.com/google/heir/issues/143): Handle tensor of polys.
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/Poly/IR/PolyOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ LogicalResult FromTensorOp::verify() {
}

APInt coefficientModulus = ring.coefficientModulus();
unsigned cmodBitWidth = coefficientModulus.logBase2();
unsigned cmodBitWidth = coefficientModulus.ceilLogBase2();
unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();

if (inputBitWidth > cmodBitWidth) {
Expand Down
54 changes: 54 additions & 0 deletions tests/poly/lower_poly.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#cycl_2048 = #poly.polynomial<1 + x**1024>
#ring = #poly.ring<cmod=4294967296, ideal=#cycl_2048>
#ring_prime = #poly.ring<cmod=4294967291, ideal=#cycl_2048>
module {
// CHECK-label: test_lower_from_tensor
func.func @test_lower_from_tensor() {
Expand Down Expand Up @@ -59,4 +60,57 @@ module {
func.call @f0(%arg) : (!poly.poly<#ring>) -> !poly.poly<#ring>
return
}

func.func @test_lower_add_power_of_two_cmod() -> !poly.poly<#ring> {
// 2 + 2x + 2x^2 + ... + 2x^{1023}
// CHECK: [[X:%.+]] = arith.constant dense<2> : [[T:tensor<1024xi32>]]
%coeffs1 = arith.constant dense<2> : tensor<1024xi32>
// CHECK: [[Y:%.+]] = arith.constant dense<3> : [[T]]
%coeffs2 = arith.constant dense<3> : tensor<1024xi32>
// CHECK-NOT: poly.from_tensor
// CHECK: [[XEXT:%.+]] = arith.extui [[X]] : [[T]] to [[TPOLY:tensor<1024xi64>]]
// CHECK: [[YEXT:%.+]] = arith.extui [[Y]] : [[T]] to [[TPOLY:tensor<1024xi64>]]
%poly0 = poly.from_tensor %coeffs1 : tensor<1024xi32> -> !poly.poly<#ring>
%poly1 = poly.from_tensor %coeffs2 : tensor<1024xi32> -> !poly.poly<#ring>
// CHECK: [[MOD:%.+]] = arith.constant dense<4294967296> : [[TPOLY]]
// CHECK-NEXT: [[ADD:%.+]], [[OVERFLOW:%.+]] = arith.addui_extended [[XEXT]], [[YEXT]] : [[TPOLY]], tensor<1024xi1>
// CHECK-NEXT: [[REM:%.+]] = arith.remui [[ADD]], [[MOD]] : [[TPOLY]]
%poly2 = poly.add(%poly0, %poly1) {ring = #ring} : !poly.poly<#ring>
// CHECK: return [[REM]] : [[TPOLY]]
return %poly2 : !poly.poly<#ring>
}

func.func @test_lower_add_prime_cmod() -> !poly.poly<#ring_prime> {
// CHECK: [[X:%.+]] = arith.constant dense<2> : [[TCOEFF:tensor<1024xi31>]]
%coeffs1 = arith.constant dense<2> : tensor<1024xi31>
// CHECK: [[Y:%.+]] = arith.constant dense<3> : [[TCOEFF]]
%coeffs2 = arith.constant dense<3> : tensor<1024xi31>
// CHECK-NOT: poly.from_tensor
// CHECK: [[XEXT:%.+]] = arith.extui [[X]] : [[TCOEFF]] to [[T:tensor<1024xi64>]]
// CHECK: [[YEXT:%.+]] = arith.extui [[Y]] : [[TCOEFF]] to [[T:tensor<1024xi64>]]
%poly0 = poly.from_tensor %coeffs1 : tensor<1024xi31> -> !poly.poly<#ring_prime>
%poly1 = poly.from_tensor %coeffs2 : tensor<1024xi31> -> !poly.poly<#ring_prime>
// CHECK: [[MOD:%.+]] = arith.constant dense<4294967291> : [[T]]
// CHECK-NEXT: [[ADD:%.+]], [[OVERFLOW:%.+]] = arith.addui_extended [[XEXT]], [[YEXT]] : [[T]], tensor<1024xi1>
// CHECK-NEXT: [[REM:%.+]] = arith.remui [[ADD]], [[MOD]] : [[T]]
// CHECK-NEXT: [[NMOD:%.+]] = arith.constant dense<25> : [[T]]
// CHECK-NEXT: [[REMPLUS2N:%.+]] = arith.addi [[REM]], [[NMOD]] : [[T]]
// CHECK-NEXT: [[RES:%.+]] = arith.select [[OVERFLOW]], [[REM]], [[REMPLUS2N]] : tensor<1024xi1>, [[T]]
// CHECK-NEXT: [[RESMOD:%.+]] = arith.remui [[RES]], [[MOD]] : [[T]]
%poly2 = poly.add(%poly0, %poly1) {ring = #ring_prime} : !poly.poly<#ring_prime>
// CHECK: return [[RESMOD]] : [[T]]
return %poly2 : !poly.poly<#ring_prime>
}

func.func @test_i32_coeff_with_i32_mod() -> () {
// CHECK: [[X:%.+]] = arith.constant dense<2> : [[TCOEFF:tensor<1024xi32>]]
%coeffs1 = arith.constant dense<2> : tensor<1024xi32>
// CHECK: [[Y:%.+]] = arith.constant dense<3> : [[TCOEFF]]
%coeffs2 = arith.constant dense<3> : tensor<1024xi32>
// CHECK-NOT: poly.from_tensor
%poly0 = poly.from_tensor %coeffs1 : tensor<1024xi32> -> !poly.poly<#ring_prime>
%poly1 = poly.from_tensor %coeffs2 : tensor<1024xi32> -> !poly.poly<#ring_prime>
// CHECK: return
return
}
}

0 comments on commit 183fc81

Please sign in to comment.