Skip to content

Commit

Permalink
more tests, partway through keygen
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmucb committed Apr 23, 2024
1 parent 6cb9a36 commit f738ac4
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 20 deletions.
2 changes: 2 additions & 0 deletions v2/include/kyber.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ bool coefficient_vector_add_to(coefficient_vector& in, coefficient_vector* out);
bool coefficient_equal(coefficient_vector& in1, coefficient_vector& in2);
bool coefficient_apply_array(coefficient_array& A, coefficient_vector& v, coefficient_vector* out);

bool make_module_array_zero(module_array& B);
void print_module_array(module_array& ma);
bool module_vector_mult_by_scalar(coefficient_vector& in1, module_vector& in2,
module_vector* out);
Expand All @@ -149,6 +150,7 @@ bool module_vector_add(module_vector& in1, module_vector& in2,
bool module_vector_subtract(module_vector& in1, module_vector& in2,
module_vector* out);
bool module_apply_array(module_array& A, module_vector& v, module_vector* out);
bool module_apply_transposed_array(module_array& A, module_vector& v, module_vector* out);
bool module_vector_is_zero(module_vector& in);
bool make_module_vector_zero(module_vector* out);
bool module_vector_equal(module_vector& in1, module_vector& in2);
Expand Down
92 changes: 79 additions & 13 deletions v2/kyber/kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,26 @@ bool module_vector_subtract(module_vector& in1, module_vector& in2, module_vecto
return module_vector_add(in1, neg_in2, out);
}

bool make_module_vector_zero(module_vector* v) {
for( int i = 0; i < v->dim_; i++) {
if (!coefficient_vector_zero(v->c_[i])) {
return false;
}
}
return true;
}

bool make_module_array_zero(module_array& B) {
for( int i = 0; i < B.nr_; i++) {
for( int j = 0; j < B.nc_; j++) {
if (!coefficient_vector_zero(B.c_[B.index(i,j)])) {
return false;
}
}
}
return true;
}

bool module_apply_array(module_array& A, module_vector& v, module_vector* out) {
if ((A.nc_ != v.dim_) || A.nr_ != out->dim_) {
printf("mismatch, nc: %d, v: %d, nr: %d, out: %d\n", A.nc_, v.dim_, A.nr_, out->dim_);
Expand All @@ -286,6 +306,33 @@ bool module_apply_array(module_array& A, module_vector& v, module_vector* out) {
return true;
}

bool module_apply_transposed_array(module_array& A, module_vector& v, module_vector* out) {
if ((A.nr_ != v.dim_) || A.nc_ != out->dim_) {
printf("mismatch, nc: %d, v: %d, nr: %d, out: %d\n", A.nc_, v.dim_, A.nr_, out->dim_);
return false;
}

coefficient_vector acc(v.q_, v.n_);
coefficient_vector t(v.q_, v.n_);

for (int i = 0; i < A.nr_; i++) {
if (!coefficient_vector_zero(&acc))
return false;
for (int j = 0; j < v.dim_; j++) {
if (!coefficient_vector_zero(&t))
return false;
if (!coefficient_mult(*A.c_[A.index(j,i)], *v.c_[j], &t))
return false;
if (!coefficient_vector_add_to(t, &acc))
return false;
}
if (!coefficient_set_vector(acc, out->c_[i]))
return false;
}
return true;
}


bool ntt_module_apply_array(int g, module_array& A, module_vector& v, module_vector* out) {
if ((A.nc_ != v.dim_) || A.nr_ != out->dim_) {
printf("mismatch, nc: %d, v: %d, nr: %d, out: %d\n", A.nc_, v.dim_, A.nr_, out->dim_);
Expand Down Expand Up @@ -770,10 +817,11 @@ bool prf(int eta, int in1_len, byte* in1, int in2_len, byte* in2, int bit_out_le

// XOF(ρ, i, j) := SHAKE128(ρ||i|| j)
bool xof(int eta, int in1_len, byte* in1, int i, int j, int bit_out_len, byte* out) {
#if 0
sha3 h;

if (!h.init(256, bit_out_len)) {
printf("xof init failed\n");
printf("xof init failed %d\n", bit_out_len);
return false;
}
h.add_to_hash(in1_len, in1);
Expand All @@ -784,6 +832,10 @@ bool xof(int eta, int in1_len, byte* in1, int i, int j, int bit_out_len, byte* o
printf("xof failed\n");
return false;
}
#else
// test only
int l = crypto_get_random_bytes(bit_out_len / NBITSINBYTE, out);
#endif
return true;
}

Expand Down Expand Up @@ -948,12 +1000,14 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
printf("kyber_keygen: crypto_get_random_bytes failed\n");
return false;
}
#if 1
printf("\n kyber_keygen\n");
printf("d: ");
print_bytes(32, d);
printf("sigma || rho: ");
print_bytes(64, parameters);
printf("\n");
#endif

module_vector e(p.q_, p.n_, p.k_);
module_vector s(p.q_, p.n_, p.k_);
Expand All @@ -964,7 +1018,6 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
module_vector t_ntt(p.q_, p.n_, p.k_);
module_vector r_ntt(p.q_, p.n_, p.k_);

return true;
int N = 0;
for (int i = 0; i < p.k_; i++) {
for (int j = 0; j < p.k_; j++) {
Expand Down Expand Up @@ -999,13 +1052,14 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
}
N++;
}
for (int i = 0; i < p.k_; i++) {
return true; // FIX
for (int i = 0; i < e.dim_; i++) {
int b_prf_len = 64;
byte b_prf[b_prf_len];
memset(b_prf, 0, b_prf_len);
if (!prf(p.eta1_, 32, &parameters[32], sizeof(int), (byte*)&N,
NBITSINBYTE * 64 * p.eta1_, b_prf)) {
printf("kyber_keygen: prf (1) failed\n");
printf("kyber_keygen: prf (2) failed\n");
return false;
}
if (!sample_poly_cbd(p.q_, p.eta1_, p.k_, b_prf_len, b_prf, e.c_[i]->c_)) {
Expand All @@ -1014,18 +1068,24 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
}
N++;
}
return true; // FIX

for (int i = 0; i < s_ntt.dim_; i++) {
#if 1
printf("s: ");
print_module_vector(s);
printf("\n");
printf("e: ");
print_module_vector(e);
printf("\n");
#endif

for (int i = 0; i < s.dim_; i++) {
if (!ntt(g, *s.c_[i], s_ntt.c_[i])) {
printf("kyber_keygen: ntt (1) failed\n");
return false;
}
}
for (int i = 0; i < e_ntt.dim_; i++) {
return false;
}
if (!ntt(g, *e.c_[i], e_ntt.c_[i])) {
printf("kyber_keygen: ntt (2) failed\n");
return false;
}
return false;
}
}

// t^ := A^(s^)+e^
Expand All @@ -1044,6 +1104,12 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
return false;
}

#if 1
printf("t^: ");
print_module_vector(t_ntt);
printf("\n");
#endif

// ek := byte_encode(12) (t^) || rho
for (int i = 0; i < t_ntt.dim_; i++) {
if (!byte_encode_from_vector(12, p.n_, t_ntt.c_[i]->c_, ek)) {
Expand Down
102 changes: 95 additions & 7 deletions v2/kyber/test_kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ bool test_kyber1() {
printf("Could not init kyber_keygen\n");
return false;
}
if (FLAGS_print_all) {
printf("ek:\n");
print_bytes(ek_len, ek);
printf("\n");
printf("dk:\n");
print_bytes(dk_len, dk);
printf("\n");
}
return true;

int m_len = 32;
Expand Down Expand Up @@ -490,13 +498,93 @@ bool test_kyber_support() {
printf("Could not inverse ntt_mult\n");
return false;
}
printf("\n");
print_coefficient_vector(ntt_in);
printf(" x_ntt\n");
print_coefficient_vector(ntt_in);
printf(" =\n");
print_coefficient_vector(m_out);
printf("\n");
if (FLAGS_print_all) {
printf("\n");
print_coefficient_vector(ntt_in);
printf(" x_ntt\n");
print_coefficient_vector(ntt_in);
printf(" =\n");
print_coefficient_vector(m_out);
printf("\n");
}

module_array B(p.q_, p.n_, 4, 4);
module_vector vb1(p.q_, p.n_, 4);
module_vector vb2(p.q_, p.n_, 4);

if (!make_module_array_zero(B)) {
return false;
}
if (!make_module_vector_zero(&vb1)) {
return false;
}
if (!make_module_vector_zero(&vb2)) {
return false;
}
for (int i = 0; i < 4; i++) {
B.c_[B.index(i,i)]->c_[0] = 1;
}
B.c_[B.index(0,1)]->c_[0] = 1;

vb1.c_[0]->c_[0] = 1;
vb1.c_[1]->c_[0] = 1;
vb1.c_[2]->c_[0] = 1;
vb1.c_[3]->c_[0] = 1;
if (!module_apply_array(B, vb1, &vb2)) {
return false;
}
if (FLAGS_print_all) {
printf("First apply:\n");
print_module_vector(vb2);
printf("\n");
}
if (vb2.c_[0]->c_[0] != 2 || vb2.c_[1]->c_[0] != 1) {
printf("module_apply_array failed\n");
return false;
}
if (!make_module_vector_zero(&vb2)) {
return false;
}
if (!module_apply_transposed_array(B, vb1, &vb2)) {
return false;
}
if (FLAGS_print_all) {
printf("Second apply:\n");
print_module_vector(vb2);
printf("\n");
}
if (vb2.c_[0]->c_[0] != 1 || vb2.c_[1]->c_[0] != 2) {
printf("module_apply_transposed_array failed\n");
return false;
}

vb1.c_[0]->c_[0] = 1;
vb1.c_[1]->c_[0] = 1;
vb1.c_[2]->c_[0] = 1;
vb1.c_[3]->c_[0] = 1;
vb2.c_[0]->c_[0] = 1;
vb2.c_[1]->c_[0] = -1;
vb2.c_[2]->c_[0] = 1;
vb2.c_[3]->c_[0] = 1;
coefficient_vector cv1(p.q_, p.n_);
coefficient_vector_zero(&cv1);
if (!module_vector_dot_product(vb1, vb2, &cv1)) {
return false;
}
if (FLAGS_print_all) {
printf("Dot product:\n");
print_coefficient_vector(cv1);
printf("\n");
}
coefficient_vector_zero(&cv1);
if (!module_vector_dot_product_first_transposed(vb1, vb2, &cv1)) {
return false;
}
if (FLAGS_print_all) {
printf("Dot product transposed:\n");
print_coefficient_vector(cv1);
printf("\n");
}

return true;
}
Expand Down

0 comments on commit f738ac4

Please sign in to comment.