Skip to content

Commit

Permalink
Merge pull request #330 from cryspen/karthik/unpacked-api
Browse files Browse the repository at this point in the history
Unpacked API for ML-KEM
  • Loading branch information
karthikbhargavan authored Jul 10, 2024
2 parents 4a72136 + 86a5cba commit b345add
Show file tree
Hide file tree
Showing 72 changed files with 15,909 additions and 1,950 deletions.
162 changes: 160 additions & 2 deletions libcrux-ml-kem/benches/ml-kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,41 @@ pub fn key_generation(c: &mut Criterion) {
init!(mlkem512, "Key Generation", c);
init!(mlkem768, "Key Generation", c);
init!(mlkem1024, "Key Generation", c);

#[cfg(all(
feature = "mlkem768",
feature = "pre-verification",
feature = "simd256"
))]
c.bench_function("libcrux avx2 unpacked (external random)", |b| {
let mut seed = [0; 64];
rng.fill_bytes(&mut seed);
b.iter(|| {
let _kp = mlkem768::avx2::generate_key_pair_unpacked(seed);
})
});

#[cfg(all(
feature = "mlkem768",
feature = "pre-verification",
feature = "simd128"
))]
c.bench_function("libcrux neon unpacked (external random)", |b| {
let mut seed = [0; 64];
rng.fill_bytes(&mut seed);
b.iter(|| {
let _kp = mlkem768::neon::generate_key_pair_unpacked(seed);
})
});

#[cfg(all(feature = "mlkem768", feature = "pre-verification"))]
c.bench_function("libcrux portable unpacked (external random)", |b| {
let mut seed = [0; 64];
rng.fill_bytes(&mut seed);
b.iter(|| {
let _kp = mlkem768::portable::generate_key_pair_unpacked(seed);
})
});
}

pub fn pk_validation(c: &mut Criterion) {
Expand Down Expand Up @@ -80,7 +115,6 @@ pub fn encapsulation(c: &mut Criterion) {
($name:expr, $p:path, $group:expr) => {
$group.bench_function(format!("libcrux {} (external random)", $name), |b| {
use $p as p;

let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
Expand All @@ -100,14 +134,69 @@ pub fn encapsulation(c: &mut Criterion) {
init!(mlkem512, "Encapsulation", c);
init!(mlkem768, "Encapsulation", c);
init!(mlkem1024, "Encapsulation", c);

#[cfg(all(feature = "mlkem768", feature = "pre-verification"))]
c.bench_function("libcrux unpacked portable (external random)", |b| {
let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
OsRng.fill_bytes(&mut seed2);
b.iter_batched(
|| mlkem768::portable::generate_key_pair_unpacked(seed1),
|keypair| {
let (_shared_secret, _ciphertext) =
mlkem768::portable::encapsulate_unpacked(&keypair.public_key, seed2);
},
BatchSize::SmallInput,
)
});

#[cfg(all(
feature = "mlkem768",
feature = "pre-verification",
feature = "simd128"
))]
c.bench_function("libcrux unpacked neon (external random)", |b| {
let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
OsRng.fill_bytes(&mut seed2);
b.iter_batched(
|| mlkem768::neon::generate_key_pair_unpacked(seed1),
|keypair| {
let (_shared_secret, _ciphertext) =
mlkem768::neon::encapsulate_unpacked(&keypair.public_key, seed2);
},
BatchSize::SmallInput,
)
});

#[cfg(all(
feature = "mlkem768",
feature = "pre-verification",
feature = "simd256"
))]
c.bench_function("libcrux unpacked avx2 (external random)", |b| {
let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
OsRng.fill_bytes(&mut seed2);
b.iter_batched(
|| mlkem768::avx2::generate_key_pair_unpacked(seed1),
|keypair| {
let (_shared_secret, _ciphertext) =
mlkem768::avx2::encapsulate_unpacked(&keypair.public_key, seed2);
},
BatchSize::SmallInput,
)
});
}

