Skip to content

Commit

Permalink
Fix logQ(P) estimation used for finding the minimum ring dimension (#796
Browse files Browse the repository at this point in the history
)

* fixed the estimation for LWE check for BGV

* fixed estimation for LWE check in CKKS parameter generation

* cleaned up the BFV parameterization code

* fixed the logq estimation for BFV

* added doxygen comments for new methods

* added ring dimension validators for BGV, BFV, and CKKS

* fixed a bug in BGV parameter estimation logic

* Addressed review comments

---------

Co-authored-by: Yuriy Polyakov <[email protected]>
Co-authored-by: Dmitriy Suponitskiy <[email protected]>
  • Loading branch information
3 people authored Jun 11, 2024
1 parent 5c7164c commit 4eed808
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 55 deletions.
24 changes: 24 additions & 0 deletions src/pke/include/schemerns/rns-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,30 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {

virtual uint64_t FindAuxPrimeStep() const;

/*
* Estimates the extra modulus bitsize needed for hybrid key swithing (used for finding the minimum secure ring dimension).
*
* @param numPartQ number of digits in hybrid key switching
* @param firstModulusSize bit size of first modulus
* @param dcrtBits bit size for other moduli
* @param extraModulusSize bit size for extra modulus in FLEXIBLEAUTOEXT (CKKS and BGV only)
* @param numPrimes number of moduli witout extraModulus
* @param auxBits size of auxiliar moduli used for hybrid key switching
*
* @return log2 of the modulus and number of RNS limbs.
*/
static std::pair<double, uint32_t> EstimateLogP(uint32_t numPartQ, double firstModulusSize, double dcrtBits,
double extraModulusSize, uint32_t numPrimes, uint32_t auxBits);

/*
* Estimates the extra modulus bitsize needed for threshold FHE noise flooding (only for BGV and BFV)
*
* @return number of extra bits needed for noise flooding
*/
static constexpr double EstimateMultipartyFloodingLogQ() {
return static_cast<double>(NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY);
}

/**
* == operator to compare to this instance of CryptoParametersBase object.
*
Expand Down
87 changes: 60 additions & 27 deletions src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
double p = static_cast<double>(cryptoParamsBFVRNS->GetPlaintextModulus());
uint32_t digitSize = cryptoParamsBFVRNS->GetDigitSize();
SecurityLevel stdLevel = cryptoParamsBFVRNS->GetStdLevel();
uint32_t auxBits = DCRT_MODULUS::MAX_SIZE;

// Bound of the Gaussian error polynomial
double Berr = sigma * std::sqrt(alpha);
Expand Down Expand Up @@ -115,8 +116,20 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
return 0;
}
else {
// takes into account the noise added during the threshold FHE instantiation of BFV
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
logq += cryptoParamsBFVRNS->EstimateMultipartyFloodingLogQ();
// adds logP in the case of HYBRID key switching
if (ksTech == HYBRID) {
// number of RNS limbs
uint32_t k = static_cast<uint32_t>(std::ceil(std::ceil(logq) / dcrtBits));
// set the number of digits
uint32_t numPartQ = ComputeNumLargeDigits(numDigits, k - 1);
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(numPartQ, dcrtBits, dcrtBits, 0, k, auxBits);
logq += std::get<0>(hybridKSInfo);
}
return static_cast<double>(
StdLatticeParm::FindRingDim(distType, stdLevel, static_cast<usint>(std::ceil(logq / std::log(2)))));
StdLatticeParm::FindRingDim(distType, stdLevel, static_cast<uint32_t>(std::ceil(logq))));
}
};

Expand All @@ -125,12 +138,12 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
// conservative estimate for HYBRID to avoid the use of method of
// iterative approximations; we do not know the number
// of digits and moduli at this point and use upper bounds
double numTowers = ceil(static_cast<double>(logqPrev) / dcrtBits);
double numTowers = ceil(logqPrev / dcrtBits);
return numTowers * (delta(n) * Berr + delta(n) * Bkey + 1.0) / 2.0;
}
else {
double numDigitsPerTower = (digitSize == 0) ? 1 : ((dcrtBits / digitSize) + 1);
return delta(n) * numDigitsPerTower * (floor(logqPrev / (std::log(2) * dcrtBits)) + 1) * w * Berr / 2.0;
return delta(n) * numDigitsPerTower * (floor(logqPrev / (dcrtBits)) + 1) * w * Berr / 2.0;
}
};

Expand All @@ -147,7 +160,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
if ((multiplicativeDepth == 0) && (keySwitchCount == 0)) {
// Correctness constraint
auto logqBFV = [&](uint32_t n) -> double {
return std::log(p * (4 * ((evalAddCount + 1) * Vnorm(n) + evalAddCount) + p));
return std::log2(p * (4 * ((evalAddCount + 1) * Vnorm(n) + evalAddCount) + p));
};

// initial value
Expand All @@ -160,15 +173,15 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

// this code updates n and q to account for the discrete size of CRT moduli
// = dcrtBits
int32_t k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
int32_t k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));

double logqCeil = k * dcrtBits * std::log(2);
double logqCeil = k * dcrtBits;

while (nRLWE(logqCeil) > n) {
n = 2 * n;
logq = logqBFV(n);
k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
logqCeil = k * dcrtBits * std::log(2);
k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));
logqCeil = k * dcrtBits;
}
}
else if ((multiplicativeDepth == 0) && (keySwitchCount > 0) && (evalAddCount == 0)) {
Expand All @@ -179,11 +192,11 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

// Correctness constraint
auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
return std::log(p * (4 * (Vnorm(n) + keySwitchCount * noiseKS(n, logqPrev, w, false)) + p));
return std::log2(p * (4 * (Vnorm(n) + keySwitchCount * noiseKS(n, logqPrev, w, false)) + p));
};

// initial values
double logqPrev = 6. * std::log(10);
double logqPrev = 6. * std::log2(10);
logq = logqBFV(n, logqPrev);
logqPrev = logq;

Expand All @@ -197,23 +210,23 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
logq = logqBFV(n, logqPrev);

// let logq converge with prescribed accuracy
while (std::fabs(logq - logqPrev) > std::log(1.001)) {
while (std::fabs(logq - logqPrev) > std::log2(1.001)) {
logqPrev = logq;
logq = logqBFV(n, logqPrev);
}

// this code updates n and q to account for the discrete size of CRT
// moduli = dcrtBits
int32_t k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
int32_t k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));

double logqCeil = k * dcrtBits * std::log(2);
double logqCeil = k * dcrtBits;
logqPrev = logqCeil;

while (nRLWE(logqCeil) > n) {
n = 2 * n;
logq = logqBFV(n, logqPrev);
k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
logqCeil = k * dcrtBits * std::log(2);
k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));
logqCeil = k * dcrtBits;
logqPrev = logqCeil;
}
}
Expand All @@ -237,12 +250,12 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

// main correctness constraint
auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
return log(4 * p) + (multiplicativeDepth - 1) * log(C1(n)) +
log(C1(n) * Vnorm(n) + multiplicativeDepth * C2(n, logqPrev));
return log2(4 * p) + (multiplicativeDepth - 1) * log2(C1(n)) +
log2(C1(n) * Vnorm(n) + multiplicativeDepth * C2(n, logqPrev));
};

// initial values
double logqPrev = 6. * std::log(10);
double logqPrev = 6. * std::log2(10);
logq = logqBFV(n, logqPrev);
logqPrev = logq;

Expand All @@ -256,24 +269,24 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
logq = logqBFV(n, logqPrev);

// let logq converge with prescribed accuracy
while (std::fabs(logq - logqPrev) > std::log(1.001)) {
while (std::fabs(logq - logqPrev) > std::log2(1.001)) {
logqPrev = logq;
logq = logqBFV(n, logqPrev);
}

// this code updates n and q to account for the discrete size of CRT
// moduli = dcrtBits

int32_t k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
int32_t k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));

double logqCeil = k * dcrtBits * std::log(2);
double logqCeil = k * dcrtBits;
logqPrev = logqCeil;

while (nRLWE(logqCeil) > n) {
n = 2 * n;
logq = logqBFV(n, logqPrev);
k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
logqCeil = k * dcrtBits * std::log(2);
k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));
logqCeil = k * dcrtBits;
logqPrev = logqCeil;
}
}
Expand All @@ -293,7 +306,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
"security requirement. Please increase it to " +
std::to_string(n) + ".");

const size_t numInitialModuli = static_cast<size_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
const size_t numInitialModuli = static_cast<size_t>(std::ceil(std::ceil(logq) / dcrtBits));
if (numInitialModuli < 1)
OPENFHE_THROW("numInitialModuli must be greater than 0.");
const size_t sizeQ = multipartyMode == NOISE_FLOODING_MULTIPARTY ?
Expand All @@ -303,8 +316,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
std::vector<NativeInteger> moduliQ(sizeQ);
std::vector<NativeInteger> rootsQ(sizeQ);

// makes sure the first integer is less than 2^60-1 to take advantage of NTL
// optimizations
// makes sure the first integer is less than 2^60-1
moduliQ[0] = LastPrime<NativeInteger>(dcrtBits, 2 * n);
rootsQ[0] = RootOfUnity<NativeInteger>(2 * n, moduliQ[0]);
NativeInteger lastModulus = moduliQ[0];
Expand Down Expand Up @@ -359,7 +371,28 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

uint32_t numPartQ = ComputeNumLargeDigits(numDigits, sizeQ - 1);

cryptoParamsBFVRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, 60, 0);
cryptoParamsBFVRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, 0);

// Validate the ring dimension found using estimated logQ(P) against actual logQ(P)
if (stdLevel != HEStd_NotSet) {
uint32_t logActualQ = 0;
if (ksTech == HYBRID) {
logActualQ = cryptoParamsBFVRNS->GetParamsQP()->GetModulus().GetMSB();
}
else {
logActualQ = cryptoParamsBFVRNS->GetElementParams()->GetModulus().GetMSB();
}

uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ);
if (n < nActual) {
std::string errMsg("The ring dimension found using estimated logQ(P) [");
errMsg += std::to_string(n) + "] does does not meet security requirements. ";
errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to ";
errMsg += std::to_string(nActual) + ".";

OPENFHE_THROW(errMsg);
}
}

return true;
}
Expand Down
62 changes: 53 additions & 9 deletions src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,19 +444,26 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
usint extraModSize = (scalTech == FLEXIBLEAUTOEXT) ? DCRT_MODULUS::DEFAULT_EXTRA_MOD_SIZE : 0;
uint32_t qBound = firstModSize + (numPrimes - 1) * dcrtBits + extraModSize;

// Number of RNS limbs in P
// estimate the extra modulus Q needed for threshold FHE flooding
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
qBound += cryptoParamsBGVRNS->EstimateMultipartyFloodingLogQ();
uint32_t auxTowers = 0;
if (ksTech == HYBRID) {
auxTowers = ceil(ceil(static_cast<double>(qBound) / numPartQ) / auxBits);
qBound += auxTowers * auxBits;
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, dcrtBits, extraModSize, numPrimes, auxBits);
qBound += std::get<0>(hybridKSInfo);
auxTowers = std::get<1>(hybridKSInfo);
}

// when the scaling technique is not FIXED_MANUAL, set a small value so that the rest of the logic could go through
// when the scaling technique is not FIXEDMANUAL (and not FLEXIBLEAUTOEXT),
// set a small value so that the rest of the logic could go through (this is a workaround)
// TODO we should uncouple the logic of FIXEDMANUAL and all FLEXIBLE MODES; some of the code above should be moved
// to the branch for FIXEDMANUAL
if (qBound == 0)
qBound = 20;

// HE Standards compliance logic/check
uint32_t n = computeRingDimension(cryptoParams, qBound, cyclOrder);
// End HE Standards compliance logic/check

uint32_t vecSize = (scalTech != FLEXIBLEAUTOEXT) ? numPrimes : numPrimes + 1;
std::vector<NativeInteger> moduliQ(vecSize);
Expand All @@ -467,15 +474,26 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
auto moduliInfo = computeModuli(cryptoParams, n, evalAddCount, keySwitchCount, auxTowers, numPrimes);
moduliQ = std::get<0>(moduliInfo);
uint32_t newQBound = std::get<1>(moduliInfo);
while (qBound < newQBound) {

// the loop must be executed at least once
do {
qBound = newQBound;
n = computeRingDimension(cryptoParams, newQBound, cyclOrder);
auto moduliInfo = computeModuli(cryptoParams, n, evalAddCount, keySwitchCount, auxTowers, numPrimes);
moduliQ = std::get<0>(moduliInfo);
newQBound = std::get<1>(moduliInfo);
if (ksTech == HYBRID)
newQBound += ceil(ceil(static_cast<double>(newQBound) / numPartQ) / auxBits) * auxBits;
}
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
newQBound += cryptoParamsBGVRNS->EstimateMultipartyFloodingLogQ();
if (ksTech == HYBRID) {
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(
numPartQ, std::log2(moduliQ[0].ConvertToDouble()),
(moduliQ.size() > 1) ? std::log2(moduliQ[1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? std::log2(moduliQ[moduliQ.size() - 1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits);
newQBound += std::get<0>(hybridKSInfo);
}
} while (qBound < newQBound);

cyclOrder = 2 * n;
modulusOrder = getCyclicOrder(n, ptm, scalTech);

Expand All @@ -484,6 +502,7 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
}
}
else {
// FIXEDMANUAL mode
cyclOrder = 2 * n;
// For ModulusSwitching to work we need the moduli to be also congruent to 1 modulo ptm
usint plaintextModulus = ptm;
Expand Down Expand Up @@ -596,6 +615,31 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
cryptoParamsBGVRNS->SetEncodingParams(encodingParamsNew);
}
cryptoParamsBGVRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, 0);

// Validate the ring dimension found using estimated logQ(P) against actual logQ(P)
SecurityLevel stdLevel = cryptoParamsBGVRNS->GetStdLevel();
if (stdLevel != HEStd_NotSet) {
uint32_t logActualQ = 0;
if (ksTech == HYBRID) {
logActualQ = cryptoParamsBGVRNS->GetParamsQP()->GetModulus().GetMSB();
}
else {
logActualQ = cryptoParamsBGVRNS->GetElementParams()->GetModulus().GetMSB();
}

DistributionType distType = (cryptoParamsBGVRNS->GetSecretKeyDist() == GAUSSIAN) ? HEStd_error : HEStd_ternary;
uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ);

if (n < nActual) {
std::string errMsg("The ring dimension found using estimated logQ(P) [");
errMsg += std::to_string(n) + "] does does not meet security requirements. ";
errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to ";
errMsg += std::to_string(nActual) + ".";

OPENFHE_THROW(errMsg);
}
}

return true;
}

Expand Down
28 changes: 26 additions & 2 deletions src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
uint32_t auxBits = AUXMODSIZE;
uint32_t n = cyclOrder / 2;
uint32_t qBound = firstModSize + (numPrimes - 1) * scalingModSize + extraModSize;
// Estimate ciphertext modulus Q bound (in case of GHS/HYBRID P*Q)

// Estimate ciphertext modulus Q*P bound (in case of HYBRID P*Q)
if (ksTech == HYBRID) {
qBound += ceil(ceil(static_cast<double>(qBound) / numPartQ) / auxBits) * auxBits;
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, scalingModSize, extraModSize, numPrimes, auxBits);
qBound += std::get<0>(hybridKSInfo);
}

// GAUSSIAN security constraint
Expand Down Expand Up @@ -232,6 +235,27 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete

cryptoParamsCKKSRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, extraModSize);

// Validate the ring dimension found using estimated logQ(P) against actual logQ(P)
if (stdLevel != HEStd_NotSet) {
uint32_t logActualQ = 0;
if (ksTech == HYBRID) {
logActualQ = cryptoParamsCKKSRNS->GetParamsQP()->GetModulus().GetMSB();
}
else {
logActualQ = cryptoParamsCKKSRNS->GetElementParams()->GetModulus().GetMSB();
}

uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ);
if (n < nActual) {
std::string errMsg("The ring dimension found using estimated logQ(P) [");
errMsg += std::to_string(n) + "] does does not meet security requirements. ";
errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to ";
errMsg += std::to_string(nActual) + ".";

OPENFHE_THROW(errMsg);
}
}

return true;
}

Expand Down
Loading

0 comments on commit 4eed808

Please sign in to comment.