Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logQ(P) estimation used for finding the minimum ring dimension #796

Merged
merged 8 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/pke/include/schemerns/rns-cryptoparameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,30 @@ class CryptoParametersRNS : public CryptoParametersRLWE<DCRTPoly> {

virtual uint64_t FindAuxPrimeStep() const;

/*
* Estimates the extra modulus bitsize needed for hybrid key swithing (used for finding the minimum secure ring dimension).
*
* @param numPartQ number of digits in hybrid key switching
* @param firstModulusSize bit size of first modulus
* @param dcrtBits bit size for other moduli
* @param extraModulusSize bit size for extra modulus in FLEXIBLEAUTOEXT (CKKS and BGV only)
* @param numPrimes number of moduli witout extraModulus
* @param auxBits size of auxiliar moduli used for hybrid key switching
*
* @return log2 of the modulus and number of RNS limbs.
*/
static std::pair<double, uint32_t> EstimateLogP(uint32_t numPartQ, double firstModulusSize, double dcrtBits,
double extraModulusSize, uint32_t numPrimes, uint32_t auxBits);

/*
* Estimates the extra modulus bitsize needed for threshold FHE noise flooding (only for BGV and BFV)
*
* @return number of extra bits needed for noise flooding
*/
static constexpr double EstimateMultipartyFloodingLogQ() {
return static_cast<double>(NoiseFlooding::MULTIPARTY_MOD_SIZE * NoiseFlooding::NUM_MODULI_MULTIPARTY);
}

/**
* == operator to compare to this instance of CryptoParametersBase object.
*
Expand Down
87 changes: 60 additions & 27 deletions src/pke/lib/scheme/bfvrns/bfvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
double p = static_cast<double>(cryptoParamsBFVRNS->GetPlaintextModulus());
uint32_t digitSize = cryptoParamsBFVRNS->GetDigitSize();
SecurityLevel stdLevel = cryptoParamsBFVRNS->GetStdLevel();
uint32_t auxBits = DCRT_MODULUS::MAX_SIZE;

// Bound of the Gaussian error polynomial
double Berr = sigma * std::sqrt(alpha);
Expand Down Expand Up @@ -115,8 +116,20 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
return 0;
}
else {
// takes into account the noise added during the threshold FHE instantiation of BFV
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
logq += cryptoParamsBFVRNS->EstimateMultipartyFloodingLogQ();
// adds logP in the case of HYBRID key switching
if (ksTech == HYBRID) {
// number of RNS limbs
uint32_t k = static_cast<uint32_t>(std::ceil(std::ceil(logq) / dcrtBits));
// set the number of digits
uint32_t numPartQ = ComputeNumLargeDigits(numDigits, k - 1);
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(numPartQ, dcrtBits, dcrtBits, 0, k, auxBits);
logq += std::get<0>(hybridKSInfo);
}
return static_cast<double>(
StdLatticeParm::FindRingDim(distType, stdLevel, static_cast<usint>(std::ceil(logq / std::log(2)))));
StdLatticeParm::FindRingDim(distType, stdLevel, static_cast<uint32_t>(std::ceil(logq))));
}
};

Expand All @@ -125,12 +138,12 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
// conservative estimate for HYBRID to avoid the use of method of
// iterative approximations; we do not know the number
// of digits and moduli at this point and use upper bounds
double numTowers = ceil(static_cast<double>(logqPrev) / dcrtBits);
double numTowers = ceil(logqPrev / dcrtBits);
pascoec marked this conversation as resolved.
Show resolved Hide resolved
return numTowers * (delta(n) * Berr + delta(n) * Bkey + 1.0) / 2.0;
}
else {
double numDigitsPerTower = (digitSize == 0) ? 1 : ((dcrtBits / digitSize) + 1);
return delta(n) * numDigitsPerTower * (floor(logqPrev / (std::log(2) * dcrtBits)) + 1) * w * Berr / 2.0;
return delta(n) * numDigitsPerTower * (floor(logqPrev / (dcrtBits)) + 1) * w * Berr / 2.0;
}
};

Expand All @@ -147,7 +160,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
if ((multiplicativeDepth == 0) && (keySwitchCount == 0)) {
// Correctness constraint
auto logqBFV = [&](uint32_t n) -> double {
return std::log(p * (4 * ((evalAddCount + 1) * Vnorm(n) + evalAddCount) + p));
return std::log2(p * (4 * ((evalAddCount + 1) * Vnorm(n) + evalAddCount) + p));
yspolyakov marked this conversation as resolved.
Show resolved Hide resolved
};

// initial value
Expand All @@ -160,15 +173,15 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

// this code updates n and q to account for the discrete size of CRT moduli
// = dcrtBits
int32_t k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
int32_t k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));

double logqCeil = k * dcrtBits * std::log(2);
double logqCeil = k * dcrtBits;

while (nRLWE(logqCeil) > n) {
n = 2 * n;
logq = logqBFV(n);
k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
logqCeil = k * dcrtBits * std::log(2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not adding the +1 anymore? The same comment for the occurrences below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why + 1 is needed in the current version of OpenFHE. Maybe in an old version (of PALISADE), we sometimes had larger moduli (NextPrime might have been used). So in my mind, ceil(logq/dcrtBits) is sufficient and guarantees that that logq<=k*dcrtBits. Moreover, adding + 1 would cause problems in the new estimation logic (getting an extra RNS limb when it is not needed).

k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));
logqCeil = k * dcrtBits;
}
}
else if ((multiplicativeDepth == 0) && (keySwitchCount > 0) && (evalAddCount == 0)) {
Expand All @@ -179,11 +192,11 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

// Correctness constraint
auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
return std::log(p * (4 * (Vnorm(n) + keySwitchCount * noiseKS(n, logqPrev, w, false)) + p));
return std::log2(p * (4 * (Vnorm(n) + keySwitchCount * noiseKS(n, logqPrev, w, false)) + p));
};

// initial values
double logqPrev = 6. * std::log(10);
double logqPrev = 6. * std::log2(10);
logq = logqBFV(n, logqPrev);
logqPrev = logq;

Expand All @@ -197,23 +210,23 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
logq = logqBFV(n, logqPrev);

// let logq converge with prescribed accuracy
while (std::fabs(logq - logqPrev) > std::log(1.001)) {
while (std::fabs(logq - logqPrev) > std::log2(1.001)) {
logqPrev = logq;
logq = logqBFV(n, logqPrev);
}

// this code updates n and q to account for the discrete size of CRT
// moduli = dcrtBits
int32_t k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
int32_t k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));

double logqCeil = k * dcrtBits * std::log(2);
double logqCeil = k * dcrtBits;
logqPrev = logqCeil;

while (nRLWE(logqCeil) > n) {
n = 2 * n;
logq = logqBFV(n, logqPrev);
k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
logqCeil = k * dcrtBits * std::log(2);
k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));
logqCeil = k * dcrtBits;
logqPrev = logqCeil;
}
}
Expand All @@ -237,12 +250,12 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

// main correctness constraint
auto logqBFV = [&](uint32_t n, double logqPrev) -> double {
return log(4 * p) + (multiplicativeDepth - 1) * log(C1(n)) +
log(C1(n) * Vnorm(n) + multiplicativeDepth * C2(n, logqPrev));
return log2(4 * p) + (multiplicativeDepth - 1) * log2(C1(n)) +
log2(C1(n) * Vnorm(n) + multiplicativeDepth * C2(n, logqPrev));
};

// initial values
double logqPrev = 6. * std::log(10);
double logqPrev = 6. * std::log2(10);
logq = logqBFV(n, logqPrev);
logqPrev = logq;

Expand All @@ -256,24 +269,24 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
logq = logqBFV(n, logqPrev);

// let logq converge with prescribed accuracy
while (std::fabs(logq - logqPrev) > std::log(1.001)) {
while (std::fabs(logq - logqPrev) > std::log2(1.001)) {
logqPrev = logq;
logq = logqBFV(n, logqPrev);
}

// this code updates n and q to account for the discrete size of CRT
// moduli = dcrtBits

int32_t k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
int32_t k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));

double logqCeil = k * dcrtBits * std::log(2);
double logqCeil = k * dcrtBits;
logqPrev = logqCeil;

while (nRLWE(logqCeil) > n) {
n = 2 * n;
logq = logqBFV(n, logqPrev);
k = static_cast<int32_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
logqCeil = k * dcrtBits * std::log(2);
k = static_cast<int32_t>(std::ceil(std::ceil(logq) / dcrtBits));
logqCeil = k * dcrtBits;
logqPrev = logqCeil;
}
}
Expand All @@ -293,7 +306,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
"security requirement. Please increase it to " +
std::to_string(n) + ".");

const size_t numInitialModuli = static_cast<size_t>(std::ceil((std::ceil(logq / std::log(2)) + 1.0) / dcrtBits));
const size_t numInitialModuli = static_cast<size_t>(std::ceil(std::ceil(logq) / dcrtBits));
if (numInitialModuli < 1)
OPENFHE_THROW("numInitialModuli must be greater than 0.");
const size_t sizeQ = multipartyMode == NOISE_FLOODING_MULTIPARTY ?
Expand All @@ -303,8 +316,7 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters
std::vector<NativeInteger> moduliQ(sizeQ);
std::vector<NativeInteger> rootsQ(sizeQ);

// makes sure the first integer is less than 2^60-1 to take advantage of NTL
// optimizations
// makes sure the first integer is less than 2^60-1
moduliQ[0] = LastPrime<NativeInteger>(dcrtBits, 2 * n);
rootsQ[0] = RootOfUnity<NativeInteger>(2 * n, moduliQ[0]);
NativeInteger lastModulus = moduliQ[0];
Expand Down Expand Up @@ -359,7 +371,28 @@ bool ParameterGenerationBFVRNS::ParamsGenBFVRNS(std::shared_ptr<CryptoParameters

uint32_t numPartQ = ComputeNumLargeDigits(numDigits, sizeQ - 1);

cryptoParamsBFVRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, 60, 0);
cryptoParamsBFVRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, 0);

// Validate the ring dimension found using estimated logQ(P) against actual logQ(P)
if (stdLevel != HEStd_NotSet) {
uint32_t logActualQ = 0;
if (ksTech == HYBRID) {
logActualQ = cryptoParamsBFVRNS->GetParamsQP()->GetModulus().GetMSB();
}
else {
logActualQ = cryptoParamsBFVRNS->GetElementParams()->GetModulus().GetMSB();
}

uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ);
if (n < nActual) {
std::string errMsg("The ring dimension found using estimated logQ(P) [");
errMsg += std::to_string(n) + "] does does not meet security requirements. ";
errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to ";
errMsg += std::to_string(nActual) + ".";

OPENFHE_THROW(errMsg);
}
}

return true;
}
Expand Down
62 changes: 53 additions & 9 deletions src/pke/lib/scheme/bgvrns/bgvrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,19 +444,26 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
usint extraModSize = (scalTech == FLEXIBLEAUTOEXT) ? DCRT_MODULUS::DEFAULT_EXTRA_MOD_SIZE : 0;
uint32_t qBound = firstModSize + (numPrimes - 1) * dcrtBits + extraModSize;

// Number of RNS limbs in P
// estimate the extra modulus Q needed for threshold FHE flooding
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
qBound += cryptoParamsBGVRNS->EstimateMultipartyFloodingLogQ();
uint32_t auxTowers = 0;
if (ksTech == HYBRID) {
auxTowers = ceil(ceil(static_cast<double>(qBound) / numPartQ) / auxBits);
qBound += auxTowers * auxBits;
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, dcrtBits, extraModSize, numPrimes, auxBits);
qBound += std::get<0>(hybridKSInfo);
auxTowers = std::get<1>(hybridKSInfo);
}

// when the scaling technique is not FIXED_MANUAL, set a small value so that the rest of the logic could go through
// when the scaling technique is not FIXEDMANUAL (and not FLEXIBLEAUTOEXT),
// set a small value so that the rest of the logic could go through (this is a workaround)
// TODO we should uncouple the logic of FIXEDMANUAL and all FLEXIBLE MODES; some of the code above should be moved
// to the branch for FIXEDMANUAL
if (qBound == 0)
qBound = 20;

// HE Standards compliance logic/check
uint32_t n = computeRingDimension(cryptoParams, qBound, cyclOrder);
// End HE Standards compliance logic/check

uint32_t vecSize = (scalTech != FLEXIBLEAUTOEXT) ? numPrimes : numPrimes + 1;
std::vector<NativeInteger> moduliQ(vecSize);
Expand All @@ -467,15 +474,26 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
auto moduliInfo = computeModuli(cryptoParams, n, evalAddCount, keySwitchCount, auxTowers, numPrimes);
moduliQ = std::get<0>(moduliInfo);
uint32_t newQBound = std::get<1>(moduliInfo);
while (qBound < newQBound) {

// the loop must be executed at least once
do {
qBound = newQBound;
n = computeRingDimension(cryptoParams, newQBound, cyclOrder);
auto moduliInfo = computeModuli(cryptoParams, n, evalAddCount, keySwitchCount, auxTowers, numPrimes);
moduliQ = std::get<0>(moduliInfo);
newQBound = std::get<1>(moduliInfo);
if (ksTech == HYBRID)
newQBound += ceil(ceil(static_cast<double>(newQBound) / numPartQ) / auxBits) * auxBits;
}
if (multipartyMode == NOISE_FLOODING_MULTIPARTY)
newQBound += cryptoParamsBGVRNS->EstimateMultipartyFloodingLogQ();
if (ksTech == HYBRID) {
auto hybridKSInfo = CryptoParametersRNS::EstimateLogP(
numPartQ, std::log2(moduliQ[0].ConvertToDouble()),
(moduliQ.size() > 1) ? std::log2(moduliQ[1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? std::log2(moduliQ[moduliQ.size() - 1].ConvertToDouble()) : 0,
(scalTech == FLEXIBLEAUTOEXT) ? moduliQ.size() - 1 : moduliQ.size(), auxBits);
newQBound += std::get<0>(hybridKSInfo);
}
} while (qBound < newQBound);

cyclOrder = 2 * n;
modulusOrder = getCyclicOrder(n, ptm, scalTech);

Expand All @@ -484,6 +502,7 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
}
}
else {
// FIXEDMANUAL mode
cyclOrder = 2 * n;
// For ModulusSwitching to work we need the moduli to be also congruent to 1 modulo ptm
usint plaintextModulus = ptm;
Expand Down Expand Up @@ -596,6 +615,31 @@ bool ParameterGenerationBGVRNS::ParamsGenBGVRNS(std::shared_ptr<CryptoParameters
cryptoParamsBGVRNS->SetEncodingParams(encodingParamsNew);
}
cryptoParamsBGVRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, 0);

// Validate the ring dimension found using estimated logQ(P) against actual logQ(P)
SecurityLevel stdLevel = cryptoParamsBGVRNS->GetStdLevel();
if (stdLevel != HEStd_NotSet) {
uint32_t logActualQ = 0;
if (ksTech == HYBRID) {
logActualQ = cryptoParamsBGVRNS->GetParamsQP()->GetModulus().GetMSB();
}
else {
logActualQ = cryptoParamsBGVRNS->GetElementParams()->GetModulus().GetMSB();
}

DistributionType distType = (cryptoParamsBGVRNS->GetSecretKeyDist() == GAUSSIAN) ? HEStd_error : HEStd_ternary;
uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ);

if (n < nActual) {
std::string errMsg("The ring dimension found using estimated logQ(P) [");
errMsg += std::to_string(n) + "] does does not meet security requirements. ";
errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to ";
errMsg += std::to_string(nActual) + ".";

OPENFHE_THROW(errMsg);
}
}

return true;
}

Expand Down
28 changes: 26 additions & 2 deletions src/pke/lib/scheme/ckksrns/ckksrns-parametergeneration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete
uint32_t auxBits = AUXMODSIZE;
uint32_t n = cyclOrder / 2;
uint32_t qBound = firstModSize + (numPrimes - 1) * scalingModSize + extraModSize;
// Estimate ciphertext modulus Q bound (in case of GHS/HYBRID P*Q)

// Estimate ciphertext modulus Q*P bound (in case of HYBRID P*Q)
if (ksTech == HYBRID) {
qBound += ceil(ceil(static_cast<double>(qBound) / numPartQ) / auxBits) * auxBits;
auto hybridKSInfo =
CryptoParametersRNS::EstimateLogP(numPartQ, firstModSize, scalingModSize, extraModSize, numPrimes, auxBits);
qBound += std::get<0>(hybridKSInfo);
}

// GAUSSIAN security constraint
Expand Down Expand Up @@ -232,6 +235,27 @@ bool ParameterGenerationCKKSRNS::ParamsGenCKKSRNS(std::shared_ptr<CryptoParamete

cryptoParamsCKKSRNS->PrecomputeCRTTables(ksTech, scalTech, encTech, multTech, numPartQ, auxBits, extraModSize);

// Validate the ring dimension found using estimated logQ(P) against actual logQ(P)
if (stdLevel != HEStd_NotSet) {
uint32_t logActualQ = 0;
if (ksTech == HYBRID) {
logActualQ = cryptoParamsCKKSRNS->GetParamsQP()->GetModulus().GetMSB();
}
else {
logActualQ = cryptoParamsCKKSRNS->GetElementParams()->GetModulus().GetMSB();
}

uint32_t nActual = StdLatticeParm::FindRingDim(distType, stdLevel, logActualQ);
if (n < nActual) {
std::string errMsg("The ring dimension found using estimated logQ(P) [");
errMsg += std::to_string(n) + "] does does not meet security requirements. ";
errMsg += "Report this problem to OpenFHE developers and set the ring dimension manually to ";
errMsg += std::to_string(nActual) + ".";

OPENFHE_THROW(errMsg);
}
}

return true;
}

Expand Down
Loading