Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuponitskiy-duality committed Jun 11, 2024
1 parent 0cd4df6 commit 705bc25
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 48 deletions.
4 changes: 2 additions & 2 deletions src/pke/include/schemerns/rns-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {
*
* @return number of extra bits needed for noise flooding
*/
static double EstimateMultipartyFloodingLogQ() {
return NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY;
static constexpr double EstimateMultipartyFloodingLogQ() {
return static_cast<double>(NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY);
}

/**
Expand Down
6 changes: 3 additions & 3 deletions src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
}
else {
// takes into account the noise added during the threshold FHE instantiation of BFV
if ((multipartyMode == NOISE_FLOODING_MULTIPARTY))
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
logq += cryptoParamsBFVRNS->EstimateMultipartyFloodingLogQ();
// adds logP in the case of HYBRID key switching
if (ksTech == HYBRID) {
Expand All @@ -129,7 +129,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
logq += std::get<0>(hybridKSInfo);
}
return static_cast<double>(
StdLatticeParm::FindRingDim(distType, stdLevel, static_cast<usint>(std::ceil(logq))));
StdLatticeParm::FindRingDim(distType, stdLevel, static_cast<uint32_t>(std::ceil(logq))));
}
};

Expand All @@ -138,7 +138,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
// conservative estimate for HYBRID to avoid the use of method of
// iterative approximations; we do not know the number
// of digits and moduli at this point and use upper bounds
double numTowers = ceil(static_cast<double>(logqPrev) / dcrtBits);
double numTowers = ceil(logqPrev / dcrtBits);
return numTowers * (delta(n) * Berr + delta(n) * Bkey + 1.0) / 2.0;
}
else {
Expand Down
9 changes: 4 additions & 5 deletions src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,8 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
moduliQ = std::get<0>(moduliInfo);
uint32_t newQBound = std::get<1>(moduliInfo);

// the counter makes sure the first iteration of the while loop is always run
uint32_t counter = 0;
while ((counter == 0) || (qBound < newQBound)) {
// the loop must be executed at least once
do {
qBound = newQBound;
n = computeRingDimension(cryptoParams, newQBound, cyclOrder);
auto moduliInfo = computeModuli(cryptoParams, n, evalAddCount, keySwitchCount, auxTowers, numPrimes);
Expand All @@ -493,8 +492,8 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
(scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits);
newQBound += std::get<0>(hybridKSInfo);
}
counter++;
}
} while (qBound < newQBound);

cyclOrder = 2 * n;
modulusOrder = getCyclicOrder(n, ptm, scalTech);

Expand Down
73 changes: 35 additions & 38 deletions src/pke/lib/schemerns/rns-cryptoparameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,16 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling
DiscreteFourierTransform::Initialize(n * 2, n / 2);
ChineseRemainderTransformFTT<NativeVector>().PreCompute(rootsQ, 2 * n, moduliQ);
if (m_ksTechnique == HYBRID) {
// Compute ceil(sizeQ/m_numPartQ), the # of towers per digit
uint32_t a = ceil(static_cast<double>(sizeQ) / numPartQ);
if ((int32_t)(sizeQ - a * (numPartQ - 1)) <= 0) {
auto str =
"CryptoParametersRNS::PrecomputeCRTTables - HYBRID key "
"switching parameters: Can't appropriately distribute " +
std::to_string(sizeQ) + " towers into " + std::to_string(numPartQ) +
" digits. Please select different number of digits.";
// numPartQ can not be zero as there is a division by numPartQ
if (numPartQ == 0)
OPENFHE_THROW("numPartQ is zero");

// Compute ceil(sizeQ/numPartQ), the # of towers per digit
uint32_t a = static_cast<uint32_t>(ceil(static_cast<double>(sizeQ) / numPartQ));
if (sizeQ <= (a * (numPartQ - 1))) {
auto str = "HYBRID key switching parameters: Can't appropriately distribute " + std::to_string(sizeQ) +
" towers into " + std::to_string(numPartQ) +
" digits. Please select different number of digits.";
OPENFHE_THROW(str);
}

Expand Down Expand Up @@ -122,16 +124,15 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling
std::make_shared<ILDCRTParams<BigInteger>>(params[0]->GetCyclotomicOrder(), moduli, roots);
}

uint32_t sizeP;
// Find number and size of individual special primes.
uint32_t maxBits = moduliPartQ[0].GetLengthForBase(2);
for (usint j = 1; j < m_numPartQ; j++) {
for (uint32_t j = 1; j < m_numPartQ; j++) {
uint32_t bits = moduliPartQ[j].GetLengthForBase(2);
if (bits > maxBits)
maxBits = bits;
}
// Select number of primes in auxiliary CRT basis
sizeP = ceil(static_cast<double>(maxBits) / auxBits);
uint32_t sizeP = static_cast<uint32_t>(ceil(static_cast<double>(maxBits) / auxBits));
uint64_t primeStep = FindAuxPrimeStep();

// Choose special primes in auxiliary basis and compute their roots
Expand All @@ -143,7 +144,7 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling
NativeInteger firstP = FirstPrime<NativeInteger>(auxBits, primeStep);
NativeInteger pPrev = firstP;
BigInteger modulusP(1);
for (usint i = 0; i < sizeP; i++) {
for (uint32_t i = 0; i < sizeP; i++) {
// The following loop makes sure that moduli in
// P and Q are different
bool foundInQ = false;
Expand Down Expand Up @@ -234,11 +235,11 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling
}

// Pre-compute compementary partitions for ModUp
uint32_t alpha = ceil(static_cast<double>(sizeQ) / m_numPartQ);
uint32_t alpha = static_cast<uint32_t>(ceil(static_cast<double>(sizeQ) / m_numPartQ));
m_paramsComplPartQ.resize(sizeQ);
m_modComplPartqBarrettMu.resize(sizeQ);
for (int32_t l = sizeQ - 1; l >= 0; l--) {
uint32_t beta = ceil(static_cast<double>(l + 1) / alpha);
uint32_t beta = static_cast<uint32_t>(ceil(static_cast<double>(l + 1) / alpha));
m_paramsComplPartQ[l].resize(beta);
m_modComplPartqBarrettMu[l].resize(beta);
for (uint32_t j = 0; j < beta; j++) {
Expand Down Expand Up @@ -305,8 +306,8 @@ void CryptoParametersRNS::PrecomputeCRTTables(KeySwitchTechnique ksTech, Scaling
// Pre-compute QHat mod complementary partition qi's
m_PartQlHatModp.resize(sizeQ);
for (uint32_t l = 0; l < sizeQ; l++) {
uint32_t alpha = ceil(static_cast<double>(sizeQ) / m_numPartQ);
uint32_t beta = ceil(static_cast<double>(l + 1) / alpha);
uint32_t alpha = static_cast<uint32_t>(ceil(static_cast<double>(sizeQ) / m_numPartQ));
uint32_t beta = static_cast<uint32_t>(ceil(static_cast<double>(l + 1) / alpha));
m_PartQlHatModp[l].resize(beta);
for (uint32_t k = 0; k < beta; k++) {
auto paramsPartQ = GetParamsPartQ(k)->GetParams();
Expand Down Expand Up @@ -391,46 +392,42 @@ uint64_t CryptoParametersRNS::FindAuxPrimeStep() const {
std::pair<double, uint32_t> CryptoParametersRNS::EstimateLogP(uint32_t numPartQ, double firstModulusSize,
double dcrtBits, double extraModulusSize,
uint32_t numPrimes, uint32_t auxBits) {
uint32_t sizeQ = numPrimes;
// numPartQ can not be zero as there is a division by numPartQ
if (numPartQ == 0)
OPENFHE_THROW("numPartQ is zero");

size_t sizeQ = numPrimes;
if (extraModulusSize > 0)
sizeQ++;

// Compute ceil(sizeQ/numPartQ), the # of towers per digit
uint32_t numPerPartQ = ceil(static_cast<double>(sizeQ) / numPartQ);
if ((uint32_t)(sizeQ - numPerPartQ * (numPartQ - 1)) <= 0) {
auto str =
"CryptoParametersRNS::EstimateLogP - HYBRID key "
"switching parameters: Can't appropriately distribute " +
std::to_string(sizeQ) + " towers into " + std::to_string(numPartQ) +
" digits. Please select different number of digits.";
size_t numPerPartQ = static_cast<size_t>(ceil(static_cast<double>(sizeQ) / numPartQ));
if (sizeQ <= (numPerPartQ * (numPartQ - 1))) {
auto str = "HYBRID key switching parameters: Can't appropriately distribute " + std::to_string(sizeQ) +
" towers into " + std::to_string(numPartQ) + " digits. Please select different number of digits.";
OPENFHE_THROW(str);
}

// create a vector with bit sizes
std::vector<double> qi(sizeQ);
// create a vector with the same value of bit sizes
std::vector<double> qi(sizeQ, dcrtBits);
qi[0] = firstModulusSize;
for (uint32_t i = 1; i < numPrimes; i++) {
qi[i] = dcrtBits;
}
if (extraModulusSize > 0)
qi[sizeQ - 1] = extraModulusSize;

// Compute partitions of Q into numPartQ digits
double maxBits = 0;
for (uint32_t j = 0; j < numPartQ; j++) {
auto startTower = j * numPerPartQ;
auto endTower = ((j + 1) * numPerPartQ - 1 < sizeQ) ? (j + 1) * numPerPartQ - 1 : sizeQ - 1;
double bits = 0.0;
for (uint32_t i = startTower; i <= endTower; i++) {
bits += qi[i];
}
for (size_t j = 0; j < numPartQ; ++j) {
size_t startTower = j * numPerPartQ;
size_t endTower = ((j + 1) * numPerPartQ - 1 < sizeQ) ? (j + 1) * numPerPartQ - 1 : sizeQ - 1;

// sum qi elements qi[startTower] + ... + qi[endTower] inclusive. the end element should be qi.begin()+(endTower+1)
double bits = std::accumulate(qi.begin() + startTower, qi.begin() + (endTower + 1), 0.0);
if (bits > maxBits)
maxBits = bits;
}

// Select number of primes in auxiliary CRT basis
uint32_t sizeP;
sizeP = std::ceil(static_cast<double>(maxBits) / auxBits);
auto sizeP = static_cast<uint32_t>(std::ceil(maxBits / auxBits));

return std::make_pair(sizeP * auxBits, sizeP);
}
Expand Down

0 comments on commit 705bc25

Please sign in to comment.