Skip to content

Commit

Permalink
fixed u encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmucb committed Apr 28, 2024
1 parent 8d633f1 commit 0f139ca
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 62 deletions.
4 changes: 0 additions & 4 deletions v2/include/kyber.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,8 @@ bool make_module_vector_zero(module_vector* out);
bool module_vector_equal(module_vector& in1, module_vector& in2);
bool module_vector_dot_product(module_vector& in1, module_vector& in2,
coefficient_vector* out);
bool module_vector_dot_product_first_transposed(module_vector& in1,
module_vector& in2, coefficient_vector* out);
bool ntt_module_vector_dot_product(module_vector& in1,
module_vector& in2, coefficient_vector* out);
bool ntt_module_vector_dot_product_first_transposed(module_vector& in1,
module_vector& in2, coefficient_vector* out);
void print_module_vector(module_vector& mv);

bool ntt_module_apply_array(int g, module_array& A, module_vector& v, module_vector* out);
Expand Down
81 changes: 32 additions & 49 deletions v2/kyber/kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ bool kyber_parameters::init_kyber(int ks) {
q_ = 3329;
k_ = 4;
gamma_ = 17;
du_ = 11;
dv_ = 5;
du_ = 12; // remove 11;
dv_ = 12; // remove 5;
eta1_ = 2;
eta2_ = 2;
return true;
Expand Down Expand Up @@ -406,51 +406,12 @@ bool module_vector_dot_product(module_vector& in1, module_vector& in2,
return true;
}

bool module_vector_dot_product_first_transposed(module_vector& in1,
module_vector& in2, coefficient_vector* out) {

if (!coefficient_vector_zero(out)) {
return false;
}
for (int i = 0; i < in1.dim_; i++) {
coefficient_vector t(in1.q_, in1.n_);
if (!coefficient_vector_zero(&t)) {
return false;
}
if (!coefficient_mult(*in1.c_[in1.dim_ - 1 -i], *in2.c_[i], &t)) {
return false;
}
if (!coefficient_vector_add_to(t, out)) {
return false;
}
}
return true;
}

bool ntt_module_vector_dot_product(module_vector& in1, module_vector& in2,
coefficient_vector* out) {

if (!coefficient_vector_zero(out)) {
if (in1.n_ != in2.n_ || out->len_ != in2.n_ || in1.dim_ != in2.dim_) {
return false;
}
for (int i = 0; i < in1.dim_; i++) {
coefficient_vector t(in1.q_, in1.n_);
if (!coefficient_vector_zero(&t)) {
return false;
}
if (!multiply_ntt(17, *in1.c_[i], *in2.c_[i], &t)) {
return false;
}
if (!coefficient_vector_add_to(t, out)) {
return false;
}
}
return true;
}

bool ntt_module_vector_dot_product_first_transposed(module_vector& in1,
module_vector& in2, coefficient_vector* out) {

if (!coefficient_vector_zero(out)) {
return false;
}
Expand All @@ -459,7 +420,7 @@ bool ntt_module_vector_dot_product_first_transposed(module_vector& in1,
if (!coefficient_vector_zero(&t)) {
return false;
}
if (!multiply_ntt(17, *in1.c_[in1.dim_ - 1 - i], *in2.c_[i], &t)) {
if (!multiply_ntt(17, *in1.c_[i], *in2.c_[i], &t)) {
return false;
}
if (!coefficient_vector_add_to(t, out)) {
Expand Down Expand Up @@ -1075,6 +1036,9 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
N++;
}

// remove
make_module_vector_zero(&e);

// Secret and noise to ntt domain
for (int i = 0; i < s.dim_; i++) {
if (!ntt(g, *s.c_[i], s_ntt.c_[i])) {
Expand Down Expand Up @@ -1164,11 +1128,11 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
printf("test ntt_module_apply_array fail\n");
return false;
}
if (!module_vector_dot_product(s_ntt, s1, &r1)) {
if (!ntt_module_vector_dot_product(s_ntt, s1, &r1)) {
printf("test module_vector_dot_product (1) fail\n");
return false;
}
if (!module_vector_dot_product(r_ntt, s2, &r2)) {
if (!ntt_module_vector_dot_product(r_ntt, s2, &r2)) {
printf("test module_vector_dot_product (3) fail\n");
return false;
}
Expand Down Expand Up @@ -1305,6 +1269,9 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek,
N++;
}

//remove
make_module_vector_zero(&e1);

// Generate noise element (e2)
{
int b_prf_len = 64 * p.eta2_;
Expand All @@ -1323,6 +1290,9 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek,
N++;
}

// remove
coefficient_vector_zero(&e2);

module_vector tmp1(p.q_, p.n_, p.k_);
module_vector tmp2(p.q_, p.n_, p.k_);
if (!make_module_vector_zero(&tmp1)) {
Expand Down Expand Up @@ -1368,14 +1338,15 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek,
int c1_b_len = (p.du_ * p.n_ * p.k_) / NBITSINBYTE;
byte b_c1[c1_b_len];
byte* pp = b_c1;
int len = (p.du_ * p.n_) / NBITSINBYTE;

for (int i = 0; i < p.k_; i++) {
for (int j = 0; j < p.n_; j++) {
compressed_u.c_[i]->c_[j] = compress(p.q_, u.c_[i]->c_[j], p.du_);
}
if (!byte_encode_from_vector(p.du_, p.n_, compressed_u.c_[i]->c_, pp)) {
return false;
}
int len = (p.du_ * p.n_) / NBITSINBYTE;
pp += len;
}

Expand All @@ -1394,6 +1365,12 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek,
if (!ntt_inv(g, nu_ntt, &nu)) {
return false;
}

printf("t_ntt dot r_ntt:\n");
print_coefficient_vector(nu_ntt);
printf("t dot r:\n");
print_coefficient_vector(nu);

if (!coefficient_vector_add_to(e2, &nu)) {
return false;
}
Expand All @@ -1403,12 +1380,12 @@ bool kyber_encrypt(int g, kyber_parameters& p, int ek_len, byte* ek,

// Compress and encode nu (c2)
coefficient_vector compressed_nu(p.q_, p.n_);
int c2_b_len = (p.dv_ * 256) / NBITSINBYTE;
int c2_b_len = (p.dv_ * p.n_) / NBITSINBYTE;
byte b_c2[c2_b_len];
for (int j = 0; j < compressed_nu.len_; j++) {
compressed_nu.c_[j] = compress(p.q_, nu.c_[j], p.dv_);
}
if (!byte_encode_from_vector(p.dv_, 256, compressed_nu.c_, b_c2)) {
if (!byte_encode_from_vector(p.dv_, p.n_, compressed_nu.c_, b_c2)) {
return false;
}

Expand Down Expand Up @@ -1512,10 +1489,10 @@ bool kyber_decrypt(int g, kyber_parameters& p, int dk_len, byte* dk,
byte* c2 = &c[32 * p.du_ * p.k_];

byte* p_c1 = c1;
int len = 32;
module_vector u(p.q_, p.n_, p.k_);

// Recover u from c1
int len = (p.du_ *p.n_) / NBITSINBYTE;
for (int i = 0; i < p.k_; i++) {
if (!byte_decode_to_vector(p.du_, p.n_, len, p_c1, u.c_[i]->c_)) {
return false;
Expand Down Expand Up @@ -1566,6 +1543,12 @@ bool kyber_decrypt(int g, kyber_parameters& p, int dk_len, byte* dk,
printf("kyber_decrypt: ntt_inv failed\n");
return false;
}

printf("s_ntt dot u_ntt:\n");
print_coefficient_vector(w_ntt);
printf("s dot u:\n");
print_coefficient_vector(w);

// w = -w
for (int j = 0; j < w.len_; j++) {
w.c_[j] = (w.q_ - w.c_[j]) % w.q_;
Expand Down
9 changes: 0 additions & 9 deletions v2/kyber/test_kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,15 +585,6 @@ bool test_kyber_support() {
print_coefficient_vector(cv1);
printf("\n");
}
coefficient_vector_zero(&cv1);
if (!module_vector_dot_product_first_transposed(vb1, vb2, &cv1)) {
return false;
}
if (FLAGS_print_all) {
printf("Dot product transposed:\n");
print_coefficient_vector(cv1);
printf("\n");
}

int a1, a2;
int m1, m2;
Expand Down

0 comments on commit 0f139ca

Please sign in to comment.