From 71ff58234e64c814047ad1b29b9ac281ac69c448 Mon Sep 17 00:00:00 2001 From: Wei Dai Date: Wed, 29 Apr 2020 23:17:46 -0700 Subject: [PATCH] Fixed a bug in multiply_plain_inplace when plain is a monomial and using_fast_plain_lift is enabled. --- native/src/seal/evaluator.cpp | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/native/src/seal/evaluator.cpp b/native/src/seal/evaluator.cpp index 17a80814b..cb7e2c2ab 100644 --- a/native/src/seal/evaluator.cpp +++ b/native/src/seal/evaluator.cpp @@ -1697,18 +1697,6 @@ namespace seal // Multiplying by a monomial? size_t mono_exponent = plain.significant_coeff_count() - 1; - // RNS monomial multiplication: monomial and multiplicand polynomial are in RNS form - auto rns_monomial_multiply = [&](PolyIter in_iter, ConstCoeffIter mono_iter) { - for_each_n(in_iter, encrypted_size, [&](auto I) { - for_each_n( - IterTuple>(I, mono_iter, coeff_modulus), - coeff_modulus_size, [&](auto J) { - negacyclic_multiply_poly_mono_coeffmod( - get<0>(J), coeff_count, *get<1>(J), mono_exponent, *get<2>(J), get<0>(J), pool); - }); - }); - }; - if (plain[mono_exponent] >= plain_upper_half_threshold) { if (!context_data.qualifiers().using_fast_plain_lift) @@ -1723,13 +1711,28 @@ namespace seal // addition we decompose the multi-precision integer into RNS components, and then multiply. add_uint_uint64(plain_upper_half_increment, plain[mono_exponent], coeff_modulus_size, temp.get()); context_data.rns_tool()->base_q()->decompose(temp.get(), pool); - rns_monomial_multiply(encrypted, temp.get()); + for_each_n(PolyIter(encrypted), encrypted_size, [&](auto I) { + for_each_n( + IterTuple>(I, temp.get(), coeff_modulus), + coeff_modulus_size, [&](auto J) { + negacyclic_multiply_poly_mono_coeffmod( + get<0>(J), coeff_count, *get<1>(J), mono_exponent, *get<2>(J), get<0>(J), pool); + }); + }); } else { // Every coeff_modulus prime is larger than plain_modulus, so there is no need to adjust the // monomial. Instead, just do an RNS multiplication. - rns_monomial_multiply(encrypted, plain.data()); + for_each_n(PolyIter(encrypted), encrypted_size, [&](auto I) { + for_each_n( + IterTuple>(I, coeff_modulus), coeff_modulus_size, + [&](auto J) { + negacyclic_multiply_poly_mono_coeffmod( + get<0>(J), coeff_count, plain[mono_exponent], mono_exponent, *get<1>(J), get<0>(J), + pool); + }); + }); } } else