Skip to content

Commit

Permalink
Merge pull request #55 from japaric/rm-rwlock
Browse files Browse the repository at this point in the history
rm `RwLock` from `Hpke` and no-std-ify the `hpke-rs` library
  • Loading branch information
franziskuskiefer authored Nov 29, 2023
2 parents f2a5b1a + fcb8a45 commit 1d9baa4
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 48 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ repository = "https://github.com/franziskuskiefer/hpke-rs"

[dependencies]
log = "0.4"
serde_json = { version = "1.0", optional = true }
serde = { version = "1.0", features = ["derive"], optional = true }
tls_codec = { version = "0.4.0", features = ["derive"], optional = true }
zeroize = { version = "1.5", features = ["zeroize_derive"] }
hpke-rs-crypto = { version = "0.1.3", path = "./traits" }

[features]
default = []
serialization = ["serde", "serde_json", "tls_codec", "tls_codec/serde"]
std = []
serialization = ["serde", "tls_codec", "tls_codec/serde", "std"]
hazmat = []
hpke-test = []
hpke-test = ["std"]
hpke-test-prng = [] # ⚠️ Enable testing PRNG - DO NOT USE

[dev-dependencies]
Expand Down
12 changes: 6 additions & 6 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn benchmark<Crypto: HpkeCrypto + 'static>(c: &mut Criterion) {
if Crypto::supports_kem(kem_mode).is_err() {
continue;
}
let hpke = Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let mut hpke = Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let label = format!("{} {}", Crypto::name(), hpke);
let kp = hpke.generate_key_pair().unwrap();
let enc = kp.public_key().as_slice();
Expand Down Expand Up @@ -83,7 +83,7 @@ fn benchmark<Crypto: HpkeCrypto + 'static>(c: &mut Criterion) {
let mut group = c.benchmark_group(format!("{}", label));
group.bench_function("Setup Sender", |b| {
b.iter(|| {
let hpke =
let mut hpke =
Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
hpke.setup_sender(
&pk_rm,
Expand Down Expand Up @@ -114,7 +114,7 @@ fn benchmark<Crypto: HpkeCrypto + 'static>(c: &mut Criterion) {
group.bench_function(&format!("Seal {}({})", AEAD_PAYLOAD, AEAD_AAD), |b| {
b.iter_batched(
|| {
let hpke =
let mut hpke =
Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let (_enc, context) = hpke
.setup_sender(
Expand All @@ -140,7 +140,7 @@ fn benchmark<Crypto: HpkeCrypto + 'static>(c: &mut Criterion) {
group.bench_function(&format!("Open {}({})", AEAD_PAYLOAD, AEAD_AAD), |b| {
b.iter_batched(
|| {
let hpke =
let mut hpke =
Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let (enc, mut sender_context) = hpke
.setup_sender(
Expand Down Expand Up @@ -190,7 +190,7 @@ fn benchmark<Crypto: HpkeCrypto + 'static>(c: &mut Criterion) {
OsRng.fill_bytes(&mut ptxt);
(hpke, aad, ptxt)
},
|(hpke, aad, ptxt)| {
|(mut hpke, aad, ptxt)| {
let _ctxt = hpke
.seal(
&pk_rm,
Expand All @@ -212,7 +212,7 @@ fn benchmark<Crypto: HpkeCrypto + 'static>(c: &mut Criterion) {
|b| {
b.iter_batched(
|| {
let hpke = Hpke::<Crypto>::new(
let mut hpke = Hpke::<Crypto>::new(
hpke_mode, kem_mode, kdf_mode, aead_mode,
);
let (enc, mut sender_context) = hpke
Expand Down
5 changes: 3 additions & 2 deletions benches/manual_benches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn benchmark<Crypto: HpkeCrypto + 'static>() {
if Crypto::supports_kem(kem_mode).is_err() {
continue;
}
let hpke = Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let mut hpke = Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let label = format!(
"{} {} {} {} {}",
Crypto::name(),
Expand Down Expand Up @@ -98,7 +98,8 @@ fn benchmark<Crypto: HpkeCrypto + 'static>() {

let start = Instant::now();
for _ in 0..ITERATIONS {
let hpke = Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let mut hpke =
Hpke::<Crypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);
let _sender = hpke
.setup_sender(
&pk_rm,
Expand Down
2 changes: 1 addition & 1 deletion fuzz/fuzz_targets/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use hpke_rs::prelude::*;
use hpke_rs_crypto::types::*;

fuzz_target!(|data: &[u8]| {
let hpke = Hpke::<hpke_rs_rust_crypto::HpkeRustCrypto>::new(
let mut hpke = Hpke::<hpke_rs_rust_crypto::HpkeRustCrypto>::new(
HpkeMode::Base,
KemAlgorithm::DhKemP256,
KdfAlgorithm::HkdfSha256,
Expand Down
2 changes: 2 additions & 0 deletions src/dh_kem.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! DH KEM as described in §4.1. DH-Based KEM.
use alloc::{string::ToString, vec::Vec};

use hpke_rs_crypto::{error::Error, types::KemAlgorithm, HpkeCrypto};

use crate::util::*;
Expand Down
2 changes: 2 additions & 0 deletions src/kdf.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use alloc::vec::Vec;

use hpke_rs_crypto::{error::Error, types::KdfAlgorithm, HpkeCrypto};

use crate::util::concat;
Expand Down
2 changes: 2 additions & 0 deletions src/kem.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use alloc::vec::Vec;

use hpke_rs_crypto::{error::Error, types::KemAlgorithm, HpkeCrypto};

use crate::dh_kem;
Expand Down
72 changes: 40 additions & 32 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@
unused_extern_crates,
unused_qualifications
)]
#![cfg_attr(not(test), no_std)]

use std::sync::RwLock;
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;

use alloc::{
format,
string::{String, ToString},
vec,
vec::Vec,
};

#[cfg(feature = "hpke-test-prng")]
use hpke_rs_crypto::HpkeTestRng;
Expand Down Expand Up @@ -79,15 +89,13 @@ pub enum HpkeError {

/// Unable to collect enough randomness.
InsufficientRandomness,

/// A concurrency issue with an [`RwLock`].
LockPoisoned,
}

#[cfg(feature = "std")]
impl std::error::Error for HpkeError {}

impl std::fmt::Display for HpkeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl core::fmt::Display for HpkeError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "HPKE Error: {:?}", self)
}
}
Expand Down Expand Up @@ -159,8 +167,8 @@ pub enum Mode {
AuthPsk = 0x03,
}

impl std::fmt::Display for Mode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
impl core::fmt::Display for Mode {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "{:?}", self)
}
}
Expand Down Expand Up @@ -202,8 +210,8 @@ pub struct Context<Crypto: 'static + HpkeCrypto> {
}

#[cfg(feature = "hazmat")]
impl<Crypto: HpkeCrypto> std::fmt::Debug for Context<Crypto> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl<Crypto: HpkeCrypto> core::fmt::Debug for Context<Crypto> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Context {{\n key: {:?}\n nonce: {:?}\n exporter_secret: {:?}\n seq no: {:?}\n}}",
Expand All @@ -213,8 +221,8 @@ impl<Crypto: HpkeCrypto> std::fmt::Debug for Context<Crypto> {
}

#[cfg(not(feature = "hazmat"))]
impl<Crypto: HpkeCrypto> std::fmt::Debug for Context<Crypto> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
impl<Crypto: HpkeCrypto> core::fmt::Debug for Context<Crypto> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"Context {{\n key: {:?}\n nonce: {:?}\n exporter_secret: {:?}\n seq no: {:?}\n}}",
Expand Down Expand Up @@ -331,7 +339,7 @@ pub struct Hpke<Crypto: 'static + HpkeCrypto> {
kem_id: KemAlgorithm,
kdf_id: KdfAlgorithm,
aead_id: AeadAlgorithm,
prng: RwLock<Crypto::HpkePrng>,
prng: Crypto::HpkePrng,
}

impl<Crypto: 'static + HpkeCrypto> Clone for Hpke<Crypto> {
Expand All @@ -341,13 +349,13 @@ impl<Crypto: 'static + HpkeCrypto> Clone for Hpke<Crypto> {
kem_id: self.kem_id,
kdf_id: self.kdf_id,
aead_id: self.aead_id,
prng: RwLock::new(Crypto::prng()),
prng: Crypto::prng(),
}
}
}

impl<Crypto: HpkeCrypto> std::fmt::Display for Hpke<Crypto> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
impl<Crypto: HpkeCrypto> core::fmt::Display for Hpke<Crypto> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(
f,
"{}_{}_{}_{}",
Expand All @@ -372,7 +380,7 @@ impl<Crypto: HpkeCrypto> Hpke<Crypto> {
kem_id,
kdf_id,
aead_id,
prng: RwLock::new(Crypto::prng()),
prng: Crypto::prng(),
}
}

Expand All @@ -391,7 +399,7 @@ impl<Crypto: HpkeCrypto> Hpke<Crypto> {
/// The encapsulated secret is returned together with the context.
/// If the secret key is missing in an authenticated mode, an error is returned.
pub fn setup_sender(
&self,
&mut self,
pk_r: &HpkePublicKey,
info: &[u8],
psk: Option<&[u8]>,
Expand Down Expand Up @@ -478,7 +486,7 @@ impl<Crypto: HpkeCrypto> Hpke<Crypto> {
/// Returns the encapsulated secret and the ciphertext, or an error.
#[allow(clippy::too_many_arguments)]
pub fn seal(
&self,
&mut self,
pk_r: &HpkePublicKey,
info: &[u8],
aad: &[u8],
Expand Down Expand Up @@ -534,7 +542,7 @@ impl<Crypto: HpkeCrypto> Hpke<Crypto> {
/// exporter context and length.
#[allow(clippy::too_many_arguments)]
pub fn send_export(
&self,
&mut self,
pk_r: &HpkePublicKey,
info: &[u8],
psk: Option<&[u8]>,
Expand Down Expand Up @@ -674,9 +682,8 @@ impl<Crypto: HpkeCrypto> Hpke<Crypto> {
/// This is equivalent to `derive_key_pair(random_vector(sk.len()))`
///
/// Returns an `HpkeKeyPair`.
pub fn generate_key_pair(&self) -> Result<HpkeKeyPair, HpkeError> {
let mut prng = self.prng.write().map_err(|_| HpkeError::LockPoisoned)?;
let (sk, pk) = kem::key_gen::<Crypto>(self.kem_id, &mut prng)?;
pub fn generate_key_pair(&mut self) -> Result<HpkeKeyPair, HpkeError> {
let (sk, pk) = kem::key_gen::<Crypto>(self.kem_id, &mut self.prng)?;
Ok(HpkeKeyPair::new(sk, pk))
}

Expand All @@ -690,8 +697,8 @@ impl<Crypto: HpkeCrypto> Hpke<Crypto> {
}

#[inline]
pub(crate) fn random(&self, len: usize) -> Result<Vec<u8>, HpkeError> {
let mut prng = self.prng.write().map_err(|_| HpkeError::LockPoisoned)?;
pub(crate) fn random(&mut self, len: usize) -> Result<Vec<u8>, HpkeError> {
let prng = &mut self.prng;
let mut out = vec![0u8; len];

#[cfg(feature = "hpke-test-prng")]
Expand Down Expand Up @@ -794,17 +801,17 @@ impl PartialEq for HpkePrivateKey {
}

#[cfg(not(feature = "hazmat"))]
impl std::fmt::Debug for HpkePrivateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
impl core::fmt::Debug for HpkePrivateKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
f.debug_struct("HpkePrivateKey")
.field("value", &"***")
.finish()
}
}

#[cfg(feature = "hazmat")]
impl std::fmt::Debug for HpkePrivateKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
impl core::fmt::Debug for HpkePrivateKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> Result<(), core::fmt::Error> {
f.debug_struct("HpkePrivateKey")
.field("value", &self.value)
.finish()
Expand Down Expand Up @@ -891,14 +898,15 @@ impl tls_codec::Deserialize for &HpkePublicKey {
/// Test util module. Should be moved really.
#[cfg(feature = "hpke-test")]
pub mod test_util {
use alloc::{format, string::String, vec, vec::Vec};

use crate::HpkeError;
use hpke_rs_crypto::{HpkeCrypto, HpkeTestRng};

impl<Crypto: HpkeCrypto> super::Hpke<Crypto> {
/// Set PRNG state for testing.
pub fn seed(&self, seed: &[u8]) -> Result<(), HpkeError> {
let mut prng = self.prng.write().map_err(|_| HpkeError::LockPoisoned)?;
prng.seed(seed);
pub fn seed(&mut self, seed: &[u8]) -> Result<(), HpkeError> {
self.prng.seed(seed);
Ok(())
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
//! Include this to get access to all the public functions of HPKE.
pub use super::{Mode as HpkeMode, *};
pub use std::convert::TryFrom;
pub use core::convert::TryFrom;
2 changes: 2 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use alloc::vec::Vec;

#[inline]
pub(crate) fn concat(values: &[&[u8]]) -> Vec<u8> {
values.join(&[][..])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hpke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ macro_rules! generate_test_case {
($name:ident, $hpke_mode:expr, $kem_mode:expr, $kdf_mode:expr, $aead_mode:expr, $provider:ident) => {
#[test]
fn $name() {
let hpke = Hpke::<$provider>::new($hpke_mode, $kem_mode, $kdf_mode, $aead_mode);
let mut hpke = Hpke::<$provider>::new($hpke_mode, $kem_mode, $kdf_mode, $aead_mode);
println!("Self test {}", hpke);

// Self test seal and open with random keys.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hpke_kat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ fn kat<Crypto: HpkeCrypto + 'static>(tests: Vec<HpkeTestVector>) {
#[cfg(feature = "hpke-test-prng")]
{
log::trace!("Testing with known ikmE ...");
let hpke_sender = Hpke::<Crypto>::new(mode, kem_id, kdf_id, aead_id);
let mut hpke_sender = Hpke::<Crypto>::new(mode, kem_id, kdf_id, aead_id);
// This only works when seeding the PRNG with ikmE.
hpke_sender.seed(&ikm_e).expect("Error injecting ikm_e");
let (enc, _sender_context_kat) = hpke_sender
Expand Down Expand Up @@ -298,7 +298,7 @@ fn test_serialization() {
for &kem_mode in &[0x10u16, 0x20] {
let kem_mode = KemAlgorithm::try_from(kem_mode).unwrap();

let hpke =
let mut hpke =
Hpke::<HpkeRustCrypto>::new(hpke_mode, kem_mode, kdf_mode, aead_mode);

// JSON: Public, Private, KeyPair
Expand Down

0 comments on commit 1d9baa4

Please sign in to comment.