Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Homo AES compatible #343

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions misc/aes/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,14 @@
cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
## Use -std=c++14 as default.
set(CMAKE_CXX_STANDARD 14)
## Disable C++ extensions
set(CMAKE_CXX_EXTENSIONS OFF)
## Require full C++ standard
set(CMAKE_CXX_STANDARD_REQUIRED ON)

project(Test_AES_example
LANGUAGES CXX)

find_package(helib 1.0.0 EXACT REQUIRED)
add_executable(TEST_AES simpleAES.cpp homAES.cpp Test_AES.cpp)
target_link_libraries(TEST_AES PUBLIC helib)
28 changes: 9 additions & 19 deletions misc/aes/Test_AES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

namespace std {} using namespace std;
namespace NTL {} using namespace NTL;
namespace helib{} using namespace helib;
#include <cstring>
#include "homAES.h"
#include "Ctxt.h"
#include "helib/Ctxt.h"
#include "helib/ArgMap.h"

static long mValues[][14] = {
//{ p, phi(m), m, d, m1, m2, m3, g1, g2, g3,ord1,ord2,ord3, c_m}
Expand Down Expand Up @@ -39,16 +41,16 @@ extern void Cipher(unsigned char out[16],

int main(int argc, char **argv)
{
ArgMapping amap;
ArgMap amap;

long idx = 0;
amap.arg("sz", idx, "parameter-sets: toy=0 through huge=5");

long c=3;
amap.arg("c", c, "number of columns in the key-switching matrices");

long L=0;
amap.arg("L", L, "# of levels in the modulus chain", "heuristic");
long N=0;
amap.arg("N", N, "# of bits of the modulus chain");

long B=23;
amap.arg("B", B, "# of bits per level (only 64-bit machines)");
Expand All @@ -66,17 +68,6 @@ int main(int argc, char **argv)
vector<long> gens;
vector<long> ords;

if (boot) {
if (L<23) L=23;
if (idx<1) idx=1; // the sz=0 params are incompatible with bootstrapping
} else {
#if (NTL_SP_NBITS<50)
if (L<46) L=46;
#else
if (L<42) L=42;
#endif
}

long p = mValues[idx][0];
// long phim = mValues[idx][1];
long m = mValues[idx][2];
Expand All @@ -94,8 +85,7 @@ int main(int argc, char **argv)
if (abs(mValues[idx][12])>1) ords.push_back(mValues[idx][12]);

cout << "*** Test_AES: c=" << c
<< ", L=" << L
<< ", B=" << B
<< ", N=" << N
<< ", boot=" << boot
<< ", packed=" << packed
<< ", m=" << m
Expand All @@ -107,10 +97,10 @@ int main(int argc, char **argv)
cout << "computing key-independent tables..." << std::flush;
Context context(m, p, /*r=*/1, gens, ords);
#if (NTL_SP_NBITS>=50) // 64-bit machines
context.bitsPerLevel = B;
//context.bitsPerLevel = B;
#endif
context.zMStar.set_cM(mValues[idx][13]/100.0); // the ring constant
buildModChain(context, L, c);
buildModChain(context, N, c);

if (boot) context.makeBootstrappable(mvec);
tm += GetTime();
Expand Down
33 changes: 21 additions & 12 deletions misc/aes/homAES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*/
namespace std {} using namespace std;
namespace NTL {} using namespace NTL;
namespace helib {} using namespace helib;
#include <cstring>
#include "homAES.h"

Expand Down Expand Up @@ -92,19 +93,21 @@ static void packCtxt(vector<Ctxt>& to, const vector<Ctxt>& from,
static void unackCtxt(vector<Ctxt>& to, const vector<Ctxt>& from,
const Mat<GF2X>& unpackConsts);

static long findBaseLevel(const Ctxt& c);

// Implementation of the class HomAES

static const uint8_t aesPolyBytes[] = { 0x1B, 0x1 }; // X^8+X^4+X^3+X+1
const GF2X HomAES::aesPoly = GF2XFromBytes(aesPolyBytes, 2);

HomAES::HomAES(const Context& context): ea2(context,aesPoly,context.alMod)
#ifndef USE_ZZX_POLY // initialize DoubleCRT using the context
, affVec(context)
, affVec(context,context.allPrimes())
#endif
{
// Sanity-check: we need the first dimension to be divisible by 16.
//OLD: assert( context.zMStar.OrderOf(0) % 16 == 0 );
helib::assertEq(context.zMStar.OrderOf(0) % 16, 0l);
helib::assertEq(context.zMStar.OrderOf(0) % 16, 0l, "The first dimension need to be divisible by 16");

// Compute the GF2-affine transformation constants
buildAffineEnc(encAffMat, affVec, ea2);
Expand Down Expand Up @@ -171,7 +174,7 @@ void HomAES::setPackingConstants()

long e = ea.getDegree() / 8; // the extension degree
//OLD: assert(ea.getDegree()==e*8 && e<=(long) sizeof(long));
helib::assertEq(ea.getDegree()==e*8, "ea must have degree divisible by 8");
helib::assertEq(ea.getDegree(), e*8, "ea must have degree divisible by 8");
helib::assertTrue(e<=(long) sizeof(long), "extension degree must be at most 8 times sizeof(long)");

GF2EBak bak; bak.save(); // save current modulus (if any)
Expand Down Expand Up @@ -218,14 +221,14 @@ void HomAES::homAESenc(vector<Ctxt>& eData, const vector<Ctxt>& aesKey) const
for (long i=1; i<(long)aesKey.size(); i++) { // apply the AES rounds

// ByteSub
if (eData[0].findBaseLevel() < 4) batchRecrypt(eData);
if (findBaseLevel(eData[0]) < 4) batchRecrypt(eData);
invert(eData); // apply Z -> Z^{-1} to all elements of eData
#ifdef DEBUG_PRINTOUT
CheckCtxt(eData[0], "+ After invert");
// cerr << " + After invert ";
// decryptAndPrint(cerr, eData[0], *dbgKey, *dbgEa);
#endif
if (eData[0].findBaseLevel() < 2) batchRecrypt(eData);
if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData);
for (long j=0; j<(long)eData.size(); j++) { // GF2 affine transformation
applyLinPolyLL(eData[j], encAffMat, ea2.getDegree());
eData[j].addConstant(affVec);
Expand All @@ -237,7 +240,7 @@ void HomAES::homAESenc(vector<Ctxt>& eData, const vector<Ctxt>& aesKey) const
#endif

// Apply RowShift/ColMix to each ciphertext
if (eData[0].findBaseLevel() < 2) batchRecrypt(eData);
if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData);
if (i<(long)aesKey.size()-1) {
for (long j=0; j<(long)eData.size(); j++)
encRowColTran(eData[j], encLinTran, ea2);
Expand Down Expand Up @@ -295,7 +298,7 @@ void HomAES::homAESdec(vector<Ctxt>& eData, const vector<Ctxt>& aesKey) const
for (long j=0; j<(long)eData.size(); j++) eData[j] -= aesKey[i];

// Apply RowShift/ColMix to each ciphertext
if (eData[0].findBaseLevel() < 2) batchRecrypt(eData);
if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData);
// if (eData[0].log_of_ratio() > (-lvlBits)) batchRecrypt(eData);
if (i<(long)aesKey.size()-1)
for (long j=0; j<(long)eData.size(); j++)
Expand All @@ -311,7 +314,7 @@ void HomAES::homAESdec(vector<Ctxt>& eData, const vector<Ctxt>& aesKey) const
#endif

// ByteSub
if (eData[0].findBaseLevel() < 2) batchRecrypt(eData);
if (findBaseLevel(eData[0]) < 2) batchRecrypt(eData);
for (long j=0; j<(long)eData.size(); j++) { // GF2 affine transformation
eData[j].addConstant(affVec);
applyLinPolyLL(eData[j], decAffMat, ea2.getDegree());
Expand All @@ -321,7 +324,7 @@ void HomAES::homAESdec(vector<Ctxt>& eData, const vector<Ctxt>& aesKey) const
// cerr << " + After affine ";
// decryptAndPrint(cerr, eData[0], *dbgKey, *dbgEa);
#endif
if (eData[0].findBaseLevel() < 4) batchRecrypt(eData);
if (findBaseLevel(eData[0]) < 4) batchRecrypt(eData);
invert(eData); // apply Z -> Z^{-1} to all elements of eData
#ifdef DEBUG_PRINTOUT
CheckCtxt(eData[0], "+ After invert");
Expand Down Expand Up @@ -440,7 +443,7 @@ static void buildAffine(vector<PolyType>& binMat, PolyType* binVec,
ea2.encode(zzxMat[j], scratch); // encode these slots
}
#ifndef USE_ZZX_POLY
binMat.resize(8,DoubleCRT(ea2.getContext()));
binMat.resize(8,DoubleCRT(ea2.getContext(),ea2.getContext().allPrimes()));
for (long j=0; j<8; j++) binMat[j] = zzxMat[j]; // convert to DoubleCRT
#endif

Expand Down Expand Up @@ -478,7 +481,7 @@ static void buildLinEnc(vector<PolyType>& encLinTran,
#ifdef USE_ZZX_POLY
encLinTran.resize(6);
#else
encLinTran.resize(6,DoubleCRT(ea2.getContext()));
encLinTran.resize(6,DoubleCRT(ea2.getContext(),ea2.getContext().allPrimes()));
#endif
for (long i=0; i<3; i++) { // constants for the RowShift/ColMix trans
for (long j=0; j<blocksPerCtxt; j++) {
Expand Down Expand Up @@ -603,7 +606,7 @@ static void buildLinDec(vector<PolyType>& decLinTran,
#ifdef USE_ZZX_POLY
decLinTran.resize(8);
#else
decLinTran.resize(8,DoubleCRT(ea2.getContext()));
decLinTran.resize(8,DoubleCRT(ea2.getContext(),ea2.getContext().allPrimes()));
#endif
for (long i=0; i<4; i++) { // constants for the RowShift/ColMix trans
for (long j=0; j<blocksPerCtxt; j++) {
Expand Down Expand Up @@ -863,3 +866,9 @@ static void unackCtxt(vector<Ctxt>& to, const vector<Ctxt>& from,
}
}
}

// A hack to get this to compile for now
static long findBaseLevel(const Ctxt& c)
{
return long(c.naturalSize() / 23); // FIXME: replace 23 by something else
}
4 changes: 2 additions & 2 deletions misc/aes/homAES.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <stdint.h>
#include <NTL/ZZX.h>
#include <NTL/GF2X.h>
#include "EncryptedArray.h"
#include "hypercube.h"
#include "helib/EncryptedArray.h"
#include "helib/hypercube.h"

#ifdef USE_ZZX_POLY
#define PolyType ZZX
Expand Down
1 change: 1 addition & 0 deletions misc/aes/simpleAES.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Find the Wikipedia page of AES at:
// Used for giving output to the screen.
#include<cstdlib>
#include<cstdio>
#include<helib/helib.h>

// The number of columns comprising a state in AES. This is a constant in AES.
// Value=4
Expand Down