Skip to content

Commit

Permalink
reorganized tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmucb committed Apr 28, 2024
1 parent 0f139ca commit 4d42ecc
Showing 1 changed file with 96 additions and 50 deletions.
146 changes: 96 additions & 50 deletions v2/kyber/kyber.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1112,36 +1112,13 @@ bool kyber_keygen(int g, kyber_parameters& p, int* ek_len, byte* ek,
printf("r_ntt:\n");
print_module_vector(r_ntt);
printf("\n");

// Compare s_ntt dot (A_ntt^T r_ntt) and r_ntt dot (A_ntt s_ntt)
module_vector s1(p.q_, p.n_, p.k_);
module_vector s2(p.q_, p.n_, p.k_);
coefficient_vector r1(p.q_, p.n_);
coefficient_vector r2(p.q_, p.n_);
coefficient_vector r3(p.q_, p.n_);

if (!ntt_module_apply_transposed_array(g, A_ntt, r_ntt, &s1)) {
printf("test ntt_module_apply_transposed_array fail\n");
return false;
}
if (!ntt_module_apply_array(g, A_ntt, s_ntt, &s2)) {
printf("test ntt_module_apply_array fail\n");
return false;
}
if (!ntt_module_vector_dot_product(s_ntt, s1, &r1)) {
printf("test module_vector_dot_product (1) fail\n");
extern bool special_test_2(kyber_parameters& p, int g, module_array& A_ntt,
module_vector& s_ntt, module_vector& r_ntt);
if (!special_test_2(p, g, A_ntt, s_ntt, r_ntt)) {
printf("********special_test_2 failed\n");
return false;
}
if (!ntt_module_vector_dot_product(r_ntt, s2, &r2)) {
printf("test module_vector_dot_product (3) fail\n");
return false;
}
for (int j = 0; j < 256; j++) {
r3.c_[j] = (p.q_ + r1.c_[j] - r1.c_[j]) % p.q_;
}
printf("COMPARISON test\n");
print_coefficient_vector(r3);
printf("\n");

#endif
return true;
}
Expand Down Expand Up @@ -1436,30 +1413,13 @@ print_coefficient_vector(nu);
print_bytes(c2_b_len, b_c2);
printf("\n");

printf("\n\ntest, decompressed mu\n");
coefficient_vector t_compressed_mu(p.q_, p.n_);
byte checked_m[32];
memset(checked_m, 0, 32);
for (int j = 0; j < p.n_; j++) {
t_compressed_mu.c_[j] = compress(p.q_, mu.c_[j], 1);
}
if (!byte_encode_from_vector(1, p.n_, t_compressed_mu.c_, checked_m)) {
extern bool special_test_1(kyber_parameters& p, coefficient_vector& mu,
coefficient_vector& nu,
int len_m, byte* m, int len_c2, byte*c2);
if (!special_test_1(p, mu, nu, m_len, m, c2_b_len, b_c2)) {
printf("****special_test_1 failed\n");
return false;
}
printf("recovered m from mu: ");
print_bytes(32, checked_m);

printf("\n\ntest, decompressed nu\n");
coefficient_vector t_nu(p.q_, p.n_);
if (!byte_decode_to_vector(p.dv_, p.n_, c2_b_len, b_c2, t_nu.c_)) {
return false;
}
for (int j = 0; j < p.n_; j++) {
t_nu.c_[j] = decompress(p.q_, t_nu.c_[j], p.dv_);
}
printf("Recovered nu\n");
print_coefficient_vector(t_nu);
printf("\n");
#endif
return true;
}
Expand Down Expand Up @@ -1738,3 +1698,89 @@ bool kyber_kem_decaps(int g, kyber_parameters& p, int kem_dk_len, byte* kem_dk,
return true;
}


// ------------------------------------------------------------------------------------

bool special_test_1(kyber_parameters& p, coefficient_vector& mu,
coefficient_vector& nu,
int len_m, byte* m, int len_c2, byte*c2) {

printf("\n\ntest, decompressed mu\n");
coefficient_vector t_compressed_mu(p.q_, p.n_);
byte checked_m[32];
memset(checked_m, 0, 32);

for (int j = 0; j < p.n_; j++) {
t_compressed_mu.c_[j] = compress(p.q_, mu.c_[j], 1);
}
if (!byte_encode_from_vector(1, p.n_, t_compressed_mu.c_, checked_m)) {
return false;
}
printf("m : ");
print_bytes(32, m);
printf("recovered m from mu: ");
print_bytes(32, checked_m);
if (memcmp(m, checked_m, 32) != 0) {
printf("********m doesnt match checked m\n");
}

printf("\n\ntest, decompressed nu\n");
coefficient_vector t_nu(p.q_, p.n_);
if (!byte_decode_to_vector(12, p.n_, len_c2, c2, t_nu.c_)) {
return false;
}
for (int j = 0; j < p.n_; j++) {
t_nu.c_[j] = decompress(p.q_, t_nu.c_[j], p.dv_);
}
printf("Recovered nu\n");
print_coefficient_vector(t_nu);
printf("\n");
if (!coefficient_equal(nu, t_nu)) {
printf("********m doesnt match checked m\n");
}
return true;
}

bool special_test_2(kyber_parameters& p, int g, module_array& A_ntt,
module_vector& s_ntt, module_vector& r_ntt) {

printf("special_test_2\n");
// Compare s_ntt dot (A_ntt^T r_ntt) and r_ntt dot (A_ntt s_ntt)
module_vector s1(p.q_, p.n_, p.k_);
module_vector s2(p.q_, p.n_, p.k_);
coefficient_vector r1(p.q_, p.n_);
coefficient_vector r2(p.q_, p.n_);
coefficient_vector r3(p.q_, p.n_);

if (!module_apply_transposed_array(A_ntt, r_ntt, &s1)) {
printf("test ntt_module_apply_transposed_array fail\n");
return false;
}
if (!module_apply_array(A_ntt, s_ntt, &s2)) {
printf("test ntt_module_apply_array fail\n");
return false;
}
if (!module_vector_dot_product(s_ntt, s1, &r1)) {
printf("test module_vector_dot_product (1) fail\n");
return false;
}
if (!module_vector_dot_product(r_ntt, s2, &r2)) {
printf("test module_vector_dot_product (3) fail\n");
return false;
}

for (int j = 0; j < 256; j++) {
r3.c_[j] = (p.q_ + r1.c_[j] - r2.c_[j]) % p.q_;
}
printf("COMPARISON test\n");
print_coefficient_vector(r3);
printf("\n");
if (!coefficient_equal(r1, r2)) {
printf("r1 != r2\n");
return false;
}
return true;
}

// ------------------------------------------------------------------------------------

0 comments on commit 4d42ecc

Please sign in to comment.