pub fn decapsulation(c: &mut Criterion) {
macro_rules! fun {
($name:expr, $p:path, $group:expr) => {
$group.bench_function(format!("libcrux {}", $name), |b| {
use $p as p;

let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
Expand All @@ -132,6 +221,75 @@ pub fn decapsulation(c: &mut Criterion) {
init!(mlkem512, "Decapsulation", c);
init!(mlkem768, "Decapsulation", c);
init!(mlkem1024, "Decapsulation", c);

#[cfg(all(feature = "mlkem768", feature = "pre-verification"))]
c.bench_function("libcrux unpacked portable", |b| {
let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
OsRng.fill_bytes(&mut seed2);
b.iter_batched(
|| {
let keypair = mlkem768::portable::generate_key_pair_unpacked(seed1);
let (ciphertext, _shared_secret) =
mlkem768::portable::encapsulate_unpacked(&keypair.public_key, seed2);
(keypair, ciphertext)
},
|(keypair, ciphertext)| {
let _shared_secret =
mlkem768::portable::decapsulate_unpacked(&keypair, &ciphertext);
},
BatchSize::SmallInput,
)
});

#[cfg(all(
feature = "mlkem768",
feature = "pre-verification",
feature = "simd128"
))]
c.bench_function("libcrux unpacked neon", |b| {
let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
OsRng.fill_bytes(&mut seed2);
b.iter_batched(
|| {
let keypair = mlkem768::neon::generate_key_pair_unpacked(seed1);
let (ciphertext, _shared_secret) =
mlkem768::neon::encapsulate_unpacked(&keypair.public_key, seed2);
(keypair, ciphertext)
},
|(keypair, ciphertext)| {
let _shared_secret = mlkem768::neon::decapsulate_unpacked(&keypair, &ciphertext);
},
BatchSize::SmallInput,
)
});

#[cfg(all(
feature = "mlkem768",
feature = "pre-verification",
feature = "simd256"
))]
c.bench_function("libcrux unpacked avx2", |b| {
let mut seed1 = [0; 64];
OsRng.fill_bytes(&mut seed1);
let mut seed2 = [0; 32];
OsRng.fill_bytes(&mut seed2);
b.iter_batched(
|| {
let keypair = mlkem768::avx2::generate_key_pair_unpacked(seed1);
let (ciphertext, _shared_secret) =
mlkem768::avx2::encapsulate_unpacked(&keypair.public_key, seed2);
(keypair, ciphertext)
},
|(keypair, ciphertext)| {
let _shared_secret = mlkem768::avx2::decapsulate_unpacked(&keypair, &ciphertext);
},
BatchSize::SmallInput,
)
});
}

