Skip to content

Commit

Permalink
Fixed a bug in multiply_plain_inplace when plain is a monomial and us…
Browse files Browse the repository at this point in the history
…ing_fast_plain_lift is enabled.
  • Loading branch information
Wei Dai authored and kimlaine committed Apr 30, 2020
1 parent 3514838 commit 71ff582
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions native/src/seal/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1697,18 +1697,6 @@ namespace seal
// Multiplying by a monomial?
size_t mono_exponent = plain.significant_coeff_count() - 1;

// RNS monomial multiplication: monomial and multiplicand polynomial are in RNS form
auto rns_monomial_multiply = [&](PolyIter in_iter, ConstCoeffIter mono_iter) {
for_each_n(in_iter, encrypted_size, [&](auto I) {
for_each_n(
IterTuple<RNSIter, ConstCoeffIter, PtrIter<const Modulus *>>(I, mono_iter, coeff_modulus),
coeff_modulus_size, [&](auto J) {
negacyclic_multiply_poly_mono_coeffmod(
get<0>(J), coeff_count, *get<1>(J), mono_exponent, *get<2>(J), get<0>(J), pool);
});
});
};

if (plain[mono_exponent] >= plain_upper_half_threshold)
{
if (!context_data.qualifiers().using_fast_plain_lift)
Expand All @@ -1723,13 +1711,28 @@ namespace seal
// addition we decompose the multi-precision integer into RNS components, and then multiply.
add_uint_uint64(plain_upper_half_increment, plain[mono_exponent], coeff_modulus_size, temp.get());
context_data.rns_tool()->base_q()->decompose(temp.get(), pool);
rns_monomial_multiply(encrypted, temp.get());
for_each_n(PolyIter(encrypted), encrypted_size, [&](auto I) {
for_each_n(
IterTuple<RNSIter, ConstCoeffIter, PtrIter<const Modulus *>>(I, temp.get(), coeff_modulus),
coeff_modulus_size, [&](auto J) {
negacyclic_multiply_poly_mono_coeffmod(
get<0>(J), coeff_count, *get<1>(J), mono_exponent, *get<2>(J), get<0>(J), pool);
});
});
}
else
{
// Every coeff_modulus prime is larger than plain_modulus, so there is no need to adjust the
// monomial. Instead, just do an RNS multiplication.
rns_monomial_multiply(encrypted, plain.data());
for_each_n(PolyIter(encrypted), encrypted_size, [&](auto I) {
for_each_n(
IterTuple<RNSIter, PtrIter<const Modulus *>>(I, coeff_modulus), coeff_modulus_size,
[&](auto J) {
negacyclic_multiply_poly_mono_coeffmod(
get<0>(J), coeff_count, plain[mono_exponent], mono_exponent, *get<1>(J), get<0>(J),
pool);
});
});
}
}
else
Expand Down

0 comments on commit 71ff582

Please sign in to comment.