diff --git a/v2/include/kyber.h b/v2/include/kyber.h index 28039c2..ef4df31 100644 --- a/v2/include/kyber.h +++ b/v2/include/kyber.h @@ -154,12 +154,12 @@ bool make_module_vector_zero(module_vector* out); bool module_vector_equal(module_vector& in1, module_vector& in2); void print_module_vector(module_vector& mv); -bool ntt_module_apply_array(module_array& A, module_vector& v, module_vector* out); +bool ntt_module_apply_array(int g, module_array& A, module_vector& v, module_vector* out); void print_kyber_parameters(kyber_parameters& p); -bool kyber_keygen(kyber_parameters& p, int* ek_len, byte* ek, - int* dk_len, byte* dk); +bool kyber_keygen(kyber_parameters& p, int b_r_len, byte* b_r, + int* ek_len, byte* ek, int* dk_len, byte* dk); bool kyber_encrypt(kyber_parameters& p, int ek_len, byte* ek, int m_len, byte* m, int* c_len, byte* c); bool kyber_decrypt(kyber_parameters& p, int dk_len, byte* dk, diff --git a/v2/kyber/kyber.cc b/v2/kyber/kyber.cc index 2104592..897a954 100644 --- a/v2/kyber/kyber.cc +++ b/v2/kyber/kyber.cc @@ -283,7 +283,7 @@ bool module_apply_array(module_array& A, module_vector& v, module_vector* out) { return true; } -bool ntt_module_apply_array(module_array& A, module_vector& v, module_vector* out) { +bool ntt_module_apply_array(int g, module_array& A, module_vector& v, module_vector* out) { if ((A.nc_ != v.dim_) || A.nr_ != out->dim_) { printf("mismatch, nc: %d, v: %d, nr: %d, out: %d\n", A.nc_, v.dim_, A.nr_, out->dim_); return false; @@ -298,9 +298,9 @@ bool ntt_module_apply_array(module_array& A, module_vector& v, module_vector* ou for (int j = 0; j < v.dim_; j++) { if (!coefficient_vector_zero(&t)) return false; - // change this to an multiply_ntt - if (!coefficient_mult(*A.c_[A.index(i,j)], *v.c_[j], &t)) + if (!ntt_mult(g, *A.c_[A.index(i,j)], *v.c_[j], &t)) { return false; + } if (!coefficient_vector_add_to(t, &acc)) return false; } @@ -815,8 +815,8 @@ bool byte_decode_from_vector(int d, int n, int in_len, byte* in, vector& v) // ek := byte_encode(12) (t^) || rho // dk := byte_encode(12) (s^) // return (ek, dk) -bool kyber_keygen(kyber_parameters& p, int* ek_len, byte* ek, - int* dk_len, byte* dk) { +bool kyber_keygen(kyber_parameters& p, int r_len, byte* b_r, + int* ek_len, byte* ek, int* dk_len, byte* dk) { #if 1 int g = 17; @@ -858,7 +858,8 @@ bool kyber_keygen(kyber_parameters& p, int* ek_len, byte* ek, int b_xof_len = 384; byte b_xof[b_xof_len]; memset(b_xof, 0, b_xof_len); - if (!xof(p.eta1_, b_xof_len, b_xof, i, j, b_xof_len * NBITSINBYTE, b_xof)) { + + if (!xof(p.eta1_, 32, parameters, i, j, b_xof_len * NBITSINBYTE, b_xof)) { printf("kyber_keygen: xof failed\n"); return false; } @@ -873,8 +874,16 @@ bool kyber_keygen(kyber_parameters& p, int* ek_len, byte* ek, int b_prf_len = 64 * p.eta1_; byte b_prf[b_prf_len]; memset(b_prf, 0, b_prf_len); + + if (!prf(p.eta1_, 32, ¶meters[32], sizeof(int), (byte*)&N, + NBITSINBYTE * 64 * p.eta1_, b_prf)) { + return false; + } + // s[i] := sample_poly_cbd(eta1, PRF(eta1, sigma, N)) - // if (!sample_poly_cbd(p.q_, p.eta1_, p.k_, b_prf_len, b_prf, int* out)) + //if (!sample_poly_cbd(p.q_, p.eta1_, p.k_, b_prf_len, b_prf, s[])) { + //return false; + //} N++; } for (int i = 0; i < p.k_; i++) { diff --git a/v2/kyber/test_kyber.cc b/v2/kyber/test_kyber.cc index 36e876b..4b4b54c 100644 --- a/v2/kyber/test_kyber.cc +++ b/v2/kyber/test_kyber.cc @@ -41,7 +41,17 @@ bool test_kyber1() { module_vector e(p.q_, p.n_, p.k_); module_vector s(p.q_, p.n_, p.k_); module_vector t(p.q_, p.n_, p.k_); - if (!kyber_keygen(p, &ek_len, ek, &dk_len, dk)) { + int b_r_len = 32; + byte b_r[b_r_len]; + memset(b_r, 0, b_r_len); + int n_b = crypto_get_random_bytes(b_r_len, b_r); + if (n_b != b_r_len) { + printf("wrong return from crypto_get_random_bytes\n"); + return false; + } + + + if (!kyber_keygen(p, b_r_len, b_r, &ek_len, ek, &dk_len, dk)) { printf("Could not init kyber_keygen\n"); return false; }