pub fn comparisons(c: &mut Criterion) {
Expand Down
103 changes: 103 additions & 0 deletions libcrux-ml-kem/c/benches/mlkem768.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,57 @@ BENCHMARK(kyber768_key_generation);
BENCHMARK(kyber768_encapsulation);
BENCHMARK(kyber768_decapsulation);

static void
kyber768_key_generation_unpacked(benchmark::State &state)
{
uint8_t randomness[64];
generate_random(randomness, 64);
auto key_pair = libcrux_ml_kem_mlkem768_portable_generate_key_pair_unpacked(randomness);

for (auto _ : state)
{
key_pair = libcrux_ml_kem_mlkem768_portable_generate_key_pair_unpacked(randomness);
}
}

static void
kyber768_encapsulation_unpacked(benchmark::State &state)
{
uint8_t randomness[64];
generate_random(randomness, 64);

auto key_pair = libcrux_ml_kem_mlkem768_portable_generate_key_pair_unpacked(randomness);
generate_random(randomness, 32);
auto ctxt = libcrux_ml_kem_mlkem768_portable_encapsulate_unpacked(&key_pair.public_key, randomness);

for (auto _ : state)
{
ctxt = libcrux_ml_kem_mlkem768_portable_encapsulate_unpacked(&key_pair.public_key, randomness);
}
}

static void
kyber768_decapsulation_unpacked(benchmark::State &state)
{
uint8_t randomness[64];
generate_random(randomness, 64);

auto key_pair = libcrux_ml_kem_mlkem768_portable_generate_key_pair_unpacked(randomness);
generate_random(randomness, 32);
auto ctxt = libcrux_ml_kem_mlkem768_portable_encapsulate_unpacked(&key_pair.public_key, randomness);

uint8_t sharedSecret2[LIBCRUX_ML_KEM_CONSTANTS_SHARED_SECRET_SIZE];

for (auto _ : state)
{
libcrux_ml_kem_mlkem768_portable_decapsulate_unpacked(&key_pair, &ctxt.fst, sharedSecret2);
}
}

BENCHMARK(kyber768_key_generation_unpacked);
BENCHMARK(kyber768_encapsulation_unpacked);
BENCHMARK(kyber768_decapsulation_unpacked);

#ifdef LIBCRUX_AARCH64
#include "libcrux_mlkem768_neon.h"

Expand Down Expand Up @@ -177,6 +228,58 @@ kyber768_decapsulation_avx2(benchmark::State &state)
BENCHMARK(kyber768_key_generation_avx2);
BENCHMARK(kyber768_encapsulation_avx2);
BENCHMARK(kyber768_decapsulation_avx2);

static void
kyber768_key_generation_avx2_unpacked(benchmark::State &state)
{
uint8_t randomness[64];
generate_random(randomness, 64);
auto key_pair = libcrux_ml_kem_mlkem768_avx2_generate_key_pair_unpacked(randomness);

for (auto _ : state)
{
key_pair = libcrux_ml_kem_mlkem768_avx2_generate_key_pair_unpacked(randomness);
}
}

static void
kyber768_encapsulation_avx2_unpacked(benchmark::State &state)
{
uint8_t randomness[64];
generate_random(randomness, 64);

auto key_pair = libcrux_ml_kem_mlkem768_avx2_generate_key_pair_unpacked(randomness);
generate_random(randomness, 32);
auto ctxt = libcrux_ml_kem_mlkem768_avx2_encapsulate_unpacked(&key_pair.public_key, randomness);

for (auto _ : state)
{
ctxt = libcrux_ml_kem_mlkem768_avx2_encapsulate_unpacked(&key_pair.public_key, randomness);
}
}

static void
kyber768_decapsulation_avx2_unpacked(benchmark::State &state)
{
uint8_t randomness[64];
generate_random(randomness, 64);

auto key_pair = libcrux_ml_kem_mlkem768_avx2_generate_key_pair_unpacked(randomness);
generate_random(randomness, 32);
auto ctxt = libcrux_ml_kem_mlkem768_avx2_encapsulate_unpacked(&key_pair.public_key, randomness);

uint8_t sharedSecret2[LIBCRUX_ML_KEM_CONSTANTS_SHARED_SECRET_SIZE];

for (auto _ : state)
{
libcrux_ml_kem_mlkem768_avx2_decapsulate_unpacked(&key_pair, &ctxt.fst, sharedSecret2);
}
}

BENCHMARK(kyber768_key_generation_avx2_unpacked);
BENCHMARK(kyber768_encapsulation_avx2_unpacked);
BENCHMARK(kyber768_decapsulation_avx2_unpacked);

#endif

#ifdef LIBCRUX_SYMCRYPT
Expand Down
6 changes: 3 additions & 3 deletions libcrux-ml-kem/c/code_gen.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
This code was generated with the following tools:
Charon: 4bc2a90d4dab2efeb7f6db3fb61f850440d1b9e8
Charon: aeeae1d46704810bf498db552a75dff15aa3abcc
Eurydice: ffeb01ce4cf0646e5cadec836bc042f98b8a16a8
Karamel: 285552497829dd57fc019f946dce21c70ab35a0b
F*: a32b316e521fa4f239b610ec8f1d15e78d62cbe8-dirty
Karamel: 42a431696cd32d41155d7e484720eb71fd5dc7b1
F*: f09228ef9a64ac4ef383ee0e10656ccb612db2ee
25 changes: 19 additions & 6 deletions libcrux-ml-kem/c/internal/libcrux_core.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
/*
This file was generated by KaRaMeL <https://github.com/FStarLang/karamel>
KaRaMeL invocation: /home/franziskus/eurydice//eurydice --config ../c.yaml
KaRaMeL invocation: /home/karthik/eurydice/eurydice --config ../c.yaml
-funroll-loops 16 ../../libcrux_ml_kem.llbc ../../libcrux_sha3.llbc F*
version: <unknown> KaRaMeL version: 28555249
version: f09228ef KaRaMeL version: 42a43169
*/

#ifndef __internal_libcrux_core_H
Expand Down Expand Up @@ -136,6 +136,19 @@ libcrux_ml_kem_types__libcrux_ml_kem__types__MlKemPublicKey_SIZE__18__as_slice__
void libcrux_ml_kem_utils_into_padded_array___33size_t(Eurydice_slice slice,
uint8_t ret[33U]);

typedef struct
core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError_s {
core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError_tags tag;
union {
uint8_t case_Ok[32U];
core_array_TryFromSliceError case_Err;
} val;
} core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError;

void core_result__core__result__Result_T__E___unwrap__uint8_t_32size_t__core_array_TryFromSliceError(
core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError self,
uint8_t ret[32U]);

void libcrux_ml_kem_utils_into_padded_array___34size_t(Eurydice_slice slice,
uint8_t ret[34U]);

Expand All @@ -151,7 +164,7 @@ void libcrux_ml_kem_utils_into_padded_array___64size_t(Eurydice_slice slice,

typedef struct
core_result_Result__uint8_t_24size_t__core_array_TryFromSliceError_s {
core_result_Result__uint8_t_24size_t__core_array_TryFromSliceError_tags tag;
core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError_tags tag;
union {
uint8_t case_Ok[24U];
core_array_TryFromSliceError case_Err;
Expand All @@ -164,7 +177,7 @@ void core_result__core__result__Result_T__E___unwrap__uint8_t_24size_t__core_arr

typedef struct
core_result_Result__uint8_t_20size_t__core_array_TryFromSliceError_s {
core_result_Result__uint8_t_24size_t__core_array_TryFromSliceError_tags tag;
core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError_tags tag;
union {
uint8_t case_Ok[20U];
core_array_TryFromSliceError case_Err;
Expand All @@ -177,7 +190,7 @@ void core_result__core__result__Result_T__E___unwrap__uint8_t_20size_t__core_arr

typedef struct
core_result_Result__uint8_t_10size_t__core_array_TryFromSliceError_s {
core_result_Result__uint8_t_24size_t__core_array_TryFromSliceError_tags tag;
core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError_tags tag;
union {
uint8_t case_Ok[10U];
core_array_TryFromSliceError case_Err;
Expand All @@ -190,7 +203,7 @@ void core_result__core__result__Result_T__E___unwrap__uint8_t_10size_t__core_arr

typedef struct
core_result_Result__int16_t_16size_t__core_array_TryFromSliceError_s {
core_result_Result__uint8_t_24size_t__core_array_TryFromSliceError_tags tag;
core_result_Result__uint8_t_32size_t__core_array_TryFromSliceError_tags tag;
union {
int16_t case_Ok[16U];
core_array_TryFromSliceError case_Err;
Expand Down
Loading

0 comments on commit b345add

Please sign in to comment.