diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2eb0aa47fa..1d907ae157 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -67,3 +67,18 @@ jobs: - name: Build if: ${{ matrix.arch != 'wasm32-unknown-unknown' }} run: cargo build $TEST_MODE --verbose --target ${{ matrix.arch }} -p openmls + + # Check feature powerset + check: + strategy: + fail-fast: false + + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - uses: taiki-e/install-action@cargo-hack + + - name: Cargo hack + run: cargo hack check --feature-powerset --no-dev-deps --verbose -p openmls diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index b86fea2f88..f373a02a63 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -23,6 +23,7 @@ jobs: CARGO_INCREMENTAL: '0' RUSTFLAGS: '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Cpanic=abort -Zpanic_abort_tests' RUSTDOCFLAGS: '-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Clink-dead-code -Coverflow-checks=off -Cpanic=abort -Zpanic_abort_tests' + LIBCRUX_DISABLE_SIMD256: '1' - name: Run grcov id: coverage diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml index 627d9275c3..857cf8e0c7 100644 --- a/.github/workflows/gh-pages.yml +++ b/.github/workflows/gh-pages.yml @@ -19,7 +19,7 @@ jobs: - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - name: Setup mdBook - uses: peaceiris/actions-mdbook@v1 + uses: peaceiris/actions-mdbook@v2 with: mdbook-version: 'latest' - name: Build docs @@ -43,7 +43,7 @@ jobs: EOF - name: Deploy docs to GitHub Pages - uses: peaceiris/actions-gh-pages@v3 + uses: peaceiris/actions-gh-pages@v4 with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: temp_docs diff --git a/CHANGELOG.md b/CHANGELOG.md index 397a52fb42..5c68edd5b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - [#1479](https://github.com/openmls/openmls/pull/1479): Allow the use of extensions with `ExtensionType::Unknown` in group context, key packages and leaf nodes - [#1488](https://github.com/openmls/openmls/pull/1488): Allow unknown credentials. Credentials other than the basic credential or X.509 may be used now as long as they are encoded as variable-sized vectors. - [#1527](https://github.com/openmls/openmls/pull/1527): CredentialType::Unknown is now called CredentialType::Other. +- [#1543](https://github.com/openmls/openmls/pull/1543): PreSharedKeyId.write_to_key_store() no longer requires the cipher suite. +- [#1546](https://github.com/openmls/openmls/pull/1546): Add experimental ciphersuite based on the PQ-secure XWing hybrid KEM (currently supported only by the libcrux crypto provider). +- [#1548](https://github.com/openmls/openmls/pull/1548): CryptoConfig is now replaced by just Ciphersuite. +- [#1542](https://github.com/openmls/openmls/pull/1542): Add support for custom proposals. ProposalType::Unknown is now called ProposalType::Other. Proposal::Unknown is now called Proposal::Other. +- [#1559](https://github.com/openmls/openmls/pull/1559): Remove the `PartialEq` type constraint on the error type of both the `OpenMlsRand` and `OpenMlsKeyStore` traits. Additionally, remove the `Clone` type constraint on the error type of the `OpenMlsRand` trait. ### Fixed diff --git a/Cargo.toml b/Cargo.toml index 86692820d2..04e824d33e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,11 +7,12 @@ members = [ "fuzz", "cli", "interop_client", - "memory_keystore", + "memory_storage", "delivery-service/ds", "delivery-service/ds-lib", "basic_credential", "openmls-wasm", + "openmls_test", ] resolver = "2" diff --git a/basic_credential/src/lib.rs b/basic_credential/src/lib.rs index 24e6be79d0..fe8ea30c26 100644 --- a/basic_credential/src/lib.rs +++ b/basic_credential/src/lib.rs @@ -7,14 +7,15 @@ use std::fmt::Debug; use openmls_traits::{ - key_store::{MlsEntity, MlsEntityId, OpenMlsKeyStore}, signatures::{Signer, SignerError}, + storage::{self, StorageProvider, CURRENT_VERSION}, types::{CryptoError, SignatureScheme}, }; use p256::ecdsa::{signature::Signer as P256Signer, Signature, SigningKey}; use rand::rngs::OsRng; +use serde::{Deserialize, Serialize}; use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize}; /// A signature key pair for the basic credential. @@ -75,10 +76,6 @@ fn id(public_key: &[u8], signature_scheme: SignatureScheme) -> Vec { id } -impl MlsEntity for SignatureKeyPair { - const ID: MlsEntityId = MlsEntityId::SignatureKeyPair; -} - impl SignatureKeyPair { /// Generates a fresh signature keypair using the [`SignatureScheme`]. pub fn new(signature_scheme: SignatureScheme) -> Result { @@ -112,25 +109,32 @@ impl SignatureKeyPair { } } - fn id(&self) -> Vec { - id(&self.public, self.signature_scheme) + fn id(&self) -> StorageId { + StorageId { + value: id(&self.public, self.signature_scheme), + } } /// Store this signature key pair in the key store. - pub fn store(&self, key_store: &T) -> Result<(), ::Error> + pub fn store(&self, store: &T) -> Result<(), T::Error> where - T: OpenMlsKeyStore, + T: StorageProvider, { - key_store.store(&self.id(), self) + store.write_signature_key_pair(&self.id(), self) } /// Read a signature key pair from the key store. pub fn read( - key_store: &impl OpenMlsKeyStore, + store: &impl StorageProvider, public_key: &[u8], signature_scheme: SignatureScheme, ) -> Option { - key_store.read(&id(public_key, signature_scheme)) + store + .signature_key_pair(&StorageId { + value: id(public_key, signature_scheme), + }) + .ok() + .flatten() } /// Get the public key as byte slice. @@ -153,3 +157,24 @@ impl SignatureKeyPair { &self.private } } + +// Storage + +#[derive(Debug, Serialize, Deserialize)] +pub struct StorageId { + value: Vec, +} + +impl From> for StorageId { + fn from(vec: Vec) -> Self { + StorageId { value: vec } + } +} + +// Implement key traits for the storage id +impl storage::Key for StorageId {} +impl storage::traits::SignaturePublicKey for StorageId {} + +// Implement entity trait for the signature key pair +impl storage::Entity for SignatureKeyPair {} +impl storage::traits::SignatureKeyPair for SignatureKeyPair {} diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 9c21e147aa..3763cdf5ab 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -12,6 +12,7 @@ - [Removing members from a group](user_manual/remove_members.md) - [Updating own key package](user_manual/updates.md) - [Leaving a group](user_manual/leaving.md) + - [Custom proposals](user_manual/custom_proposals.md) - [Creating application messages](user_manual/application_messages.md) - [Committing to pending proposals](user_manual/commit_to_proposals.md) - [Processing incoming messages](user_manual/processing.md) diff --git a/book/src/message_validation.md b/book/src/message_validation.md index 1790ee3f07..b264c6f480 100644 --- a/book/src/message_validation.md +++ b/book/src/message_validation.md @@ -67,6 +67,7 @@ The following is a list of the individual semantic validation steps performed by | `ValSem110` | Update Proposal: Encryption key must be unique among proposals & members | ✅ | ✅ | `openmls/src/group/tests/test_proposal_validation.rs` | | `ValSem111` | Update Proposal: The sender of a full Commit must not include own update proposals | ✅ | ✅ | `openmls/src/group/tests/test_proposal_validation.rs` | | `ValSem112` | Update Proposal: The sender of a standalone update proposal must be of type member | ✅ | ✅ | `openmls/src/group/tests/test_proposal_validation.rs` | +| `ValSem113` | All Proposals: The proposal type must be supported by all members of the group | ✅ | ✅ | `openmls/src/group/tests/test_proposal_validation.rs` | ### Commit message validation diff --git a/book/src/traits/README.md b/book/src/traits/README.md index 95d2eb01af..50d02ebb77 100644 --- a/book/src/traits/README.md +++ b/book/src/traits/README.md @@ -10,31 +10,30 @@ enclaves. - [Traits](./traits.md) - [External Types](./types.md) -## Using the key store +## Using storage -The key store is probably one of the most interesting traits because applications +The store is probably one of the most interesting traits because applications that use OpenMLS will interact with it. -See the [OpenMlsKeyStore trait](./traits.md#openmlskeystore) description for details -but note that the key used to store, read, and delete values in the key store has -to be provided as a byte slice. +See the [StorageProvider trait](./traits.md#storageprovider) description for details. -In the following examples, we have a `ciphersuite` and a `provider` (`OpenMlsCryptoProvider`). +In the following examples, we have a `ciphersuite` and a `provider` (`OpenMlsProvider`). ```rust,no_run,noplayground -{{#include ../../../openmls/tests/key_store.rs:key_store_store}} +{{#include ../../../openmls/tests/store.rs:store_store}} ``` -The `delete` is called with the identifier to delete a value. +Retrieving a value from the store is as simple as calling `read`. +The retrieved key package bundles the private keys for the init and encryption keys +as well. ```rust,no_run,noplayground -{{#include ../../../openmls/tests/key_store.rs:key_store_delete}} +{{#include ../../../openmls/tests/store.rs:store_read}} ``` -Retrieving a value from the key store is as simple as calling `read`. -In this example, we assume we got a `credential` where we want to retrieve the credential bundle, i.e., the private key material. +The `delete` is called with the identifier to delete a value. ```rust,no_run,noplayground -{{#include ../../../openmls/tests/key_store.rs:key_store_read}} +{{#include ../../../openmls/tests/store.rs:store_delete}} ``` [//]: # "links" diff --git a/book/src/traits/traits.md b/book/src/traits/traits.md index 13366cc5d5..deba322fb1 100644 --- a/book/src/traits/traits.md +++ b/book/src/traits/traits.md @@ -52,21 +52,80 @@ This trait defines all cryptographic functions required by OpenMLS. In particula {{#include ../../../traits/src/crypto.rs:10}} ``` -### OpenMlsKeyStore +### StorageProvider -This trait defines a CRUD API for a key store that is used to store long-term -key material from OpenMLS. +This trait defines an API for a storage backend that is used for all OpenMLS +persistence. -The key store provides functions to `store`, `read`, and `delete` values. +The store provides functions to `store`, `read`, and `delete` values. Note that it does not allow updating values. Instead, entries must be deleted and newly stored. ```rust,no_run,noplayground -{{#include ../../../traits/src/key_store.rs:15:40}} +{{#include ../../../traits/src/storage.rs:16:25}} ``` -**NOTE:** Right now, key material must be extracted from the key store. -This will most likely change in the future. +The trait is generic over a `VERSION`, which is used to ensure that the values +that are persisted can be upgraded when OpenMLS changes the stored structs. + +Every function takes `Key` and `Value` arguments. + +```rust,no_run,noplayground +{{#include ../../../traits/src/storage.rs:key_trait}} +``` + +```rust,no_run,noplayground +{{#include ../../../traits/src/storage.rs:entity_trait}} +``` + +To ensure that each function takes the correct input, they use trait bounds. +These are the available traits. + +```rust,no_run,noplayground +{{#include ../../../traits/src/storage.rs:traits}} +``` + +An implementation of the storage trait should ensure that it can address and +efficiently handle values. + +#### Example: Key packages + +This is only an example, but it illustrates that the application may need to do more +when it comes to implementing storage. + +Key packages are only deleted by OpenMLS when they are used and _not_ last resort +key packages (which may be used multiple times). +The application needs to implement some logic to manage last resort key packages. + +```rust,no_run,noplayground +{{#include ../../../traits/src/storage.rs:write_key_package}} +``` + +The application may store the hash references in a separate list with a validity +period. + +```rust,ro_run,noplayground +fn write_key_package< + HashReference: traits::HashReference, + KeyPackage: traits::KeyPackage, +>( + &self, + hash_ref: &HashReference, + key_package: &KeyPackage, +) -> Result<(), Self::Error> { + // Get the validity from the application in some way. + let validity = self.get_validity(hash_ref); + + // Store the reference and its validity period. + self.store_hash_ref(hash_ref, validity); + + // Store the actual key package. + self.store_key_package(hash_ref, key_package); +} +``` + +This allows the application to iterate over the hash references and delete outdated +key packages. ### OpenMlsCryptoProvider diff --git a/book/src/user_manual/custom_proposals.md b/book/src/user_manual/custom_proposals.md new file mode 100644 index 0000000000..4611ca7dba --- /dev/null +++ b/book/src/user_manual/custom_proposals.md @@ -0,0 +1,15 @@ +# Custom proposals + +OpenMLS allows the creation and use of application-defined proposals. To create such a proposal, the application needs to define a Proposal Type in such a way that its value doesn't collide with any Proposal Types defined in Section 17.4. of RFC 9420. If the proposal is meant to be used only inside of a particular application, the value of the Proposal Type is recommended to be in the range between `0xF000` and `0xFFFF`, as that range is reserved for private use. + +Custom proposals can contain arbitrary octet-strings as defined by the application. Any policy decisions based on custom proposals will have to be made by the application, such as the decision to include a given custom proposal in a commit, or whether to accept a commit that includes one or more custom proposals. To decide the latter, applications can inspect the queued proposals in a `ProcessedMessageContent::StagedCommitMesage(staged_commit)`. + +Example on how to use custom proposals: + +```rust,no_run,noplayground +{{#include ../../../openmls/tests/book_code.rs:custom_proposal_type}} +``` + +```rust,no_run,noplayground +{{#include ../../../openmls/tests/book_code.rs:custom_proposal_usage}} +``` diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 376c670de9..86a35804d8 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -15,7 +15,9 @@ openmls = { path = "../openmls", features = ["test-utils"] } ds-lib = { path = "../delivery-service/ds-lib" } openmls_traits = { path = "../traits" } openmls_rust_crypto = { path = "../openmls_rust_crypto" } -openmls_memory_keystore = { path = "../memory_keystore" } +openmls_memory_storage = { path = "../memory_storage", features = [ + "persistence", +] } openmls_basic_credential = { path = "../basic_credential" } serde = { version = "^1.0" } thiserror = "1.0" diff --git a/cli/src/backend.rs b/cli/src/backend.rs index 10c0d2d3d3..170e114fb8 100644 --- a/cli/src/backend.rs +++ b/cli/src/backend.rs @@ -1,12 +1,20 @@ use tls_codec::{Deserialize, TlsVecU16, TlsVecU32}; use url::Url; +use crate::networking::get_with_body; + use super::{ networking::{get, post}, user::User, }; -use ds_lib::*; +use ds_lib::{ + messages::{ + AuthToken, PublishKeyPackagesRequest, RecvMessageRequest, RegisterClientRequest, + RegisterClientSuccessResponse, + }, + *, +}; use openmls::prelude::*; pub struct Backend { @@ -15,30 +23,37 @@ pub struct Backend { impl Backend { /// Register a new client with the server. - pub fn register_client(&self, user: &User) -> Result { + pub fn register_client( + &self, + key_packages: Vec<(Vec, KeyPackage)>, + ) -> Result { let mut url = self.ds_url.clone(); url.set_path("/clients/register"); - let client_info = ClientInfo::new( - user.username.clone(), - user.key_packages() + let key_packages = ClientKeyPackages( + key_packages .into_iter() - .map(|(b, kp)| (b, KeyPackageIn::from(kp))) - .collect(), + .map(|(b, kp)| (b.into(), KeyPackageIn::from(kp))) + .collect::>() + .into(), ); - let response = post(&url, &client_info)?; + let request = RegisterClientRequest { key_packages }; + let response_bytes = post(&url, &request)?; + let response = + RegisterClientSuccessResponse::tls_deserialize(&mut response_bytes.as_slice()) + .map_err(|e| format!("Error decoding server response: {e:?}"))?; - Ok(String::from_utf8(response).unwrap()) + Ok(response.auth_token) } /// Get a list of all clients with name, ID, and key packages from the /// server. - pub fn list_clients(&self) -> Result, String> { + pub fn list_clients(&self) -> Result>, String> { let mut url = self.ds_url.clone(); url.set_path("/clients/list"); let response = get(&url)?; - match TlsVecU32::::tls_deserialize(&mut response.as_slice()) { + match TlsVecU32::>::tls_deserialize(&mut response.as_slice()) { Ok(clients) => Ok(clients.into()), Err(e) => Err(format!("Error decoding server response: {e:?}")), } @@ -59,14 +74,22 @@ impl Backend { } /// Publish client additional key packages - pub fn publish_key_packages(&self, user: &User, ckp: &ClientKeyPackages) -> Result<(), String> { + pub fn publish_key_packages(&self, user: &User, ckp: ClientKeyPackages) -> Result<(), String> { + let Some(auth_token) = user.auth_token() else { + return Err("Please register user before publishing key packages".to_string()); + }; let mut url = self.ds_url.clone(); let path = "/clients/key_packages/".to_string() + &base64::encode_config(user.identity.borrow().identity(), base64::URL_SAFE); url.set_path(&path); + let request = PublishKeyPackagesRequest { + key_packages: ckp, + auth_token: auth_token.clone(), + }; + // The response should be empty. - let _response = post(&url, &ckp)?; + let _response = post(&url, &request)?; Ok(()) } @@ -92,12 +115,19 @@ impl Backend { /// Get a list of all new messages for the user. pub fn recv_msgs(&self, user: &User) -> Result, String> { + let Some(auth_token) = user.auth_token() else { + return Err("Please register user before publishing key packages".to_string()); + }; let mut url = self.ds_url.clone(); let path = "/recv/".to_string() + &base64::encode_config(user.identity.borrow().identity(), base64::URL_SAFE); url.set_path(&path); - let response = get(&url)?; + let request = RecvMessageRequest { + auth_token: auth_token.clone(), + }; + + let response = get_with_body(&url, &request)?; match TlsVecU16::::tls_deserialize(&mut response.as_slice()) { Ok(r) => Ok(r.into()), Err(e) => Err(format!("Invalid message list: {e:?}")), diff --git a/cli/src/file_helpers.rs b/cli/src/file_helpers.rs deleted file mode 100644 index 956d75feee..0000000000 --- a/cli/src/file_helpers.rs +++ /dev/null @@ -1,6 +0,0 @@ -use std::{env, path::PathBuf}; - -pub fn get_file_path(file_name: &String) -> PathBuf { - let tmp_folder = env::temp_dir(); - tmp_folder.join(file_name) -} diff --git a/cli/src/identity.rs b/cli/src/identity.rs index 8e3eacd106..1fd2b7d3de 100644 --- a/cli/src/identity.rs +++ b/cli/src/identity.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use openmls::prelude::{config::CryptoConfig, *}; +use openmls::prelude::*; use openmls_basic_credential::SignatureKeyPair; use openmls_traits::OpenMlsProvider; @@ -21,22 +21,19 @@ impl Identity { pub(crate) fn new( ciphersuite: Ciphersuite, crypto: &OpenMlsRustPersistentCrypto, - id: &[u8], + username: &[u8], ) -> Self { - let credential = BasicCredential::new(id.to_vec()).unwrap(); + let credential = BasicCredential::new(username.to_vec()); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); let credential_with_key = CredentialWithKey { credential: credential.into(), signature_key: signature_keys.to_public_vec().into(), }; - signature_keys.store(crypto.key_store()).unwrap(); + signature_keys.store(crypto.storage()).unwrap(); let key_package = KeyPackage::builder() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, crypto, &signature_keys, credential_with_key.clone(), @@ -46,11 +43,12 @@ impl Identity { Self { kp: HashMap::from([( key_package + .key_package() .hash_ref(crypto.crypto()) .unwrap() .as_slice() .to_vec(), - key_package, + key_package.key_package().clone(), )]), credential_with_key, signer: signature_keys, @@ -65,10 +63,7 @@ impl Identity { ) -> KeyPackage { let key_package = KeyPackage::builder() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, crypto, &self.signer, self.credential_with_key.clone(), @@ -77,17 +72,25 @@ impl Identity { self.kp.insert( key_package + .key_package() .hash_ref(crypto.crypto()) .unwrap() .as_slice() .to_vec(), - key_package.clone(), + key_package.key_package().clone(), ); - key_package + key_package.key_package().clone() } /// Get the plain identity as byte vector. pub fn identity(&self) -> &[u8] { self.credential_with_key.credential.serialized_content() } + + /// Get the plain identity as byte vector. + pub fn identity_as_string(&self) -> String { + std::str::from_utf8(self.credential_with_key.credential.serialized_content()) + .unwrap() + .to_string() + } } diff --git a/cli/src/main.rs b/cli/src/main.rs index 5d4bc0bfcd..f0ffcae940 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -7,11 +7,9 @@ use termion::input::TermRead; mod backend; mod conversation; -mod file_helpers; mod identity; mod networking; mod openmls_rust_persistent_crypto; -mod persistent_key_store; mod serialize_any_hashmap; mod user; @@ -71,7 +69,7 @@ fn main() { client = Some(user::User::new(client_name.to_string())); client.as_mut().unwrap().add_key_package(); client.as_mut().unwrap().add_key_package(); - client.as_ref().unwrap().register(); + client.as_mut().unwrap().register(); stdout .write_all(format!("registered new client {client_name}\n\n").as_bytes()) .unwrap(); @@ -114,7 +112,7 @@ fn main() { if op == "save" { if let Some(client) = &mut client { client.save(); - let name = &client.username; + let name = &client.identity.borrow().identity_as_string(); stdout .write_all(format!(" >>> client {name} state saved\n\n").as_bytes()) .unwrap(); @@ -130,7 +128,7 @@ fn main() { if op == "autosave" { if let Some(client) = &mut client { client.enable_auto_save(); - let name = &client.username; + let name = &client.identity.borrow().identity_as_string(); stdout .write_all(format!(" >>> autosave enabled for client {name} \n\n").as_bytes()) .unwrap(); diff --git a/cli/src/networking.rs b/cli/src/networking.rs index 7ec86fbd3a..31a5df5506 100644 --- a/cli/src/networking.rs +++ b/cli/src/networking.rs @@ -1,3 +1,4 @@ +use ds_lib::messages::AuthToken; use reqwest::{self, blocking::Client, StatusCode}; use url::Url; @@ -25,9 +26,25 @@ pub fn post(url: &Url, msg: &impl Serialize) -> Result, String> { } pub fn get(url: &Url) -> Result, String> { + let auth_token_option: Option<&AuthToken> = None; + get_internal(url, auth_token_option) +} + +pub fn get_with_body(url: &Url, body: &impl Serialize) -> Result, String> { + get_internal(url, Some(body)) +} + +fn get_internal(url: &Url, msg: Option<&impl Serialize>) -> Result, String> { log::debug!("Get {:?}", url); - let client = Client::new(); - let response = client.get(url.to_string()).send(); + let client = Client::new().get(url.to_string()); + let client = if let Some(msg) = msg { + let serialized_msg = msg.tls_serialize_detached().unwrap(); + log::trace!("Payload: {:?}", serialized_msg); + client.body(serialized_msg) + } else { + client + }; + let response = client.send(); if let Ok(r) = response { if r.status() != StatusCode::OK { return Err(format!("Error status code {:?}", r.status())); diff --git a/cli/src/openmls_rust_persistent_crypto.rs b/cli/src/openmls_rust_persistent_crypto.rs index d62264b46b..2c01d9a583 100644 --- a/cli/src/openmls_rust_persistent_crypto.rs +++ b/cli/src/openmls_rust_persistent_crypto.rs @@ -3,20 +3,19 @@ //! This is an implementation of the [`OpenMlsCryptoProvider`] trait to use with //! OpenMLS. -use super::persistent_key_store::PersistentKeyStore; -use openmls_rust_crypto::RustCrypto; +use openmls_rust_crypto::{MemoryStorage, RustCrypto}; use openmls_traits::OpenMlsProvider; #[derive(Default, Debug)] pub struct OpenMlsRustPersistentCrypto { crypto: RustCrypto, - key_store: PersistentKeyStore, + storage: MemoryStorage, } impl OpenMlsProvider for OpenMlsRustPersistentCrypto { type CryptoProvider = RustCrypto; type RandProvider = RustCrypto; - type KeyStoreProvider = PersistentKeyStore; + type StorageProvider = MemoryStorage; fn crypto(&self) -> &Self::CryptoProvider { &self.crypto @@ -26,17 +25,17 @@ impl OpenMlsProvider for OpenMlsRustPersistentCrypto { &self.crypto } - fn key_store(&self) -> &Self::KeyStoreProvider { - &self.key_store + fn storage(&self) -> &Self::StorageProvider { + &self.storage } } impl OpenMlsRustPersistentCrypto { pub fn save_keystore(&self, user_name: String) -> Result<(), String> { - self.key_store.save(user_name) + self.storage.save(user_name) } pub fn load_keystore(&mut self, user_name: String) -> Result<(), String> { - self.key_store.load(user_name) + self.storage.load(user_name) } } diff --git a/cli/src/persistent_key_store.rs b/cli/src/persistent_key_store.rs deleted file mode 100644 index 20c020ae2b..0000000000 --- a/cli/src/persistent_key_store.rs +++ /dev/null @@ -1,131 +0,0 @@ -use openmls_traits::key_store::{MlsEntity, OpenMlsKeyStore}; -use serde::{Deserialize, Serialize}; -use std::{ - collections::HashMap, - fs::File, - io::{BufReader, BufWriter}, - path::PathBuf, - sync::RwLock, -}; - -use super::file_helpers; - -#[derive(Debug, Default)] -pub struct PersistentKeyStore { - values: RwLock, Vec>>, -} - -#[derive(Debug, Default, Serialize, Deserialize)] -struct SerializableKeyStore { - values: HashMap, -} - -impl OpenMlsKeyStore for PersistentKeyStore { - /// The error type returned by the [`OpenMlsKeyStore`]. - type Error = PersistentKeyStoreError; - - /// Store a value `v` that implements the [`ToKeyStoreValue`] trait for - /// serialization for ID `k`. - /// - /// Returns an error if storing fails. - fn store(&self, k: &[u8], v: &V) -> Result<(), Self::Error> { - let value = - serde_json::to_vec(v).map_err(|_| PersistentKeyStoreError::SerializationError)?; - // We unwrap here, because this is the only function claiming a write - // lock on `credential_bundles`. It only holds the lock very briefly and - // should not panic during that period. - let mut values = self.values.write().unwrap(); - values.insert(k.to_vec(), value); - Ok(()) - } - - /// Read and return a value stored for ID `k` that implements the - /// [`FromKeyStoreValue`] trait for deserialization. - /// - /// Returns [`None`] if no value is stored for `k` or reading fails. - fn read(&self, k: &[u8]) -> Option { - // We unwrap here, because the two functions claiming a write lock on - // `init_key_package_bundles` (this one and `generate_key_package_bundle`) only - // hold the lock very briefly and should not panic during that period. - let values = self.values.read().unwrap(); - if let Some(value) = values.get(k) { - serde_json::from_slice(value).ok() - } else { - None - } - } - - /// Delete a value stored for ID `k`. - /// - /// Returns an error if storing fails. - fn delete(&self, k: &[u8]) -> Result<(), Self::Error> { - // We just delete both ... - let mut values = self.values.write().unwrap(); - values.remove(k); - Ok(()) - } -} - -impl PersistentKeyStore { - fn get_file_path(user_name: &String) -> PathBuf { - file_helpers::get_file_path(&("openmls_cli_".to_owned() + user_name + "_ks.json")) - } - - fn save_to_file(&self, output_file: &File) -> Result<(), String> { - let writer = BufWriter::new(output_file); - - let mut ser_ks = SerializableKeyStore::default(); - for (key, value) in &*self.values.read().unwrap() { - ser_ks - .values - .insert(base64::encode(key), base64::encode(value)); - } - - match serde_json::to_writer_pretty(writer, &ser_ks) { - Ok(()) => Ok(()), - Err(e) => Err(e.to_string()), - } - } - - pub fn save(&self, user_name: String) -> Result<(), String> { - let ks_output_path = PersistentKeyStore::get_file_path(&user_name); - - match File::create(ks_output_path) { - Ok(output_file) => self.save_to_file(&output_file), - Err(e) => Err(e.to_string()), - } - } - - fn load_from_file(&mut self, input_file: &File) -> Result<(), String> { - // Prepare file reader. - let reader = BufReader::new(input_file); - - // Read the JSON contents of the file as an instance of `SerializableKeyStore`. - match serde_json::from_reader::, SerializableKeyStore>(reader) { - Ok(ser_ks) => { - let mut ks_map = self.values.write().unwrap(); - for (key, value) in ser_ks.values { - ks_map.insert(base64::decode(key).unwrap(), base64::decode(value).unwrap()); - } - Ok(()) - } - Err(e) => Err(e.to_string()), - } - } - - pub fn load(&mut self, user_name: String) -> Result<(), String> { - let ks_input_path = PersistentKeyStore::get_file_path(&user_name); - - match File::open(ks_input_path) { - Ok(input_file) => self.load_from_file(&input_file), - Err(e) => Err(e.to_string()), - } - } -} - -/// Errors thrown by the key store. -#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)] -pub enum PersistentKeyStoreError { - #[error("Error serializing value.")] - SerializationError, -} diff --git a/cli/src/user.rs b/cli/src/user.rs index 45f5578bdb..25c5ec770d 100644 --- a/cli/src/user.rs +++ b/cli/src/user.rs @@ -5,12 +5,13 @@ use std::io::{BufReader, BufWriter}; use std::path::PathBuf; use std::{cell::RefCell, collections::HashMap, str}; +use ds_lib::messages::AuthToken; use ds_lib::{ClientKeyPackages, GroupMessage}; use openmls::prelude::{tls_codec::*, *}; use openmls_traits::OpenMlsProvider; use super::{ - backend::Backend, conversation::Conversation, conversation::ConversationMessage, file_helpers, + backend::Backend, conversation::Conversation, conversation::ConversationMessage, identity::Identity, openmls_rust_persistent_crypto::OpenMlsRustPersistentCrypto, serialize_any_hashmap, }; @@ -19,10 +20,15 @@ const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA2 #[derive(serde::Serialize, serde::Deserialize)] pub struct Contact { - username: String, id: Vec, } +impl Contact { + fn username(&self) -> String { + String::from_utf8(self.id.clone()).unwrap() + } +} + #[derive(serde::Serialize, serde::Deserialize)] pub struct Group { group_name: String, @@ -32,7 +38,6 @@ pub struct Group { #[derive(serde::Serialize, serde::Deserialize)] pub struct User { - pub(crate) username: String, #[serde( serialize_with = "serialize_any_hashmap::serialize_hashmap", deserialize_with = "serialize_any_hashmap::deserialize_hashmap" @@ -45,8 +50,9 @@ pub struct User { #[serde(skip)] backend: Backend, #[serde(skip)] - crypto: OpenMlsRustPersistentCrypto, + provider: OpenMlsRustPersistentCrypto, autosave_enabled: bool, + auth_token: Option, } #[derive(PartialEq)] @@ -60,20 +66,22 @@ impl User { pub fn new(username: String) -> Self { let crypto = OpenMlsRustPersistentCrypto::default(); let out = Self { - username: username.clone(), groups: RefCell::new(HashMap::new()), group_list: HashSet::new(), contacts: HashMap::new(), identity: RefCell::new(Identity::new(CIPHERSUITE, &crypto, username.as_bytes())), backend: Backend::default(), - crypto, + provider: crypto, autosave_enabled: false, + auth_token: None, }; out } fn get_file_path(user_name: &String) -> PathBuf { - file_helpers::get_file_path(&("openmls_cli_".to_owned() + user_name + ".json")) + openmls_memory_storage::persistence::get_file_path( + &("openmls_cli_".to_owned() + user_name + ".json"), + ) } fn load_from_file(input_file: &File) -> Result { @@ -100,16 +108,16 @@ impl User { if user_result.is_ok() { let mut user = user_result.ok().unwrap(); - match user.crypto.load_keystore(user_name) { + match user.provider.load_keystore(user_name) { Ok(_) => { let groups = user.groups.get_mut(); for group_name in &user.group_list { let mlsgroup = MlsGroup::load( + user.provider.storage(), &GroupId::from_slice(group_name.as_bytes()), - user.crypto.key_store(), ); let grp = Group { - mls_group: RefCell::new(mlsgroup.unwrap()), + mls_group: RefCell::new(mlsgroup.unwrap().unwrap()), group_name: group_name.clone(), conversation: Conversation::default(), }; @@ -135,23 +143,13 @@ impl User { } pub fn save(&mut self) { - let output_path = User::get_file_path(&self.username); + let output_path = User::get_file_path(&self.identity.borrow().identity_as_string()); match File::create(output_path) { Err(e) => log::error!("Error saving user state: {:?}", e.to_string()), Ok(output_file) => { - let groups = self.groups.get_mut(); - for (group_name, group) in groups { - self.group_list.replace(group_name.clone()); - group - .mls_group - .borrow_mut() - .save(self.crypto.key_store()) - .unwrap(); - } - self.save_to_file(&output_file); - match self.crypto.save_keystore(self.username.clone()) { + match self.provider.save_keystore(self.username()) { Ok(_) => log::info!("User state saved"), Err(e) => log::error!("Error saving user state : {:?}", e.to_string()), } @@ -175,9 +173,9 @@ impl User { let kp = self .identity .borrow_mut() - .add_key_package(CIPHERSUITE, &self.crypto); + .add_key_package(CIPHERSUITE, &self.provider); ( - kp.hash_ref(self.crypto.crypto()) + kp.hash_ref(self.provider.crypto()) .unwrap() .as_slice() .to_vec(), @@ -195,7 +193,7 @@ impl User { credential, } in mls_group.members() { - let credential = BasicCredential::try_from(&credential).unwrap(); + let credential = BasicCredential::try_from(credential).unwrap(); if credential.identity() == name.as_bytes() { return Ok(index); } @@ -210,9 +208,12 @@ impl User { Vec::from_iter(kpgs) } - pub fn register(&self) { - match self.backend.register_client(self) { - Ok(r) => log::debug!("Created new user: {:?}", r), + pub fn register(&mut self) { + match self.backend.register_client(self.key_packages()) { + Ok(token) => { + log::debug!("Created new user: {:?}", self.username()); + self.set_auth_token(token) + } Err(e) => log::error!("Error creating user: {:?}", e), } } @@ -237,7 +238,7 @@ impl User { .as_slice() != signature_key.as_slice() { - let credential = BasicCredential::try_from(&credential).unwrap(); + let credential = BasicCredential::try_from(credential).unwrap(); log::debug!( "Searching for contact {:?}", str::from_utf8(credential.identity()).unwrap() @@ -279,7 +280,7 @@ impl User { .into(), ); - match self.backend.publish_key_packages(self, &ckp) { + match self.backend.publish_key_packages(self, ckp) { Ok(()) => (), Err(e) => println!("Error sending new key package: {e:?}"), }; @@ -296,7 +297,11 @@ impl User { let message_out = group .mls_group .borrow_mut() - .create_message(&self.crypto, &self.identity.borrow().signer, msg.as_bytes()) + .create_message( + &self.provider, + &self.identity.borrow().signer, + msg.as_bytes(), + ) .map_err(|e| format!("{e}"))?; let msg = GroupMessage::new(message_out.into(), &self.recipients(group)); @@ -317,20 +322,18 @@ impl User { fn update_clients(&mut self) { match self.backend.list_clients() { Ok(mut v) => { - for c in v.drain(..) { - let client_id = c.id.clone(); + for client_id in v.drain(..) { log::debug!( "update::Processing client for contact {:?}", str::from_utf8(&client_id).unwrap() ); - if c.id != self.identity.borrow().identity() + if client_id != self.identity.borrow().identity() && self .contacts .insert( - c.id.clone(), + client_id.clone(), Contact { - username: c.client_name, - id: c.id, + id: client_id.clone(), }, ) .is_some() @@ -354,65 +357,57 @@ impl User { } } - /// Update the user. This involves: - /// * retrieving all new messages from the server - /// * update the contacts with all other clients known to the server - pub fn update( + fn process_protocol_message( &mut self, group_name: Option, - ) -> Result, String> { - log::debug!("Updating {} ...", self.username); - - let mut messages_out: Vec = Vec::new(); - - let mut process_protocol_message = |message: ProtocolMessage| -> Result< - (PostUpdateActions, Option), - String, - > { - let processed_message: ProcessedMessage; - let mut groups = self.groups.borrow_mut(); + message: ProtocolMessage, + ) -> Result< + ( + PostUpdateActions, + Option, + Option, + ), + String, + > { + let processed_message: ProcessedMessage; + let mut groups = self.groups.borrow_mut(); - let group = match groups.get_mut(str::from_utf8(message.group_id().as_slice()).unwrap()) - { - Some(g) => g, - None => { - log::error!( - "Error getting group {:?} for a message. Dropping message.", - message.group_id() - ); - return Err("error".to_string()); - } - }; - let mut mls_group = group.mls_group.borrow_mut(); - - processed_message = match mls_group.process_message(&self.crypto, message) { - Ok(msg) => msg, - Err(e) => { - log::error!( - "Error processing unverified message: {:?} - Dropping message.", - e - ); - return Err("error".to_string()); - } - }; + let group = match groups.get_mut(str::from_utf8(message.group_id().as_slice()).unwrap()) { + Some(g) => g, + None => { + log::error!( + "Error getting group {:?} for a message. Dropping message.", + message.group_id() + ); + return Err("error".to_string()); + } + }; + let mut mls_group = group.mls_group.borrow_mut(); - let processed_message_credential: Credential = processed_message.credential().clone(); + processed_message = match mls_group.process_message(&self.provider, message) { + Ok(msg) => msg, + Err(e) => { + log::error!( + "Error processing unverified message: {:?} - Dropping message.", + e + ); + return Err("error".to_string()); + } + }; - match processed_message.into_content() { - ProcessedMessageContent::ApplicationMessage(application_message) => { - let processed_message_credential = - BasicCredential::try_from(&processed_message_credential).unwrap(); + let processed_message_credential: Credential = processed_message.credential().clone(); - let sender_name = match self - .contacts - .get(processed_message_credential.identity()) - { - Some(c) => c.username.clone(), - None => { - // Contact list is not updated right now, get the identity from the - // mls_group member - let user_id = mls_group.members().find_map(|m| { - let m_credential = BasicCredential::try_from(&m.credential).unwrap(); + let message_out = match processed_message.into_content() { + ProcessedMessageContent::ApplicationMessage(application_message) => { + let processed_message_credential = + BasicCredential::try_from(processed_message_credential.clone()).unwrap(); + let sender_name = match self.contacts.get(processed_message_credential.identity()) { + Some(c) => c.id.clone(), + None => { + // Contact list is not updated right now, get the identity from the + // mls_group member + let user_id = mls_group.members().find_map(|m| { + let m_credential = BasicCredential::try_from(m.credential.clone()).unwrap(); if m_credential.identity() == processed_message_credential.identity() && (self @@ -431,49 +426,70 @@ impl User { None } }); - user_id.unwrap_or("".to_owned()) - } - }; - let conversation_message = ConversationMessage::new( - String::from_utf8(application_message.into_bytes()) - .unwrap() - .clone(), - sender_name.to_string(), - ); - if group_name.is_none() || group_name.clone().unwrap() == group.group_name { - messages_out.push(conversation_message.clone()); + user_id.unwrap_or("".to_owned()).as_bytes().to_vec() } - group.conversation.add(conversation_message); - } - ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { - // intentionally left blank. + }; + let conversation_message = ConversationMessage::new( + String::from_utf8(application_message.into_bytes()) + .unwrap() + .clone(), + String::from_utf8(sender_name).unwrap(), + ); + group.conversation.add(conversation_message.clone()); + if group_name.is_none() || group_name.clone().unwrap() == group.group_name { + Some(conversation_message) + } else { + None } - ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { - // intentionally left blank. + } + ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { + // intentionally left blank. + None + } + ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { + // intentionally left blank. + None + } + ProcessedMessageContent::StagedCommitMessage(commit_ptr) => { + let mut remove_proposal: bool = false; + if commit_ptr.self_removed() { + remove_proposal = true; } - ProcessedMessageContent::StagedCommitMessage(commit_ptr) => { - let mut remove_proposal: bool = false; - if commit_ptr.self_removed() { - remove_proposal = true; - } - match mls_group.merge_staged_commit(&self.crypto, *commit_ptr) { - Ok(()) => { - if remove_proposal { - log::debug!("update::Processing StagedCommitMessage removing {} from group {} ", self.username, group.group_name); - return Ok(( - PostUpdateActions::Remove, - Some(mls_group.group_id().clone()), - )); - } + match mls_group.merge_staged_commit(&self.provider, *commit_ptr) { + Ok(()) => { + if remove_proposal { + log::debug!( + "update::Processing StagedCommitMessage removing {} from group {} ", + self.username(), + group.group_name + ); + return Ok(( + PostUpdateActions::Remove, + Some(mls_group.group_id().clone()), + None, + )); } - Err(e) => return Err(e.to_string()), } + Err(e) => return Err(e.to_string()), } + None } - Ok((PostUpdateActions::None, None)) }; + Ok((PostUpdateActions::None, None, message_out)) + } - log::debug!("update::Processing messages for {} ", self.username); + /// Update the user. This involves: + /// * retrieving all new messages from the server + /// * update the contacts with all other clients known to the server + pub fn update( + &mut self, + group_name: Option, + ) -> Result, String> { + log::debug!("Updating {} ...", self.username()); + + let mut messages_out: Vec = Vec::new(); + + log::debug!("update::Processing messages for {} ", self.username()); // Go through the list of messages and process or store them. for message in self.backend.recv_msgs(self)?.drain(..) { log::debug!("Reading message format {:#?} ...", message.wire_format()); @@ -484,10 +500,13 @@ impl User { self.join_group(welcome)?; } MlsMessageBodyIn::PrivateMessage(message) => { - match process_protocol_message(message.into()) { - Ok(p) => { - if p.0 == PostUpdateActions::Remove { - match p.1 { + match self.process_protocol_message(group_name.clone(), message.into()) { + Ok((post_update_actions, group_id_option, message_out_option)) => { + if let Some(message_out) = message_out_option { + messages_out.push(message_out); + } + if post_update_actions == PostUpdateActions::Remove { + match group_id_option { Some(gid) => { let mut grps = self.groups.borrow_mut(); grps.remove_entry(str::from_utf8(gid.as_slice()).unwrap()); @@ -506,7 +525,10 @@ impl User { }; } MlsMessageBodyIn::PublicMessage(message) => { - if process_protocol_message(message.into()).is_err() { + if self + .process_protocol_message(group_name.clone(), message.into()) + .is_err() + { continue; } } @@ -524,7 +546,7 @@ impl User { /// Create a group with the given name. pub fn create_group(&mut self, name: String) { - log::debug!("{} creates group {}", self.username, name); + log::debug!("{} creates group {}", self.username(), name); let group_id = name.as_bytes(); let mut group_aad = group_id.to_vec(); group_aad.extend(b" AAD"); @@ -536,14 +558,16 @@ impl User { .build(); let mut mls_group = MlsGroup::new_with_group_id( - &self.crypto, + &self.provider, &self.identity.borrow().signer, &group_config, GroupId::from_slice(group_id), self.identity.borrow().credential_with_key.clone(), ) .expect("Failed to create MlsGroup"); - mls_group.set_aad(group_aad.as_slice()); + mls_group + .set_aad(self.provider.storage(), group_aad.as_slice()) + .expect("Failed to write the AAD for the new group to storage"); let group = Group { group_name: name.clone(), @@ -563,7 +587,7 @@ impl User { /// Invite user with the given name to the group. pub fn invite(&mut self, name: String, group_name: String) -> Result<(), String> { // First we need to get the key package for {id} from the DS. - let contact = match self.contacts.values().find(|c| c.username == name) { + let contact = match self.contacts.values().find(|c| c.username() == name) { Some(v) => v, None => return Err(format!("No contact with name {name} known.")), }; @@ -582,7 +606,7 @@ impl User { .mls_group .borrow_mut() .add_members( - &self.crypto, + &self.provider, &self.identity.borrow().signer, &[joiner_key_package.into()], ) @@ -602,7 +626,7 @@ impl User { group .mls_group .borrow_mut() - .merge_pending_commit(&self.crypto) + .merge_pending_commit(&self.provider) .expect("error merging pending commit"); // Finally, send Welcome to the joiner. @@ -639,7 +663,11 @@ impl User { let (remove_message, _welcome, _group_info) = group .mls_group .borrow_mut() - .remove_members(&self.crypto, &self.identity.borrow().signer, &[leaf_index]) + .remove_members( + &self.provider, + &self.identity.borrow().signer, + &[leaf_index], + ) .map_err(|e| format!("Failed to remove member from group - {e}"))?; // First, send the MlsMessage remove commit to the group. @@ -654,7 +682,7 @@ impl User { group .mls_group .borrow_mut() - .merge_pending_commit(&self.crypto) + .merge_pending_commit(&self.provider) .expect("error merging pending commit"); drop(groups); @@ -666,7 +694,7 @@ impl User { /// Join a group with the provided welcome message. fn join_group(&self, welcome: Welcome) -> Result<(), String> { - log::debug!("{} joining group ...", self.username); + log::debug!("{} joining group ...", self.username()); let mut ident = self.identity.borrow_mut(); for secret in welcome.secrets().iter() { @@ -681,9 +709,9 @@ impl User { .use_ratchet_tree_extension(true) .build(); let mut mls_group = - StagedWelcome::new_from_welcome(&self.crypto, &group_config, welcome, None) + StagedWelcome::new_from_welcome(&self.provider, &group_config, welcome, None) .expect("Failed to create staged join") - .into_group(&self.crypto) + .into_group(&self.provider) .expect("Failed to create MlsGroup"); let group_id = mls_group.group_id().to_vec(); @@ -691,7 +719,9 @@ impl User { let group_name = String::from_utf8(group_id.clone()).unwrap(); let group_aad = group_name.clone() + " AAD"; - mls_group.set_aad(group_aad.as_bytes()); + mls_group + .set_aad(self.provider.storage(), group_aad.as_bytes()) + .expect("Failed to update the AAD in the storage"); let group = Group { group_name: group_name.clone(), @@ -706,4 +736,16 @@ impl User { None => Ok(()), } } + + pub(crate) fn username(&self) -> String { + self.identity.borrow().identity_as_string() + } + + pub(super) fn set_auth_token(&mut self, token: AuthToken) { + self.auth_token = Some(token); + } + + pub(super) fn auth_token(&self) -> Option<&AuthToken> { + self.auth_token.as_ref() + } } diff --git a/delivery-service/ds-lib/Cargo.toml b/delivery-service/ds-lib/Cargo.toml index c780fba5da..c4f13fcd86 100644 --- a/delivery-service/ds-lib/Cargo.toml +++ b/delivery-service/ds-lib/Cargo.toml @@ -6,9 +6,10 @@ edition = "2021" description = "Types to interact with the OpenMLS DS." [dependencies] +rand = { version = "^0.8" } openmls = { path = "../../openmls", features = ["test-utils"] } openmls_traits = { path = "../../traits" } openmls_rust_crypto = { path = "../../openmls_rust_crypto" } -openmls_memory_keystore = { path = "../../memory_keystore" } +openmls_memory_storage = { path = "../../memory_storage" } openmls_basic_credential = { path = "../../basic_credential" } serde = { version = "^1.0", features = ["derive"] } diff --git a/delivery-service/ds-lib/src/lib.rs b/delivery-service/ds-lib/src/lib.rs index 4ec3b8dbd9..ea2b9e67fb 100644 --- a/delivery-service/ds-lib/src/lib.rs +++ b/delivery-service/ds-lib/src/lib.rs @@ -5,22 +5,63 @@ //! //! Clients are represented by the `ClientInfo` struct. +pub mod messages; + use std::collections::HashSet; -use openmls::prelude::{tls_codec::*, *}; +use messages::AuthToken; +use openmls::prelude::tls_codec::*; +use openmls::prelude::*; /// Information about a client. /// To register a new client create a new `ClientInfo` and send it to /// `/clients/register`. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, TlsSize, TlsSerialize, TlsDeserialize)] pub struct ClientInfo { - pub client_name: String, + pub id: Vec, pub key_packages: ClientKeyPackages, /// map of reserved key_packages [group_id, key_package_hash] + #[tls_codec(with = "hashset_codec")] pub reserved_key_pkg_hash: HashSet>, - pub id: Vec, pub msgs: Vec, pub welcome_queue: Vec, + pub auth_token: AuthToken, +} + +mod hashset_codec { + use std::{ + collections::HashSet, + io::{Read, Write}, + }; + + use crate::tls_codec::{self, Deserialize, Serialize}; + + pub fn tls_serialized_len(hashset: &HashSet>) -> usize { + let hashset_len = hashset.len(); + let length_encoding_bytes = match hashset_len { + 0..=0x3f => 1, + 0x40..=0x3fff => 2, + 0x4000..=0x3fff_ffff => 4, + _ => 8, + }; + hashset_len + length_encoding_bytes + } + + pub fn tls_serialize( + hashset: &HashSet>, + writer: &mut W, + ) -> Result + where + W: Write, + { + let vec = hashset.iter().map(|v| v.as_slice()).collect::>(); + vec.tls_serialize(writer) + } + + pub fn tls_deserialize(bytes: &mut R) -> Result>, tls_codec::Error> { + let vec = Vec::>::tls_deserialize(bytes)?; + Ok(vec.into_iter().collect::>()) + } } /// The DS returns a list of key packages for a client as `ClientKeyPackages`. @@ -44,14 +85,10 @@ pub struct ClientKeyPackages(pub TlsVecU32<(TlsByteVecU8, KeyPackageIn)>); impl ClientInfo { /// Create a new `ClientInfo` struct for a given client name and vector of /// key packages with corresponding hashes. - pub fn new(client_name: String, mut key_packages: Vec<(Vec, KeyPackageIn)>) -> Self { + pub fn new(mut key_packages: Vec<(Vec, KeyPackageIn)>) -> Self { let key_package: KeyPackage = KeyPackage::from(key_packages[0].1.clone()); - let id = VLBytes::tls_deserialize_exact( - key_package.leaf_node().credential().serialized_content(), - ) - .unwrap(); + let id = key_package.leaf_node().credential().serialized_content(); Self { - client_name, id: id.into(), key_packages: ClientKeyPackages( key_packages @@ -63,6 +100,7 @@ impl ClientInfo { reserved_key_pkg_hash: HashSet::new(), msgs: Vec::new(), welcome_queue: Vec::new(), + auth_token: AuthToken::random(), } } @@ -114,34 +152,6 @@ impl GroupMessage { } } -impl tls_codec::Size for ClientInfo { - fn tls_serialized_len(&self) -> usize { - TlsByteSliceU16(self.client_name.as_bytes()).tls_serialized_len() - + self.key_packages.tls_serialized_len() - } -} - -impl tls_codec::Serialize for ClientInfo { - fn tls_serialize(&self, writer: &mut W) -> Result { - let written = TlsByteSliceU16(self.client_name.as_bytes()).tls_serialize(writer)?; - self.key_packages.tls_serialize(writer).map(|l| l + written) - } -} - -impl tls_codec::Deserialize for ClientInfo { - fn tls_deserialize(bytes: &mut R) -> Result { - let client_name = - String::from_utf8_lossy(TlsByteVecU16::tls_deserialize(bytes)?.as_slice()).into(); - let mut key_packages: Vec<(TlsByteVecU8, KeyPackageIn)> = - TlsVecU32::<(TlsByteVecU8, KeyPackageIn)>::tls_deserialize(bytes)?.into(); - let key_packages = key_packages - .drain(..) - .map(|(e1, e2)| (e1.into(), e2)) - .collect(); - Ok(Self::new(client_name, key_packages)) - } -} - impl tls_codec::Size for GroupMessage { fn tls_serialized_len(&self) -> usize { self.msg.tls_serialized_len() + self.recipients.tls_serialized_len() diff --git a/delivery-service/ds-lib/src/messages.rs b/delivery-service/ds-lib/src/messages.rs new file mode 100644 index 0000000000..ae124017c0 --- /dev/null +++ b/delivery-service/ds-lib/src/messages.rs @@ -0,0 +1,50 @@ +use crate::tls_codec::{self, TlsDeserialize, TlsSerialize, TlsSize}; +use rand::{thread_rng, Rng}; +use serde::{Deserialize, Serialize}; + +use crate::ClientKeyPackages; + +#[derive( + Debug, Clone, TlsSize, TlsSerialize, TlsDeserialize, PartialEq, Serialize, Deserialize, +)] +pub struct AuthToken { + token: Vec, +} + +impl Default for AuthToken { + fn default() -> Self { + Self::random() + } +} + +impl AuthToken { + pub(super) fn random() -> Self { + let token = thread_rng().gen::<[u8; 32]>().to_vec(); + Self { token } + } +} + +#[derive(Debug, Clone, TlsSize, TlsSerialize, TlsDeserialize)] +pub struct RegisterClientRequest { + pub key_packages: ClientKeyPackages, +} + +pub struct RegisterClientErrorResponse { + pub message: String, +} + +#[derive(Debug, Clone, TlsSize, TlsSerialize, TlsDeserialize)] +pub struct RegisterClientSuccessResponse { + pub auth_token: AuthToken, +} + +#[derive(Debug, Clone, TlsSize, TlsSerialize, TlsDeserialize)] +pub struct PublishKeyPackagesRequest { + pub key_packages: ClientKeyPackages, + pub auth_token: AuthToken, +} + +#[derive(Debug, Clone, TlsSize, TlsSerialize, TlsDeserialize)] +pub struct RecvMessageRequest { + pub auth_token: AuthToken, +} diff --git a/delivery-service/ds-lib/tests/test_codec.rs b/delivery-service/ds-lib/tests/test_codec.rs index 9c9c7845b1..cadd9be6cc 100644 --- a/delivery-service/ds-lib/tests/test_codec.rs +++ b/delivery-service/ds-lib/tests/test_codec.rs @@ -1,5 +1,5 @@ use ds_lib::{self, *}; -use openmls::prelude::{config::CryptoConfig, *}; +use openmls::prelude::*; use openmls_basic_credential::SignatureKeyPair; use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::OpenMlsProvider; @@ -11,35 +11,28 @@ fn test_client_info() { let client_name = "Client1"; let ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519; - let credential = BasicCredential::new(client_name.as_bytes().to_vec()).unwrap(); + let credential = BasicCredential::new(client_name.as_bytes().to_vec()); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); let credential_with_key = CredentialWithKey { credential: credential.into(), signature_key: signature_keys.to_public_vec().into(), }; - signature_keys.store(crypto.key_store()).unwrap(); + signature_keys.store(crypto.storage()).unwrap(); let client_key_package = KeyPackage::builder() - .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - crypto, - &signature_keys, - credential_with_key, - ) + .build(ciphersuite, crypto, &signature_keys, credential_with_key) .unwrap(); let client_key_package = vec![( client_key_package + .key_package() .hash_ref(crypto.crypto()) .expect("Could not hash KeyPackage.") .as_slice() .to_vec(), - KeyPackageIn::from(client_key_package), + KeyPackageIn::from(client_key_package.key_package().clone()), )]; - let client_data = ClientInfo::new(client_name.to_string(), client_key_package); + let client_data = ClientInfo::new(client_key_package); let encoded_client_data = client_data.tls_serialize_detached().unwrap(); let client_data2 = ClientInfo::tls_deserialize(&mut encoded_client_data.as_slice()) diff --git a/delivery-service/ds/README.md b/delivery-service/ds/README.md index 64886fe246..43d53b035e 100644 --- a/delivery-service/ds/README.md +++ b/delivery-service/ds/README.md @@ -4,7 +4,8 @@ This is a proof-of-concept for an MLS delivery service that can be used for test * Registering Clients via a POST request to `/clients/register` * Listing Clients via a GET request to `/clients/list` -* Get a list of key packages of a client via a GET request to `/clients/get/{name}` +* Get a list of key packages of a client via a GET request to `/clients/key_packages/{name}` +* Add a new key package for a client via a POST request to `/clients/key_packages/{name}` * Send an MLS group message via a POST request to `/send/message` * Send a Welcome message via a POST request to `/send/welcome` * Get a list of messages for a client via a GET request to `/recv/{name}` diff --git a/delivery-service/ds/src/main.rs b/delivery-service/ds/src/main.rs index d5734f265b..1776cdd535 100644 --- a/delivery-service/ds/src/main.rs +++ b/delivery-service/ds/src/main.rs @@ -37,7 +37,13 @@ use std::collections::HashMap; use std::sync::Mutex; use tls_codec::{Deserialize, Serialize, TlsSliceU16, TlsVecU32}; -use ds_lib::*; +use ds_lib::{ + messages::{ + PublishKeyPackagesRequest, RecvMessageRequest, RegisterClientRequest, + RegisterClientSuccessResponse, + }, + *, +}; use openmls::prelude::*; #[cfg(test)] @@ -84,23 +90,42 @@ async fn register_client(mut body: Payload, data: web::Data) -> impl Res while let Some(item) = body.next().await { bytes.extend_from_slice(&unwrap_item!(item)); } - let info = match ClientInfo::tls_deserialize(&mut &bytes[..]) { + let req = match RegisterClientRequest::tls_deserialize(&mut &bytes[..]) { Ok(i) => i, Err(_) => { log::error!("Invalid payload for /clients/register\n{:?}", bytes); return actix_web::HttpResponse::BadRequest().finish(); } }; - log::debug!("Registering client: {:?}", info); + + if req.key_packages.0.is_empty() { + log::error!("Invalid payload for /clients/register: no key packages"); + return actix_web::HttpResponse::BadRequest().finish(); + } + + let key_packages = req + .key_packages + .0 + .into_vec() + .into_iter() + .map(|(b, kp)| (b.into_vec(), kp)) + .collect(); + let new_client_info = ClientInfo::new(key_packages); + + log::debug!("Registering client: {:?}", new_client_info.id); + + let response = RegisterClientSuccessResponse { + auth_token: new_client_info.auth_token.clone(), + }; let mut clients = unwrap_data!(data.clients.lock()); - let client_name = info.client_name.clone(); - let old = clients.insert(info.id.clone(), info); - if old.is_some() { + if clients.contains_key(&new_client_info.id) { return actix_web::HttpResponse::Conflict().finish(); } + let old = clients.insert(new_client_info.id.clone(), new_client_info); + assert!(old.is_none()); - actix_web::HttpResponse::Ok().body(format!("Welcome {client_name}!\n")) + actix_web::HttpResponse::Ok().body(response.tls_serialize_detached().unwrap()) } /// Returns a list of clients with their names and IDs. @@ -110,10 +135,10 @@ async fn list_clients(_req: HttpRequest, data: web::Data) -> impl Respon let clients = unwrap_data!(data.clients.lock()); // XXX: we could encode while iterating to be less wasteful. - let clients: TlsVecU32 = clients + let clients: TlsVecU32> = clients .values() - .cloned() - .collect::>() + .map(|c| c.id().to_vec()) + .collect::>>() .into(); let mut out_bytes = Vec::new(); if clients.tls_serialize(&mut out_bytes).is_err() { @@ -124,7 +149,12 @@ async fn list_clients(_req: HttpRequest, data: web::Data) -> impl Respon /// Resets the server state. #[get("/reset")] -async fn reset(_req: HttpRequest, data: web::Data) -> impl Responder { +async fn reset(req: HttpRequest, data: web::Data) -> impl Responder { + if let Some(reset_key) = req.headers().get("reset-key") { + if reset_key != "poc-reset-password" { + return actix_web::HttpResponse::NetworkAuthenticationRequired().finish(); + } + } log::debug!("Resetting server"); let mut clients = unwrap_data!(data.clients.lock()); let mut groups = unwrap_data!(data.groups.lock()); @@ -144,6 +174,7 @@ async fn get_key_packages(path: web::Path, data: web::Data) -> i Ok(v) => v, Err(_) => return actix_web::HttpResponse::BadRequest().finish(), }; + log::debug!("Getting key packages for {:?}", id); let client = match clients.get(&id) { @@ -171,15 +202,15 @@ async fn publish_key_packages( Ok(v) => v, Err(_) => return actix_web::HttpResponse::BadRequest().finish(), }; - log::debug!("Add key package for {:?}", id); - let client = match clients.get_mut(&id) { - Some(client) => client, + let client = match clients.get(&id) { + Some(c) => c, None => return actix_web::HttpResponse::NotFound().finish(), }; - let key_packages = match ClientKeyPackages::tls_deserialize(&mut &bytes[..]) { - Ok(ckp) => ckp, + // Deserialize request + let req = match PublishKeyPackagesRequest::tls_deserialize(&mut &bytes[..]) { + Ok(i) => i, Err(_) => { log::error!( "Invalid payload for /clients/key_packages/{:?}\n{:?}", @@ -190,10 +221,22 @@ async fn publish_key_packages( } }; - key_packages + // Auth + if client.auth_token != req.auth_token { + return actix_web::HttpResponse::Unauthorized().finish(); + } + + log::debug!("Add key package for {:?}", id); + + let client = match clients.get_mut(&id) { + Some(client) => client, + None => return actix_web::HttpResponse::NotFound().finish(), + }; + + req.key_packages .0 - .iter() - .map(|(b, kp)| (b.clone(), kp.clone())) + .into_vec() + .into_iter() .for_each(|value| client.key_packages.0.push(value)); actix_web::HttpResponse::Ok().finish() @@ -319,19 +362,48 @@ async fn msg_send(mut body: Payload, data: web::Data) -> impl Responder /// details) the DS has stored for the given client. /// The messages are deleted on the DS when sent out. #[get("/recv/{id}")] -async fn msg_recv(path: web::Path, data: web::Data) -> impl Responder { +async fn msg_recv( + path: web::Path, + mut body: Payload, + data: web::Data, +) -> impl Responder { + let mut bytes = web::BytesMut::new(); + while let Some(item) = body.next().await { + bytes.extend_from_slice(&unwrap_item!(item)); + } + let mut clients = unwrap_data!(data.clients.lock()); let id = match base64::decode_config(path.into_inner(), base64::URL_SAFE) { Ok(v) => v, Err(_) => return actix_web::HttpResponse::BadRequest().finish(), }; - log::debug!("Getting messages for client {:?}", id); + let client = match clients.get_mut(&id) { Some(client) => client, None => return actix_web::HttpResponse::NotFound().finish(), }; + // Auth + // Deserialize request + let req = match RecvMessageRequest::tls_deserialize(&mut &bytes[..]) { + Ok(i) => i, + Err(_) => { + log::error!( + "Invalid payload for /clients/key_packages/{:?}\n{:?}", + id, + bytes + ); + return actix_web::HttpResponse::BadRequest().finish(); + } + }; + + if req.auth_token != client.auth_token { + return actix_web::HttpResponse::Unauthorized().finish(); + } + + log::debug!("Getting messages for client {:?}", id); + let mut out: Vec = Vec::new(); let mut welcomes: Vec = client.welcome_queue.drain(..).collect(); out.append(&mut welcomes); diff --git a/delivery-service/ds/src/test.rs b/delivery-service/ds/src/test.rs index e655a795e5..24afc91053 100644 --- a/delivery-service/ds/src/test.rs +++ b/delivery-service/ds/src/test.rs @@ -1,17 +1,16 @@ use super::*; use actix_web::{body::MessageBody, http::StatusCode, test, web, web::Bytes, App}; -use openmls::prelude::config::CryptoConfig; use openmls_basic_credential::SignatureKeyPair; use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::types::SignatureScheme; use openmls_traits::OpenMlsProvider; -use tls_codec::{TlsByteVecU8, TlsVecU16, VLBytes}; +use tls_codec::{TlsByteVecU8, TlsVecU16}; fn generate_credential( identity: Vec, signature_scheme: SignatureScheme, ) -> (CredentialWithKey, SignatureKeyPair) { - let credential = BasicCredential::new(identity).unwrap(); + let credential = BasicCredential::new(identity); let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap(); let credential_with_key = CredentialWithKey { credential: credential.into(), @@ -27,18 +26,10 @@ fn generate_key_package( extensions: Extensions, crypto_provider: &impl OpenMlsProvider, signer: &SignatureKeyPair, -) -> KeyPackage { +) -> KeyPackageBundle { KeyPackage::builder() .key_package_extensions(extensions) - .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - crypto_provider, - signer, - credential_with_key, - ) + .build(ciphersuite, crypto_provider, signer, credential_with_key) .unwrap() } @@ -79,10 +70,8 @@ async fn test_list_clients() { let crypto = &OpenMlsRustCrypto::default(); let (credential_with_key, signer) = generate_credential(client_name.into(), SignatureScheme::from(ciphersuite)); - let identity = - VLBytes::tls_deserialize_exact(credential_with_key.credential.serialized_content()) - .unwrap(); - let client_id = identity.as_slice().to_vec(); + let identity = credential_with_key.credential.serialized_content(); + let client_id = identity.to_vec(); let client_key_package = generate_key_package( ciphersuite, credential_with_key.clone(), @@ -92,22 +81,39 @@ async fn test_list_clients() { ); let client_key_package = vec![( client_key_package + .key_package() .hash_ref(crypto.crypto()) .unwrap() .as_slice() .to_vec(), KeyPackageIn::from(client_key_package.clone()), )]; - let client_data = ClientInfo::new(client_name.to_string(), client_key_package.clone()); + let mut client_data = ClientInfo::new(client_key_package.clone()); + let body = RegisterClientRequest { + key_packages: ClientKeyPackages( + client_key_package + .clone() + .into_iter() + .map(|(b, kp)| (b.into(), kp)) + .collect::>() + .into(), + ), + }; + let req = test::TestRequest::post() .uri("/clients/register") .set_payload(Bytes::copy_from_slice( - &client_data.tls_serialize_detached().unwrap(), + &body.tls_serialize_detached().unwrap(), )) .to_request(); let response = test::call_service(&app, req).await; assert_eq!(response.status(), StatusCode::OK); + let response_body = RegisterClientSuccessResponse::tls_deserialize_exact( + response.into_body().try_into_bytes().unwrap(), + ) + .unwrap(); + client_data.auth_token = response_body.auth_token; // There should be Client1 now. let req = test::TestRequest::with_uri("/clients/list").to_request(); @@ -116,13 +122,13 @@ async fn test_list_clients() { assert_eq!(response.status(), StatusCode::OK); let bytes = response.into_body().try_into_bytes().unwrap(); - let client_info = - TlsVecU32::::tls_deserialize(&mut bytes.as_ref()).expect("Invalid client list"); + let client_ids = + TlsVecU32::>::tls_deserialize(&mut bytes.as_ref()).expect("Invalid client list"); - let expected = TlsVecU32::::new(vec![client_data]); + let expected = TlsVecU32::>::new(vec![client_data.id().to_vec()]); assert_eq!( - client_info.tls_serialize_detached().unwrap(), + client_ids.tls_serialize_detached().unwrap(), expected.tls_serialize_detached().unwrap() ); @@ -173,6 +179,7 @@ async fn test_group() { let mut credentials_with_key = Vec::new(); let mut signers = Vec::new(); let mut client_ids = Vec::new(); + let mut client_data_vec = Vec::new(); for client_name in clients.iter() { let ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519; let (credential_with_key, signer) = generate_credential( @@ -186,37 +193,52 @@ async fn test_group() { crypto, &signer, ); - let client_data = ClientInfo::new( - client_name.to_string(), - vec![( - client_key_package - .hash_ref(crypto.crypto()) - .unwrap() - .as_slice() - .to_vec(), - client_key_package.clone().into(), - )], + let client_key_packages = ( + client_key_package + .key_package() + .hash_ref(crypto.crypto()) + .unwrap() + .as_slice() + .to_vec(), + client_key_package.clone().into(), ); + + let mut client_data = ClientInfo::new(vec![(client_key_packages.clone())]); key_packages.push(client_key_package); - let id = - VLBytes::tls_deserialize_exact(credential_with_key.credential.serialized_content()) - .unwrap(); - client_ids.push(id.as_slice().to_vec()); + let id = credential_with_key.credential.serialized_content(); + client_ids.push(id.to_vec()); credentials_with_key.push(credential_with_key); signers.push(signer); + + let body = RegisterClientRequest { + key_packages: ClientKeyPackages( + vec![client_key_packages.clone()] + .into_iter() + .map(|(b, kp)| (b.into(), kp)) + .collect::>() + .into(), + ), + }; + let req = test::TestRequest::post() .uri("/clients/register") .set_payload(Bytes::copy_from_slice( - &client_data.tls_serialize_detached().unwrap(), + &body.tls_serialize_detached().unwrap(), )) .to_request(); let response = test::call_service(&app, req).await; assert_eq!(response.status(), StatusCode::OK); + let response_body = RegisterClientSuccessResponse::tls_deserialize_exact( + response.into_body().try_into_bytes().unwrap(), + ) + .unwrap(); + client_data.auth_token = response_body.auth_token; + client_data_vec.push(client_data); } // Add an additional key package for Client2 - let group_ciphersuite = key_packages[0].ciphersuite(); + let group_ciphersuite = key_packages[0].key_package().ciphersuite(); let key_package_2 = generate_key_package( group_ciphersuite, credentials_with_key.get(1).unwrap().clone(), @@ -227,6 +249,7 @@ async fn test_group() { let key_package_2 = ( key_package_2 + .key_package() .hash_ref(crypto.crypto()) .unwrap() .as_slice() @@ -245,10 +268,14 @@ async fn test_group() { // Publish key package to the DS for Client2 let path = "/clients/key_packages/".to_string() + &base64::encode_config(&client_ids[1], base64::URL_SAFE); + let body = PublishKeyPackagesRequest { + key_packages: ckp, + auth_token: client_data_vec[1].auth_token.clone(), + }; let req = test::TestRequest::post() .uri(&path) .set_payload(Bytes::copy_from_slice( - &ckp.tls_serialize_detached().unwrap(), + &body.tls_serialize_detached().unwrap(), )) .to_request(); @@ -304,8 +331,15 @@ async fn test_group() { assert_eq!(response.status(), StatusCode::OK); // There should be a welcome message now for Client2. + let body = RecvMessageRequest { + auth_token: client_data_vec[1].auth_token.clone(), + }; let path = "/recv/".to_owned() + &base64::encode_config(clients[1], base64::URL_SAFE); - let req = test::TestRequest::with_uri(&path).to_request(); + let req = test::TestRequest::with_uri(&path) + .set_payload(Bytes::copy_from_slice( + &body.tls_serialize_detached().unwrap(), + )) + .to_request(); let response = test::call_service(&app, req).await; assert_eq!(response.status(), StatusCode::OK); @@ -361,8 +395,15 @@ async fn test_group() { assert_eq!(response.status(), StatusCode::OK); // Client1 retrieves messages from the DS + let body = RecvMessageRequest { + auth_token: client_data_vec[0].auth_token.clone(), + }; let path = "/recv/".to_owned() + &base64::encode_config(clients[0], base64::URL_SAFE); - let req = test::TestRequest::with_uri(&path).to_request(); + let req = test::TestRequest::with_uri(&path) + .set_payload(Bytes::copy_from_slice( + &body.tls_serialize_detached().unwrap(), + )) + .to_request(); let response = test::call_service(&app, req).await; assert_eq!(response.status(), StatusCode::OK); diff --git a/interop_client/src/main.rs b/interop_client/src/main.rs index 065c2b22fa..4304a7113e 100644 --- a/interop_client/src/main.rs +++ b/interop_client/src/main.rs @@ -13,30 +13,21 @@ use mls_client::{ }; use mls_interop_proto::mls_client; use openmls::{ - ciphersuite::HpkePrivateKey, credentials::{BasicCredential, Credential, CredentialType, CredentialWithKey}, framing::{MlsMessageBodyIn, MlsMessageIn, MlsMessageOut, ProcessedMessageContent}, group::{ GroupEpoch, GroupId, MlsGroup, MlsGroupCreateConfig, MlsGroupJoinConfig, StagedWelcome, WireFormatPolicy, PURE_CIPHERTEXT_WIRE_FORMAT_POLICY, PURE_PLAINTEXT_WIRE_FORMAT_POLICY, }, - key_packages::KeyPackage, - prelude::{config::CryptoConfig, Capabilities, ExtensionType, SenderRatchetConfiguration}, + key_packages::{KeyPackage, KeyPackageBundle}, + prelude::{Capabilities, ExtensionType, SenderRatchetConfiguration}, schedule::{psk::ResumptionPskUsage, ExternalPsk, PreSharedKeyId, Psk}, - treesync::{ - test_utils::{read_keys_from_key_store, write_keys_from_key_store}, - RatchetTreeIn, - }, + treesync::RatchetTreeIn, versions::ProtocolVersion, }; use openmls_basic_credential::SignatureKeyPair; use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{ - key_store::OpenMlsKeyStore, - random::OpenMlsRand, - types::{Ciphersuite, HpkeKeyPair}, - OpenMlsProvider, -}; +use openmls_traits::{random::OpenMlsRand, types::Ciphersuite, OpenMlsProvider}; use tls_codec::{Deserialize, Serialize}; use tonic::{async_trait, transport::Server, Code, Request, Response, Status}; use tracing::{debug, error, info, instrument, trace, Span}; @@ -64,9 +55,7 @@ pub struct InteropGroup { } type PendingState = ( - KeyPackage, - HpkePrivateKey, - HpkeKeyPair, + KeyPackageBundle, Credential, SignatureKeyPair, OpenMlsRustCrypto, @@ -222,16 +211,16 @@ impl MlsClient for MlsClientImpl { let provider = OpenMlsRustCrypto::default(); let ciphersuite = Ciphersuite::try_from(request.cipher_suite as u16).unwrap(); - let credential = BasicCredential::new(request.identity.clone()).unwrap(); + let credential = BasicCredential::new(request.identity.clone()); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); - signature_keys.store(provider.key_store()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); let wire_format_policy = wire_format_policy(request.encrypt_handshake); // Note: We just use some values here that make live testing work. // There is nothing special about the used numbers and they // can be increased (or decreased) depending on the available scenarios. let mls_group_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .max_past_epochs(32) .number_of_resumption_psks(32) .sender_ratchet_configuration(SenderRatchetConfiguration::default()) @@ -287,12 +276,12 @@ impl MlsClient for MlsClientImpl { "Creating key package." ); - let credential = BasicCredential::new(identity).unwrap(); + let credential = BasicCredential::new(identity); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); let key_package = KeyPackage::builder() .leaf_node_capabilities(Capabilities::new( - Some(&[ProtocolVersion::Mls10, ProtocolVersion::Mls10Draft11]), + Some(&[ProtocolVersion::Mls10, ProtocolVersion::Other(999)]), Some(&[ Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, @@ -303,10 +292,7 @@ impl MlsClient for MlsClientImpl { Some(&CREDENTIAL_TYPES), )) .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, &crypto_provider, &signature_keys, CredentialWithKey { @@ -315,13 +301,6 @@ impl MlsClient for MlsClientImpl { }, ) .unwrap(); - let private_key = crypto_provider - .key_store() - .read::(key_package.hpke_init_key().as_slice()) - .unwrap(); - - let encryption_key_pair = - read_keys_from_key_store(&crypto_provider, key_package.leaf_node().encryption_key()); let transaction_id: [u8; 4] = crypto_provider.rand().random_array().unwrap(); let transaction_id = u32::from_be_bytes(transaction_id); @@ -332,11 +311,14 @@ impl MlsClient for MlsClientImpl { key_package: key_package_msg .tls_serialize_detached() .expect("error serializing key package"), - encryption_priv: encryption_key_pair - .private + encryption_priv: key_package + .encryption_private_key() + .tls_serialize_detached() + .unwrap(), + init_priv: key_package + .init_private_key() .tls_serialize_detached() .unwrap(), - init_priv: private_key.tls_serialize_detached().unwrap(), signature_priv: signature_keys.private().to_vec(), }; @@ -348,8 +330,6 @@ impl MlsClient for MlsClientImpl { request.identity.clone(), ( key_package, - private_key, - encryption_key_pair, credential.into(), signature_keys, crypto_provider, @@ -383,49 +363,29 @@ impl MlsClient for MlsClientImpl { .build(); let mut pending_key_packages = self.pending_state.lock().unwrap(); - let ( - my_key_package, - private_key, - encryption_keypair, - _my_credential, - my_signature_keys, - crypto_provider, - ) = pending_key_packages - .remove(&request.identity) - .ok_or(Status::aborted(format!( - "failed to find key package for identity {:x?}", - request.identity - )))?; - - // Store keys so OpenMLS can find them. - crypto_provider - .key_store() - .store(my_key_package.hpke_init_key().as_slice(), &private_key) - .map_err(|_| Status::aborted("failed to interact with the key store"))?; + let (my_key_package, _my_credential, my_signature_keys, crypto_provider) = + pending_key_packages + .remove(&request.identity) + .ok_or(Status::aborted(format!( + "failed to find key package for identity {:x?}", + request.identity + )))?; + + use openmls_traits::storage::StorageProvider as _; // Store the key package in the key store with the hash reference as id // for retrieval when parsing welcome messages. crypto_provider - .key_store() - .store( - my_key_package + .storage() + .write_key_package( + &my_key_package + .key_package() .hash_ref(crypto_provider.crypto()) - .map_err(into_status)? - .as_slice(), + .map_err(into_status)?, &my_key_package, ) .map_err(into_status)?; - // Store the encryption key pair in the key store. - write_keys_from_key_store(&crypto_provider, encryption_keypair); - - // Store the private part of the init_key into the key store. - // The key is the public key. - crypto_provider - .key_store() - .store::(my_key_package.hpke_init_key().as_slice(), &private_key) - .map_err(into_status)?; - let welcome = MlsMessageIn::tls_deserialize(&mut request.welcome.as_slice()) .map_err(|_| Status::aborted("failed to deserialize MlsMessage with a Welcome"))? .into_welcome() @@ -501,12 +461,12 @@ impl MlsClient for MlsClientImpl { let ciphersuite = verifiable_group_info.ciphersuite(); let (credential_with_key, signer) = { - let credential = BasicCredential::new(request.identity.to_vec()).unwrap(); + let credential = BasicCredential::new(request.identity.to_vec()); let signature_keypair = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); - signature_keypair.store(provider.key_store()).unwrap(); + signature_keypair.store(provider.storage()).unwrap(); let credential_with_key = CredentialWithKey { credential: credential.into(), @@ -616,7 +576,7 @@ impl MlsClient for MlsClientImpl { let exported_secret = interop_group .group .export_secret( - interop_group.crypto_provider.crypto(), + &interop_group.crypto_provider, &request.label, &request.context, request.key_length as usize, @@ -642,7 +602,13 @@ impl MlsClient for MlsClientImpl { .get_mut(request.state_id as usize) .ok_or_else(|| Status::new(Code::InvalidArgument, "unknown state_id"))?; - interop_group.group.set_aad(&request.authenticated_data); + interop_group + .group + .set_aad( + interop_group.crypto_provider.storage(), + &request.authenticated_data, + ) + .map_err(|err| tonic::Status::internal(format!("error setting aad: {err}")))?; let ciphertext = interop_group .group @@ -731,7 +697,7 @@ impl MlsClient for MlsClientImpl { let psk_id = PreSharedKeyId::new(ciphersuite, crypto_provider.rand(), external_psk) .map_err(|_| Status::internal("unable to create PreSharedKeyId from raw psk_id"))?; psk_id - .write_to_key_store(crypto_provider, ciphersuite, secret) + .store(crypto_provider, secret) .map_err(|_| Status::new(Code::Internal, "unable to store PSK"))?; Ok(()) } @@ -747,8 +713,8 @@ impl MlsClient for MlsClientImpl { .ok_or(Status::internal("Unable to retrieve pending state"))?; store( - pending_state.0.ciphersuite(), - &pending_state.5, + pending_state.0.key_package().ciphersuite(), + &pending_state.3, external_psk, &request.psk_secret, )?; @@ -807,7 +773,12 @@ impl MlsClient for MlsClientImpl { .number_of_resumption_psks(32) .wire_format_policy(interop_group.wire_format_policy) .build(); - interop_group.group.set_configuration(&mls_group_config); + interop_group + .group + .set_configuration(interop_group.crypto_provider.storage(), &mls_group_config) + .map_err(|err| { + tonic::Status::internal(format!("error setting configuration: {err}")) + })?; let (proposal, _) = interop_group .group .propose_add_member( @@ -851,7 +822,12 @@ impl MlsClient for MlsClientImpl { .use_ratchet_tree_extension(true) .wire_format_policy(interop_group.wire_format_policy) .build(); - interop_group.group.set_configuration(&mls_group_config); + interop_group + .group + .set_configuration(interop_group.crypto_provider.storage(), &mls_group_config) + .map_err(|err| { + tonic::Status::internal(format!("error setting configuration: {err}")) + })?; let (proposal, _) = interop_group .group .propose_self_update( @@ -882,7 +858,7 @@ impl MlsClient for MlsClientImpl { let request = request.get_ref(); info!(?request, "Request"); - let removed_credential = BasicCredential::new(request.removed_id.clone()).unwrap(); + let removed_credential = BasicCredential::new(request.removed_id.clone()); trace!(" for credential: {removed_credential:x?}"); let mut groups = self.groups.lock().unwrap(); @@ -900,7 +876,12 @@ impl MlsClient for MlsClientImpl { .use_ratchet_tree_extension(true) .wire_format_policy(interop_group.wire_format_policy) .build(); - interop_group.group.set_configuration(&mls_group_config); + interop_group + .group + .set_configuration(interop_group.crypto_provider.storage(), &mls_group_config) + .map_err(|err| { + tonic::Status::internal(format!("error setting configuration: {err}")) + })?; trace!(" prepared remove"); let (proposal, _) = interop_group @@ -977,7 +958,11 @@ impl MlsClient for MlsClientImpl { match processed_message.into_content() { ProcessedMessageContent::ApplicationMessage(_) => unreachable!(), ProcessedMessageContent::ProposalMessage(proposal) => { - group.store_pending_proposal(*proposal); + group + .store_pending_proposal(interop_group.crypto_provider.storage(), *proposal) + .map_err(|err| { + tonic::Status::internal(format!("error storing proposal: {err}")) + })?; } ProcessedMessageContent::ExternalJoinProposalMessage(_) => unreachable!(), ProcessedMessageContent::StagedCommitMessage(_) => unreachable!(), @@ -1009,8 +994,7 @@ impl MlsClient for MlsClientImpl { .map_err(|_| Status::internal("Unable to generate proposal by value"))? } "remove" => { - let removed_credential = - BasicCredential::new(proposal.removed_id.clone()).unwrap(); + let removed_credential = BasicCredential::new(proposal.removed_id.clone()); group .propose_remove_member_by_credential_by_value( @@ -1156,7 +1140,13 @@ impl MlsClient for MlsClientImpl { match processed_message.into_content() { ProcessedMessageContent::ApplicationMessage(_) => unreachable!(), ProcessedMessageContent::ProposalMessage(proposal) => { - group.store_pending_proposal(*proposal); + group + .store_pending_proposal(interop_group.crypto_provider.storage(), *proposal) + .map_err(|err| { + tonic::Status::internal(format!( + "error storing pending proposal: {err}" + )) + })?; } ProcessedMessageContent::ExternalJoinProposalMessage(_) => unreachable!(), ProcessedMessageContent::StagedCommitMessage(_) => unreachable!(), @@ -1258,7 +1248,7 @@ impl MlsClient for MlsClientImpl { let group_info = group .export_group_info( - interop_group.crypto_provider.crypto(), + &interop_group.crypto_provider, &interop_group.signature_keys, !request.external_tree, ) diff --git a/libcrux_crypto/Cargo.toml b/libcrux_crypto/Cargo.toml index 74c3698364..4fd77b8101 100644 --- a/libcrux_crypto/Cargo.toml +++ b/libcrux_crypto/Cargo.toml @@ -11,9 +11,7 @@ readme = "../README.md" [dependencies] getrandom = "0.2.12" -libcrux = { git = "https://github.com/cryspen/libcrux", rev = "cb55a1f2eeccd6f4aeb47f9597f94e4633c16222", features = [ - "rand", -] } +libcrux = { git = "https://github.com/cryspen/libcrux", features = ["rand"] } openmls_traits = { path = "../traits" } openmls_rust_crypto = { path = "../openmls_rust_crypto" } rand = "0.8.5" diff --git a/libcrux_crypto/src/crypto.rs b/libcrux_crypto/src/crypto.rs index cba409b814..c85d5321c6 100644 --- a/libcrux_crypto/src/crypto.rs +++ b/libcrux_crypto/src/crypto.rs @@ -28,9 +28,16 @@ impl Default for CryptoProvider { } } +impl CryptoProvider { + #[inline(always)] + fn aes_support(&self) -> bool { + libcrux::aes_ni_support() && cfg!(target_arch = "x86_64") + } +} + impl OpenMlsCrypto for CryptoProvider { fn supports(&self, ciphersuite: Ciphersuite) -> Result<(), CryptoError> { - match (ciphersuite.aead_algorithm(), libcrux::aes_ni_support()) { + match (ciphersuite.aead_algorithm(), self.aes_support()) { (AeadType::Aes128Gcm, true) | (AeadType::Aes256Gcm, true) | (AeadType::ChaCha20Poly1305, true) @@ -49,9 +56,7 @@ impl OpenMlsCrypto for CryptoProvider { match ciphersuite.hpke_aead_algorithm() { HpkeAeadType::ChaCha20Poly1305 => Ok(()), - HpkeAeadType::AesGcm128 | HpkeAeadType::AesGcm256 if libcrux::aes_ni_support() => { - Ok(()) - } + HpkeAeadType::AesGcm128 | HpkeAeadType::AesGcm256 if self.aes_support() => Ok(()), _ => Err(CryptoError::UnsupportedCiphersuite), }?; @@ -59,14 +64,18 @@ impl OpenMlsCrypto for CryptoProvider { } fn supported_ciphersuites(&self) -> Vec { - if libcrux::aes_ni_support() { + if self.aes_support() { vec![ Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, + Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519, ] } else { - vec![Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519] + vec![ + Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, + Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519, + ] } } @@ -425,6 +434,7 @@ fn hpke_kem(kem: HpkeKemType) -> libcrux::hpke::kem::KEM { HpkeKemType::DhKemP521 => libcrux::hpke::kem::KEM::DHKEM_P521_HKDF_SHA512, HpkeKemType::DhKem25519 => libcrux::hpke::kem::KEM::DHKEM_X25519_HKDF_SHA256, HpkeKemType::DhKem448 => libcrux::hpke::kem::KEM::DHKEM_X448_HKDF_SHA512, + HpkeKemType::XWingKemDraft2 => libcrux::hpke::kem::KEM::XWingDraft02, } } @@ -463,16 +473,6 @@ pub trait ReadU8: std::io::Read { } } -impl WriteU8 for W {} - -pub trait WriteU8: std::io::Write { - /// A small helper function to write a u8 to a Writer. - #[inline] - fn write_u8(&mut self, n: u8) -> std::io::Result<()> { - self.write_all(&[n]) - } -} - /// This function takes a DER encoded ECDSA signature and decodes it to the /// bytes representing the concatenated scalars. If the decoding fails, it /// will throw a `CryptoError`. diff --git a/libcrux_crypto/src/lib.rs b/libcrux_crypto/src/lib.rs index ed4bc23e32..c61a437588 100644 --- a/libcrux_crypto/src/lib.rs +++ b/libcrux_crypto/src/lib.rs @@ -12,15 +12,17 @@ pub use rand::RandProvider; pub struct Provider { crypto: crypto::CryptoProvider, rand: rand::RandProvider, - key_store: openmls_rust_crypto::MemoryKeyStore, + key_store: openmls_rust_crypto::MemoryStorage, } impl OpenMlsProvider for Provider { type CryptoProvider = CryptoProvider; - type RandProvider = RandProvider; + type StorageProvider = openmls_rust_crypto::MemoryStorage; - type KeyStoreProvider = openmls_rust_crypto::MemoryKeyStore; + fn storage(&self) -> &Self::StorageProvider { + &self.key_store + } fn crypto(&self) -> &Self::CryptoProvider { &self.crypto @@ -29,8 +31,4 @@ impl OpenMlsProvider for Provider { fn rand(&self) -> &Self::RandProvider { &self.rand } - - fn key_store(&self) -> &Self::KeyStoreProvider { - &self.key_store - } } diff --git a/memory_keystore/Cargo.toml b/memory_keystore/Cargo.toml deleted file mode 100644 index 922f9aaedb..0000000000 --- a/memory_keystore/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "openmls_memory_keystore" -authors = ["OpenMLS Authors"] -version = "0.2.0" -edition = "2021" -description = "A very basic key store for OpenMLS implementing openmls_traits." -license = "MIT" -documentation = "https://docs.rs/openmls_memory_keystore" -repository = "https://github.com/openmls/openmls/tree/main/memory_keystore" -readme = "README.md" - -[dependencies] -openmls_traits = { version = "0.2.0", path = "../traits" } -thiserror = "1.0" -serde_json = "1.0" diff --git a/memory_keystore/README.md b/memory_keystore/README.md deleted file mode 100644 index d6e60b7729..0000000000 --- a/memory_keystore/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# OpenMLS Memory Keystore - -A very basic in-memory key store implementing the `OpenMlsKeyStore` trait from `openmls_traits`. diff --git a/memory_keystore/src/lib.rs b/memory_keystore/src/lib.rs deleted file mode 100644 index d2194bbe95..0000000000 --- a/memory_keystore/src/lib.rs +++ /dev/null @@ -1,63 +0,0 @@ -use openmls_traits::key_store::{MlsEntity, OpenMlsKeyStore}; -use std::{collections::HashMap, sync::RwLock}; - -#[derive(Debug, Default)] -pub struct MemoryKeyStore { - values: RwLock, Vec>>, -} - -impl OpenMlsKeyStore for MemoryKeyStore { - /// The error type returned by the [`OpenMlsKeyStore`]. - type Error = MemoryKeyStoreError; - - /// Store a value `v` that implements the [`ToKeyStoreValue`] trait for - /// serialization for ID `k`. - /// - /// Returns an error if storing fails. - fn store(&self, k: &[u8], v: &V) -> Result<(), Self::Error> { - let value = serde_json::to_vec(v).map_err(|_| MemoryKeyStoreError::SerializationError)?; - // We unwrap here, because this is the only function claiming a write - // lock on `credential_bundles`. It only holds the lock very briefly and - // should not panic during that period. - let mut values = self.values.write().unwrap(); - values.insert(k.to_vec(), value); - Ok(()) - } - - /// Read and return a value stored for ID `k` that implements the - /// [`FromKeyStoreValue`] trait for deserialization. - /// - /// Returns [`None`] if no value is stored for `k` or reading fails. - fn read(&self, k: &[u8]) -> Option { - // We unwrap here, because the two functions claiming a write lock on - // `init_key_package_bundles` (this one and `generate_key_package_bundle`) only - // hold the lock very briefly and should not panic during that period. - let values = self.values.read().unwrap(); - if let Some(value) = values.get(k) { - serde_json::from_slice(value).ok() - } else { - None - } - } - - /// Delete a value stored for ID `k`. - /// - /// Returns an error if storing fails. - fn delete(&self, k: &[u8]) -> Result<(), Self::Error> { - // We just delete both ... - let mut values = self.values.write().unwrap(); - values.remove(k); - Ok(()) - } -} - -/// Errors thrown by the key store. -#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)] -pub enum MemoryKeyStoreError { - #[error("The key store does not allow storing serialized values.")] - UnsupportedValueTypeBytes, - #[error("Updating is not supported by this key store.")] - UnsupportedMethod, - #[error("Error serializing value.")] - SerializationError, -} diff --git a/memory_keystore/CHANGELOG.md b/memory_storage/CHANGELOG.md similarity index 100% rename from memory_keystore/CHANGELOG.md rename to memory_storage/CHANGELOG.md diff --git a/memory_storage/Cargo.toml b/memory_storage/Cargo.toml new file mode 100644 index 0000000000..d6387d9d7e --- /dev/null +++ b/memory_storage/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "openmls_memory_storage" +authors = ["OpenMLS Authors"] +version = "0.2.0" +edition = "2021" +description = "A very basic storage for OpenMLS implementing openmls_traits." +license = "MIT" +documentation = "https://docs.rs/openmls_memory_storage" +repository = "https://github.com/openmls/openmls/tree/main/memory_storage" +readme = "README.md" + +[dependencies] +openmls_traits = { version = "0.2.0", path = "../traits" } +thiserror = "1.0" +serde_json = "1.0" +serde = { version = "1.0", features = ["derive"] } +log = { version = "0.4" } +hex = { version = "0.4", features = ["serde"], optional = true } +base64 = { version = "0.13", optional = true } + +[features] +test-utils = ["hex", "openmls_traits/test-utils"] # Enable test utilites +persistence = ["base64"] + +[dev-dependencies] +openmls_memory_storage = { path = ".", features = ["test-utils"] } diff --git a/memory_storage/README.md b/memory_storage/README.md new file mode 100644 index 0000000000..9c1fd45954 --- /dev/null +++ b/memory_storage/README.md @@ -0,0 +1,3 @@ +# OpenMLS Memory Storage + +A very basic in-memory storage implementing the `StorageProvider` trait from `openmls_traits`. diff --git a/memory_storage/src/lib.rs b/memory_storage/src/lib.rs new file mode 100644 index 0000000000..3eadce72f9 --- /dev/null +++ b/memory_storage/src/lib.rs @@ -0,0 +1,995 @@ +use openmls_traits::storage::*; + +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, sync::RwLock}; + +/// A storage for the V_TEST version. +#[cfg(any(test, feature = "test-utils"))] +mod test_store; + +#[cfg(feature = "persistence")] +pub mod persistence; + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct MemoryStorage { + values: RwLock, Vec>>, +} + +impl MemoryStorage { + /// Internal helper to abstract write operations. + #[inline(always)] + fn write( + &self, + label: &[u8], + key: &[u8], + value: Vec, + ) -> Result<(), >::Error> { + let mut values = self.values.write().unwrap(); + let storage_key = build_key_from_vec::(label, key.to_vec()); + + #[cfg(feature = "test-utils")] + log::debug!(" write key: {}", hex::encode(&storage_key)); + log::trace!("{}", std::backtrace::Backtrace::capture()); + + values.insert(storage_key, value.to_vec()); + Ok(()) + } + + fn append( + &self, + label: &[u8], + key: &[u8], + value: Vec, + ) -> Result<(), >::Error> { + let mut values = self.values.write().unwrap(); + let storage_key = build_key_from_vec::(label, key.to_vec()); + + #[cfg(feature = "test-utils")] + log::debug!(" write key: {}", hex::encode(&storage_key)); + log::trace!("{}", std::backtrace::Backtrace::capture()); + + // fetch value from db, falling back to an empty list if doens't exist + let list_bytes = values.entry(storage_key).or_insert(b"[]".to_vec()); + + // parse old value and push new data + let mut list: Vec> = serde_json::from_slice(list_bytes)?; + list.push(value); + + // write back, reusing the old buffer + list_bytes.truncate(0); + serde_json::to_writer(list_bytes, &list)?; + + Ok(()) + } + + fn remove_item( + &self, + label: &[u8], + key: &[u8], + value: Vec, + ) -> Result<(), >::Error> { + let mut values = self.values.write().unwrap(); + let storage_key = build_key_from_vec::(label, key.to_vec()); + + #[cfg(feature = "test-utils")] + log::debug!(" write key: {}", hex::encode(&storage_key)); + log::trace!("{}", std::backtrace::Backtrace::capture()); + + // fetch value from db, falling back to an empty list if doens't exist + let list_bytes = values.entry(storage_key).or_insert(b"[]".to_vec()); + + // parse old value, find value to delete and remove it from list + let mut list: Vec> = serde_json::from_slice(list_bytes)?; + if let Some(pos) = list.iter().position(|stored_item| stored_item == &value) { + list.remove(pos); + } + + // write back, reusing the old buffer + list_bytes.truncate(0); + serde_json::to_writer(list_bytes, &list)?; + + Ok(()) + } + + /// Internal helper to abstract read operations. + #[inline(always)] + fn read>( + &self, + label: &[u8], + key: &[u8], + ) -> Result, >::Error> { + let values = self.values.read().unwrap(); + let storage_key = build_key_from_vec::(label, key.to_vec()); + + #[cfg(feature = "test-utils")] + log::debug!(" read key: {}", hex::encode(&storage_key)); + log::trace!("{}", std::backtrace::Backtrace::capture()); + + let value = values.get(&storage_key); + + if let Some(value) = value { + serde_json::from_slice(value) + .map_err(|_| MemoryStorageError::SerializationError) + .map(|v| Some(v)) + } else { + Ok(None) + } + } + + /// Internal helper to abstract read operations. + #[inline(always)] + fn read_list>( + &self, + label: &[u8], + key: &[u8], + ) -> Result, >::Error> { + let values = self.values.read().unwrap(); + + let mut storage_key = label.to_vec(); + storage_key.extend_from_slice(key); + storage_key.extend_from_slice(&u16::to_be_bytes(VERSION)); + + #[cfg(feature = "test-utils")] + log::debug!(" read list key: {}", hex::encode(&storage_key)); + log::trace!("{}", std::backtrace::Backtrace::capture()); + + let value: Vec> = match values.get(&storage_key) { + Some(list_bytes) => { + println!("{}", String::from_utf8(list_bytes.to_vec()).unwrap()); + serde_json::from_slice(list_bytes).unwrap() + } + None => vec![], + }; + + value + .iter() + .map(|value_bytes| serde_json::from_slice(value_bytes)) + .collect::, _>>() + .map_err(|_| MemoryStorageError::SerializationError) + } + + /// Internal helper to abstract delete operations. + #[inline(always)] + fn delete( + &self, + label: &[u8], + key: &[u8], + ) -> Result<(), >::Error> { + let mut values = self.values.write().unwrap(); + + let mut storage_key = label.to_vec(); + storage_key.extend_from_slice(key); + storage_key.extend_from_slice(&u16::to_be_bytes(VERSION)); + + #[cfg(feature = "test-utils")] + log::debug!(" delete key: {}", hex::encode(&storage_key)); + log::trace!("{}", std::backtrace::Backtrace::capture()); + + values.remove(&storage_key); + + Ok(()) + } +} + +/// Errors thrown by the key store. +#[derive(thiserror::Error, Debug, Copy, Clone, PartialEq, Eq)] +pub enum MemoryStorageError { + #[error("The key store does not allow storing serialized values.")] + UnsupportedValueTypeBytes, + #[error("Updating is not supported by this key store.")] + UnsupportedMethod, + #[error("Error serializing value.")] + SerializationError, + #[error("Value does not exist.")] + None, +} + +const KEY_PACKAGE_LABEL: &[u8] = b"KeyPackage"; +const PSK_LABEL: &[u8] = b"Psk"; +const ENCRYPTION_KEY_PAIR_LABEL: &[u8] = b"EncryptionKeyPair"; +const SIGNATURE_KEY_PAIR_LABEL: &[u8] = b"SignatureKeyPair"; +const EPOCH_KEY_PAIRS_LABEL: &[u8] = b"EpochKeyPairs"; + +// related to PublicGroup +const TREE_LABEL: &[u8] = b"Tree"; +const GROUP_CONTEXT_LABEL: &[u8] = b"GroupContext"; +const INTERIM_TRANSCRIPT_HASH_LABEL: &[u8] = b"InterimTranscriptHash"; +const CONFIRMATION_TAG_LABEL: &[u8] = b"ConfirmationTag"; + +// related to CoreGroup +const OWN_LEAF_NODE_INDEX_LABEL: &[u8] = b"OwnLeafNodeIndex"; +const EPOCH_SECRETS_LABEL: &[u8] = b"EpochSecrets"; +const RESUMPTION_PSK_STORE_LABEL: &[u8] = b"ResumptionPsk"; +const MESSAGE_SECRETS_LABEL: &[u8] = b"MessageSecrets"; +const USE_RATCHET_TREE_LABEL: &[u8] = b"UseRatchetTree"; + +// related to MlsGroup +const JOIN_CONFIG_LABEL: &[u8] = b"MlsGroupJoinConfig"; +const OWN_LEAF_NODES_LABEL: &[u8] = b"OwnLeafNodes"; +const AAD_LABEL: &[u8] = b"AAD"; +const GROUP_STATE_LABEL: &[u8] = b"GroupState"; +const QUEUED_PROPOSAL_LABEL: &[u8] = b"QueuedProposal"; +const PROPOSAL_QUEUE_REFS_LABEL: &[u8] = b"ProposalQueueRefs"; + +impl StorageProvider for MemoryStorage { + type Error = MemoryStorageError; + + fn queue_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + group_id: &GroupId, + proposal_ref: &ProposalRef, + proposal: &QueuedProposal, + ) -> Result<(), Self::Error> { + // write proposal to key (group_id, proposal_ref) + let key = serde_json::to_vec(&(group_id, proposal_ref))?; + let value = serde_json::to_vec(proposal)?; + self.write::(QUEUED_PROPOSAL_LABEL, &key, value)?; + + // update proposal list for group_id + let key = serde_json::to_vec(group_id)?; + let value = serde_json::to_vec(proposal_ref)?; + self.append::(PROPOSAL_QUEUE_REFS_LABEL, &key, value)?; + + Ok(()) + } + + fn write_tree< + GroupId: traits::GroupId, + TreeSync: traits::TreeSync, + >( + &self, + group_id: &GroupId, + tree: &TreeSync, + ) -> Result<(), Self::Error> { + self.write::( + TREE_LABEL, + &serde_json::to_vec(&group_id).unwrap(), + serde_json::to_vec(&tree).unwrap(), + ) + } + + fn write_interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + group_id: &GroupId, + interim_transcript_hash: &InterimTranscriptHash, + ) -> Result<(), Self::Error> { + let mut values = self.values.write().unwrap(); + let key = build_key::(INTERIM_TRANSCRIPT_HASH_LABEL, group_id); + let value = serde_json::to_vec(&interim_transcript_hash).unwrap(); + + values.insert(key, value); + Ok(()) + } + + fn write_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + group_id: &GroupId, + group_context: &GroupContext, + ) -> Result<(), Self::Error> { + let mut values = self.values.write().unwrap(); + let key = build_key::(GROUP_CONTEXT_LABEL, group_id); + let value = serde_json::to_vec(&group_context).unwrap(); + + values.insert(key, value); + Ok(()) + } + + fn write_confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + group_id: &GroupId, + confirmation_tag: &ConfirmationTag, + ) -> Result<(), Self::Error> { + let mut values = self.values.write().unwrap(); + let key = build_key::(CONFIRMATION_TAG_LABEL, group_id); + let value = serde_json::to_vec(&confirmation_tag).unwrap(); + + values.insert(key, value); + Ok(()) + } + + fn write_signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + public_key: &SignaturePublicKey, + signature_key_pair: &SignatureKeyPair, + ) -> Result<(), Self::Error> { + let mut values = self.values.write().unwrap(); + let key = + build_key::(SIGNATURE_KEY_PAIR_LABEL, public_key); + let value = serde_json::to_vec(&signature_key_pair).unwrap(); + + values.insert(key, value); + Ok(()) + } + + fn queued_proposal_refs< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &serde_json::to_vec(group_id)?) + } + + fn queued_proposals< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let refs: Vec = + self.read_list(PROPOSAL_QUEUE_REFS_LABEL, &serde_json::to_vec(group_id)?)?; + + refs.into_iter() + .map(|proposal_ref| -> Result<_, _> { + let key = (group_id, &proposal_ref); + let key = serde_json::to_vec(&key)?; + + let proposal = self.read(QUEUED_PROPOSAL_LABEL, &key)?.unwrap(); + Ok((proposal_ref, proposal)) + }) + .collect::, _>>() + } + + fn treesync< + GroupId: traits::GroupId, + TreeSync: traits::TreeSync, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let values = self.values.read().unwrap(); + let key = build_key::(TREE_LABEL, group_id); + + let value = values.get(&key).unwrap(); + let value = serde_json::from_slice(value).unwrap(); + + Ok(value) + } + + fn group_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let values = self.values.read().unwrap(); + let key = build_key::(GROUP_CONTEXT_LABEL, group_id); + + let value = values.get(&key).unwrap(); + let value = serde_json::from_slice(value).unwrap(); + + Ok(value) + } + + fn interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let values = self.values.read().unwrap(); + let key = build_key::(INTERIM_TRANSCRIPT_HASH_LABEL, group_id); + + let value = values.get(&key).unwrap(); + let value = serde_json::from_slice(value).unwrap(); + + Ok(value) + } + + fn confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let values = self.values.read().unwrap(); + let key = build_key::(CONFIRMATION_TAG_LABEL, group_id); + + let value = values.get(&key).unwrap(); + let value = serde_json::from_slice(value).unwrap(); + + Ok(value) + } + + fn signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + public_key: &SignaturePublicKey, + ) -> Result, Self::Error> { + let values = self.values.read().unwrap(); + + let key = + build_key::(SIGNATURE_KEY_PAIR_LABEL, public_key); + + let value = values.get(&key).unwrap(); + let value = serde_json::from_slice(value).unwrap(); + + Ok(value) + } + + fn write_key_package< + HashReference: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &HashReference, + key_package: &KeyPackage, + ) -> Result<(), Self::Error> { + let key = serde_json::to_vec(&hash_ref).unwrap(); + let value = serde_json::to_vec(&key_package).unwrap(); + + self.write::(KEY_PACKAGE_LABEL, &key, value) + .unwrap(); + + Ok(()) + } + + fn write_psk< + PskId: traits::PskId, + PskBundle: traits::PskBundle, + >( + &self, + psk_id: &PskId, + psk: &PskBundle, + ) -> Result<(), Self::Error> { + self.write::( + PSK_LABEL, + &serde_json::to_vec(&psk_id).unwrap(), + serde_json::to_vec(&psk).unwrap(), + ) + } + + fn write_encryption_key_pair< + EncryptionKey: traits::EncryptionKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + public_key: &EncryptionKey, + key_pair: &HpkeKeyPair, + ) -> Result<(), Self::Error> { + self.write::( + ENCRYPTION_KEY_PAIR_LABEL, + &serde_json::to_vec(public_key).unwrap(), + serde_json::to_vec(key_pair).unwrap(), + ) + } + + fn key_package< + KeyPackageRef: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &KeyPackageRef, + ) -> Result, Self::Error> { + let key = serde_json::to_vec(&hash_ref).unwrap(); + self.read(KEY_PACKAGE_LABEL, &key) + } + + fn psk, PskId: traits::PskId>( + &self, + psk_id: &PskId, + ) -> Result, Self::Error> { + self.read(PSK_LABEL, &serde_json::to_vec(&psk_id).unwrap()) + } + + fn encryption_key_pair< + HpkeKeyPair: traits::HpkeKeyPair, + EncryptionKey: traits::EncryptionKey, + >( + &self, + public_key: &EncryptionKey, + ) -> Result, Self::Error> { + self.read( + ENCRYPTION_KEY_PAIR_LABEL, + &serde_json::to_vec(public_key).unwrap(), + ) + } + + fn delete_signature_key_pair< + SignaturePublicKeuy: traits::SignaturePublicKey, + >( + &self, + public_key: &SignaturePublicKeuy, + ) -> Result<(), Self::Error> { + self.delete::( + SIGNATURE_KEY_PAIR_LABEL, + &serde_json::to_vec(public_key).unwrap(), + ) + } + + fn delete_encryption_key_pair>( + &self, + public_key: &EncryptionKey, + ) -> Result<(), Self::Error> { + self.delete::( + ENCRYPTION_KEY_PAIR_LABEL, + &serde_json::to_vec(&public_key).unwrap(), + ) + } + + fn delete_key_package>( + &self, + hash_ref: &KeyPackageRef, + ) -> Result<(), Self::Error> { + self.delete::(KEY_PACKAGE_LABEL, &serde_json::to_vec(&hash_ref)?) + } + + fn delete_psk>( + &self, + psk_id: &PskKey, + ) -> Result<(), Self::Error> { + self.delete::(PSK_LABEL, &serde_json::to_vec(&psk_id)?) + } + + fn group_state< + GroupState: traits::GroupState, + GroupId: traits::GroupId, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(GROUP_STATE_LABEL, &serde_json::to_vec(&group_id)?) + } + + fn write_group_state< + GroupState: traits::GroupState, + GroupId: traits::GroupId, + >( + &self, + group_id: &GroupId, + group_state: &GroupState, + ) -> Result<(), Self::Error> { + self.write::( + GROUP_STATE_LABEL, + &serde_json::to_vec(group_id)?, + serde_json::to_vec(group_state)?, + ) + } + + fn delete_group_state>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(GROUP_STATE_LABEL, &serde_json::to_vec(group_id)?) + } + + fn message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(MESSAGE_SECRETS_LABEL, &serde_json::to_vec(group_id)?) + } + + fn write_message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + group_id: &GroupId, + message_secrets: &MessageSecrets, + ) -> Result<(), Self::Error> { + self.write::( + MESSAGE_SECRETS_LABEL, + &serde_json::to_vec(group_id)?, + serde_json::to_vec(message_secrets)?, + ) + } + + fn delete_message_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(MESSAGE_SECRETS_LABEL, &serde_json::to_vec(group_id)?) + } + + fn resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(RESUMPTION_PSK_STORE_LABEL, &serde_json::to_vec(group_id)?) + } + + fn write_resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + group_id: &GroupId, + resumption_psk_store: &ResumptionPskStore, + ) -> Result<(), Self::Error> { + self.write::( + RESUMPTION_PSK_STORE_LABEL, + &serde_json::to_vec(group_id)?, + serde_json::to_vec(resumption_psk_store)?, + ) + } + + fn delete_all_resumption_psk_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(RESUMPTION_PSK_STORE_LABEL, &serde_json::to_vec(group_id)?) + } + + fn own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(OWN_LEAF_NODE_INDEX_LABEL, &serde_json::to_vec(group_id)?) + } + + fn write_own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + group_id: &GroupId, + own_leaf_index: &LeafNodeIndex, + ) -> Result<(), Self::Error> { + self.write::( + OWN_LEAF_NODE_INDEX_LABEL, + &serde_json::to_vec(group_id)?, + serde_json::to_vec(own_leaf_index)?, + ) + } + + fn delete_own_leaf_index>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(OWN_LEAF_NODE_INDEX_LABEL, &serde_json::to_vec(group_id)?) + } + + fn use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(USE_RATCHET_TREE_LABEL, &serde_json::to_vec(group_id)?) + } + + fn set_use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + value: bool, + ) -> Result<(), Self::Error> { + self.write::( + USE_RATCHET_TREE_LABEL, + &serde_json::to_vec(group_id)?, + serde_json::to_vec(&value)?, + ) + } + + fn delete_use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(USE_RATCHET_TREE_LABEL, &serde_json::to_vec(group_id)?) + } + + fn group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(EPOCH_SECRETS_LABEL, &serde_json::to_vec(group_id)?) + } + + fn write_group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + group_id: &GroupId, + group_epoch_secrets: &GroupEpochSecrets, + ) -> Result<(), Self::Error> { + self.write::( + EPOCH_SECRETS_LABEL, + &serde_json::to_vec(group_id)?, + serde_json::to_vec(group_epoch_secrets)?, + ) + } + + fn delete_group_epoch_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(EPOCH_SECRETS_LABEL, &serde_json::to_vec(group_id)?) + } + + fn write_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + key_pairs: &[HpkeKeyPair], + ) -> Result<(), Self::Error> { + let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?; + let value = serde_json::to_vec(key_pairs)?; + log::debug!("Writing encryption epoch key pairs"); + #[cfg(feature = "test-utils")] + { + log::debug!(" key: {}", hex::encode(&key)); + log::debug!(" value: {}", hex::encode(&value)); + } + + self.write::(EPOCH_KEY_PAIRS_LABEL, &key, value) + } + + fn encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + ) -> Result, Self::Error> { + let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?; + let storage_key = build_key_from_vec::(EPOCH_KEY_PAIRS_LABEL, key); + log::debug!("Reading encryption epoch key pairs"); + + let values = self.values.read().unwrap(); + let value = values.get(&storage_key); + + #[cfg(feature = "test-utils")] + log::debug!(" key: {}", hex::encode(&storage_key)); + + if let Some(value) = value { + #[cfg(feature = "test-utils")] + log::debug!(" value: {}", hex::encode(value)); + return Ok(serde_json::from_slice(value).unwrap()); + } + + Err(MemoryStorageError::None) + } + + fn delete_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + ) -> Result<(), Self::Error> { + let key = epoch_key_pairs_id(group_id, epoch, leaf_index)?; + self.delete::(EPOCH_KEY_PAIRS_LABEL, &key) + } + + fn clear_proposal_queue< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let mut values = self.values.write().unwrap(); + + let key = build_key::(QUEUED_PROPOSAL_LABEL, group_id); + + // XXX #1566: also remove the proposal refs. can't be done now because they are stored in a + // non-recoverable way + values.remove(&key); + + Ok(()) + } + + fn mls_group_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read(JOIN_CONFIG_LABEL, &serde_json::to_vec(group_id).unwrap()) + } + + fn write_mls_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + group_id: &GroupId, + config: &MlsGroupJoinConfig, + ) -> Result<(), Self::Error> { + let key = serde_json::to_vec(group_id).unwrap(); + let value = serde_json::to_vec(config).unwrap(); + + self.write::(JOIN_CONFIG_LABEL, &key, value) + } + + fn own_leaf_nodes< + GroupId: traits::GroupId, + LeafNode: traits::LeafNode, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + self.read_list(OWN_LEAF_NODES_LABEL, &serde_json::to_vec(group_id).unwrap()) + } + + fn append_own_leaf_node< + GroupId: traits::GroupId, + LeafNode: traits::LeafNode, + >( + &self, + group_id: &GroupId, + leaf_node: &LeafNode, + ) -> Result<(), Self::Error> { + let key = serde_json::to_vec(group_id)?; + let value = serde_json::to_vec(leaf_node)?; + self.append::(OWN_LEAF_NODES_LABEL, &key, value) + } + + fn clear_own_leaf_nodes>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + let key = serde_json::to_vec(group_id)?; + self.delete::(OWN_LEAF_NODES_LABEL, &key) + } + + fn aad>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error> { + let key = serde_json::to_vec(group_id)?; + self.read::>(AAD_LABEL, &key) + .map(|v| { + // When we didn't find the value, we return an empty vector as + // required by the trait. + v.unwrap_or_default() + }) + } + + fn write_aad>( + &self, + group_id: &GroupId, + aad: &[u8], + ) -> Result<(), Self::Error> { + let key = serde_json::to_vec(group_id)?; + self.write::(AAD_LABEL, &key, serde_json::to_vec(aad).unwrap()) + } + + fn delete_aad>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(AAD_LABEL, &serde_json::to_vec(group_id).unwrap()) + } + + fn delete_own_leaf_nodes>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(OWN_LEAF_NODES_LABEL, &serde_json::to_vec(group_id).unwrap()) + } + + fn delete_group_config>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(JOIN_CONFIG_LABEL, &serde_json::to_vec(group_id).unwrap()) + } + + fn delete_tree>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(TREE_LABEL, &serde_json::to_vec(group_id).unwrap()) + } + + fn delete_confirmation_tag>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::( + CONFIRMATION_TAG_LABEL, + &serde_json::to_vec(group_id).unwrap(), + ) + } + + fn delete_context>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::(GROUP_CONTEXT_LABEL, &serde_json::to_vec(group_id).unwrap()) + } + + fn delete_interim_transcript_hash>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error> { + self.delete::( + INTERIM_TRANSCRIPT_HASH_LABEL, + &serde_json::to_vec(group_id).unwrap(), + ) + } + + fn remove_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + proposal_ref: &ProposalRef, + ) -> Result<(), Self::Error> { + let key = serde_json::to_vec(group_id).unwrap(); + let value = serde_json::to_vec(proposal_ref).unwrap(); + + self.remove_item::(PROPOSAL_QUEUE_REFS_LABEL, &key, value)?; + + let key = serde_json::to_vec(&(group_id, proposal_ref)).unwrap(); + self.delete::(QUEUED_PROPOSAL_LABEL, &key) + } +} + +/// Build a key with version and label. +fn build_key_from_vec(label: &[u8], key: Vec) -> Vec { + let mut key_out = label.to_vec(); + key_out.extend_from_slice(&key); + key_out.extend_from_slice(&u16::to_be_bytes(V)); + key_out +} + +/// Build a key with version and label. +fn build_key(label: &[u8], key: K) -> Vec { + build_key_from_vec::(label, serde_json::to_vec(&key).unwrap()) +} + +fn epoch_key_pairs_id( + group_id: &impl traits::GroupId, + epoch: &impl traits::EpochKey, + leaf_index: u32, +) -> Result, >::Error> { + let mut key = serde_json::to_vec(group_id)?; + key.extend_from_slice(&serde_json::to_vec(epoch)?); + key.extend_from_slice(&serde_json::to_vec(&leaf_index)?); + Ok(key) +} + +impl From for MemoryStorageError { + fn from(_: serde_json::Error) -> Self { + Self::SerializationError + } +} diff --git a/memory_storage/src/persistence.rs b/memory_storage/src/persistence.rs new file mode 100644 index 0000000000..b746aba51d --- /dev/null +++ b/memory_storage/src/persistence.rs @@ -0,0 +1,76 @@ +use std::{ + collections::HashMap, + env, + fs::File, + io::{BufReader, BufWriter}, + path::PathBuf, +}; + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Default, Serialize, Deserialize)] +struct SerializableKeyStore { + values: HashMap, +} + +pub fn get_file_path(file_name: &String) -> PathBuf { + let tmp_folder = env::temp_dir(); + tmp_folder.join(file_name) +} + +impl super::MemoryStorage { + fn get_file_path(user_name: &String) -> PathBuf { + get_file_path(&("openmls_cli_".to_owned() + user_name + "_ks.json")) + } + + fn save_to_file(&self, output_file: &File) -> Result<(), String> { + let writer = BufWriter::new(output_file); + + let mut ser_ks = SerializableKeyStore::default(); + for (key, value) in &*self.values.read().unwrap() { + ser_ks + .values + .insert(base64::encode(key), base64::encode(value)); + } + + match serde_json::to_writer_pretty(writer, &ser_ks) { + Ok(()) => Ok(()), + Err(e) => Err(e.to_string()), + } + } + + pub fn save(&self, user_name: String) -> Result<(), String> { + let ks_output_path = Self::get_file_path(&user_name); + + match File::create(ks_output_path) { + Ok(output_file) => self.save_to_file(&output_file), + Err(e) => Err(e.to_string()), + } + } + + fn load_from_file(&mut self, input_file: &File) -> Result<(), String> { + // Prepare file reader. + let reader = BufReader::new(input_file); + + // Read the JSON contents of the file as an instance of `SerializableKeyStore`. + match serde_json::from_reader::, SerializableKeyStore>(reader) { + Ok(ser_ks) => { + let mut ks_map = self.values.write().unwrap(); + for (key, value) in ser_ks.values { + ks_map.insert(base64::decode(key).unwrap(), base64::decode(value).unwrap()); + } + Ok(()) + } + Err(e) => Err(e.to_string()), + } + } + + pub fn load(&mut self, user_name: String) -> Result<(), String> { + let ks_input_path = Self::get_file_path(&user_name); + + match File::open(ks_input_path) { + Ok(input_file) => self.load_from_file(&input_file), + Err(e) => Err(e.to_string()), + } + } +} diff --git a/memory_storage/src/test_store.rs b/memory_storage/src/test_store.rs new file mode 100644 index 0000000000..f8c7b31bfd --- /dev/null +++ b/memory_storage/src/test_store.rs @@ -0,0 +1,582 @@ +use super::*; +use std::io::Write; + +impl StorageProvider for MemoryStorage { + type Error = MemoryStorageError; + + fn write_encryption_key_pair< + EncryptionKey: traits::EncryptionKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + public_key: &EncryptionKey, + key_pair: &HpkeKeyPair, + ) -> Result<(), Self::Error> { + self.write::( + ENCRYPTION_KEY_PAIR_LABEL, + &serde_json::to_vec(&public_key).unwrap(), + serde_json::to_vec(&key_pair).unwrap(), + ) + } + + fn encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + ) -> Result, Self::Error> { + let mut key = vec![]; + write!( + &mut key, + "{group_id},{epoch},{leaf_index}", + group_id = serde_json::to_string(group_id).unwrap(), + epoch = serde_json::to_string(epoch).unwrap(), + ) + .unwrap(); + self.read_list(ENCRYPTION_KEY_PAIR_LABEL, &key) + } + + fn key_package< + KeyPackageRef: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &KeyPackageRef, + ) -> Result, Self::Error> { + let key = serde_json::to_vec(&hash_ref).unwrap(); + + println!("getting key package at {key:?} for version {V_TEST}"); + println!( + "the whole store when trying to get the key package: {:?}", + self.values.read().unwrap() + ); + self.read(KEY_PACKAGE_LABEL, &key) + } + + fn write_key_package< + HashReference: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &HashReference, + key_package: &KeyPackage, + ) -> Result<(), Self::Error> { + let key = serde_json::to_vec(&hash_ref).unwrap(); + println!("setting key package at {key:?} for version {V_TEST}"); + let value = serde_json::to_vec(&key_package).unwrap(); + + self.write::(KEY_PACKAGE_LABEL, &key, value) + .unwrap(); + + self.key_package::(hash_ref) + .unwrap(); + + Ok(()) + } + + fn queue_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + _group_id: &GroupId, + _proposal_ref: &ProposalRef, + _proposal: &QueuedProposal, + ) -> Result<(), Self::Error> { + todo!() + } + + fn write_tree, TreeSync: traits::TreeSync>( + &self, + _group_id: &GroupId, + _tree: &TreeSync, + ) -> Result<(), Self::Error> { + todo!() + } + + fn write_interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + _group_id: &GroupId, + _interim_transcript_hash: &InterimTranscriptHash, + ) -> Result<(), Self::Error> { + todo!() + } + + fn write_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + _group_id: &GroupId, + _group_context: &GroupContext, + ) -> Result<(), Self::Error> { + todo!() + } + + fn write_confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + _group_id: &GroupId, + _confirmation_tag: &ConfirmationTag, + ) -> Result<(), Self::Error> { + todo!() + } + + fn write_signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + _public_key: &SignaturePublicKey, + _signature_key_pair: &SignatureKeyPair, + ) -> Result<(), Self::Error> { + todo!() + } + + fn write_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + _group_id: &GroupId, + _epoch: &EpochKey, + _leaf_index: u32, + _key_pairs: &[HpkeKeyPair], + ) -> Result<(), Self::Error> { + todo!() + } + + fn write_psk, PskBundle: traits::PskBundle>( + &self, + _psk_id: &PskId, + _psk: &PskBundle, + ) -> Result<(), Self::Error> { + todo!() + } + + fn queued_proposal_refs< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn treesync, TreeSync: traits::TreeSync>( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn group_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + _public_key: &SignaturePublicKey, + ) -> Result, Self::Error> { + todo!() + } + + fn encryption_key_pair< + HpkeKeyPair: traits::HpkeKeyPair, + EncryptionKey: traits::EncryptionKey, + >( + &self, + _public_key: &EncryptionKey, + ) -> Result, Self::Error> { + todo!() + } + + fn psk, PskId: traits::PskId>( + &self, + _psk_id: &PskId, + ) -> Result, Self::Error> { + todo!() + } + + fn delete_signature_key_pair>( + &self, + _public_key: &SignaturePublicKeuy, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_encryption_key_pair>( + &self, + _public_key: &EncryptionKey, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + >( + &self, + _group_id: &GroupId, + _epoch: &EpochKey, + _leaf_index: u32, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_key_package>( + &self, + _hash_ref: &KeyPackageRef, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_psk>( + &self, + _psk_id: &PskKey, + ) -> Result<(), Self::Error> { + todo!() + } + + fn group_state, GroupId: traits::GroupId>( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn write_group_state< + GroupState: traits::GroupState, + GroupId: traits::GroupId, + >( + &self, + _group_id: &GroupId, + _group_state: &GroupState, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_group_state>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn write_message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + _group_id: &GroupId, + _message_secrets: &MessageSecrets, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_message_secrets>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn write_resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + _group_id: &GroupId, + _resumption_psk_store: &ResumptionPskStore, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_all_resumption_psk_secrets>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn write_own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + _group_id: &GroupId, + _own_leaf_index: &LeafNodeIndex, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_own_leaf_index>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn use_ratchet_tree_extension>( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn set_use_ratchet_tree_extension>( + &self, + _group_id: &GroupId, + _value: bool, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_use_ratchet_tree_extension>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn write_group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + _group_id: &GroupId, + _group_epoch_secrets: &GroupEpochSecrets, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_group_epoch_secrets>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn clear_proposal_queue< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn mls_group_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn write_mls_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + _group_id: &GroupId, + _config: &MlsGroupJoinConfig, + ) -> Result<(), Self::Error> { + todo!() + } + + fn own_leaf_nodes, LeafNode: traits::LeafNode>( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn append_own_leaf_node< + GroupId: traits::GroupId, + LeafNode: traits::LeafNode, + >( + &self, + _group_id: &GroupId, + _leaf_node: &LeafNode, + ) -> Result<(), Self::Error> { + todo!() + } + + fn clear_own_leaf_nodes>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn aad>( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn write_aad>( + &self, + _group_id: &GroupId, + _aad: &[u8], + ) -> Result<(), Self::Error> { + todo!() + } + + fn queued_proposals< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + _group_id: &GroupId, + ) -> Result, Self::Error> { + todo!() + } + + fn remove_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + _group_id: &GroupId, + _proposal_ref: &ProposalRef, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_aad>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_own_leaf_nodes>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_group_config>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_tree>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_confirmation_tag>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_context>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } + + fn delete_interim_transcript_hash>( + &self, + _group_id: &GroupId, + ) -> Result<(), Self::Error> { + todo!() + } +} diff --git a/openmls-wasm/src/lib.rs b/openmls-wasm/src/lib.rs index d91a80e447..bdd6738e6b 100644 --- a/openmls-wasm/src/lib.rs +++ b/openmls-wasm/src/lib.rs @@ -4,11 +4,10 @@ use js_sys::Uint8Array; use openmls::{ credentials::{BasicCredential, CredentialWithKey}, framing::{MlsMessageBodyIn, MlsMessageIn, MlsMessageOut}, - group::{config::CryptoConfig, GroupId, MlsGroup, MlsGroupJoinConfig, StagedWelcome}, + group::{GroupId, MlsGroup, MlsGroupJoinConfig, StagedWelcome}, key_packages::KeyPackage as OpenMlsKeyPackage, prelude::SignatureScheme, treesync::RatchetTreeIn, - versions::ProtocolVersion, }; use openmls_basic_credential::SignatureKeyPair; use openmls_rust_crypto::OpenMlsRustCrypto; @@ -29,15 +28,6 @@ extern "C" { /// The ciphersuite used here. Fixed in order to reduce the binary size. static CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519; -/// The protocol version. We only support RFC MLS. -static VERSION: ProtocolVersion = ProtocolVersion::Mls10; - -/// The config used in all calls that need a CryptoConfig, using hardcoded settings. -static CRYPTO_CONFIG: CryptoConfig = CryptoConfig { - ciphersuite: CIPHERSUITE, - version: VERSION, -}; - #[wasm_bindgen] #[derive(Default)] pub struct Provider(OpenMlsRustCrypto); @@ -48,6 +38,12 @@ impl AsRef for Provider { } } +impl AsMut for Provider { + fn as_mut(&mut self) -> &mut OpenMlsRustCrypto { + &mut self.0 + } +} + #[wasm_bindgen] impl Provider { #[wasm_bindgen(constructor)] @@ -73,10 +69,10 @@ impl Identity { pub fn new(provider: &Provider, name: &str) -> Result { let signature_scheme = SignatureScheme::ED25519; let identity = name.bytes().collect(); - let credential = BasicCredential::new(identity)?; + let credential = BasicCredential::new(identity); let keypair = SignatureKeyPair::new(signature_scheme)?; - keypair.store(provider.0.key_store())?; + keypair.store(provider.0.storage())?; let credential_with_key = CredentialWithKey { credential: credential.into(), @@ -93,12 +89,14 @@ impl Identity { KeyPackage( OpenMlsKeyPackage::builder() .build( - CRYPTO_CONFIG, + CIPHERSUITE, &provider.0, &self.keypair, self.credential_with_key.clone(), ) - .unwrap(), + .unwrap() + .key_package() + .clone(), ) } } @@ -145,7 +143,7 @@ impl Group { let group_id_bytes = group_id.bytes().collect::>(); let mls_group = MlsGroup::builder() - .crypto_config(CRYPTO_CONFIG) + .ciphersuite(CIPHERSUITE) .with_group_id(GroupId::from_slice(&group_id_bytes)) .build( &provider.0, @@ -206,15 +204,15 @@ impl Group { }) } - pub fn merge_pending_commit(&mut self, provider: &Provider) -> Result<(), JsError> { + pub fn merge_pending_commit(&mut self, provider: &mut Provider) -> Result<(), JsError> { self.mls_group - .merge_pending_commit(provider.as_ref()) + .merge_pending_commit(provider.as_mut()) .map_err(|e| e.into()) } pub fn process_message( &mut self, - provider: &Provider, + provider: &mut Provider, mut msg: &[u8], ) -> Result, JsError> { let msg = MlsMessageIn::tls_deserialize(&mut msg).unwrap(); @@ -242,7 +240,7 @@ impl Group { } openmls::framing::ProcessedMessageContent::StagedCommitMessage(staged_commit) => { self.mls_group - .merge_staged_commit(provider.as_ref(), *staged_commit)?; + .merge_staged_commit(provider.as_mut(), *staged_commit)?; Ok(vec![]) } } @@ -256,7 +254,7 @@ impl Group { key_length: usize, ) -> Result, JsError> { self.mls_group - .export_secret(provider.as_ref().crypto(), label, context, key_length) + .export_secret(provider.as_ref(), label, context, key_length) .map_err(|e| { println!("export key error: {e}"); e.into() @@ -360,7 +358,7 @@ mod tests { #[test] fn basic() { - let alice_provider = Provider::new(); + let mut alice_provider = Provider::new(); let bob_provider = Provider::new(); let alice = Identity::new(&alice_provider, "alice") @@ -380,7 +378,7 @@ mod tests { .unwrap(); chess_club_alice - .merge_pending_commit(&alice_provider) + .merge_pending_commit(&mut alice_provider) .map_err(js_error_to_string) .unwrap(); diff --git a/openmls-wasm/src/utils.rs b/openmls-wasm/src/utils.rs index b1d7929dc9..c4be847eed 100644 --- a/openmls-wasm/src/utils.rs +++ b/openmls-wasm/src/utils.rs @@ -1,3 +1,4 @@ +#[allow(dead_code)] pub fn set_panic_hook() { // When the `console_error_panic_hook` feature is enabled, we can call the // `set_panic_hook` function at least once during initialization, and then diff --git a/openmls/Cargo.toml b/openmls/Cargo.toml index 6af75f8d66..17a6daa946 100644 --- a/openmls/Cargo.toml +++ b/openmls/Cargo.toml @@ -24,16 +24,18 @@ serde_json = { version = "1.0", optional = true } # Crypto providers required for KAT and testing - "test-utils" feature itertools = { version = "0.10", optional = true } openmls_rust_crypto = { version = "0.2.0", path = "../openmls_rust_crypto", optional = true } -openmls_libcrux_crypto = { path = "../libcrux_crypto", optional = true } openmls_basic_credential = { version = "0.2.0", path = "../basic_credential", optional = true, features = [ "clonable", "test-utils", ] } -rstest = { version = "^0.16", optional = true } -rstest_reuse = { version = "0.4", optional = true } -wasm-bindgen-test = {version = "0.3.40", optional = true} -getrandom = {version = "0.2.12", optional = true, features = [ "js" ]} -fluvio-wasm-timer = {version = "0.2.5", optional = true} +wasm-bindgen-test = { version = "0.3.40", optional = true } +getrandom = { version = "0.2.12", optional = true, features = ["js"] } +fluvio-wasm-timer = { version = "0.2.5", optional = true } +openmls_memory_storage = { path = "../memory_storage", features = [ + "test-utils", +], optional = true } +openmls_test = { path = "../openmls_test", optional = true } +openmls_libcrux_crypto = { path = "../libcrux_crypto", optional = true } [features] default = ["backtrace"] @@ -43,33 +45,42 @@ test-utils = [ "dep:itertools", "dep:openmls_rust_crypto", "dep:rand", - "dep:rstest", - "dep:rstest_reuse", "dep:wasm-bindgen-test", "dep:openmls_basic_credential", + "dep:openmls_memory_storage", + "dep:openmls_test", +] +libcrux-provider = [ + "dep:openmls_libcrux_crypto", + "openmls_test?/libcrux-provider", ] -libcrux-provider = ["dep:openmls_libcrux_crypto",] crypto-debug = [] # ☣️ Enable logging of sensitive cryptographic information content-debug = [] # ☣️ Enable logging of sensitive message content -js = ["dep:getrandom", "dep:fluvio-wasm-timer"] # enable js randomness source for provider +js = [ + "dep:getrandom", + "dep:fluvio-wasm-timer", +] # enable js randomness source for provider [dev-dependencies] backtrace = "0.3" -criterion = {version = "^0.5", default-features = false} # need to disable default features for wasm +criterion = { version = "^0.5", default-features = false } # need to disable default features for wasm hex = { version = "0.4", features = ["serde"] } itertools = "0.10" lazy_static = "1.4" -openmls = { path = ".", features = ["test-utils"] } openmls_traits = { version = "0.2.0", path = "../traits", features = [ "test-utils", ] } pretty_env_logger = "0.5" -rstest = "^0.16" -rstest_reuse = "0.4" tempfile = "3" wasm-bindgen = "0.2.90" wasm-bindgen-test = "0.3.40" +# Disable for wasm32 and Win32 +[target.'cfg(not(any(target_arch = "wasm32", all(target_arch = "x86", target_os = "windows"))))'.dev-dependencies] +openmls = { path = ".", features = ["test-utils", "libcrux-provider"] } +[target.'cfg(any(target_arch = "wasm32", all(target_arch = "x86", target_os = "windows")))'.dev-dependencies] +openmls = { path = ".", features = ["test-utils"] } + [[bench]] name = "benchmark" harness = false diff --git a/openmls/benches/benchmark.rs b/openmls/benches/benchmark.rs index 7a3e185bbd..0d11c40bd3 100644 --- a/openmls/benches/benchmark.rs +++ b/openmls/benches/benchmark.rs @@ -4,11 +4,12 @@ extern crate openmls; extern crate rand; use criterion::Criterion; -use openmls::prelude::{config::CryptoConfig, *}; +use openmls::prelude::*; use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{crypto::OpenMlsCrypto, OpenMlsProvider}; +pub type OpenMlsRustCrypto = openmls_rust_crypto::OpenMlsRustCrypto; + fn criterion_kp_bundle(c: &mut Criterion, provider: &impl OpenMlsProvider) { for &ciphersuite in provider.crypto().supported_ciphersuites().iter() { c.bench_function( @@ -16,7 +17,7 @@ fn criterion_kp_bundle(c: &mut Criterion, provider: &impl OpenMlsProvider) { move |b| { b.iter_with_setup( || { - let credential = BasicCredential::new(vec![1, 2, 3]).unwrap(); + let credential = BasicCredential::new(vec![1, 2, 3]); let signer = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); let credential_with_key = CredentialWithKey { @@ -28,15 +29,7 @@ fn criterion_kp_bundle(c: &mut Criterion, provider: &impl OpenMlsProvider) { }, |(credential_with_key, signer)| { let _key_package = KeyPackage::builder() - .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - provider, - &signer, - credential_with_key, - ) + .build(ciphersuite, provider, &signer, credential_with_key) .expect("An unexpected error occurred."); }, ); diff --git a/openmls/src/binary_tree/array_representation/treemath.rs b/openmls/src/binary_tree/array_representation/treemath.rs index b74380e539..9e9b613eab 100644 --- a/openmls/src/binary_tree/array_representation/treemath.rs +++ b/openmls/src/binary_tree/array_representation/treemath.rs @@ -151,7 +151,7 @@ impl TreeNodeIndex { } /// Re-exported for testing. - #[cfg(any(feature = "test-utils", test))] + #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))] pub(crate) fn test_u32(&self) -> u32 { self.u32() } diff --git a/openmls/src/ciphersuite/aead.rs b/openmls/src/ciphersuite/aead.rs index 2ab76a6dde..1fc4e6ddd6 100644 --- a/openmls/src/ciphersuite/aead.rs +++ b/openmls/src/ciphersuite/aead.rs @@ -39,10 +39,10 @@ impl core::fmt::Debug for AeadNonce { impl AeadKey { /// Create an `AeadKey` from a `Secret`. TODO: This function should /// disappear when tackling issue #103. - pub(crate) fn from_secret(secret: Secret) -> Self { - log::trace!("AeadKey::from_secret with {}", secret.ciphersuite); + pub(crate) fn from_secret(secret: Secret, ciphersuite: Ciphersuite) -> Self { + log::trace!("AeadKey::from_secret with {}", ciphersuite); AeadKey { - aead_mode: secret.ciphersuite.aead_algorithm(), + aead_mode: ciphersuite.aead_algorithm(), value: secret.value, } } @@ -162,8 +162,8 @@ mod unit_tests { /// Make sure that xoring works by xoring a nonce with a reuse guard, testing if /// it has changed, xoring it again and testing that it's back in its original /// state. - #[apply(providers)] - fn test_xor(provider: &impl OpenMlsProvider) { + #[openmls_test::openmls_test] + fn test_xor() { let reuse_guard: ReuseGuard = ReuseGuard::try_from_random(provider.rand()).expect("An unexpected error occurred."); let original_nonce = AeadNonce::random(provider.rand()); diff --git a/openmls/src/ciphersuite/codec.rs b/openmls/src/ciphersuite/codec.rs index c8b0296517..857aad5928 100644 --- a/openmls/src/ciphersuite/codec.rs +++ b/openmls/src/ciphersuite/codec.rs @@ -25,8 +25,6 @@ impl Deserialize for Secret { let value = Vec::tls_deserialize(bytes)?; Ok(Secret { value: value.into(), - mls_version: ProtocolVersion::default(), - ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, }) } } diff --git a/openmls/src/ciphersuite/mac.rs b/openmls/src/ciphersuite/mac.rs index d61fb39a7a..a9ed36febd 100644 --- a/openmls/src/ciphersuite/mac.rs +++ b/openmls/src/ciphersuite/mac.rs @@ -23,15 +23,13 @@ impl Mac { /// Compute the HMAC on `salt` with key `ikm`. pub(crate) fn new( crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, salt: &Secret, ikm: &[u8], ) -> Result { Ok(Mac { mac_value: salt - .hkdf_extract( - crypto, - &Secret::from_slice(ikm, salt.mls_version, salt.ciphersuite), - )? + .hkdf_extract(crypto, ciphersuite, &Secret::from_slice(ikm))? .value .as_slice() .into(), diff --git a/openmls/src/ciphersuite/mod.rs b/openmls/src/ciphersuite/mod.rs index 3edf14fe55..99656f23b4 100644 --- a/openmls/src/ciphersuite/mod.rs +++ b/openmls/src/ciphersuite/mod.rs @@ -3,7 +3,6 @@ //! This file contains the API to interact with ciphersuites. //! See `codec.rs` and `ciphersuites.rs` for internals. -use crate::versions::ProtocolVersion; use ::tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes}; use openmls_traits::{ crypto::OpenMlsCrypto, diff --git a/openmls/src/ciphersuite/secret.rs b/openmls/src/ciphersuite/secret.rs index 8a03b0d8ee..4f4dfea7bc 100644 --- a/openmls/src/ciphersuite/secret.rs +++ b/openmls/src/ciphersuite/secret.rs @@ -12,31 +12,24 @@ use super::{kdf_label::KdfLabel, *}; /// Please update as well when changing this struct. #[derive(Clone, Serialize, Deserialize, Eq)] pub(crate) struct Secret { - pub(in crate::ciphersuite) ciphersuite: Ciphersuite, pub(in crate::ciphersuite) value: SecretVLBytes, - pub(in crate::ciphersuite) mls_version: ProtocolVersion, } impl Debug for Secret { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { let mut ds = f.debug_struct("Secret"); - ds.field("ciphersuite", &self.ciphersuite); #[cfg(feature = "crypto-debug")] - ds.field("value", &self.value); + return ds.field("value", &self.value).finish(); #[cfg(not(feature = "crypto-debug"))] - ds.field("value", &"***"); - - ds.field("mls_version", &self.mls_version).finish() + ds.field("value", &"***").finish() } } impl Default for Secret { fn default() -> Self { Self { - ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, value: Vec::new().into(), - mls_version: ProtocolVersion::default(), } } } @@ -46,23 +39,10 @@ impl PartialEq for Secret { fn eq(&self, other: &Secret) -> bool { // These values can be considered public and checked before the actual // comparison. - if self.ciphersuite != other.ciphersuite - || self.mls_version != other.mls_version - || self.value.as_slice().len() != other.value.as_slice().len() - { + if self.value.as_slice().len() != other.value.as_slice().len() { log::error!("Incompatible secrets"); - log::trace!( - " {} {} {}", - self.ciphersuite, - self.mls_version, - self.value.as_slice().len() - ); - log::trace!( - " {} {} {}", - other.ciphersuite, - other.mls_version, - other.value.as_slice().len() - ); + log::trace!(" {}", self.value.as_slice().len()); + log::trace!(" {}", other.value.as_slice().len()); return false; } equal_ct(self.value.as_slice(), other.value.as_slice()) @@ -76,43 +56,26 @@ impl Secret { pub(crate) fn random( ciphersuite: Ciphersuite, rand: &impl OpenMlsRand, - version: impl Into>, ) -> Result { - let mls_version = version.into().unwrap_or_default(); - log::trace!( - "Creating a new random secret for {:?} and {:?}", - ciphersuite, - mls_version - ); Ok(Secret { value: rand .random_vec(ciphersuite.hash_length()) .map_err(|_| CryptoError::InsufficientRandomness)? .into(), - mls_version, - ciphersuite, }) } /// Create an all zero secret. - pub(crate) fn zero(ciphersuite: Ciphersuite, mls_version: ProtocolVersion) -> Self { + pub(crate) fn zero(ciphersuite: Ciphersuite) -> Self { Self { value: vec![0u8; ciphersuite.hash_length()].into(), - mls_version, - ciphersuite, } } /// Create a new secret from a byte vector. - pub(crate) fn from_slice( - bytes: &[u8], - mls_version: ProtocolVersion, - ciphersuite: Ciphersuite, - ) -> Self { + pub(crate) fn from_slice(bytes: &[u8]) -> Self { Secret { value: bytes.into(), - mls_version, - ciphersuite, } } @@ -120,38 +83,21 @@ impl Secret { pub(crate) fn hkdf_extract<'a>( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, ikm_option: impl Into>, ) -> Result { - log::trace!("HKDF extract with {:?}", self.ciphersuite); + log::trace!("HKDF extract with"); log_crypto!(trace, " salt: {:x?}", self.value); - let zero_secret = Self::zero(self.ciphersuite, self.mls_version); + let zero_secret = Self::zero(ciphersuite); let ikm = ikm_option.into().unwrap_or(&zero_secret); log_crypto!(trace, " ikm: {:x?}", ikm.value); - // We don't return an error here to keep the error propagation from - // blowing up. If this fails, something in the library is really wrong - // and we can't recover from it. - assert!( - self.mls_version == ikm.mls_version, - "{} != {}", - self.mls_version, - ikm.mls_version - ); - assert!( - self.ciphersuite == ikm.ciphersuite, - "{} != {}", - self.ciphersuite, - ikm.ciphersuite - ); - Ok(Self { value: crypto.hkdf_extract( - self.ciphersuite.hash_algorithm(), + ciphersuite.hash_algorithm(), self.value.as_slice(), ikm.value.as_slice(), )?, - mls_version: self.mls_version, - ciphersuite: self.ciphersuite, }) } @@ -159,12 +105,13 @@ impl Secret { pub(crate) fn hkdf_expand( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, info: &[u8], okm_len: usize, ) -> Result { let key = crypto .hkdf_expand( - self.ciphersuite.hash_algorithm(), + ciphersuite.hash_algorithm(), self.value.as_slice(), info, okm_len, @@ -173,11 +120,7 @@ impl Secret { if key.as_slice().is_empty() { return Err(CryptoError::InvalidLength); } - Ok(Self { - value: key, - mls_version: self.mls_version, - ciphersuite: self.ciphersuite, - }) + Ok(Self { value: key }) } /// Expand a `Secret` to a new `Secret` of length `length` including a @@ -185,21 +128,22 @@ impl Secret { pub(crate) fn kdf_expand_label( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, label: &str, context: &[u8], length: usize, ) -> Result { - let full_label = format!("{} {}", self.mls_version, label); + let full_label = format!("MLS 1.0 {}", label); log::trace!( "KDF expand with label \"{}\" and {:?} with context {:x?}", &full_label, - self.ciphersuite, + ciphersuite, context ); let info = KdfLabel::serialized_label(context, full_label, length)?; log::trace!(" serialized info: {:x?}", info); log_crypto!(trace, " secret: {:x?}", self.value); - self.hkdf_expand(crypto, &info, length) + self.hkdf_expand(crypto, ciphersuite, &info, length) } /// Derive a new `Secret` from the this one by expanding it with the given @@ -207,41 +151,22 @@ impl Secret { pub(crate) fn derive_secret( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, label: &str, ) -> Result { log_crypto!( trace, - "derive secret from {:x?} with label {} and {:?}", + "derive secret from {:x?} with label {}", self.value, - label, - self.ciphersuite + label ); - self.kdf_expand_label(crypto, label, &[], self.ciphersuite.hash_length()) - } - - /// Update the ciphersuite and MLS version of this secret. - /// Ideally we wouldn't need this function but the way decoding works right - /// now this is the easiest for now. - pub(crate) fn config(&mut self, ciphersuite: Ciphersuite, mls_version: ProtocolVersion) { - self.ciphersuite = ciphersuite; - self.mls_version = mls_version; + self.kdf_expand_label(crypto, ciphersuite, label, &[], ciphersuite.hash_length()) } /// Returns the inner bytes of a secret pub(crate) fn as_slice(&self) -> &[u8] { self.value.as_slice() } - - /// Returns the ciphersuite of the secret - pub(crate) fn ciphersuite(&self) -> Ciphersuite { - self.ciphersuite - } - - /// Returns the version of the secret. TODO: This function should go away - /// when tackling issue #641. - pub(crate) fn version(&self) -> ProtocolVersion { - self.mls_version - } } #[cfg(any(feature = "test-utils", test))] @@ -250,8 +175,6 @@ impl From<&[u8]> for Secret { log::trace!("Secret from slice"); Secret { value: bytes.into(), - mls_version: ProtocolVersion::default(), - ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, } } } diff --git a/openmls/src/ciphersuite/signature.rs b/openmls/src/ciphersuite/signature.rs index 47d94e7b1c..c0a685e15f 100644 --- a/openmls/src/ciphersuite/signature.rs +++ b/openmls/src/ciphersuite/signature.rs @@ -128,8 +128,9 @@ pub struct OpenMlsSignaturePublicKey { pub(in crate::ciphersuite) value: VLBytes, } -#[cfg(test)] +#[cfg(any(test, feature = "test-utils"))] impl Signature { + #[cfg(test)] pub(crate) fn modify(&mut self, value: &[u8]) { self.value = value.to_vec().into(); } diff --git a/openmls/src/ciphersuite/tests.rs b/openmls/src/ciphersuite/tests.rs index 2851571a90..a1c700565a 100644 --- a/openmls/src/ciphersuite/tests.rs +++ b/openmls/src/ciphersuite/tests.rs @@ -1,7 +1,6 @@ //! Unit tests for the ciphersuites. mod test_ciphersuite; -mod test_secrets; // Test vector for basic crypto functionality mod kat_crypto_basics; diff --git a/openmls/src/ciphersuite/tests/kat_crypto_basics.rs b/openmls/src/ciphersuite/tests/kat_crypto_basics.rs index 9c42b83e47..76e900e9fa 100644 --- a/openmls/src/ciphersuite/tests/kat_crypto_basics.rs +++ b/openmls/src/ciphersuite/tests/kat_crypto_basics.rs @@ -226,7 +226,6 @@ pub fn run_test_vector( use crate::{ prelude_test::{hash_ref, hpke, OpenMlsSignaturePublicKey, Secret}, tree::secret_tree::derive_tree_secret, - versions::ProtocolVersion, }; let ciphersuite = Ciphersuite::try_from(test.cipher_suite).unwrap(); @@ -258,8 +257,14 @@ pub fn run_test_vector( let label = test.expand_with_label.label; let context = hex_to_bytes(&test.expand_with_label.context); let length = test.expand_with_label.length; - let out = Secret::from_slice(&secret, ProtocolVersion::default(), ciphersuite) - .kdf_expand_label(provider.crypto(), &label, &context, length.into()) + let out = Secret::from_slice(&secret) + .kdf_expand_label( + provider.crypto(), + ciphersuite, + &label, + &context, + length.into(), + ) .unwrap(); assert_eq!(&hex_to_bytes(&test.expand_with_label.out), out.as_slice()); @@ -269,8 +274,8 @@ pub fn run_test_vector( { let label = test.derive_secret.label; let secret = hex_to_bytes(&test.derive_secret.secret); - let out = Secret::from_slice(&secret, ProtocolVersion::default(), ciphersuite) - .derive_secret(provider.crypto(), &label) + let out = Secret::from_slice(&secret) + .derive_secret(provider.crypto(), ciphersuite, &label) .unwrap(); assert_eq!(&hex_to_bytes(&test.derive_secret.out), out.as_slice()); @@ -378,7 +383,8 @@ pub fn run_test_vector( let out = hex_to_bytes(&test.derive_tree_secret.out); let tree_secret = derive_tree_secret( - &Secret::from_slice(&secret, ProtocolVersion::Mls10, ciphersuite), + ciphersuite, + &Secret::from_slice(&secret), &label, generation, length.into(), diff --git a/openmls/src/ciphersuite/tests/test_ciphersuite.rs b/openmls/src/ciphersuite/tests/test_ciphersuite.rs index 71f99dc7bb..aa5ab6f605 100644 --- a/openmls/src/ciphersuite/tests/test_ciphersuite.rs +++ b/openmls/src/ciphersuite/tests/test_ciphersuite.rs @@ -1,18 +1,18 @@ //! Unit tests for the ciphersuites. -use openmls_rust_crypto::OpenMlsRustCrypto; + use openmls_traits::types::HpkeCiphertext; use crate::{ciphersuite::*, test_utils::*}; // Spot test to make sure hpke seal/open work. -#[apply(ciphersuites_and_providers)] -fn test_hpke_seal_open(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_hpke_seal_open() { let plaintext = &[1, 2, 3]; let kp = provider .crypto() .derive_hpke_keypair( ciphersuite.hpke_config(), - Secret::random(ciphersuite, provider.rand(), None) + Secret::random(ciphersuite, provider.rand()) .expect("Not enough randomness.") .as_slice(), ) diff --git a/openmls/src/ciphersuite/tests/test_secrets.rs b/openmls/src/ciphersuite/tests/test_secrets.rs deleted file mode 100644 index cb3fb3165d..0000000000 --- a/openmls/src/ciphersuite/tests/test_secrets.rs +++ /dev/null @@ -1,33 +0,0 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; - -use crate::{ - ciphersuite::{Ciphersuite, Secret}, - test_utils::*, - versions::ProtocolVersion, -}; - -#[apply(ciphersuites_and_providers)] -fn secret_init(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - // These two secrets must be incompatible - let default_secret = - Secret::random(ciphersuite, provider.rand(), None).expect("Not enough randomness."); - let draft_secret = Secret::random(ciphersuite, provider.rand(), ProtocolVersion::Mls10Draft11) - .expect("Not enough randomness."); - - let derived_default_secret = default_secret.derive_secret(provider.crypto(), "my_test_label"); - let derived_draft_secret = draft_secret.derive_secret(provider.crypto(), "my_test_label"); - assert_ne!(derived_default_secret, derived_draft_secret); -} - -#[should_panic] -#[apply(ciphersuites_and_providers)] -fn secret_incompatible(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - // These two secrets must be incompatible - let default_secret = - Secret::random(ciphersuite, provider.rand(), None).expect("Not enough randomness."); - let draft_secret = Secret::random(ciphersuite, provider.rand(), ProtocolVersion::Mls10Draft11) - .expect("Not enough randomness."); - - // This must panic because the two secrets have incompatible MLS versions. - let _default_extracted = default_secret.hkdf_extract(provider.crypto(), &draft_secret); -} diff --git a/openmls/src/credentials/mod.rs b/openmls/src/credentials/mod.rs index a7bd9388a0..d00294affc 100644 --- a/openmls/src/credentials/mod.rs +++ b/openmls/src/credentials/mod.rs @@ -27,7 +27,7 @@ use std::io::{Read, Write}; use serde::{Deserialize, Serialize}; use tls_codec::{ Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait, - Size, VLBytes, + Size, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes, }; #[cfg(test)] @@ -184,67 +184,23 @@ pub struct Certificate { /// }; /// } Credential; /// ``` -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[derive( + Debug, + PartialEq, + Eq, + Clone, + Serialize, + Deserialize, + TlsSize, + TlsSerialize, + TlsDeserialize, + TlsDeserializeBytes, +)] pub struct Credential { credential_type: CredentialType, serialized_credential_content: VLBytes, } -impl tls_codec::Size for Credential { - fn tls_serialized_len(&self) -> usize { - CredentialType::tls_serialized_len(&CredentialType::Basic) - + self.serialized_credential_content.as_ref().len() - } -} - -impl tls_codec::Serialize for Credential { - fn tls_serialize(&self, writer: &mut W) -> Result { - self.credential_type.tls_serialize(writer)?; - writer.write_all(self.serialized_credential_content.as_slice())?; - Ok(self.tls_serialized_len()) - } -} - -impl tls_codec::Deserialize for Credential { - fn tls_deserialize(bytes: &mut R) -> Result { - // We can not deserialize arbitrary credentials because we don't know - // their structure. While we don't care, we still need to parse it - // in order to move the reader forward and read the values in the struct - // after this credential. - - // The credential type is important, so we read that. - let credential_type = CredentialType::tls_deserialize(bytes)?; - - // Now we don't know what we get unfortunately. - // We assume that it is a variable-sized vector. This works for the - // currently specified credentials and any other credential MUST be - // encoded in a vector as well. Otherwise OpenMLS may fail later on - // or exhibit unexpected behaviour. - let (length, _) = tls_codec::vlen::read_length(bytes)?; - let mut actual_credential_content = vec![0u8; length]; - bytes.read_exact(&mut actual_credential_content)?; - - // Rebuild the credential again. - let mut serialized_credential = Vec::new(); - tls_codec::vlen::write_length(&mut serialized_credential, length)?; - serialized_credential.append(&mut actual_credential_content); - - Ok(Self { - serialized_credential_content: serialized_credential.into(), - credential_type, - }) - } -} - -impl tls_codec::DeserializeBytes for Credential { - fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> { - let mut bytes_ref = bytes; - let secret = Self::tls_deserialize(&mut bytes_ref)?; - let remainder = &bytes[secret.tls_serialized_len()..]; - Ok((secret, remainder)) - } -} - impl Credential { /// Returns the credential type. pub fn credential_type(&self) -> CredentialType { @@ -288,8 +244,7 @@ impl Credential { /// `openmls_basic_credential` crate. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct BasicCredential { - header: Vec, - identity: Vec, + identity: VLBytes, } impl BasicCredential { @@ -299,38 +254,35 @@ impl BasicCredential { /// /// Returns a [`BasicCredentialError`] if the length of the identity is too /// large to be encoded as a variable-length vector. - pub fn new(identity: Vec) -> Result { - let mut header = Vec::new(); - tls_codec::vlen::write_length(&mut header, identity.len())?; - Ok(Self { header, identity }) + pub fn new(identity: Vec) -> Self { + Self { + identity: identity.into(), + } } /// Get the identity of this basic credential as byte slice. pub fn identity(&self) -> &[u8] { - &self.identity + self.identity.as_slice() } } impl From for Credential { - fn from(mut credential: BasicCredential) -> Self { - let mut serialized_credential_content = credential.header; - serialized_credential_content.append(&mut credential.identity); + fn from(credential: BasicCredential) -> Self { Credential { - serialized_credential_content: serialized_credential_content.into(), credential_type: CredentialType::Basic, + serialized_credential_content: credential.identity, } } } -impl TryFrom<&Credential> for BasicCredential { +impl TryFrom for BasicCredential { type Error = BasicCredentialError; - fn try_from(credential: &Credential) -> Result { - match credential.credential_type() { - CredentialType::Basic => { - let identity: VLBytes = credential.deserialized().unwrap(); - Ok(BasicCredential::new(identity.into())?) - } + fn try_from(credential: Credential) -> Result { + match credential.credential_type { + CredentialType::Basic => Ok(BasicCredential::new( + credential.serialized_credential_content.into(), + )), _ => Err(errors::BasicCredentialError::WrongCredentialType), } } @@ -374,9 +326,9 @@ pub mod test_utils { identity: &[u8], signature_scheme: SignatureScheme, ) -> (CredentialWithKey, SignatureKeyPair) { - let credential = BasicCredential::new(identity.into()).unwrap(); + let credential = BasicCredential::new(identity.into()); let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap(); - signature_keys.store(provider.key_store()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); ( CredentialWithKey { @@ -390,15 +342,17 @@ pub mod test_utils { #[cfg(test)] mod unit_tests { - use tls_codec::{DeserializeBytes, Serialize}; + use tls_codec::{ + DeserializeBytes, Serialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, + }; - use super::{BasicCredential, Credential}; + use super::{BasicCredential, Credential, CredentialType}; #[test] fn basic_credential_identity_and_codec() { const IDENTITY: &str = "identity"; // Test the identity getter. - let basic_credential = BasicCredential::new(IDENTITY.into()).unwrap(); + let basic_credential = BasicCredential::new(IDENTITY.into()); assert_eq!(basic_credential.identity(), IDENTITY.as_bytes()); // Test the encoding and decoding. @@ -412,11 +366,45 @@ mod unit_tests { deserialized.serialized_content() ); - let deserialized_basic_credential = BasicCredential::try_from(&deserialized).unwrap(); + let deserialized_basic_credential = BasicCredential::try_from(deserialized).unwrap(); assert_eq!( deserialized_basic_credential.identity(), IDENTITY.as_bytes() ); assert_eq!(basic_credential, deserialized_basic_credential); } + + /// Test the [`Credential`] with a custom credential. + #[test] + fn custom_credential() { + #[derive( + Debug, Clone, PartialEq, Eq, TlsSize, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, + )] + struct CustomCredential { + custom_field1: u32, + custom_field2: Vec, + custom_field3: Option, + } + + let custom_credential = CustomCredential { + custom_field1: 42, + custom_field2: vec![1, 2, 3], + custom_field3: Some(2), + }; + + let credential = Credential::new( + CredentialType::Other(1234), + custom_credential.tls_serialize_detached().unwrap(), + ); + + let serialized = credential.tls_serialize_detached().unwrap(); + let deserialized = Credential::tls_deserialize_exact_bytes(&serialized).unwrap(); + assert_eq!(credential, deserialized); + + let deserialized_custom_credential = + CustomCredential::tls_deserialize_exact_bytes(deserialized.serialized_content()) + .unwrap(); + + assert_eq!(custom_credential, deserialized_custom_credential); + } } diff --git a/openmls/src/credentials/tests.rs b/openmls/src/credentials/tests.rs index 3f1138e160..d32cb961c6 100644 --- a/openmls/src/credentials/tests.rs +++ b/openmls/src/credentials/tests.rs @@ -6,23 +6,23 @@ use super::*; fn test_protocol_version() { use crate::versions::ProtocolVersion; let mls10_version = ProtocolVersion::Mls10; - let default_version = ProtocolVersion::default(); + let other_version = ProtocolVersion::Other(999); let mls10_e = mls10_version .tls_serialize_detached() .expect("An unexpected error occurred."); assert_eq!( - ProtocolVersion::try_from(u16::from_be_bytes(mls10_e[0..2].try_into().unwrap())).unwrap(), + ProtocolVersion::from(u16::from_be_bytes(mls10_e[0..2].try_into().unwrap())), mls10_version ); - let default_e = default_version + let default_e = other_version .tls_serialize_detached() .expect("An unexpected error occurred."); assert_eq!( - ProtocolVersion::try_from(u16::from_be_bytes(default_e[0..2].try_into().unwrap())).unwrap(), - default_version + ProtocolVersion::from(u16::from_be_bytes(default_e[0..2].try_into().unwrap())), + other_version ); assert_eq!(u16::from_be_bytes(mls10_e[0..2].try_into().unwrap()), 1); - assert_eq!(u16::from_be_bytes(default_e[0..2].try_into().unwrap()), 1); + assert_eq!(u16::from_be_bytes(default_e[0..2].try_into().unwrap()), 999); } #[test] diff --git a/openmls/src/extensions/external_pub_extension.rs b/openmls/src/extensions/external_pub_extension.rs index 9190078a67..8fb7551985 100644 --- a/openmls/src/extensions/external_pub_extension.rs +++ b/openmls/src/extensions/external_pub_extension.rs @@ -39,12 +39,12 @@ impl ExternalPubExtension { #[cfg(test)] mod test { - use openmls_rust_crypto::OpenMlsRustCrypto; + use crate::test_utils::OpenMlsRustCrypto; use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite, OpenMlsProvider}; use tls_codec::{Deserialize, Serialize}; use super::*; - use crate::{prelude_test::Secret, versions::ProtocolVersion}; + use crate::prelude_test::Secret; #[test] fn test_serialize_deserialize() { @@ -59,7 +59,6 @@ mod test { let ikm = Secret::random( Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, provider.rand(), - ProtocolVersion::default(), ) .unwrap(); let init_key = provider.crypto().derive_hpke_keypair( diff --git a/openmls/src/extensions/external_sender_extension.rs b/openmls/src/extensions/external_sender_extension.rs index dcba1442ce..d1c1fd21ac 100644 --- a/openmls/src/extensions/external_sender_extension.rs +++ b/openmls/src/extensions/external_sender_extension.rs @@ -85,19 +85,18 @@ impl SenderExtensionIndex { #[cfg(test)] mod test { use openmls_basic_credential::SignatureKeyPair; - use openmls_traits::types::Ciphersuite; use tls_codec::{Deserialize, Serialize}; use super::*; - use crate::{credentials::BasicCredential, test_utils::*}; + use crate::credentials::BasicCredential; - #[apply(ciphersuites)] - fn test_serialize_deserialize(ciphersuite: Ciphersuite) { + #[openmls_test::openmls_test] + fn test_serialize_deserialize() { let tests = { let mut external_sender_extensions = Vec::new(); for _ in 0..8 { - let credential = BasicCredential::new(b"Alice".to_vec()).unwrap(); + let credential = BasicCredential::new(b"Alice".to_vec()); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); diff --git a/openmls/src/extensions/required_capabilities.rs b/openmls/src/extensions/required_capabilities.rs index a02745cb1b..945c78a381 100644 --- a/openmls/src/extensions/required_capabilities.rs +++ b/openmls/src/extensions/required_capabilities.rs @@ -5,7 +5,7 @@ use crate::{ treesync::node::leaf_node::default_extensions, }; -use super::{Deserialize, ExtensionError, ExtensionType, Serialize}; +use super::{Deserialize, ExtensionType, Serialize}; /// # Required Capabilities Extension. /// @@ -64,18 +64,17 @@ impl RequiredCapabilitiesExtension { } /// Get a slice with the required extension types. - pub(crate) fn extension_types(&self) -> &[ExtensionType] { + pub fn extension_types(&self) -> &[ExtensionType] { self.extension_types.as_slice() } /// Get a slice with the required proposal types. - pub(crate) fn proposal_types(&self) -> &[ProposalType] { + pub fn proposal_types(&self) -> &[ProposalType] { self.proposal_types.as_slice() } /// Get a slice with the required credential types. - #[allow(unused)] - pub(crate) fn credential_types(&self) -> &[CredentialType] { + pub fn credential_types(&self) -> &[CredentialType] { self.credential_types.as_slice() } @@ -83,14 +82,4 @@ impl RequiredCapabilitiesExtension { pub(crate) fn requires_extension_type_support(&self, ext_type: ExtensionType) -> bool { self.extension_types.contains(&ext_type) || default_extensions().contains(&ext_type) } - - /// Check if all extension and proposal types are supported. - pub(crate) fn check_support(&self) -> Result<(), ExtensionError> { - for proposal in self.proposal_types() { - if !proposal.is_supported() { - return Err(ExtensionError::UnsupportedProposalType); - } - } - Ok(()) - } } diff --git a/openmls/src/extensions/test_extensions.rs b/openmls/src/extensions/test_extensions.rs index cfa223dcde..492c5bbc6b 100644 --- a/openmls/src/extensions/test_extensions.rs +++ b/openmls/src/extensions/test_extensions.rs @@ -2,25 +2,21 @@ //! Some basic unit tests for extensions //! Proper testing is done through the public APIs. -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::key_store::OpenMlsKeyStore; use tls_codec::{Deserialize, Serialize}; use super::*; use crate::{ - ciphersuite::HpkePrivateKey, credentials::*, framing::*, - group::{config::CryptoConfig, errors::*, *}, + group::{errors::*, *}, key_packages::*, messages::proposals::ProposalType, prelude::{Capabilities, RatchetTreeIn}, prelude_test::HpkePublicKey, schedule::psk::store::ResumptionPskStore, - test_utils::*, - treesync::node::encryption_keys::EncryptionKeyPair, versions::ProtocolVersion, }; +use openmls_traits::prelude::*; #[test] fn application_id() { @@ -40,8 +36,8 @@ fn application_id() { // This tests the ratchet tree extension to deliver the public ratcheting tree // in-band -#[apply(ciphersuites_and_providers)] -fn ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn ratchet_tree_extension() { // Basic group setup. let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -53,7 +49,7 @@ fn ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi test_utils::new_credential(provider, b"Bob", ciphersuite.signature_algorithm()); // Generate KeyPackages - let bob_key_package_bundle = KeyPackageBundle::new( + let bob_key_package_bundle = KeyPackageBundle::generate( provider, &bob_signature_keys, ciphersuite, @@ -68,7 +64,7 @@ fn ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi // === Alice creates a group with the ratchet tree extension === let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key.clone(), ) .with_config(config) @@ -132,7 +128,7 @@ fn ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi // === Alice creates a group without the ratchet tree extension === // Generate KeyPackages - let bob_key_package_bundle = KeyPackageBundle::new( + let bob_key_package_bundle = KeyPackageBundle::generate( provider, &bob_signature_keys, ciphersuite, @@ -146,7 +142,7 @@ fn ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key, ) .with_config(config) @@ -197,10 +193,10 @@ fn ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi .err(); // We expect an error because the ratchet tree is missing - assert_eq!( + assert!(matches!( error.expect("We expected an error"), WelcomeError::MissingRatchetTree - ); + )); } #[test] @@ -244,51 +240,8 @@ fn required_capabilities() { assert_eq!(extension_bytes, encoded); } -#[apply(ciphersuites_and_providers)] -fn test_metadata(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - // Create credentials and keys - let alice_credential_with_key_and_signer = tests::utils::generate_credential_with_key( - b"Alice".into(), - ciphersuite.signature_algorithm(), - provider, - ); - - // example metadata (opaque data -- test hex string is "1cedc0ffee") - let metadata = vec![0x1c, 0xed, 0xc0, 0xff, 0xee]; - let ext = Extension::Unknown(0xf001, UnknownExtension(metadata.clone())); - let extensions = Extensions::from_vec(vec![ext]).expect("could not build extensions struct"); - - let config = MlsGroupCreateConfig::builder() - .with_group_context_extensions(extensions) - .unwrap() - .build(); - - // === Alice creates a group with the ratchet tree extension === - let alice_group = MlsGroup::new( - provider, - &alice_credential_with_key_and_signer.signer, - &config, - alice_credential_with_key_and_signer - .credential_with_key - .clone(), - ) - .expect("failed to build group"); - - let got_metadata = alice_group - .export_group_context() - .extensions() - .find_by_type(ExtensionType::Unknown(0xf001)) - .expect("failed to read group metadata"); - - if let Extension::Unknown(0xf001, UnknownExtension(got_metadata)) = got_metadata { - assert_eq!(got_metadata, &metadata); - } else { - panic!("metadata extension has wrong extension enum variant") - } -} - -#[apply(ciphersuites_and_providers)] -fn with_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn with_group_context_extensions() { // create an extension that we can check for later let test_extension = Extension::Unknown(0xf023, UnknownExtension(vec![0xca, 0xfe])); let extensions = Extensions::single(test_extension.clone()); @@ -302,7 +255,7 @@ fn with_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenM let mls_group_create_config = MlsGroupCreateConfig::builder() .with_group_context_extensions(extensions) .expect("failed to apply extensions at group config builder") - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -323,11 +276,8 @@ fn with_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenM assert_eq!(found_test_extension, &test_extension); } -#[apply(ciphersuites_and_providers)] -fn wrong_extension_with_group_context_extensions( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test::openmls_test] +fn wrong_extension_with_group_context_extensions() { // Extension types that are known to not be allowed here: // - application id // - external pub @@ -339,8 +289,6 @@ fn wrong_extension_with_group_context_extensions( provider, ); - let crypto_config = CryptoConfig::with_default_version(ciphersuite); - // create an extension that we can check for later let test_extension = Extension::ApplicationId(ApplicationIdExtension::new(&[0xca, 0xfe])); let extensions = Extensions::single(test_extension.clone()); @@ -352,7 +300,7 @@ fn wrong_extension_with_group_context_extensions( assert_eq!(err, InvalidExtensionError::IllegalInGroupContext); let err = PublicGroup::builder( GroupId::from_slice(&[0xbe, 0xef]), - crypto_config, + ciphersuite, alice_credential_with_key_and_signer .credential_with_key .clone(), @@ -373,7 +321,7 @@ fn wrong_extension_with_group_context_extensions( let err = PublicGroup::builder( GroupId::from_slice(&[0xbe, 0xef]), - crypto_config, + ciphersuite, alice_credential_with_key_and_signer .credential_with_key .clone(), @@ -395,7 +343,7 @@ fn wrong_extension_with_group_context_extensions( let err = PublicGroup::builder( GroupId::from_slice(&[0xbe, 0xef]), - crypto_config, + ciphersuite, alice_credential_with_key_and_signer .credential_with_key .clone(), @@ -405,17 +353,16 @@ fn wrong_extension_with_group_context_extensions( assert_eq!(err, InvalidExtensionError::IllegalInGroupContext); } -#[apply(ciphersuites_and_providers)] -fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn last_resort_extension() { let last_resort = Extension::LastResort(LastResortExtension::default()); // Build a KeyPackage with a last resort extension - let credential = BasicCredential::new(b"Bob".to_vec()).unwrap(); + let credential = BasicCredential::new(b"Bob".to_vec()); let signer = openmls_basic_credential::SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); let extensions = Extensions::single(last_resort); - let crypto_config = CryptoConfig::with_default_version(ciphersuite); let capabilities = Capabilities::new( None, None, @@ -428,7 +375,7 @@ fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvid .key_package_extensions(extensions) .leaf_node_capabilities(capabilities) .build( - crypto_config, + ciphersuite, provider, &signer, CredentialWithKey { @@ -437,8 +384,9 @@ fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvid }, ) .expect("error building key package with last resort extension"); - assert!(kp.last_resort()); + assert!(kp.key_package().last_resort()); let encoded_kp = kp + .key_package() .tls_serialize_detached() .expect("error encoding key package with last resort extension"); let decoded_kp = KeyPackageIn::tls_deserialize(&mut encoded_kp.as_slice()) @@ -457,7 +405,7 @@ fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvid ); let mls_group_create_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -475,7 +423,7 @@ fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvid .add_members( provider, &alice_credential_with_key_and_signer.signer, - &[kp.clone()], + &[kp.key_package().clone()], ) .expect("An unexpected error occurred."); @@ -484,7 +432,7 @@ fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvid let welcome: MlsMessageIn = welcome.into(); let welcome = welcome.into_welcome().expect("expected a welcome"); - let mut bob_group = StagedWelcome::new_from_welcome( + let _bob_group = StagedWelcome::new_from_welcome( provider, mls_group_create_config.join_config(), welcome, @@ -494,31 +442,15 @@ fn last_resort_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvid .into_group(provider) .expect("An unexpected error occurred."); - // === Bob sends a commit == - - let (_message, _welcome, _group_info) = bob_group - .self_update(provider, &signer) - .expect("An unexpected error occurred."); - bob_group - .merge_pending_commit(provider) - .expect("An unexpected error occurred."); - - // This should not have deleted the KP or private keys from the store - let kp: Option = provider.key_store().read( - kp.hash_ref(provider.crypto()) - .expect("error hashing kp") - .as_slice(), - ); - assert!(kp.is_some()); - - let kp = kp.unwrap(); + use openmls_traits::storage::StorageProvider; - let leaf_keypair = - EncryptionKeyPair::read_from_key_store(provider, kp.leaf_node().encryption_key()); - assert!(leaf_keypair.is_some()); - - let private_key = provider - .key_store() - .read::(kp.hpke_init_key().as_slice()); - assert!(private_key.is_some()); + let _: KeyPackageBundle = provider + .storage() + .key_package( + &kp.key_package() + .hash_ref(provider.crypto()) + .expect("error hashing key package"), + ) + .expect("error retrieving key package") + .expect("key package does not exist"); } diff --git a/openmls/src/framing/codec.rs b/openmls/src/framing/codec.rs index 3d194e007f..35e21e8090 100644 --- a/openmls/src/framing/codec.rs +++ b/openmls/src/framing/codec.rs @@ -1,5 +1,5 @@ -use std::io::{Read, Write}; -use tls_codec::{Deserialize, Serialize, Size}; +use std::io::Read; +use tls_codec::{Deserialize, Size}; use crate::versions::ProtocolVersion; @@ -8,34 +8,6 @@ use super::{ private_message_in::PrivateMessageContentIn, *, }; -impl Size for PrivateMessageContent { - fn tls_serialized_len(&self) -> usize { - self.content.tls_serialized_len() + - self.auth.tls_serialized_len() + - // Note: The padding is appended as a "raw" all-zero byte slice - // with length `length_of_padding`. Thus, we only need to add - // this length here. - self.length_of_padding - } -} - -impl Serialize for PrivateMessageContent { - fn tls_serialize(&self, writer: &mut W) -> Result { - let mut written = 0; - - // The `content` field is serialized without the `content_type`, which - // is not part of the struct as per MLS spec. - written += self.content.serialize_without_type(writer)?; - - written += self.auth.tls_serialize(writer)?; - let padding = vec![0u8; self.length_of_padding]; - writer.write_all(&padding)?; - written += self.length_of_padding; - - Ok(written) - } -} - /// This function implements deserialization manually, as it requires `content_type` as additional input. pub(super) fn deserialize_ciphertext_content( bytes: &mut R, diff --git a/openmls/src/framing/errors.rs b/openmls/src/framing/errors.rs index c6a8902131..21621a27f6 100644 --- a/openmls/src/framing/errors.rs +++ b/openmls/src/framing/errors.rs @@ -35,7 +35,7 @@ pub enum MessageDecryptionError { /// Message encryption error #[derive(Error, Debug, PartialEq, Clone)] -pub(crate) enum MessageEncryptionError { +pub(crate) enum MessageEncryptionError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -45,6 +45,9 @@ pub(crate) enum MessageEncryptionError { /// See [`SecretTreeError`] for more details. #[error(transparent)] SecretTreeError(#[from] SecretTreeError), + /// Error reading from or writing to storage + #[error("Error reading from or writing to storage: {0}")] + StorageError(StorageError), } /// MlsMessage error diff --git a/openmls/src/framing/message_out.rs b/openmls/src/framing/message_out.rs index 5e0fe9e0e4..c392229b2e 100644 --- a/openmls/src/framing/message_out.rs +++ b/openmls/src/framing/message_out.rs @@ -11,7 +11,10 @@ use tls_codec::Serialize; use super::*; -use crate::{key_packages::KeyPackage, messages::group_info::GroupInfo, versions::ProtocolVersion}; +use crate::{ + key_packages::KeyPackage, messages::group_info::GroupInfo, prelude::KeyPackageBundle, + versions::ProtocolVersion, +}; #[cfg(any(feature = "test-utils", test))] use crate::messages::group_info::VerifiableGroupInfo; @@ -116,6 +119,15 @@ impl From for MlsMessageOut { } } +impl From for MlsMessageOut { + fn from(key_package: KeyPackageBundle) -> Self { + Self { + version: key_package.key_package().protocol_version(), + body: MlsMessageBodyOut::KeyPackage(key_package.key_package), + } + } +} + impl MlsMessageOut { /// Create an [`MlsMessageOut`] from a [`PrivateMessage`], as well as the /// currently used [`ProtocolVersion`]. diff --git a/openmls/src/framing/private_message.rs b/openmls/src/framing/private_message.rs index 957f4f2fc8..a4bab73314 100644 --- a/openmls/src/framing/private_message.rs +++ b/openmls/src/framing/private_message.rs @@ -1,15 +1,12 @@ -use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; +use openmls_traits::types::Ciphersuite; use std::io::Write; use tls_codec::{Serialize, Size, TlsSerialize, TlsSize}; -use super::{ - mls_auth_content::{AuthenticatedContent, FramedContentAuthData}, - mls_content::FramedContentBody, -}; +use super::mls_auth_content::AuthenticatedContent; use crate::{ binary_tree::array_representation::LeafNodeIndex, error::LibraryError, - tree::secret_tree::SecretType, + storage::OpenMlsProvider, tree::secret_tree::SecretType, }; use super::*; @@ -69,13 +66,13 @@ impl PrivateMessage { /// /// TODO #1148: Refactor theses constructors to avoid test code in main and /// to avoid validation using a special feature flag. - pub(crate) fn try_from_authenticated_content( + pub(crate) fn try_from_authenticated_content( public_message: &AuthenticatedContent, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &Provider, message_secrets: &mut MessageSecrets, padding_size: usize, - ) -> Result { + ) -> Result> { log::debug!("PrivateMessage::try_from_authenticated_content"); log::trace!(" ciphersuite: {}", ciphersuite); // Check the message has the correct wire format @@ -93,13 +90,13 @@ impl PrivateMessage { } #[cfg(any(feature = "test-utils", test))] - pub(crate) fn encrypt_without_check( + pub(crate) fn encrypt_without_check( public_message: &AuthenticatedContent, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &Provider, message_secrets: &mut MessageSecrets, padding_size: usize, - ) -> Result { + ) -> Result> { Self::encrypt_content( None, public_message, @@ -111,14 +108,14 @@ impl PrivateMessage { } #[cfg(test)] - pub(crate) fn encrypt_with_different_header( + pub(crate) fn encrypt_with_different_header( public_message: &AuthenticatedContent, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &Provider, header: MlsMessageHeader, message_secrets: &mut MessageSecrets, padding_size: usize, - ) -> Result { + ) -> Result> { Self::encrypt_content( Some(header), public_message, @@ -131,14 +128,14 @@ impl PrivateMessage { /// Internal function to encrypt content. The extra message header is only used /// for tests. Otherwise, the data from the given `AuthenticatedContent` is used. - fn encrypt_content( + fn encrypt_content( test_header: Option, public_message: &AuthenticatedContent, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &Provider, message_secrets: &mut MessageSecrets, padding_size: usize, - ) -> Result { + ) -> Result> { let sender_index = if let Some(index) = public_message.sender().as_member() { index } else { @@ -197,7 +194,7 @@ impl PrivateMessage { // Derive the sender data key from the key schedule using the ciphertext. let sender_data_key = message_secrets .sender_data_secret() - .derive_aead_key(provider.crypto(), &ciphertext) + .derive_aead_key(provider.crypto(), ciphersuite, &ciphertext) .map_err(LibraryError::unexpected_crypto_error)?; // Derive initial nonce from the key schedule using the ciphertext. let sender_data_nonce = message_secrets @@ -299,51 +296,6 @@ impl PrivateMessage { } } -// === Helper structs === - -/// PrivateMessageContent -/// -/// ```c -/// struct { -/// select (PrivateMessage.content_type) { -/// case application: -/// opaque application_data; -/// -/// case proposal: -/// Proposal proposal; -/// -/// case commit: -/// Commit commit; -/// } -/// -/// FramedContentAuthData auth; -/// opaque padding[length_of_padding]; -/// } PrivateMessageContent; -/// ``` -#[derive(Debug, Clone)] -pub(crate) struct PrivateMessageContent { - // The `content` field is serialized and deserialized manually without the - // `content_type`, which is not part of the struct as per MLS spec. See the - // implementation of `TlsSerialize` for `PrivateMessageContent`, as well as - // `deserialize_ciphertext_content`. - pub(crate) content: FramedContentBody, - pub(crate) auth: FramedContentAuthData, - /// Length of the all-zero padding. - /// - /// We do not retain any bytes here to avoid the need to - /// keep track that all of them are zero. Instead, we only - /// use `length_of_padding` to track the (theoretical) size - /// of the all-zero byte slice. - /// - /// Note, however, that we MUST make sure to (de)serialize these bytes! - /// Otherwise this mechanism would not make any sense because it would - /// not add to the ciphertext size to hide the original message length. - /// - /// Sadly, we cannot `derive(TlsSerialize, TlsDeserialize)` due to this - /// "custom" mechanism. - pub(crate) length_of_padding: usize, -} - #[derive(TlsSerialize, TlsSize)] pub(crate) struct PrivateContentAad<'a> { pub(crate) group_id: GroupId, diff --git a/openmls/src/framing/private_message_in.rs b/openmls/src/framing/private_message_in.rs index 7ef35caee4..be46e9410e 100644 --- a/openmls/src/framing/private_message_in.rs +++ b/openmls/src/framing/private_message_in.rs @@ -57,7 +57,7 @@ impl PrivateMessageIn { // Derive key from the key schedule using the ciphertext. let sender_data_key = message_secrets .sender_data_secret() - .derive_aead_key(crypto, self.ciphertext.as_slice()) + .derive_aead_key(crypto, ciphersuite, self.ciphertext.as_slice()) .map_err(LibraryError::unexpected_crypto_error)?; // Derive initial nonce from the key schedule using the ciphertext. let sender_data_nonce = message_secrets diff --git a/openmls/src/framing/public_message.rs b/openmls/src/framing/public_message.rs index 05127baf65..285ea41a32 100644 --- a/openmls/src/framing/public_message.rs +++ b/openmls/src/framing/public_message.rs @@ -138,6 +138,7 @@ impl PublicMessage { pub(crate) fn set_membership_tag( &mut self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, membership_key: &MembershipKey, serialized_context: &[u8], ) -> Result<(), LibraryError> { @@ -150,7 +151,7 @@ impl PublicMessage { ) .map_err(LibraryError::missing_bound_check)?; let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?; - let membership_tag = membership_key.tag_message(crypto, tbm_payload)?; + let membership_tag = membership_key.tag_message(crypto, ciphersuite, tbm_payload)?; self.membership_tag = Some(membership_tag); Ok(()) diff --git a/openmls/src/framing/public_message_in.rs b/openmls/src/framing/public_message_in.rs index e7a6d9715e..54aa22a748 100644 --- a/openmls/src/framing/public_message_in.rs +++ b/openmls/src/framing/public_message_in.rs @@ -13,6 +13,7 @@ use super::{ *, }; +use openmls_traits::types::Ciphersuite; use std::io::{Read, Write}; use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait}; @@ -106,6 +107,7 @@ impl PublicMessageIn { pub(crate) fn set_membership_tag( &mut self, provider: &impl openmls_traits::OpenMlsProvider, + ciphersuite: Ciphersuite, membership_key: &MembershipKey, serialized_context: &[u8], ) -> Result<(), LibraryError> { @@ -118,7 +120,8 @@ impl PublicMessageIn { ) .map_err(LibraryError::missing_bound_check)?; let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?; - let membership_tag = membership_key.tag_message(provider.crypto(), tbm_payload)?; + let membership_tag = + membership_key.tag_message(provider.crypto(), ciphersuite, tbm_payload)?; self.membership_tag = Some(membership_tag); Ok(()) @@ -131,6 +134,7 @@ impl PublicMessageIn { pub(crate) fn verify_membership( &self, crypto: &impl openmls_traits::crypto::OpenMlsCrypto, + ciphersuite: Ciphersuite, membership_key: &MembershipKey, serialized_context: &[u8], ) -> Result<(), ValidationError> { @@ -146,7 +150,8 @@ impl PublicMessageIn { ) .map_err(LibraryError::missing_bound_check)?; let tbm_payload = AuthenticatedContentTbm::new(&tbs_payload, &self.auth)?; - let expected_membership_tag = &membership_key.tag_message(crypto, tbm_payload)?; + let expected_membership_tag = + &membership_key.tag_message(crypto, ciphersuite, tbm_payload)?; // Verify the membership tag if let Some(membership_tag) = &self.membership_tag { diff --git a/openmls/src/framing/test_framing.rs b/openmls/src/framing/test_framing.rs index 08d4dcad1d..124e215686 100644 --- a/openmls/src/framing/test_framing.rs +++ b/openmls/src/framing/test_framing.rs @@ -1,10 +1,7 @@ use openmls_basic_credential::SignatureKeyPair; -use openmls_traits::{random::OpenMlsRand, types::Ciphersuite, OpenMlsProvider}; +use openmls_traits::prelude::*; +use openmls_traits::types::Ciphersuite; -use rstest::*; -use rstest_reuse::{self, *}; - -use openmls_rust_crypto::OpenMlsRustCrypto; use signable::Verifiable; use tls_codec::{Deserialize, Serialize}; @@ -20,13 +17,14 @@ use crate::{ }, key_packages::{test_key_packages::key_package, KeyPackageBundle}, schedule::psk::{store::ResumptionPskStore, PskSecret}, + storage::OpenMlsProvider, + test_utils::frankenstein::*, tree::{secret_tree::SecretTree, sender_ratchet::SenderRatchetConfiguration}, - versions::ProtocolVersion, }; /// This tests serializing/deserializing PublicMessage -#[apply(ciphersuites_and_providers)] -fn codec_plaintext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn codec_plaintext(ciphersuite: Ciphersuite, provider: &Provider) { let (_credential, signature_keys) = test_utils::new_credential(provider, b"Creator", ciphersuite.signature_algorithm()); let sender = Sender::build_member(LeafNodeIndex::new(987543210)); @@ -57,11 +55,15 @@ fn codec_plaintext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .into(); let membership_key = MembershipKey::from_secret( - Secret::random(ciphersuite, provider.rand(), None /* MLS version */) - .expect("Not enough randomness."), + Secret::random(ciphersuite, provider.rand()).expect("Not enough randomness."), ); - orig.set_membership_tag(provider.crypto(), &membership_key, &serialized_context) - .expect("Error setting membership tag."); + orig.set_membership_tag( + provider.crypto(), + ciphersuite, + &membership_key, + &serialized_context, + ) + .expect("Error setting membership tag."); let enc = orig .tls_serialize_detached() @@ -73,8 +75,8 @@ fn codec_plaintext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } /// This tests serializing/deserializing PrivateMessage -#[apply(ciphersuites_and_providers)] -fn codec_ciphertext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn codec_ciphertext() { let (_credential, signature_keys) = test_utils::new_credential(provider, b"Creator", ciphersuite.signature_algorithm()); let sender = Sender::build_member(LeafNodeIndex::new(0)); @@ -106,8 +108,8 @@ fn codec_ciphertext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let mut key_schedule = KeySchedule::init( ciphersuite, provider.crypto(), - &JoinerSecret::random(ciphersuite, provider.rand(), ProtocolVersion::default()), - PskSecret::from(Secret::zero(ciphersuite, ProtocolVersion::Mls10)), + &JoinerSecret::random(ciphersuite, provider.rand()), + PskSecret::from(Secret::zero(ciphersuite)), ) .expect("Could not create KeySchedule."); @@ -147,8 +149,8 @@ fn codec_ciphertext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } /// This tests the correctness of wire format checks -#[apply(ciphersuites_and_providers)] -fn wire_format_checks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn wire_format_checks() { let configuration = &SenderRatchetConfiguration::default(); let (plaintext, _credential, _keys) = create_content(ciphersuite, WireFormat::PrivateMessage, provider); @@ -159,16 +161,8 @@ fn wire_format_checks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) .rand() .random_vec(ciphersuite.hash_length()) .expect("An unexpected error occurred."); - let sender_encryption_secret = EncryptionSecret::from_slice( - &encryption_secret_bytes[..], - ProtocolVersion::default(), - ciphersuite, - ); - let receiver_encryption_secret = EncryptionSecret::from_slice( - &encryption_secret_bytes[..], - ProtocolVersion::default(), - ciphersuite, - ); + let sender_encryption_secret = EncryptionSecret::from_slice(&encryption_secret_bytes[..]); + let receiver_encryption_secret = EncryptionSecret::from_slice(&encryption_secret_bytes[..]); let sender_secret_tree = SecretTree::new( sender_encryption_secret, TreeSize::new(2u32), @@ -271,7 +265,7 @@ fn wire_format_checks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) message_secrets.replace_secret_tree(sender_secret_tree); // Try to encrypt an PublicMessage with the wrong wire format - assert_eq!( + assert!(matches!( PrivateMessage::try_from_authenticated_content( &plaintext, ciphersuite, @@ -281,7 +275,7 @@ fn wire_format_checks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) ) .expect_err("Could encrypt despite wrong wire format."), MessageEncryptionError::WrongWireFormat - ); + )); } fn create_content( @@ -319,8 +313,8 @@ fn create_content( (content, credential, signature_keys) } -#[apply(ciphersuites_and_providers)] -fn membership_tag(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn membership_tag() { let (_credential, signature_keys) = test_utils::new_credential(provider, b"Creator", ciphersuite.signature_algorithm()); let group_context = GroupContext::new( @@ -332,8 +326,7 @@ fn membership_tag(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { Extensions::empty(), ); let membership_key = MembershipKey::from_secret( - Secret::random(ciphersuite, provider.rand(), None /* MLS version */) - .expect("Not enough randomness."), + Secret::random(ciphersuite, provider.rand()).expect("Not enough randomness."), ); let public_message: PublicMessage = AuthenticatedContent::new_application( LeafNodeIndex::new(987543210), @@ -349,17 +342,27 @@ fn membership_tag(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let serialized_context = group_context.tls_serialize_detached().unwrap(); public_message - .set_membership_tag(provider, &membership_key, &serialized_context) + .set_membership_tag(provider, ciphersuite, &membership_key, &serialized_context) .expect("Error setting membership tag."); println!( "Membership tag error: {:?}", - public_message.verify_membership(provider.crypto(), &membership_key, &serialized_context) + public_message.verify_membership( + provider.crypto(), + ciphersuite, + &membership_key, + &serialized_context + ) ); // Verify signature & membership tag assert!(public_message - .verify_membership(provider.crypto(), &membership_key, &serialized_context) + .verify_membership( + provider.crypto(), + ciphersuite, + &membership_key, + &serialized_context + ) .is_ok()); // Change the content of the plaintext message @@ -367,31 +370,49 @@ fn membership_tag(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Expect the signature & membership tag verification to fail assert!(public_message - .verify_membership(provider.crypto(), &membership_key, &serialized_context) + .verify_membership( + provider.crypto(), + ciphersuite, + &membership_key, + &serialized_context + ) .is_err()); } -#[apply(ciphersuites_and_providers)] -fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn unknown_sender(ciphersuite: Ciphersuite, provider: &Provider) { + let _ = pretty_env_logger::try_init(); + + let alice_provider = provider; + let bob_provider = provider; + let charlie_provider = provider; + let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); let configuration = &SenderRatchetConfiguration::default(); // Define credentials with keys let (alice_credential, alice_signature_keys) = - test_utils::new_credential(provider, b"Alice", ciphersuite.signature_algorithm()); + test_utils::new_credential(alice_provider, b"Alice", ciphersuite.signature_algorithm()); let (bob_credential, bob_signature_keys) = - test_utils::new_credential(provider, b"Bob", ciphersuite.signature_algorithm()); - let (charlie_credential, charlie_signature_keys) = - test_utils::new_credential(provider, b"Charlie", ciphersuite.signature_algorithm()); + test_utils::new_credential(bob_provider, b"Bob", ciphersuite.signature_algorithm()); + let (charlie_credential, charlie_signature_keys) = test_utils::new_credential( + charlie_provider, + b"Charlie", + ciphersuite.signature_algorithm(), + ); // Generate KeyPackages - let bob_key_package_bundle = - KeyPackageBundle::new(provider, &bob_signature_keys, ciphersuite, bob_credential); + let bob_key_package_bundle = KeyPackageBundle::generate( + bob_provider, + &bob_signature_keys, + ciphersuite, + bob_credential, + ); let bob_key_package = bob_key_package_bundle.key_package(); - let charlie_key_package_bundle = KeyPackageBundle::new( - provider, + let charlie_key_package_bundle = KeyPackageBundle::generate( + charlie_provider, &charlie_signature_keys, ciphersuite, charlie_credential, @@ -400,11 +421,11 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Alice creates a group let mut group_alice = CoreGroup::builder( - GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + GroupId::random(alice_provider.rand()), + ciphersuite, alice_credential, ) - .build(provider, &alice_signature_keys) + .build(alice_provider, &alice_signature_keys) .expect("Error creating group."); // Alice adds Bob @@ -419,7 +440,7 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let mut proposal_store = ProposalStore::from_queued_proposal( QueuedProposal::from_authenticated_content_by_ref( ciphersuite, - provider.crypto(), + alice_provider.crypto(), bob_add_proposal, ) .expect("Could not create QueuedProposal."), @@ -431,11 +452,11 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .force_self_update(false) .build(); let create_commit_result = group_alice - .create_commit(params, provider, &alice_signature_keys) + .create_commit(params, alice_provider, &alice_signature_keys) .expect("Error creating Commit"); group_alice - .merge_commit(provider, create_commit_result.staged_commit) + .merge_commit(alice_provider, create_commit_result.staged_commit) .expect("error merging pending commit"); let _group_bob = StagedCoreWelcome::new_from_welcome( @@ -444,10 +465,10 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("An unexpected error occurred."), Some(group_alice.public_group().export_ratchet_tree().into()), bob_key_package_bundle, - provider, + bob_provider, ResumptionPskStore::new(1024), ) - .and_then(|staged_join| staged_join.into_core_group(provider)) + .and_then(|staged_join| staged_join.into_core_group(bob_provider)) .expect("Bob: Error creating group from Welcome"); // Alice adds Charlie @@ -464,7 +485,7 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { proposal_store.add( QueuedProposal::from_authenticated_content_by_ref( ciphersuite, - provider.crypto(), + alice_provider.crypto(), charlie_add_proposal, ) .expect("Could not create staged proposal."), @@ -476,11 +497,11 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .force_self_update(false) .build(); let create_commit_result = group_alice - .create_commit(params, provider, &alice_signature_keys) + .create_commit(params, alice_provider, &alice_signature_keys) .expect("Error creating Commit"); group_alice - .merge_commit(provider, create_commit_result.staged_commit) + .merge_commit(alice_provider, create_commit_result.staged_commit) .expect("error merging pending commit"); let mut group_charlie = StagedCoreWelcome::new_from_welcome( @@ -489,10 +510,10 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("An unexpected error occurred."), Some(group_alice.public_group().export_ratchet_tree().into()), charlie_key_package_bundle, - provider, + charlie_provider, ResumptionPskStore::new(1024), ) - .and_then(|staged_join| staged_join.into_core_group(provider)) + .and_then(|staged_join| staged_join.into_core_group(charlie_provider)) .expect("Charlie: Error creating group from Welcome"); // Alice removes Bob @@ -508,7 +529,7 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { proposal_store.add( QueuedProposal::from_authenticated_content_by_ref( ciphersuite, - provider.crypto(), + alice_provider.crypto(), bob_remove_proposal, ) .expect("Could not create staged proposal."), @@ -520,18 +541,23 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .force_self_update(false) .build(); let create_commit_result = group_alice - .create_commit(params, provider, &alice_signature_keys) + .create_commit(params, alice_provider, &alice_signature_keys) .expect("Error creating Commit"); let staged_commit = group_charlie - .read_keys_and_stage_commit(&create_commit_result.commit, &proposal_store, &[], provider) + .read_keys_and_stage_commit( + &create_commit_result.commit, + &proposal_store, + &[], + alice_provider, + ) .expect("Charlie: Could not stage Commit"); group_charlie - .merge_commit(provider, staged_commit) + .merge_commit(charlie_provider, staged_commit) .expect("error merging commit"); group_alice - .merge_commit(provider, create_commit_result.staged_commit) + .merge_commit(alice_provider, create_commit_result.staged_commit) .expect("error merging pending commit"); group_alice.print_ratchet_tree("Alice tree"); @@ -551,7 +577,7 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let enc_message = PrivateMessage::encrypt_with_different_header( &bogus_sender_message, ciphersuite, - provider, + alice_provider, MlsMessageHeader { group_id: group_alice.group_id().clone(), epoch: group_alice.context().epoch(), @@ -563,7 +589,7 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("Encryption error"); let received_message = group_charlie.decrypt_message( - provider.crypto(), + charlie_provider.crypto(), ProtocolMessage::from(PrivateMessageIn::from(enc_message)), configuration, ); @@ -575,8 +601,8 @@ fn unknown_sender(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); } -#[apply(ciphersuites_and_providers)] -fn confirmation_tag_presence(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn confirmation_tag_presence() { let (framing_parameters, group_alice, alice_signature_keys, group_bob, _, _) = setup_alice_bob_group(ciphersuite, provider); @@ -601,9 +627,9 @@ fn confirmation_tag_presence(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr assert_eq!(err, StageCommitError::ConfirmationTagMissing); } -pub(crate) fn setup_alice_bob_group( +pub(crate) fn setup_alice_bob_group( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &Provider, ) -> ( FramingParameters, CoreGroup, @@ -622,7 +648,7 @@ pub(crate) fn setup_alice_bob_group( test_utils::new_credential(provider, b"Bob", ciphersuite.signature_algorithm()); // Generate KeyPackages - let bob_key_package_bundle = KeyPackageBundle::new( + let bob_key_package_bundle = KeyPackageBundle::generate( provider, &bob_signature_keys, ciphersuite, @@ -633,7 +659,7 @@ pub(crate) fn setup_alice_bob_group( // Alice creates a group let mut group_alice = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential, ) .build(provider, &alice_signature_keys) @@ -704,16 +730,18 @@ pub(crate) fn setup_alice_bob_group( } /// Test divergent protocol versions in KeyPackages -#[apply(ciphersuites_and_providers)] -fn key_package_version(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - let (mut key_package, _, _) = key_package(ciphersuite, provider); +#[openmls_test::openmls_test] +fn key_package_version() { + let (key_package, _, _) = key_package(ciphersuite, provider); + + let mut franken_key_package = FrankenKeyPackage::from(key_package); // Set an invalid protocol version - key_package.set_version(ProtocolVersion::Mls10Draft11); + franken_key_package.payload.protocol_version = 999; - let message = MlsMessageOut { - version: ProtocolVersion::Mls10, - body: MlsMessageBodyOut::KeyPackage(key_package), + let message = FrankenMlsMessage { + version: 1, + body: FrankenMlsMessageBody::KeyPackage(franken_key_package), }; let encoded = message diff --git a/openmls/src/framing/validation.rs b/openmls/src/framing/validation.rs index 391cdc2af7..6e27b814fb 100644 --- a/openmls/src/framing/validation.rs +++ b/openmls/src/framing/validation.rs @@ -34,6 +34,7 @@ use crate::{ core_group::{proposals::QueuedProposal, staged_commit::StagedCommit}, errors::ValidationError, }, + storage::OpenMlsProvider, tree::sender_ratchet::SenderRatchetConfiguration, treesync::TreeSync, versions::ProtocolVersion, @@ -67,6 +68,7 @@ impl DecryptedMessage { message_secrets_option: impl Into>, serialized_context: Vec, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, ) -> Result { if public_message.sender().is_member() { // ValSem007 Membership tag presence @@ -80,6 +82,7 @@ impl DecryptedMessage { // ValSem008 public_message.verify_membership( crypto, + ciphersuite, message_secrets.membership_key(), message_secrets.serialized_context(), )?; @@ -268,18 +271,23 @@ impl UnverifiedMessage { /// Verify the [`UnverifiedMessage`]. Returns the [`AuthenticatedContent`] /// and the internal [`Credential`]. - pub(crate) fn verify( + pub(crate) fn verify( self, ciphersuite: Ciphersuite, - crypto: &impl OpenMlsCrypto, + provider: &Provider, protocol_version: ProtocolVersion, - ) -> Result<(AuthenticatedContent, Credential), ProcessMessageError> { + ) -> Result<(AuthenticatedContent, Credential), ProcessMessageError> + { let content: AuthenticatedContentIn = self .verifiable_content - .verify(crypto, &self.sender_pk) + .verify(provider.crypto(), &self.sender_pk) .map_err(|_| ProcessMessageError::InvalidSignature)?; - let content = - content.validate(ciphersuite, crypto, self.sender_context, protocol_version)?; + let content = content.validate( + ciphersuite, + provider.crypto(), + self.sender_context, + protocol_version, + )?; Ok((content, self.credential)) } diff --git a/openmls/src/group/config.rs b/openmls/src/group/config.rs deleted file mode 100644 index e0d3c63411..0000000000 --- a/openmls/src/group/config.rs +++ /dev/null @@ -1,40 +0,0 @@ -//! # Group Configurations -//! -//! This modules holds helper structs to group together configurations and -//! parameters. - -use openmls_traits::types::Ciphersuite; -use serde::{Deserialize, Serialize}; - -use crate::versions::ProtocolVersion; - -/// A config struct for commonly used values when performing cryptographic -/// operations. -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] -pub struct CryptoConfig { - /// The [`Ciphersuite`] used. - pub ciphersuite: Ciphersuite, - - /// The MLS [`ProtocolVersion`] used. - pub version: ProtocolVersion, -} - -impl CryptoConfig { - /// Create a new crypto config with the given ciphersuite and the default - /// protocol version. - pub fn with_default_version(ciphersuite: Ciphersuite) -> Self { - Self { - ciphersuite, - version: ProtocolVersion::default(), - } - } -} - -impl Default for CryptoConfig { - fn default() -> Self { - Self { - ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, - version: ProtocolVersion::default(), - } - } -} diff --git a/openmls/src/group/core_group/kat_passive_client.rs b/openmls/src/group/core_group/kat_passive_client.rs index 8d57b16b99..37566545ec 100644 --- a/openmls/src/group/core_group/kat_passive_client.rs +++ b/openmls/src/group/core_group/kat_passive_client.rs @@ -1,12 +1,11 @@ use log::{debug, info, warn}; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{crypto::OpenMlsCrypto, key_store::OpenMlsKeyStore, OpenMlsProvider}; +use openmls_traits::{crypto::OpenMlsCrypto, storage::StorageProvider, OpenMlsProvider}; use serde::{self, Deserialize, Serialize}; use tls_codec::{Deserialize as TlsDeserialize, Serialize as TlsSerialize}; use crate::{ framing::{MlsMessageBodyIn, MlsMessageIn, MlsMessageOut, ProcessedMessageContent}, - group::{config::CryptoConfig, *}, + group::*, key_packages::*, schedule::psk::PreSharedKeyId, test_utils::*, @@ -127,11 +126,7 @@ pub fn run_test_vector(test_vector: PassiveClientWelcomeTestVector) { .number_of_resumption_psks(16) .build(); - let mut passive_client = PassiveClient::new( - group_config, - cipher_suite, - test_vector.external_psks.clone(), - ); + let mut passive_client = PassiveClient::new(group_config, test_vector.external_psks.clone()); passive_client.inject_key_package( test_vector.key_package, @@ -206,11 +201,7 @@ struct PassiveClient { } impl PassiveClient { - fn new( - group_config: MlsGroupJoinConfig, - ciphersuite: Ciphersuite, - psks: Vec, - ) -> Self { + fn new(group_config: MlsGroupJoinConfig, psks: Vec) -> Self { let provider = OpenMlsRustCrypto::default(); // Load all PSKs into key store. @@ -219,9 +210,7 @@ impl PassiveClient { // We only construct this to easily save the PSK in the keystore. // The nonce is not saved, so it can be empty... let psk_id = PreSharedKeyId::external(psk.psk_id, vec![]); - psk_id - .write_to_key_store(&provider, ciphersuite, &psk.psk) - .unwrap(); + psk_id.store(&provider, &psk.psk).unwrap(); } Self { @@ -251,28 +240,15 @@ impl PassiveClient { let key_package_bundle = KeyPackageBundle { key_package: key_package.clone(), - private_key: init_priv, + private_init_key: init_priv, + private_encryption_key: encryption_priv.clone().into(), }; // Store key package. + let hash_ref = key_package.hash_ref(self.provider.crypto()).unwrap(); self.provider - .key_store() - .store( - key_package - .hash_ref(self.provider.crypto()) - .unwrap() - .as_slice(), - &key_package, - ) - .unwrap(); - - // Store init key. - self.provider - .key_store() - .store::( - key_package.hpke_init_key().as_slice(), - key_package_bundle.private_key(), - ) + .storage() + .write_key_package(&hash_ref, &key_package_bundle) .unwrap(); // Store encryption key @@ -281,9 +257,7 @@ impl PassiveClient { EncryptionPrivateKey::from(encryption_priv), )); - key_pair - .write_to_key_store(self.provider.key_store()) - .unwrap(); + key_pair.write(self.provider.storage()).unwrap(); } fn join_by_welcome( @@ -322,7 +296,8 @@ impl PassiveClient { self.group .as_mut() .unwrap() - .store_pending_proposal(*queued_proposal); + .store_pending_proposal(self.provider.storage(), *queued_proposal) + .unwrap(); } ProcessedMessageContent::StagedCommitMessage(staged_commit) => { self.group @@ -345,16 +320,16 @@ impl PassiveClient { } } -pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTestVector { +pub fn generate_test_vector(ciphersuite: Ciphersuite) -> PassiveClientWelcomeTestVector { let group_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(cipher_suite)) + .ciphersuite(ciphersuite) .use_ratchet_tree_extension(true) .build(); let creator_provider = OpenMlsRustCrypto::default(); let creator = - generate_group_candidate(b"Alice (Creator)", cipher_suite, &creator_provider, true); + generate_group_candidate(b"Alice (Creator)", ciphersuite, &creator_provider, true); let mut creator_group = MlsGroup::new( &creator_provider, @@ -369,7 +344,7 @@ pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTe let passive = generate_group_candidate( b"Bob (Passive Client)", - cipher_suite, + ciphersuite, &OpenMlsRustCrypto::default(), false, ); @@ -378,7 +353,7 @@ pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTe .add_members( &creator_provider, &creator.signature_keypair, - &[passive.key_package.clone()], + &[passive.key_package.key_package().clone()], ) .unwrap(); @@ -392,7 +367,7 @@ pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTe let epoch2 = { let proposals = vec![propose_add( - cipher_suite, + ciphersuite, &creator_provider, &creator, &mut creator_group, @@ -432,14 +407,14 @@ pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTe let epoch4 = { let proposals = vec![ propose_add( - cipher_suite, + ciphersuite, &creator_provider, &creator, &mut creator_group, b"Daniel", ), propose_add( - cipher_suite, + ciphersuite, &creator_provider, &creator, &mut creator_group, @@ -462,7 +437,7 @@ pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTe let proposals = vec![ propose_remove(&creator_provider, &creator, &mut creator_group, b"Daniel"), propose_add( - cipher_suite, + ciphersuite, &creator_provider, &creator, &mut creator_group, @@ -499,9 +474,11 @@ pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTe }; let epochs = vec![epoch1, epoch2, epoch3, epoch4, epoch5, epoch6]; + let init_priv = passive.key_package.init_private_key().to_vec(); + let encryption_priv = passive.key_package.encryption_private_key().to_vec(); PassiveClientWelcomeTestVector { - cipher_suite: cipher_suite.into(), + cipher_suite: ciphersuite.into(), external_psks: vec![], key_package: MlsMessageOut::from(passive.key_package) @@ -509,8 +486,8 @@ pub fn generate_test_vector(cipher_suite: Ciphersuite) -> PassiveClientWelcomeTe .unwrap(), signature_priv: passive.signature_keypair.private().to_vec(), - encryption_priv: passive.encryption_keypair.private_key().key().to_vec(), - init_priv: passive.init_keypair.private.to_vec(), + encryption_priv, + init_priv, welcome: mls_message_welcome.tls_serialize_detached().unwrap(), ratchet_tree: None, @@ -540,7 +517,7 @@ fn propose_add( .propose_add_member( provider, &candidate.signature_keypair, - &add_candidate.key_package, + add_candidate.key_package.key_package(), ) .unwrap(); group.merge_pending_commit(provider).unwrap(); @@ -556,10 +533,7 @@ fn propose_remove( ) -> TestProposal { let remove = group .members() - .find(|Member { credential, .. }| { - let identity = VLBytes::tls_deserialize_exact(credential.serialized_content()).unwrap(); - identity.as_slice() == remove_identity - }) + .find(|Member { credential, .. }| credential.serialized_content() == remove_identity) .unwrap() .index; diff --git a/openmls/src/group/core_group/kat_welcome.rs b/openmls/src/group/core_group/kat_welcome.rs index e306fd7ffb..4ed16f194e 100644 --- a/openmls/src/group/core_group/kat_welcome.rs +++ b/openmls/src/group/core_group/kat_welcome.rs @@ -19,8 +19,9 @@ //! from the key schedule epoch and the `confirmed_transcript_hash` from the //! decrypted GroupContext -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{crypto::OpenMlsCrypto, key_store::OpenMlsKeyStore, OpenMlsProvider}; +use crate::test_utils::OpenMlsRustCrypto; +use kat_welcome::core_group::node::encryption_keys::EncryptionPrivateKey; +use openmls_traits::{crypto::OpenMlsCrypto, storage::StorageProvider, OpenMlsProvider}; use serde::{self, Deserialize, Serialize}; use tls_codec::{Deserialize as TlsDeserialize, Serialize as TlsSerialize}; @@ -163,23 +164,14 @@ pub fn run_test_vector(test_vector: WelcomeTestVector) -> Result<(), &'static st let key_package_bundle = KeyPackageBundle { key_package: key_package.clone(), - private_key: init_priv, + private_init_key: init_priv, + private_encryption_key: EncryptionPrivateKey::from(vec![]), }; + let hash_ref = key_package.hash_ref(provider.crypto()).unwrap(); provider - .key_store() - .store( - key_package.hash_ref(provider.crypto()).unwrap().as_slice(), - &key_package, - ) - .unwrap(); - - provider - .key_store() - .store::( - key_package.hpke_init_key().as_slice(), - key_package_bundle.private_key(), - ) + .storage() + .write_key_package(&hash_ref, &key_package_bundle) .unwrap(); // Verification: @@ -197,7 +189,7 @@ pub fn run_test_vector(test_vector: WelcomeTestVector) -> Result<(), &'static st // // // * Decrypt the encrypted group secrets using `init_priv` let group_secrets = GroupSecrets::try_from_ciphertext( - key_package_bundle.private_key(), + key_package_bundle.init_private_key(), encrypted_group_secrets.encrypted_group_secrets(), welcome.encrypted_group_info(), welcome.ciphersuite(), @@ -210,7 +202,7 @@ pub fn run_test_vector(test_vector: WelcomeTestVector) -> Result<(), &'static st let psk_secret = { let resumption_psk_store = ResumptionPskStore::new(1024); - let psks = load_psks(provider.key_store(), &resumption_psk_store, &[]).unwrap(); + let psks = load_psks(provider.storage(), &resumption_psk_store, &[]).unwrap(); PskSecret::new(provider.crypto(), cipher_suite, psks).unwrap() }; @@ -226,9 +218,9 @@ pub fn run_test_vector(test_vector: WelcomeTestVector) -> Result<(), &'static st let group_info: GroupInfo = { let verifiable_group_info: VerifiableGroupInfo = { let (welcome_key, welcome_nonce) = key_schedule - .welcome(provider.crypto()) + .welcome(provider.crypto(), welcome.ciphersuite()) .unwrap() - .derive_welcome_key_nonce(provider.crypto()) + .derive_welcome_key_nonce(provider.crypto(), welcome.ciphersuite()) .unwrap(); VerifiableGroupInfo::try_from_ciphertext( @@ -261,7 +253,9 @@ pub fn run_test_vector(test_vector: WelcomeTestVector) -> Result<(), &'static st .unwrap(); let (_group_epoch_secrets, message_secrets) = { - let epoch_secrets = key_schedule.epoch_secrets(provider.crypto()).unwrap(); + let epoch_secrets = key_schedule + .epoch_secrets(provider.crypto(), welcome.ciphersuite()) + .unwrap(); epoch_secrets.split_secrets( serialized_group_context.to_vec(), @@ -272,7 +266,11 @@ pub fn run_test_vector(test_vector: WelcomeTestVector) -> Result<(), &'static st let confirmation_tag = message_secrets .confirmation_key() - .tag(provider.crypto(), group_context.confirmed_transcript_hash()) + .tag( + provider.crypto(), + welcome.ciphersuite(), + group_context.confirmed_transcript_hash(), + ) .unwrap(); assert_eq!(&confirmation_tag, group_info.confirmation_tag()); diff --git a/openmls/src/group/core_group/mod.rs b/openmls/src/group/core_group/mod.rs index af0dca4164..11f9f3deac 100644 --- a/openmls/src/group/core_group/mod.rs +++ b/openmls/src/group/core_group/mod.rs @@ -34,7 +34,7 @@ mod test_proposals; use log::{debug, trace}; use openmls_traits::{ - crypto::OpenMlsCrypto, key_store::OpenMlsKeyStore, signatures::Signer, types::Ciphersuite, + crypto::OpenMlsCrypto, signatures::Signer, storage::StorageProvider as _, types::Ciphersuite, }; use serde::{Deserialize, Serialize}; use tls_codec::Serialize as TlsSerializeTrait; @@ -63,7 +63,7 @@ use crate::{ error::LibraryError, extensions::errors::InvalidExtensionError, framing::{mls_auth_content::AuthenticatedContent, *}, - group::{config::CryptoConfig, *}, + group::*, key_packages::*, messages::{ group_info::{GroupInfo, GroupInfoTBS, VerifiableGroupInfo}, @@ -75,6 +75,7 @@ use crate::{ psk::{load_psks, store::ResumptionPskStore, PskSecret}, *, }, + storage::{OpenMlsProvider, StorageProvider}, tree::{secret_tree::SecretTreeError, sender_ratchet::SenderRatchetConfiguration}, treesync::{node::encryption_keys::EncryptionKeyPair, *}, versions::ProtocolVersion, @@ -150,8 +151,8 @@ pub(crate) struct StagedCoreWelcome { /// The [`VerifiableGroupInfo`] from the [`Welcome`] message. verifiable_group_info: VerifiableGroupInfo, - /// The keypair used to decrypt the [`Welcome`] message. - leaf_keypair: EncryptionKeyPair, + /// The key package bundle used for this welcome. + pub(crate) key_package_bundle: KeyPackageBundle, /// If we got a path secret, these are the derived path keys. path_keypairs: Option>, @@ -190,11 +191,10 @@ impl CoreGroupBuilder { /// Create a new [`CoreGroupBuilder`]. pub(crate) fn new( group_id: GroupId, - crypto_config: CryptoConfig, + ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, ) -> Self { - let public_group_builder = - PublicGroup::builder(group_id, crypto_config, credential_with_key); + let public_group_builder = PublicGroup::builder(group_id, ciphersuite, credential_with_key); Self { config: None, psk_ids: vec![], @@ -266,17 +266,16 @@ impl CoreGroupBuilder { /// /// This function performs cryptographic operations and there requires an /// [`OpenMlsProvider`]. - pub(crate) fn build( + pub(crate) fn build( self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, - ) -> Result> { + ) -> Result> { let (public_group_builder, commit_secret, leaf_keypair) = self.public_group_builder.get_secrets(provider, signer)?; - let ciphersuite = public_group_builder.crypto_config().ciphersuite; + let ciphersuite = public_group_builder.group_context().ciphersuite(); let config = self.config.unwrap_or_default(); - let version = public_group_builder.crypto_config().version; let serialized_group_context = public_group_builder .group_context() @@ -290,8 +289,9 @@ impl CoreGroupBuilder { // We use a random `InitSecret` for initialization. let joiner_secret = JoinerSecret::new( provider.crypto(), + ciphersuite, commit_secret, - &InitSecret::random(ciphersuite, provider.rand(), version) + &InitSecret::random(ciphersuite, provider.rand()) .map_err(LibraryError::unexpected_crypto_error)?, &serialized_group_context, ) @@ -302,7 +302,7 @@ impl CoreGroupBuilder { // Prepare the PskSecret let psk_secret = { - let psks = load_psks(provider.key_store(), &resumption_psk_store, &self.psk_ids)?; + let psks = load_psks(provider.storage(), &resumption_psk_store, &self.psk_ids)?; PskSecret::new(provider.crypto(), ciphersuite, psks)? }; @@ -314,7 +314,7 @@ impl CoreGroupBuilder { .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; let epoch_secrets = key_schedule - .epoch_secrets(provider.crypto()) + .epoch_secrets(provider.crypto(), ciphersuite) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; let (group_epoch_secrets, message_secrets) = epoch_secrets.split_secrets( @@ -325,7 +325,7 @@ impl CoreGroupBuilder { let initial_confirmation_tag = message_secrets .confirmation_key() - .tag(provider.crypto(), &[]) + .tag(provider.crypto(), ciphersuite, &[]) .map_err(LibraryError::unexpected_crypto_error)?; let message_secrets_store = @@ -344,10 +344,15 @@ impl CoreGroupBuilder { resumption_psk_store, }; + // Store the group state + group + .store(provider.storage()) + .map_err(CoreGroupBuildError::StorageError)?; + // Store the private key of the own leaf in the key store as an epoch keypair. group - .store_epoch_keypairs(provider.key_store(), &[leaf_keypair]) - .map_err(CoreGroupBuildError::KeyStoreError)?; + .store_epoch_keypairs(provider.storage(), &[leaf_keypair]) + .map_err(CoreGroupBuildError::StorageError)?; Ok(group) } @@ -357,10 +362,10 @@ impl CoreGroup { /// Get a builder for [`CoreGroup`]. pub(crate) fn builder( group_id: GroupId, - crypto_config: CryptoConfig, + ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, ) -> CoreGroupBuilder { - CoreGroupBuilder::new(group_id, crypto_config, credential_with_key) + CoreGroupBuilder::new(group_id, ciphersuite, credential_with_key) } // === Create handshake messages === @@ -466,15 +471,31 @@ impl CoreGroup { ) } + pub(crate) fn create_custom_proposal( + &self, + framing_parameters: FramingParameters, + custom_proposal: CustomProposal, + signer: &impl Signer, + ) -> Result { + let proposal = Proposal::Custom(custom_proposal); + AuthenticatedContent::member_proposal( + framing_parameters, + self.own_leaf_index(), + proposal, + self.context(), + signer, + ) + } + // Create application message - pub(crate) fn create_application_message( + pub(crate) fn create_application_message( &mut self, aad: &[u8], msg: &[u8], padding_size: usize, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, - ) -> Result { + ) -> Result> { let public_message = AuthenticatedContent::new_application( self.own_leaf_index(), aad, @@ -486,19 +507,26 @@ impl CoreGroup { } // Encrypt an PublicMessage into an PrivateMessage - pub(crate) fn encrypt( + pub(crate) fn encrypt( &mut self, public_message: AuthenticatedContent, padding_size: usize, - provider: &impl OpenMlsProvider, - ) -> Result { - PrivateMessage::try_from_authenticated_content( + provider: &Provider, + ) -> Result> { + let msg = PrivateMessage::try_from_authenticated_content( &public_message, self.ciphersuite(), provider, self.message_secrets_store.message_secrets_mut(), padding_size, - ) + )?; + + provider + .storage() + .write_message_secrets(self.group_id(), &self.message_secrets_store) + .map_err(MessageEncryptionError::StorageError)?; + + Ok(msg) } /// Exporter @@ -563,7 +591,11 @@ impl CoreGroup { extensions, self.message_secrets() .confirmation_key() - .tag(crypto, self.context().confirmed_transcript_hash()) + .tag( + crypto, + self.ciphersuite(), + self.context().confirmed_transcript_hash(), + ) .map_err(LibraryError::unexpected_crypto_error)?, self.own_leaf_index(), ); @@ -604,11 +636,6 @@ impl CoreGroup { self.public_group.group_context() } - #[cfg(test)] - pub(crate) fn group_context_extensions(&self) -> &Extensions { - self.context().extensions() - } - /// Get the group ID pub(crate) fn group_id(&self) -> &GroupId { self.public_group.group_id() @@ -697,38 +724,96 @@ impl CoreGroup { .ok_or_else(|| LibraryError::custom("Tree has no own leaf.")) } + /// Stores the [`CoreGroup`]. Called from methods creating a new group and mutating an + /// existing group, both inside [`CoreGroup`] and in [`MlsGroup`]. + pub(super) fn store( + &self, + storage: &Storage, + ) -> Result<(), Storage::Error> { + let group_id = self.group_id(); + + self.public_group.store(storage)?; + storage.write_own_leaf_index(group_id, &self.own_leaf_index())?; + storage.write_group_epoch_secrets(group_id, &self.group_epoch_secrets)?; + storage.set_use_ratchet_tree_extension(group_id, self.use_ratchet_tree_extension)?; + storage.write_message_secrets(group_id, &self.message_secrets_store)?; + storage.write_resumption_psk_store(group_id, &self.resumption_psk_store)?; + + Ok(()) + } + + /// Loads a [`CoreGroup`]. Called in [`MlsGroup::load`]. + pub(super) fn load( + storage: &Storage, + group_id: &GroupId, + ) -> Result, Storage::Error> { + let public_group = PublicGroup::load(storage, group_id)?; + let group_epoch_secrets = storage.group_epoch_secrets(group_id)?; + let own_leaf_index = storage.own_leaf_index(group_id)?; + let use_ratchet_tree_extension = storage.use_ratchet_tree_extension(group_id)?; + let message_secrets_store = storage.message_secrets(group_id)?; + let resumption_psk_store = storage.resumption_psk_store(group_id)?; + + let build = || -> Option { + Some(Self { + public_group: public_group?, + group_epoch_secrets: group_epoch_secrets?, + own_leaf_index: own_leaf_index?, + use_ratchet_tree_extension: use_ratchet_tree_extension?, + message_secrets_store: message_secrets_store?, + resumption_psk_store: resumption_psk_store?, + }) + }; + + Ok(build()) + } + + pub(super) fn delete( + &self, + storage: &Storage, + ) -> Result<(), Storage::Error> { + self.public_group.delete(storage)?; + storage.delete_own_leaf_index(self.group_id())?; + storage.delete_group_epoch_secrets(self.group_id())?; + storage.delete_use_ratchet_tree_extension(self.group_id())?; + storage.delete_message_secrets(self.group_id())?; + storage.delete_all_resumption_psk_secrets(self.group_id())?; + + Ok(()) + } + /// Store the given [`EncryptionKeyPair`]s in the `provider`'s key store /// indexed by this group's [`GroupId`] and [`GroupEpoch`]. /// /// Returns an error if access to the key store fails. - pub(super) fn store_epoch_keypairs( + pub(super) fn store_epoch_keypairs( &self, - store: &KeyStore, + store: &Storage, keypair_references: &[EncryptionKeyPair], - ) -> Result<(), KeyStore::Error> { - let k = EpochKeypairId::new( + ) -> Result<(), Storage::Error> { + store.write_encryption_epoch_key_pairs( self.group_id(), - self.context().epoch().as_u64(), - self.own_leaf_index(), - ); - store.store(&k.0, &keypair_references.to_vec()) + &self.context().epoch(), + self.own_leaf_index().u32(), + keypair_references, + ) } /// Read the [`EncryptionKeyPair`]s of this group and its current - /// [`GroupEpoch`] from the `provider`'s key store. + /// [`GroupEpoch`] from the `provider`'s storage. /// - /// Returns `None` if access to the key store fails. - pub(super) fn read_epoch_keypairs( + /// Returns an empty vector if access to the store fails or it can't find + /// any keys. + pub(super) fn read_epoch_keypairs( &self, - store: &KeyStore, + store: &Storage, ) -> Vec { - let k = EpochKeypairId::new( - self.group_id(), - self.context().epoch().as_u64(), - self.own_leaf_index(), - ); store - .read::>(&k.0) + .encryption_epoch_key_pairs( + self.group_id(), + &self.context().epoch(), + self.own_leaf_index().u32(), + ) .unwrap_or_default() } @@ -736,24 +821,23 @@ impl CoreGroup { /// the `provider`'s key store. /// /// Returns an error if access to the key store fails. - pub(super) fn delete_previous_epoch_keypairs( + pub(super) fn delete_previous_epoch_keypairs( &self, - store: &KeyStore, - ) -> Result<(), KeyStore::Error> { - let k = EpochKeypairId::new( + store: &Storage, + ) -> Result<(), Storage::Error> { + store.delete_encryption_epoch_key_pairs( self.group_id(), - self.context().epoch().as_u64() - 1, - self.own_leaf_index(), - ); - store.delete::>(&k.0) + &GroupEpoch::from(self.context().epoch().as_u64() - 1), + self.own_leaf_index().u32(), + ) } - pub(crate) fn create_commit( + pub(crate) fn create_commit( &self, mut params: CreateCommitParams, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, - ) -> Result> { + ) -> Result> { let ciphersuite = self.ciphersuite(); let sender = match params.commit_type() { @@ -789,6 +873,10 @@ impl CoreGroup { // Validate the proposals by doing the following checks: + // ValSem113: All Proposals: The proposal type must be supported by all + // members of the group + self.public_group + .validate_proposal_type_support(&proposal_queue)?; // ValSem101 // ValSem102 // ValSem103 @@ -885,6 +973,7 @@ impl CoreGroup { let joiner_secret = JoinerSecret::new( provider.crypto(), + ciphersuite, path_computation_result.commit_secret, self.group_epoch_secrets().init_secret(), &serialized_provisional_group_context, @@ -894,7 +983,7 @@ impl CoreGroup { // Prepare the PskSecret let psk_secret = { let psks = load_psks( - provider.key_store(), + provider.storage(), &self.resumption_psk_store, &apply_proposals_values.presharedkeys, )?; @@ -912,13 +1001,13 @@ impl CoreGroup { .map_err(LibraryError::missing_bound_check)?; let welcome_secret = key_schedule - .welcome(provider.crypto()) + .welcome(provider.crypto(), self.ciphersuite()) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; key_schedule .add_context(provider.crypto(), &serialized_provisional_group_context) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; let provisional_epoch_secrets = key_schedule - .epoch_secrets(provider.crypto()) + .epoch_secrets(provider.crypto(), self.ciphersuite()) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; // Calculate the confirmation tag @@ -926,6 +1015,7 @@ impl CoreGroup { .confirmation_key() .tag( provider.crypto(), + self.ciphersuite(), diff.group_context().confirmed_transcript_hash(), ) .map_err(LibraryError::unexpected_crypto_error)?; @@ -979,7 +1069,7 @@ impl CoreGroup { let welcome_option = if !apply_proposals_values.invitation_list.is_empty() { // Encrypt GroupInfo object let (welcome_key, welcome_nonce) = welcome_secret - .derive_welcome_key_nonce(provider.crypto()) + .derive_welcome_key_nonce(provider.crypto(), self.ciphersuite()) .map_err(LibraryError::unexpected_crypto_error)?; let encrypted_group_info = welcome_key .aead_seal( @@ -1045,12 +1135,13 @@ impl CoreGroup { } /// Create a new group context extension proposal - pub(crate) fn create_group_context_ext_proposal( + pub(crate) fn create_group_context_ext_proposal( &self, framing_parameters: FramingParameters, extensions: Extensions, signer: &impl Signer, - ) -> Result> { + ) -> Result> + { // Ensure that the group supports all the extensions that are wanted. let required_extension = extensions .iter() @@ -1058,7 +1149,6 @@ impl CoreGroup { if let Some(required_extension) = required_extension { let required_capabilities = required_extension.as_required_capabilities_extension()?; // Ensure we support all the capabilities. - required_capabilities.check_support()?; self.own_leaf_node()? .capabilities() .supports_required_capabilities(required_capabilities)?; @@ -1128,19 +1218,3 @@ pub(crate) struct CoreGroupConfig { /// Defaults to false. pub(crate) add_ratchet_tree_extension: bool, } - -/// Composite key for key material of a client within an epoch -pub struct EpochKeypairId(Vec); - -impl EpochKeypairId { - fn new(group_id: &GroupId, epoch: u64, leaf_index: LeafNodeIndex) -> Self { - Self( - [ - group_id.as_slice(), - &leaf_index.u32().to_be_bytes(), - &epoch.to_be_bytes(), - ] - .concat(), - ) - } -} diff --git a/openmls/src/group/core_group/new_from_external_init.rs b/openmls/src/group/core_group/new_from_external_init.rs index 7ebd947910..29a708405f 100644 --- a/openmls/src/group/core_group/new_from_external_init.rs +++ b/openmls/src/group/core_group/new_from_external_init.rs @@ -5,6 +5,7 @@ use crate::{ errors::ExternalCommitError, }, messages::proposals::{ExternalInitProposal, Proposal}, + storage::OpenMlsProvider, }; use super::CoreGroup; @@ -24,13 +25,13 @@ impl CoreGroup { /// /// Note: If there is a group member in the group with the same identity as us, /// this will create a remove proposal. - pub(crate) fn join_by_external_commit( - provider: &impl OpenMlsProvider, + pub(crate) fn join_by_external_commit( + provider: &Provider, signer: &impl Signer, mut params: CreateCommitParams, ratchet_tree: Option, verifiable_group_info: VerifiableGroupInfo, - ) -> Result { + ) -> Result> { // Build the ratchet tree // Set nodes either from the extension or from the `nodes_option`. @@ -47,7 +48,7 @@ impl CoreGroup { }; let (public_group, group_info) = PublicGroup::from_external( - provider.crypto(), + provider, ratchet_tree, verifiable_group_info, // Existing proposals are discarded when joining by external commit. @@ -72,8 +73,12 @@ impl CoreGroup { // The `EpochSecrets` we create here are essentially zero, with the // exception of the `InitSecret`, which is all we need here for the // external commit. - let epoch_secrets = EpochSecrets::with_init_secret(provider.crypto(), init_secret) - .map_err(LibraryError::unexpected_crypto_error)?; + let epoch_secrets = EpochSecrets::with_init_secret( + provider.crypto(), + group_info.group_context().ciphersuite(), + init_secret, + ) + .map_err(LibraryError::unexpected_crypto_error)?; let (group_epoch_secrets, message_secrets) = epoch_secrets.split_secrets( group_context .tls_serialize_detached() @@ -129,6 +134,10 @@ impl CoreGroup { "Error creating commit {create_commit_result:?}" ); + group + .store(provider.storage()) + .map_err(ExternalCommitError::StorageError)?; + Ok(( group, create_commit_result.map_err(|_| ExternalCommitError::CommitError)?, diff --git a/openmls/src/group/core_group/new_from_welcome.rs b/openmls/src/group/core_group/new_from_welcome.rs index bd4fe10175..0bb4a7c123 100644 --- a/openmls/src/group/core_group/new_from_welcome.rs +++ b/openmls/src/group/core_group/new_from_welcome.rs @@ -1,14 +1,11 @@ use log::debug; -use openmls_traits::key_store::OpenMlsKeyStore; use crate::{ ciphersuite::hash_ref::HashReference, group::{core_group::*, errors::WelcomeError}, schedule::psk::store::ResumptionPskStore, - treesync::{ - errors::{DerivePathError, PublicTreeError}, - node::encryption_keys::EncryptionKeyPair, - }, + storage::OpenMlsProvider, + treesync::errors::{DerivePathError, PublicTreeError}, }; impl StagedCoreWelcome { @@ -17,36 +14,14 @@ impl StagedCoreWelcome { /// group. /// Note: calling this function will consume the key material for decrypting the [`Welcome`] /// message, even if the caller does not turn the [`StagedCoreWelcome`] into a [`CoreGroup`]. - pub fn new_from_welcome( + pub fn new_from_welcome( welcome: Welcome, ratchet_tree: Option, key_package_bundle: KeyPackageBundle, - provider: &impl OpenMlsProvider, + provider: &Provider, mut resumption_psk_store: ResumptionPskStore, - ) -> Result> { + ) -> Result> { log::debug!("CoreGroup::new_from_welcome_internal"); - - // Read the encryption key pair from the key store and delete it there. - // TODO #1207: Key store access happens as early as possible so it can - // be pulled up later more easily. - let leaf_keypair = EncryptionKeyPair::read_from_key_store( - provider, - key_package_bundle.key_package.leaf_node().encryption_key(), - ) - .ok_or(WelcomeError::NoMatchingEncryptionKey)?; - - // Delete the leaf encryption keypair from the - // key store, but only if it doesn't have a last resort extension. - if !key_package_bundle.key_package().last_resort() { - leaf_keypair - .delete_from_key_store(provider.key_store()) - .map_err(|_| WelcomeError::NoMatchingEncryptionKey)?; - } else { - log::debug!( - "Found last resort extension, not deleting leaf encryption keypair from key store" - ); - } - let ciphersuite = welcome.ciphersuite(); // Find key_package in welcome secrets @@ -67,7 +42,7 @@ impl StagedCoreWelcome { } let group_secrets = GroupSecrets::try_from_ciphertext( - key_package_bundle.private_key(), + key_package_bundle.init_private_key(), egs.encrypted_group_secrets(), welcome.encrypted_group_info(), ciphersuite, @@ -77,7 +52,7 @@ impl StagedCoreWelcome { // Prepare the PskSecret let psk_secret = { let psks = load_psks( - provider.key_store(), + provider.storage(), &resumption_psk_store, &group_secrets.psks, )?; @@ -95,9 +70,9 @@ impl StagedCoreWelcome { // Derive welcome key & nonce from the key schedule let (welcome_key, welcome_nonce) = key_schedule - .welcome(provider.crypto()) + .welcome(provider.crypto(), ciphersuite) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))? - .derive_welcome_key_nonce(provider.crypto()) + .derive_welcome_key_nonce(provider.crypto(), ciphersuite) .map_err(LibraryError::unexpected_crypto_error)?; let verifiable_group_info = VerifiableGroupInfo::try_from_ciphertext( @@ -112,9 +87,6 @@ impl StagedCoreWelcome { if let Some(required_capabilities) = verifiable_group_info.extensions().required_capabilities() { - required_capabilities - .check_support() - .map_err(|_| WelcomeError::UnsupportedCapability)?; // Also check that our key package actually supports the extensions. // Per spec the sender must have checked this. But you never know. key_package_bundle @@ -142,7 +114,7 @@ impl StagedCoreWelcome { // Since there is currently only the external pub extension, there is no // group info extension of interest here. let (public_group, _group_info_extensions) = PublicGroup::from_external( - provider.crypto(), + provider, ratchet_tree, verifiable_group_info.clone(), ProposalStore::new(), @@ -180,7 +152,7 @@ impl StagedCoreWelcome { .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; let epoch_secrets = key_schedule - .epoch_secrets(provider.crypto()) + .epoch_secrets(provider.crypto(), ciphersuite) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; epoch_secrets.split_secrets( @@ -194,6 +166,7 @@ impl StagedCoreWelcome { .confirmation_key() .tag( provider.crypto(), + ciphersuite, public_group.group_context().confirmed_transcript_hash(), ) .map_err(LibraryError::unexpected_crypto_error)?; @@ -242,7 +215,7 @@ impl StagedCoreWelcome { message_secrets_store, resumption_psk_store, verifiable_group_info, - leaf_keypair, + key_package_bundle, path_keypairs, }; @@ -265,45 +238,35 @@ impl StagedCoreWelcome { } /// Consumes the [`StagedCoreWelcome`] and returns the respective [`CoreGroup`]. - pub fn into_core_group( + pub fn into_core_group( self, - provider: &impl OpenMlsProvider, - ) -> Result> { - let Self { - public_group, - group_epoch_secrets, - own_leaf_index, - use_ratchet_tree_extension, - message_secrets_store, - resumption_psk_store, - leaf_keypair, - path_keypairs, - .. - } = self; - + provider: &Provider, + ) -> Result> { // If we got a path secret, derive the path (which also checks if the // public keys match) and store the derived keys in the key store. - let group_keypairs = if let Some(path_keypairs) = path_keypairs { - vec![leaf_keypair] - .into_iter() - .chain(path_keypairs) - .collect() + let group_keypairs = if let Some(path_keypairs) = self.path_keypairs { + let mut keypairs = vec![self.key_package_bundle.encryption_key_pair()]; + keypairs.extend_from_slice(&path_keypairs); + keypairs } else { - vec![leaf_keypair] + vec![self.key_package_bundle.encryption_key_pair()] }; let group = CoreGroup { - public_group, - group_epoch_secrets, - own_leaf_index, - use_ratchet_tree_extension, - message_secrets_store, - resumption_psk_store, + public_group: self.public_group, + group_epoch_secrets: self.group_epoch_secrets, + own_leaf_index: self.own_leaf_index, + use_ratchet_tree_extension: self.use_ratchet_tree_extension, + message_secrets_store: self.message_secrets_store, + resumption_psk_store: self.resumption_psk_store, }; group - .store_epoch_keypairs(provider.key_store(), group_keypairs.as_slice()) - .map_err(WelcomeError::KeyStoreError)?; + .store(provider.storage()) + .map_err(WelcomeError::StorageError)?; + group + .store_epoch_keypairs(provider.storage(), group_keypairs.as_slice()) + .map_err(WelcomeError::StorageError)?; Ok(group) } diff --git a/openmls/src/group/core_group/process.rs b/openmls/src/group/core_group/process.rs index 4b47fc9139..e203643af5 100644 --- a/openmls/src/group/core_group/process.rs +++ b/openmls/src/group/core_group/process.rs @@ -25,6 +25,8 @@ impl CoreGroup { /// - ValSem110 /// - ValSem111 /// - ValSem112 + /// - ValSem113: All Proposals: The proposal type must be supported by all + /// members of the group /// - ValSem200 /// - ValSem201 /// - ValSem202: Path must be the right length @@ -37,19 +39,19 @@ impl CoreGroup { /// - ValSem242 /// - ValSem244 /// - ValSem246 (as part of ValSem010) - pub(crate) fn process_unverified_message( + pub(crate) fn process_unverified_message( &self, - provider: &impl OpenMlsProvider, + provider: &Provider, unverified_message: UnverifiedMessage, proposal_store: &ProposalStore, old_epoch_keypairs: Vec, leaf_node_keypairs: Vec, - ) -> Result { + ) -> Result> { // Checks the following semantic validation: // - ValSem010 // - ValSem246 (as part of ValSem010) let (content, credential) = - unverified_message.verify(self.ciphersuite(), provider.crypto(), self.version())?; + unverified_message.verify(self.ciphersuite(), provider, self.version())?; match content.sender() { Sender::Member(_) | Sender::NewMemberCommit | Sender::NewMemberProposal => { @@ -68,6 +70,7 @@ impl CoreGroup { provider.crypto(), content, )?); + if matches!(sender, Sender::NewMemberProposal) { ProcessedMessageContent::ExternalJoinProposalMessage(proposal) } else { @@ -151,6 +154,8 @@ impl CoreGroup { /// - ValSem110 /// - ValSem111 /// - ValSem112 + /// - ValSem113: All Proposals: The proposal type must be supported by all + /// members of the group /// - ValSem200 /// - ValSem201 /// - ValSem202: Path must be the right length @@ -164,14 +169,14 @@ impl CoreGroup { /// - ValSem244 /// - ValSem245 /// - ValSem246 (as part of ValSem010) - pub(crate) fn process_message( + pub(crate) fn process_message( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, message: impl Into, sender_ratchet_configuration: &SenderRatchetConfiguration, proposal_store: &ProposalStore, own_leaf_nodes: &[LeafNode], - ) -> Result { + ) -> Result> { let message: ProtocolMessage = message.into(); // Checks the following semantic validation: @@ -246,6 +251,7 @@ impl CoreGroup { message_secrets, message_secrets.serialized_context().to_vec(), crypto, + self.ciphersuite(), ) } ProtocolMessage::PrivateMessage(ciphertext) => { @@ -267,7 +273,7 @@ impl CoreGroup { own_leaf_nodes: &[LeafNode], ) -> Result<(Vec, Vec), StageCommitError> { // All keys from the previous epoch are potential decryption keypairs. - let old_epoch_keypairs = self.read_epoch_keypairs(provider.key_store()); + let old_epoch_keypairs = self.read_epoch_keypairs(provider.storage()); // If we are processing an update proposal that originally came from // us, the keypair corresponding to the leaf in the update is also a @@ -275,7 +281,7 @@ impl CoreGroup { let leaf_node_keypairs = own_leaf_nodes .iter() .map(|leaf_node| { - EncryptionKeyPair::read_from_key_store(provider, leaf_node.encryption_key()) + EncryptionKeyPair::read(provider, leaf_node.encryption_key()) .ok_or(StageCommitError::MissingDecryptionKey) }) .collect::, StageCommitError>>()?; @@ -284,12 +290,12 @@ impl CoreGroup { } /// Merge a [StagedCommit] into the group after inspection - pub(crate) fn merge_staged_commit( + pub(crate) fn merge_staged_commit( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, staged_commit: StagedCommit, proposal_store: &mut ProposalStore, - ) -> Result<(), MergeCommitError> { + ) -> Result<(), MergeCommitError> { // Save the past epoch let past_epoch = self.context().epoch(); // Get all the full leaves diff --git a/openmls/src/group/core_group/proposals.rs b/openmls/src/group/core_group/proposals.rs index eb80807067..f08695864c 100644 --- a/openmls/src/group/core_group/proposals.rs +++ b/openmls/src/group/core_group/proposals.rs @@ -512,6 +512,12 @@ impl ProposalQueue { proposal_pool.insert(queued_proposal.proposal_reference(), queued_proposal); } Proposal::AppAck(_) => unimplemented!("See #291"), + Proposal::Custom(_) => { + // Other/unknown proposals are always considered valid and + // have to be checked by the application instead. + valid_proposals.add(queued_proposal.proposal_reference()); + proposal_pool.insert(queued_proposal.proposal_reference(), queued_proposal); + } } } // Check for presence of Removes and delete Updates diff --git a/openmls/src/group/core_group/staged_commit.rs b/openmls/src/group/core_group/staged_commit.rs index d484400e9c..534047dbf6 100644 --- a/openmls/src/group/core_group/staged_commit.rs +++ b/openmls/src/group/core_group/staged_commit.rs @@ -1,15 +1,18 @@ use core::fmt::Debug; use std::mem; -use openmls_traits::key_store::OpenMlsKeyStore; use public_group::diff::{apply_proposals::ApplyProposalsValues, StagedPublicGroupDiff}; +use self::public_group::staged_commit::PublicStagedCommitState; + use super::{super::errors::*, proposals::ProposalStore, *}; use crate::{ - framing::mls_auth_content::AuthenticatedContent, + ciphersuite::Secret, framing::mls_auth_content::AuthenticatedContent, treesync::node::encryption_keys::EncryptionKeyPair, }; +use openmls_traits::storage::StorageProvider as _; + impl CoreGroup { fn derive_epoch_secrets( &self, @@ -39,6 +42,7 @@ impl CoreGroup { )?; JoinerSecret::new( provider.crypto(), + self.ciphersuite(), commit_secret, &init_secret, serialized_provisional_group_context, @@ -47,6 +51,7 @@ impl CoreGroup { } else { JoinerSecret::new( provider.crypto(), + self.ciphersuite(), commit_secret, epoch_secrets.init_secret(), serialized_provisional_group_context, @@ -56,8 +61,8 @@ impl CoreGroup { // Prepare the PskSecret let psk_secret = { - let psks = load_psks( - provider.key_store(), + let psks: Vec<(&PreSharedKeyId, Secret)> = load_psks( + provider.storage(), &self.resumption_psk_store, &apply_proposals_values.presharedkeys, )?; @@ -77,21 +82,21 @@ impl CoreGroup { .add_context(provider.crypto(), serialized_provisional_group_context) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?; Ok(key_schedule - .epoch_secrets(provider.crypto()) + .epoch_secrets(provider.crypto(), self.ciphersuite()) .map_err(|_| LibraryError::custom("Using the key schedule in the wrong state"))?) } - /// Stages a commit message that was sent by another group member. - /// This function does the following: + /// Stages a commit message that was sent by another group member. This + /// function does the following: /// - Applies the proposals covered by the commit to the tree /// - Applies the (optional) update path to the tree /// - Decrypts and calculates the path secrets /// - Initializes the key schedule for epoch rollover /// - Verifies the confirmation tag /// - /// Returns a [StagedCommit] that can be inspected and later merged - /// into the group state with [CoreGroup::merge_commit()] - /// This function does the following checks: + /// Returns a [StagedCommit] that can be inspected and later merged into the + /// group state with [CoreGroup::merge_commit()] This function does the + /// following checks: /// - ValSem101 /// - ValSem102 /// - ValSem104 @@ -102,6 +107,8 @@ impl CoreGroup { /// - ValSem110 /// - ValSem111 /// - ValSem112 + /// - ValSem113: All Proposals: The proposal type must be supported by all + /// members of the group /// - ValSem200 /// - ValSem201 /// - ValSem202: Path must be the right length @@ -112,9 +119,8 @@ impl CoreGroup { /// - ValSem240 /// - ValSem241 /// - ValSem242 - /// - ValSem244 - /// Returns an error if the given commit was sent by the owner of this - /// group. + /// - ValSem244 Returns an error if the given commit was sent by the owner + /// of this group. pub(crate) fn stage_commit( &self, mls_content: &AuthenticatedContent, @@ -146,9 +152,13 @@ impl CoreGroup { // Check if we were removed from the group if apply_proposals_values.self_removed { let staged_diff = diff.into_staged_diff(provider.crypto(), ciphersuite)?; + let staged_state = PublicStagedCommitState::new( + staged_diff, + commit.path.as_ref().map(|path| path.leaf_node().clone()), + ); return Ok(StagedCommit::new( proposal_queue, - StagedCommitState::PublicState(Box::new(staged_diff)), + StagedCommitState::PublicState(Box::new(staged_state)), )); } @@ -229,12 +239,7 @@ impl CoreGroup { apply_proposals_values.extensions.clone(), )?; - ( - CommitSecret::zero_secret(ciphersuite, self.version()), - vec![], - None, - None, - ) + (CommitSecret::zero_secret(ciphersuite), vec![], None, None) }; // Update the confirmed transcript hash before we compute the confirmation tag. @@ -269,6 +274,7 @@ impl CoreGroup { .confirmation_key() .tag( provider.crypto(), + self.ciphersuite(), diff.group_context().confirmed_transcript_hash(), ) .map_err(LibraryError::unexpected_crypto_error)?; @@ -303,17 +309,20 @@ impl CoreGroup { /// /// This function should not fail and only returns a [`Result`], because it /// might throw a `LibraryError`. - pub(crate) fn merge_commit( + pub(crate) fn merge_commit( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, staged_commit: StagedCommit, - ) -> Result, MergeCommitError> { + ) -> Result, MergeCommitError> { // Get all keypairs from the old epoch, so we can later store the ones // that are still relevant in the new epoch. - let old_epoch_keypairs = self.read_epoch_keypairs(provider.key_store()); + let old_epoch_keypairs = self.read_epoch_keypairs(provider.storage()); match staged_commit.state { - StagedCommitState::PublicState(staged_diff) => { - self.public_group.merge_diff(*staged_diff); + StagedCommitState::PublicState(staged_state) => { + self.public_group + .merge_diff(staged_state.into_staged_diff()); + self.store(provider.storage()) + .map_err(MergeCommitError::StorageError)?; Ok(None) } StagedCommitState::GroupMember(state) => { @@ -349,8 +358,8 @@ impl CoreGroup { .chain(leaf_keypair) .filter(|keypair| new_owned_encryption_keys.contains(keypair.public_key())) .collect(); - // We should have private keys for all owned encryption keys. + // We should have private keys for all owned encryption keys. debug_assert_eq!(new_owned_encryption_keys.len(), epoch_keypairs.len()); if new_owned_encryption_keys.len() != epoch_keypairs.len() { return Err(LibraryError::custom( @@ -358,16 +367,32 @@ impl CoreGroup { ) .into()); } + + // store the updated group state + let storage = provider.storage(); + let group_id = self.group_id(); + + self.public_group + .store(storage) + .map_err(MergeCommitError::StorageError)?; + storage + .write_group_epoch_secrets(group_id, &self.group_epoch_secrets) + .map_err(MergeCommitError::StorageError)?; + storage + .write_message_secrets(group_id, &self.message_secrets_store) + .map_err(MergeCommitError::StorageError)?; + // Store the relevant keys under the new epoch - self.store_epoch_keypairs(provider.key_store(), epoch_keypairs.as_slice()) - .map_err(MergeCommitError::KeyStoreError)?; + self.store_epoch_keypairs(storage, epoch_keypairs.as_slice()) + .map_err(MergeCommitError::StorageError)?; + // Delete the old keys. - self.delete_previous_epoch_keypairs(provider.key_store()) - .map_err(MergeCommitError::KeyStoreError)?; + self.delete_previous_epoch_keypairs(storage) + .map_err(MergeCommitError::StorageError)?; if let Some(keypair) = state.new_leaf_keypair_option { keypair - .delete_from_key_store(provider.key_store()) - .map_err(MergeCommitError::KeyStoreError)?; + .delete(storage) + .map_err(MergeCommitError::StorageError)?; } Ok(Some(message_secrets)) @@ -400,7 +425,7 @@ impl CoreGroup { #[derive(Debug, Serialize, Deserialize)] pub(crate) enum StagedCommitState { - PublicState(Box), + PublicState(Box), GroupMember(Box), } @@ -449,7 +474,9 @@ impl StagedCommit { /// Returns the leaf node of the (optional) update path. pub fn update_path_leaf_node(&self) -> Option<&LeafNode> { match self.state { - StagedCommitState::PublicState(_) => None, + StagedCommitState::PublicState(ref public_state) => { + public_state.update_path_leaf_node() + } StagedCommitState::GroupMember(ref group_member_state) => { group_member_state.update_path_leaf_node.as_ref() } @@ -512,7 +539,7 @@ impl StagedCommit { /// Returns the [`GroupContext`] of the staged commit state. pub fn group_context(&self) -> &GroupContext { match self.state { - StagedCommitState::PublicState(ref ps) => ps.group_context(), + StagedCommitState::PublicState(ref ps) => ps.staged_diff().group_context(), StagedCommitState::GroupMember(ref gm) => gm.group_context(), } } diff --git a/openmls/src/group/core_group/test_core_group.rs b/openmls/src/group/core_group/test_core_group.rs index 4c41f220d2..796c292616 100644 --- a/openmls/src/group/core_group/test_core_group.rs +++ b/openmls/src/group/core_group/test_core_group.rs @@ -1,14 +1,13 @@ use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{crypto::OpenMlsCrypto, types::HpkeCiphertext, OpenMlsProvider}; -use tls_codec::{Deserialize, Serialize}; +use openmls_traits::types::HpkeCiphertext; +use tls_codec::Serialize; use crate::{ binary_tree::*, ciphersuite::{signable::Signable, AeadNonce}, credentials::*, framing::*, - group::{config::CryptoConfig, errors::*, *}, + group::{errors::*, *}, key_packages::*, messages::{group_info::GroupInfoTBS, *}, schedule::psk::{store::ResumptionPskStore, ExternalPsk, PreSharedKeyId, Psk}, @@ -18,7 +17,7 @@ use crate::{ pub(crate) fn setup_alice_group( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> ( CoreGroup, CredentialWithKey, @@ -37,7 +36,7 @@ pub(crate) fn setup_alice_group( // Alice creates a group let group = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key.clone(), ) .build(provider, &alice_signature_keys) @@ -55,8 +54,11 @@ pub fn flip_last_byte(ctxt: &mut HpkeCiphertext) { ctxt.ciphertext.push(last_bits); } -#[apply(ciphersuites_and_providers)] -fn test_failed_groupinfo_decryption(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_failed_groupinfo_decryption( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { let epoch = 123; let group_id = GroupId::random(provider.rand()); let tree_hash = vec![1, 2, 3, 4, 5, 6, 7, 8, 9]; @@ -70,7 +72,7 @@ fn test_failed_groupinfo_decryption(ciphersuite: Ciphersuite, provider: &impl Op let (alice_credential_with_key, alice_signature_keys) = test_utils::new_credential(provider, b"Alice", ciphersuite.signature_algorithm()); - let key_package_bundle = KeyPackageBundle::new( + let key_package_bundle = KeyPackageBundle::generate( provider, &alice_signature_keys, ciphersuite, @@ -104,7 +106,7 @@ fn test_failed_groupinfo_decryption(ciphersuite: Ciphersuite, provider: &impl Op .crypto() .derive_hpke_keypair( ciphersuite.hpke_config(), - Secret::random(ciphersuite, provider.rand(), None) + Secret::random(ciphersuite, provider.rand()) .expect("Not enough randomness.") .as_slice(), ) @@ -161,16 +163,16 @@ fn test_failed_groupinfo_decryption(ciphersuite: Ciphersuite, provider: &impl Op .and_then(|staged_join| staged_join.into_core_group(provider)) .expect_err("Creation of core group from a broken Welcome was successful."); - assert_eq!( + assert!(matches!( error, WelcomeError::GroupSecrets(GroupSecretsError::DecryptionFailed) - ) + )) } /// Test what happens if the KEM ciphertext for the receiver in the UpdatePath /// is broken. -#[apply(ciphersuites_and_providers)] -fn test_update_path(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_update_path() { // === Alice creates a group with her and Bob === let ( framing_parameters, @@ -185,7 +187,7 @@ fn test_update_path(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let bob_old_leaf = group_bob.own_leaf_node().unwrap(); let bob_update_leaf_node = bob_old_leaf .updated( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, TreeInfoTbs::Update(group_bob.own_tree_position()), provider, &bob_signature_keys, @@ -270,7 +272,7 @@ fn test_update_path(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { fn setup_alice_bob( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> ( CredentialWithKey, SignatureKeyPair, @@ -285,7 +287,7 @@ fn setup_alice_bob( // Generate Bob's KeyPackage let bob_key_package_bundle = - KeyPackageBundle::new(provider, &bob_signer, ciphersuite, bob_credential_with_key); + KeyPackageBundle::generate(provider, &bob_signer, ciphersuite, bob_credential_with_key); ( alice_credential_with_key, @@ -296,8 +298,8 @@ fn setup_alice_bob( } // Test several scenarios when PSKs are used in a group -#[apply(ciphersuites_and_providers)] -fn test_psks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_psks() { // Basic group setup. let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -312,18 +314,15 @@ fn test_psks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // === Alice creates a group with a PSK === let psk_id = vec![1u8, 2, 3]; - let secret = Secret::random(ciphersuite, provider.rand(), None /* MLS version */) - .expect("Not enough randomness."); + let secret = Secret::random(ciphersuite, provider.rand()).expect("Not enough randomness."); let external_psk = ExternalPsk::new(psk_id); let preshared_key_id = PreSharedKeyId::new(ciphersuite, provider.rand(), Psk::External(external_psk)) .expect("An unexpected error occured."); - preshared_key_id - .write_to_key_store(provider, ciphersuite, secret.as_slice()) - .unwrap(); + preshared_key_id.store(provider, secret.as_slice()).unwrap(); let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key, ) .with_psk(vec![preshared_key_id.clone()]) @@ -394,7 +393,7 @@ fn test_psks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let bob_old_leaf = group_bob.own_leaf_node().unwrap(); let bob_update_leaf_node = bob_old_leaf .updated( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, TreeInfoTbs::Update(group_bob.own_tree_position()), provider, &bob_signature_keys, @@ -427,8 +426,11 @@ fn test_psks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // Test several scenarios when PSKs are used in a group -#[apply(ciphersuites_and_providers)] -fn test_staged_commit_creation(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_staged_commit_creation( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Basic group setup. let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -439,7 +441,7 @@ fn test_staged_commit_creation(ciphersuite: Ciphersuite, provider: &impl OpenMls // === Alice creates a group === let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key, ) .build(provider, &alice_signature_keys) @@ -500,8 +502,11 @@ fn test_staged_commit_creation(ciphersuite: Ciphersuite, provider: &impl OpenMls } // Test processing of own commits -#[apply(ciphersuites_and_providers)] -fn test_own_commit_processing(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_own_commit_processing( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Basic group setup. let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -513,7 +518,7 @@ fn test_own_commit_processing(ciphersuite: Ciphersuite, provider: &impl OpenMlsP // === Alice creates a group === let alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key, ) .build(provider, &alice_signature_keys) @@ -540,7 +545,7 @@ fn test_own_commit_processing(ciphersuite: Ciphersuite, provider: &impl OpenMlsP pub(crate) fn setup_client( id: &str, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> ( CredentialWithKey, KeyPackageBundle, @@ -556,7 +561,7 @@ pub(crate) fn setup_client( .unwrap(); // Generate the KeyPackage - let key_package_bundle = KeyPackageBundle::new( + let key_package_bundle = KeyPackageBundle::generate( provider, &signature_keys, ciphersuite, @@ -565,10 +570,10 @@ pub(crate) fn setup_client( (credential_with_key, key_package_bundle, signature_keys, pk) } -#[apply(ciphersuites_and_providers)] +#[openmls_test::openmls_test] fn test_proposal_application_after_self_was_removed( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) { // We're going to test if proposals are still applied, even after a client // notices that it was removed from a group. We do so by having Alice @@ -587,7 +592,7 @@ fn test_proposal_application_after_self_was_removed( let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - config::CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key, ) .build(provider, &alice_signature_keys) @@ -647,11 +652,7 @@ fn test_proposal_application_after_self_was_removed( index: _, credential, .. - }| { - let identity = - VLBytes::tls_deserialize_exact(credential.serialized_content()).unwrap(); - identity.as_slice() == b"Bob" - }, + }| { credential.serialized_content() == b"Bob" }, ) .expect("Couldn't find Bob in tree.") .index; @@ -740,28 +741,23 @@ fn test_proposal_application_after_self_was_removed( // didn't get updated. assert_eq!(alice_member.index, bob_member.index); - let alice_id = - VLBytes::tls_deserialize_exact(alice_member.credential.serialized_content()).unwrap(); - let bob_id = - VLBytes::tls_deserialize_exact(bob_member.credential.serialized_content()).unwrap(); - let charlie_id = - VLBytes::tls_deserialize_exact(charlie_member.credential.serialized_content()).unwrap(); - assert_eq!(alice_id.as_slice(), bob_id.as_slice()); + let alice_id = alice_member.credential.serialized_content(); + let bob_id = bob_member.credential.serialized_content(); + let charlie_id = charlie_member.credential.serialized_content(); + assert_eq!(alice_id, bob_id); assert_eq!(alice_member.signature_key, bob_member.signature_key); assert_eq!(charlie_member.index, bob_member.index); - assert_eq!(charlie_id.as_slice(), bob_id.as_slice()); + assert_eq!(charlie_id, bob_id); assert_eq!(charlie_member.signature_key, bob_member.signature_key); assert_eq!(charlie_member.encryption_key, alice_member.encryption_key); } let mut bob_members = bob_group.public_group().members(); - let bob_next_id = - VLBytes::tls_deserialize_exact(bob_members.next().unwrap().credential.serialized_content()) - .unwrap(); - assert_eq!(bob_next_id.as_slice(), b"Alice"); - let bob_next_id = - VLBytes::tls_deserialize_exact(bob_members.next().unwrap().credential.serialized_content()) - .unwrap(); - assert_eq!(bob_next_id.as_slice(), b"Charlie"); + let member = bob_members.next().unwrap(); + let bob_next_id = member.credential.serialized_content(); + assert_eq!(bob_next_id, b"Alice"); + let member = bob_members.next().unwrap(); + let bob_next_id = member.credential.serialized_content(); + assert_eq!(bob_next_id, b"Charlie"); } diff --git a/openmls/src/group/core_group/test_create_commit_params.rs b/openmls/src/group/core_group/test_create_commit_params.rs index d90fd08355..3ec1ba1425 100644 --- a/openmls/src/group/core_group/test_create_commit_params.rs +++ b/openmls/src/group/core_group/test_create_commit_params.rs @@ -1,10 +1,8 @@ -use crate::test_utils::*; - use super::*; // Tests that the builder for CreateCommitParams works as expected -#[apply(providers)] -fn build_create_commit_params(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn build_create_commit_params(provider: &Provider) { let _ = provider; let framing_parameters: FramingParameters = FramingParameters::new(&[1, 2, 3], WireFormat::PrivateMessage); diff --git a/openmls/src/group/core_group/test_external_init.rs b/openmls/src/group/core_group/test_external_init.rs index dae51f51b5..4aa36d8375 100644 --- a/openmls/src/group/core_group/test_external_init.rs +++ b/openmls/src/group/core_group/test_external_init.rs @@ -9,16 +9,14 @@ use crate::{ CreateCommitParams, }, messages::proposals::{ProposalOrRef, ProposalType}, - test_utils::*, + storage::OpenMlsProvider, }; - -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; +use openmls_traits::prelude::*; use super::{proposals::ProposalStore, CoreGroup}; -#[apply(ciphersuites_and_providers)] -fn test_external_init(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_external_init() { let ( framing_parameters, mut group_alice, @@ -180,11 +178,8 @@ fn test_external_init(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) ); } -#[apply(ciphersuites_and_providers)] -fn test_external_init_single_member_group( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test::openmls_test] +fn test_external_init_single_member_group() { let (mut group_alice, _alice_credential_with_key, alice_signer, _alice_pk) = setup_alice_group(ciphersuite, provider); @@ -242,8 +237,8 @@ fn test_external_init_single_member_group( ); } -#[apply(ciphersuites_and_providers)] -fn test_external_init_broken_signature(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_external_init_broken_signature() { let ( framing_parameters, group_alice, @@ -271,15 +266,19 @@ fn test_external_init_broken_signature(ciphersuite: Ciphersuite, provider: &impl .framing_parameters(framing_parameters) .proposal_store(&proposal_store) .build(); - assert_eq!( - ExternalCommitError::PublicGroupError(CreationFromExternalError::InvalidGroupInfoSignature), - CoreGroup::join_by_external_commit( - provider, - &charlie_signer, - params, - None, - verifiable_group_info + + let result = CoreGroup::join_by_external_commit( + provider, + &charlie_signer, + params, + None, + verifiable_group_info, + ) + .expect_err("Signature was corrupted. This should have failed."); + assert!(matches!( + result, + ExternalCommitError::<::StorageError>::PublicGroupError( + CreationFromExternalError::InvalidGroupInfoSignature ) - .expect_err("Signature was corrupted. This should have failed.") - ); + )); } diff --git a/openmls/src/group/core_group/test_past_secrets.rs b/openmls/src/group/core_group/test_past_secrets.rs index 2a607947de..e2654fdf00 100644 --- a/openmls/src/group/core_group/test_past_secrets.rs +++ b/openmls/src/group/core_group/test_past_secrets.rs @@ -5,8 +5,8 @@ use crate::{ schedule::message_secrets::MessageSecrets, test_utils::*, }; -#[apply(ciphersuites_and_providers)] -fn test_secret_tree_store(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_secret_tree_store() { // Create a store that keeps up to 3 epochs let mut message_secrets_store = MessageSecretsStore::new_with_secret( 3, @@ -44,8 +44,8 @@ fn test_secret_tree_store(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi assert!(message_secrets_store.secrets_for_epoch_mut(6).is_none()); } -#[apply(ciphersuites_and_providers)] -fn test_empty_secret_tree_store(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_empty_secret_tree_store() { // Create a store that keeps no epochs let mut message_secrets_store = MessageSecretsStore::new_with_secret( 0, diff --git a/openmls/src/group/core_group/test_proposals.rs b/openmls/src/group/core_group/test_proposals.rs index 41ac185d23..240c990bde 100644 --- a/openmls/src/group/core_group/test_proposals.rs +++ b/openmls/src/group/core_group/test_proposals.rs @@ -1,6 +1,3 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{key_store::OpenMlsKeyStore, types::Ciphersuite, OpenMlsProvider}; - use super::CoreGroup; use crate::{ binary_tree::LeafNodeIndex, @@ -11,10 +8,8 @@ use crate::{ mls_auth_content::AuthenticatedContent, sender::Sender, FramingParameters, WireFormat, }, group::{ - config::CryptoConfig, errors::*, proposals::{ProposalQueue, ProposalStore, QueuedProposal}, - public_group::errors::PublicGroupBuildError, test_core_group::setup_client, CreateCommitParams, GroupContext, GroupId, StagedCoreWelcome, }, @@ -28,8 +23,11 @@ use crate::{ /// This test makes sure ProposalQueue works as intended. This functionality is /// used in `create_commit` to filter the epoch proposals. Expected result: /// `filtered_queued_proposals` returns only proposals of a certain type -#[apply(ciphersuites_and_providers)] -fn proposal_queue_functions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn proposal_queue_functions( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Framing parameters let framing_parameters = FramingParameters::new(&[], WireFormat::PublicMessage); // Define identities @@ -40,7 +38,7 @@ fn proposal_queue_functions(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro let bob_key_package = bob_key_package_bundle.key_package(); let alice_update_key_package_bundle = - KeyPackageBundle::new(provider, &alice_signer, ciphersuite, alice_credential); + KeyPackageBundle::generate(provider, &alice_signer, ciphersuite, alice_credential); let alice_update_key_package = alice_update_key_package_bundle.key_package(); let kpi = KeyPackageIn::from(alice_update_key_package.clone()); assert!(kpi @@ -172,8 +170,8 @@ fn proposal_queue_functions(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro } /// Test, that we QueuedProposalQueue is iterated in the right order. -#[apply(ciphersuites_and_providers)] -fn proposal_queue_order(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn proposal_queue_order() { // Framing parameters let framing_parameters = FramingParameters::new(&[], WireFormat::PublicMessage); // Define identities @@ -184,7 +182,7 @@ fn proposal_queue_order(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide let bob_key_package = bob_key_package_bundle.key_package(); let alice_update_key_package_bundle = - KeyPackageBundle::new(provider, &alice_signer, ciphersuite, alice_credential); + KeyPackageBundle::generate(provider, &alice_signer, ciphersuite, alice_credential); let alice_update_key_package = alice_update_key_package_bundle.key_package(); let kpi = KeyPackageIn::from(alice_update_key_package.clone()); assert!(kpi @@ -280,42 +278,10 @@ fn proposal_queue_order(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide assert_eq!(proposal_collection[1].proposal(), &proposal_add_alice1); } -#[apply(ciphersuites_and_providers)] -fn test_required_unsupported_proposals(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - let (alice_credential, _, alice_signer, _alice_pk) = - setup_client("Alice", ciphersuite, provider); - - // Set required capabilities - let extensions = &[]; - let proposals = &[ProposalType::GroupContextExtensions, ProposalType::AppAck]; - let credentials = &[CredentialType::Basic]; - let required_capabilities = - RequiredCapabilitiesExtension::new(extensions, proposals, credentials); - - // This must fail because we don't actually support AppAck proposals - let e = CoreGroup::builder( - GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), - alice_credential, - ) - .with_group_context_extensions(Extensions::single(Extension::RequiredCapabilities( - required_capabilities, - ))) - .unwrap() - .build(provider, &alice_signer) - .expect_err( - "CoreGroup creation must fail because AppAck proposals aren't supported in OpenMLS yet.", - ); - assert_eq!( - e, - CoreGroupBuildError::PublicGroupBuildError(PublicGroupBuildError::UnsupportedProposalType) - ) -} - -#[apply(ciphersuites_and_providers)] +#[openmls_test::openmls_test] fn test_required_extension_key_package_mismatch( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) { // Basic group setup. let group_aad = b"Alice's test group"; @@ -337,7 +303,7 @@ fn test_required_extension_key_package_mismatch( let alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential, ) .with_group_context_extensions(Extensions::single(Extension::RequiredCapabilities( @@ -362,8 +328,11 @@ fn test_required_extension_key_package_mismatch( ); } -#[apply(ciphersuites_and_providers)] -fn test_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_group_context_extensions( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Basic group setup. let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -389,7 +358,7 @@ fn test_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenM let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential, ) .with_group_context_extensions(Extensions::single(Extension::RequiredCapabilities( @@ -443,10 +412,10 @@ fn test_group_context_extensions(ciphersuite: Ciphersuite, provider: &impl OpenM .expect("Error joining group."); } -#[apply(ciphersuites_and_providers)] +#[openmls_test::openmls_test] fn test_group_context_extension_proposal_fails( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) { // Basic group setup. let group_aad = b"Alice's test group"; @@ -471,7 +440,7 @@ fn test_group_context_extension_proposal_fails( let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential, ) .with_group_context_extensions(Extensions::single(Extension::RequiredCapabilities( @@ -542,11 +511,8 @@ fn test_group_context_extension_proposal_fails( // ); } -#[apply(ciphersuites_and_providers)] -fn test_group_context_extension_proposal( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test::openmls_test] +fn test_group_context_extension_proposal(ciphersuite: Ciphersuite, provider: &Provider) { // Basic group setup. let group_aad = b"Alice's test group"; let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -560,7 +526,7 @@ fn test_group_context_extension_proposal( let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential, ) .build(provider, &alice_signer) @@ -617,7 +583,7 @@ fn test_group_context_extension_proposal( &[CredentialType::Basic], )); let gce_proposal = alice_group - .create_group_context_ext_proposal::( + .create_group_context_ext_proposal::( framing_parameters, Extensions::single(required_application_id), &alice_signer, diff --git a/openmls/src/group/errors.rs b/openmls/src/group/errors.rs index c4bd00a17c..50e2e8317a 100644 --- a/openmls/src/group/errors.rs +++ b/openmls/src/group/errors.rs @@ -20,7 +20,7 @@ use crate::{ /// Welcome error #[derive(Error, Debug, PartialEq, Clone)] -pub enum WelcomeError { +pub enum WelcomeError { /// See [`GroupSecretsError`] for more details. #[error(transparent)] GroupSecrets(#[from] GroupSecretsError), @@ -78,24 +78,24 @@ pub enum WelcomeError { /// No matching key package was found in the key store. #[error("No matching key package was found in the key store.")] NoMatchingKeyPackage, - /// Error accessing the key store. - #[error("Error accessing the key store.")] - KeyStoreError(KeyStoreError), /// This error indicates the public tree is invalid. See [`PublicTreeError`] for more details. #[error(transparent)] PublicTreeError(#[from] PublicTreeError), /// This error indicates the public tree is invalid. See /// [`CreationFromExternalError`] for more details. #[error(transparent)] - PublicGroupError(#[from] CreationFromExternalError), + PublicGroupError(#[from] CreationFromExternalError), /// This error indicates the leaf node is invalid. See [`LeafNodeValidationError`] for more details. #[error(transparent)] LeafNodeValidation(#[from] LeafNodeValidationError), + /// This error indicates that an error occurred while reading or writing from/to storage. + #[error("An error occurred when querying storage")] + StorageError(StorageError), } /// External Commit error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ExternalCommitError { +pub enum ExternalCommitError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -120,10 +120,13 @@ pub enum ExternalCommitError { /// This error indicates the public tree is invalid. See /// [`CreationFromExternalError`] for more details. #[error(transparent)] - PublicGroupError(#[from] CreationFromExternalError), + PublicGroupError(#[from] CreationFromExternalError), /// Credential is missing from external commit. #[error("Credential is missing from external commit.")] MissingCredential, + /// An erorr occurred when writing group to storage + #[error("An error occurred when writing group to storage.")] + StorageError(StorageError), } /// Stage Commit error @@ -201,7 +204,7 @@ pub enum StageCommitError { /// Create commit error #[derive(Error, Debug, PartialEq, Clone)] -pub enum CreateCommitError { +pub enum CreateCommitError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -223,12 +226,12 @@ pub enum CreateCommitError { /// See [`ProposalValidationError`] for more details. #[error(transparent)] ProposalValidationError(#[from] ProposalValidationError), - /// Error interacting with the key store. - #[error("Error interacting with the key store.")] - KeyStoreError(KeyStoreError), + /// Error interacting with storage. + #[error("Error interacting with storage.")] + KeyStoreError(StorageError), /// See [`KeyPackageNewError`] for more details. #[error(transparent)] - KeyPackageGenerationError(#[from] KeyPackageNewError), + KeyPackageGenerationError(#[from] KeyPackageNewError), /// See [`SignatureError`] for more details. #[error(transparent)] SignatureError(#[from] SignatureError), @@ -374,6 +377,9 @@ pub enum ProposalValidationError { /// See [`PskError`] for more details. #[error(transparent)] Psk(#[from] PskError), + /// The proposal type is not supported by all group members. + #[error("The proposal type is not supported by all group members.")] + UnsupportedProposalType, } /// External Commit validaton error @@ -464,7 +470,7 @@ pub(crate) enum FromCommittedProposalsError { // Core group build error #[derive(Error, Debug, PartialEq, Clone)] -pub(crate) enum CoreGroupBuildError { +pub(crate) enum CoreGroupBuildError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -475,8 +481,8 @@ pub(crate) enum CoreGroupBuildError { #[error(transparent)] Psk(#[from] PskError), /// Error storing leaf private key in key store. - #[error("Error storing leaf private key in key store.")] - KeyStoreError(KeyStoreError), + #[error("Error saving data to storage: {0}.")] + StorageError(StorageError), } // CoreGroup parse message error @@ -494,7 +500,7 @@ pub(crate) enum CoreGroupParseMessageError { /// Create group context ext proposal error #[derive(Error, Debug, PartialEq, Clone)] -pub enum CreateGroupContextExtProposalError { +pub enum CreateGroupContextExtProposalError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -509,21 +515,21 @@ pub enum CreateGroupContextExtProposalError { LeafNodeValidation(#[from] LeafNodeValidationError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - MlsGroupStateError(#[from] MlsGroupStateError), + MlsGroupStateError(#[from] MlsGroupStateError), /// See [`CreateCommitError`] for more details. #[error(transparent)] - CreateCommitError(#[from] CreateCommitError), + CreateCommitError(#[from] CreateCommitError), } /// Error merging a commit. #[derive(Error, Debug, PartialEq, Clone)] -pub enum MergeCommitError { +pub enum MergeCommitError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), - /// Error accessing the key store. - #[error("Error accessing the key store.")] - KeyStoreError(KeyStoreError), + /// Error writing updated group to storage. + #[error("Error writing updated group data to storage.")] + StorageError(StorageError), } /// Error validation a GroupContextExtensions proposal. diff --git a/openmls/src/group/mls_group/application.rs b/openmls/src/group/mls_group/application.rs index 4ac7ef7fdc..52601ff4a1 100644 --- a/openmls/src/group/mls_group/application.rs +++ b/openmls/src/group/mls_group/application.rs @@ -1,5 +1,7 @@ use openmls_traits::signatures::Signer; +use crate::storage::OpenMlsProvider; + use super::{errors::CreateMessageError, *}; impl MlsGroup { @@ -11,12 +13,12 @@ impl MlsGroup { /// Returns `CreateMessageError::MlsGroupStateError::PendingProposal` if pending proposals /// exist. In that case `.process_pending_proposals()` must be called first /// and incoming messages from the DS must be processed afterwards. - pub fn create_message( + pub fn create_message( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, message: &[u8], - ) -> Result { + ) -> Result> { if !self.is_active() { return Err(CreateMessageError::GroupStateError( MlsGroupStateError::UseAfterEviction, @@ -40,9 +42,6 @@ impl MlsGroup { // We know the application message is wellformed and we have the key material of the current epoch .map_err(|_| LibraryError::custom("Malformed plaintext"))?; - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok(MlsMessageOut::from_private_message( ciphertext, self.group.version(), diff --git a/openmls/src/group/mls_group/builder.rs b/openmls/src/group/mls_group/builder.rs index 6607727dbb..59fe151b02 100644 --- a/openmls/src/group/mls_group/builder.rs +++ b/openmls/src/group/mls_group/builder.rs @@ -1,20 +1,21 @@ -use openmls_traits::{key_store::OpenMlsKeyStore, signatures::Signer, OpenMlsProvider}; +use openmls_traits::{signatures::Signer, types::Ciphersuite}; use crate::{ credentials::CredentialWithKey, error::LibraryError, extensions::{errors::InvalidExtensionError, Extensions}, group::{ - config::CryptoConfig, public_group::errors::PublicGroupBuildError, CoreGroup, - CoreGroupBuildError, CoreGroupConfig, GroupId, MlsGroupCreateConfig, - MlsGroupCreateConfigBuilder, NewGroupError, ProposalStore, WireFormatPolicy, + public_group::errors::PublicGroupBuildError, CoreGroup, CoreGroupBuildError, + CoreGroupConfig, GroupId, MlsGroupCreateConfig, MlsGroupCreateConfigBuilder, NewGroupError, + ProposalStore, WireFormatPolicy, }, key_packages::Lifetime, + storage::OpenMlsProvider, tree::sender_ratchet::SenderRatchetConfiguration, treesync::node::leaf_node::Capabilities, }; -use super::{InnerState, MlsGroup, MlsGroupState}; +use super::{MlsGroup, MlsGroupState}; #[derive(Default, Debug)] pub struct MlsGroupBuilder { @@ -34,12 +35,12 @@ impl MlsGroupBuilder { } /// Build a new group as configured by this builder. - pub fn build( + pub fn build( self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, credential_with_key: CredentialWithKey, - ) -> Result> { + ) -> Result> { self.build_internal(provider, signer, credential_with_key, None) } @@ -48,13 +49,13 @@ impl MlsGroupBuilder { /// If an [`MlsGroupCreateConfig`] is provided, it will be used to configure the /// group. Otherwise, the internal builder is used to build one with the /// parameters set on this builder. - pub(super) fn build_internal( + pub(super) fn build_internal( self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, credential_with_key: CredentialWithKey, mls_group_create_config_option: Option, - ) -> Result> { + ) -> Result> { let mls_group_create_config = mls_group_create_config_option .unwrap_or_else(|| self.mls_group_create_config_builder.build()); let group_id = self @@ -69,7 +70,7 @@ impl MlsGroupBuilder { let mut group = CoreGroup::builder( group_id, - mls_group_create_config.crypto_config, + mls_group_create_config.ciphersuite, credential_with_key, ) .with_config(group_config) @@ -86,15 +87,9 @@ impl MlsGroupBuilder { log::debug!("Unexpected PSK error: {:?}", e); LibraryError::custom("Unexpected PSK error").into() } - CoreGroupBuildError::KeyStoreError(e) => NewGroupError::KeyStoreError(e), + CoreGroupBuildError::StorageError(e) => NewGroupError::StorageError(e), CoreGroupBuildError::PublicGroupBuildError(e) => match e { PublicGroupBuildError::LibraryError(e) => e.into(), - PublicGroupBuildError::UnsupportedProposalType => { - NewGroupError::UnsupportedProposalType - } - PublicGroupBuildError::UnsupportedExtensionType => { - NewGroupError::UnsupportedExtensionType - } PublicGroupBuildError::InvalidExtensions(e) => NewGroupError::InvalidExtensions(e), }, })?; @@ -112,9 +107,23 @@ impl MlsGroupBuilder { own_leaf_nodes: vec![], aad: vec![], group_state: MlsGroupState::Operational, - state_changed: InnerState::Changed, }; + use openmls_traits::storage::StorageProvider as _; + + provider + .storage() + .write_mls_join_config(mls_group.group_id(), &mls_group.mls_group_config) + .map_err(NewGroupError::StorageError)?; + provider + .storage() + .write_group_state(mls_group.group_id(), &mls_group.group_state) + .map_err(NewGroupError::StorageError)?; + mls_group + .group + .store(provider.storage()) + .map_err(NewGroupError::StorageError)?; + Ok(mls_group) } @@ -188,10 +197,11 @@ impl MlsGroupBuilder { self } - /// Sets the `crypto_config` of the MlsGroup. - pub fn crypto_config(mut self, config: CryptoConfig) -> Self { - self.mls_group_create_config_builder = - self.mls_group_create_config_builder.crypto_config(config); + /// Sets the `ciphersuite` of the MlsGroup. + pub fn ciphersuite(mut self, ciphersuite: Ciphersuite) -> Self { + self.mls_group_create_config_builder = self + .mls_group_create_config_builder + .ciphersuite(ciphersuite); self } diff --git a/openmls/src/group/mls_group/config.rs b/openmls/src/group/mls_group/config.rs index 041fbcd8d8..faf4804adc 100644 --- a/openmls/src/group/mls_group/config.rs +++ b/openmls/src/group/mls_group/config.rs @@ -29,7 +29,7 @@ use super::*; use crate::{ - extensions::errors::InvalidExtensionError, group::config::CryptoConfig, key_packages::Lifetime, + extensions::errors::InvalidExtensionError, key_packages::Lifetime, tree::sender_ratchet::SenderRatchetConfiguration, treesync::node::leaf_node::Capabilities, }; use serde::{Deserialize, Serialize}; @@ -81,14 +81,14 @@ impl MlsGroupJoinConfig { /// Specifies configuration for the creation of an [`MlsGroup`]. Refer to the /// [User Manual](https://openmls.tech/book/user_manual/group_config.html) for /// more information about the different configuration values. -#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct MlsGroupCreateConfig { /// Capabilities advertised in the creator's leaf node pub(crate) capabilities: Capabilities, /// Lifetime of the own leaf node pub(crate) lifetime: Lifetime, /// Ciphersuite and protocol version - pub(crate) crypto_config: CryptoConfig, + pub(crate) ciphersuite: Ciphersuite, /// Configuration parameters relevant to group operation at runtime pub(crate) join_config: MlsGroupJoinConfig, /// List of initial group context extensions @@ -97,6 +97,19 @@ pub struct MlsGroupCreateConfig { pub(crate) leaf_node_extensions: Extensions, } +impl Default for MlsGroupCreateConfig { + fn default() -> Self { + Self { + capabilities: Capabilities::default(), + lifetime: Lifetime::default(), + ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, + join_config: MlsGroupJoinConfig::default(), + group_context_extensions: Extensions::default(), + leaf_node_extensions: Extensions::default(), + } + } +} + /// Builder struct for an [`MlsGroupJoinConfig`]. #[derive(Default)] pub struct MlsGroupJoinConfigBuilder { @@ -204,9 +217,9 @@ impl MlsGroupCreateConfig { &self.lifetime } - /// Returns the [`CryptoConfig`]. - pub fn crypto_config(&self) -> &CryptoConfig { - &self.crypto_config + /// Returns the [`Ciphersuite`]. + pub fn ciphersuite(&self) -> Ciphersuite { + self.ciphersuite } #[cfg(any(feature = "test-utils", test))] @@ -216,7 +229,7 @@ impl MlsGroupCreateConfig { OutgoingWireFormatPolicy::AlwaysPlaintext, IncomingWireFormatPolicy::Mixed, )) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build() } @@ -301,9 +314,9 @@ impl MlsGroupCreateConfigBuilder { self } - /// Sets the `crypto_config` property of the MlsGroupCreateConfig. - pub fn crypto_config(mut self, config: CryptoConfig) -> Self { - self.config.crypto_config = config; + /// Sets the `ciphersuite` property of the MlsGroupCreateConfig. + pub fn ciphersuite(mut self, ciphersuite: Ciphersuite) -> Self { + self.config.ciphersuite = ciphersuite; self } diff --git a/openmls/src/group/mls_group/creation.rs b/openmls/src/group/mls_group/creation.rs index 4161861d60..ff0522bb86 100644 --- a/openmls/src/group/mls_group/creation.rs +++ b/openmls/src/group/mls_group/creation.rs @@ -1,8 +1,7 @@ -use openmls_traits::signatures::Signer; +use openmls_traits::{signatures::Signer, storage::StorageProvider as StorageProviderTrait}; use super::{builder::MlsGroupBuilder, *}; use crate::{ - ciphersuite::HpkePrivateKey, credentials::CredentialWithKey, group::{ core_group::create_commit_params::CreateCommitParams, @@ -13,6 +12,7 @@ use crate::{ Welcome, }, schedule::psk::store::ResumptionPskStore, + storage::OpenMlsProvider, treesync::RatchetTreeIn, }; @@ -30,12 +30,12 @@ impl MlsGroup { /// /// This function removes the private key corresponding to the /// `key_package` from the key store. - pub fn new( - provider: &impl OpenMlsProvider, + pub fn new( + provider: &Provider, signer: &impl Signer, mls_group_create_config: &MlsGroupCreateConfig, credential_with_key: CredentialWithKey, - ) -> Result> { + ) -> Result> { MlsGroupBuilder::new().build_internal( provider, signer, @@ -46,13 +46,13 @@ impl MlsGroup { /// Creates a new group with a given group ID with the creator as the only /// member. - pub fn new_with_group_id( - provider: &impl OpenMlsProvider, + pub fn new_with_group_id( + provider: &Provider, signer: &impl Signer, mls_group_create_config: &MlsGroupCreateConfig, group_id: GroupId, credential_with_key: CredentialWithKey, - ) -> Result> { + ) -> Result> { MlsGroupBuilder::new() .with_group_id(group_id) .build_internal( @@ -77,15 +77,16 @@ impl MlsGroup { /// /// Note: If there is a group member in the group with the same identity as /// us, this will create a remove proposal. - pub fn join_by_external_commit( - provider: &impl OpenMlsProvider, + pub fn join_by_external_commit( + provider: &Provider, signer: &impl Signer, ratchet_tree: Option, verifiable_group_info: VerifiableGroupInfo, mls_group_config: &MlsGroupJoinConfig, aad: &[u8], credential_with_key: CredentialWithKey, - ) -> Result<(Self, MlsMessageOut, Option), ExternalCommitError> { + ) -> Result<(Self, MlsMessageOut, Option), ExternalCommitError> + { // Prepare the commit parameters let framing_parameters = FramingParameters::new(aad, WireFormat::PublicMessage); @@ -113,9 +114,21 @@ impl MlsGroup { group_state: MlsGroupState::PendingCommit(Box::new(PendingCommitState::External( create_commit_result.staged_commit, ))), - state_changed: InnerState::Changed, }; + provider + .storage() + .write_mls_join_config(mls_group.group_id(), &mls_group.mls_group_config) + .map_err(ExternalCommitError::StorageError)?; + provider + .storage() + .write_group_state(mls_group.group_id(), &mls_group.group_state) + .map_err(ExternalCommitError::StorageError)?; + mls_group + .group + .store(provider.storage()) + .map_err(ExternalCommitError::StorageError)?; + let public_message: PublicMessage = create_commit_result.commit.into(); Ok(( @@ -126,6 +139,14 @@ impl MlsGroup { } } +fn transpose_err_opt(v: Result, E>) -> Option> { + match v { + Ok(Some(v)) => Some(Ok(v)), + Ok(None) => None, + Err(err) => Some(Err(err)), + } +} + impl StagedWelcome { /// Creates a new staged welcome from a [`Welcome`] message. Returns an error /// ([`WelcomeError::NoMatchingKeyPackage`]) if no [`KeyPackage`] @@ -134,44 +155,36 @@ impl StagedWelcome { /// message, even if the caller does not turn the [`StagedWelcome`] into an [`MlsGroup`]. /// /// [`Welcome`]: crate::messages::Welcome - pub fn new_from_welcome( - provider: &impl OpenMlsProvider, + pub fn new_from_welcome( + provider: &Provider, mls_group_config: &MlsGroupJoinConfig, welcome: Welcome, ratchet_tree: Option, - ) -> Result> { + ) -> Result> { let resumption_psk_store = ResumptionPskStore::new(mls_group_config.number_of_resumption_psks); - let (key_package, _) = welcome + let key_package_bundle: KeyPackageBundle = welcome .secrets() .iter() .find_map(|egs| { - let new_member = egs.new_member(); - let hash_ref = new_member.as_slice(); - provider - .key_store() - .read(hash_ref) - .map(|kp: KeyPackage| (kp, hash_ref.to_vec())) + let hash_ref = egs.new_member(); + + transpose_err_opt( + provider + .storage() + .key_package(&hash_ref) + .map_err(WelcomeError::StorageError), + ) }) - .ok_or(WelcomeError::NoMatchingKeyPackage)?; - - // TODO #751 - let private_key = provider - .key_store() - .read::(key_package.hpke_init_key().as_slice()) - .ok_or(WelcomeError::NoMatchingKeyPackage)?; - let key_package_bundle = KeyPackageBundle { - key_package, - private_key, - }; + .ok_or(WelcomeError::NoMatchingKeyPackage)??; // Delete the [`KeyPackage`] and the corresponding private key from the // key store, but only if it doesn't have a last resort extension. if !key_package_bundle.key_package().last_resort() { - key_package_bundle - .key_package - .delete(provider) - .map_err(WelcomeError::KeyStoreError)?; + provider + .storage() + .delete_key_package(&key_package_bundle.key_package.hash_ref(provider.crypto())?) + .map_err(WelcomeError::StorageError)?; } else { log::debug!("Key package has last resort extension, not deleting"); } @@ -207,10 +220,10 @@ impl StagedWelcome { } /// Consumes the [`StagedWelcome`] and returns the respective [`MlsGroup`]. - pub fn into_group( + pub fn into_group( self, - provider: &impl OpenMlsProvider, - ) -> Result> { + provider: &Provider, + ) -> Result> { let mut group = self.group.into_core_group(provider)?; group.set_max_past_epochs(self.mls_group_config.max_past_epochs); @@ -221,9 +234,17 @@ impl StagedWelcome { own_leaf_nodes: vec![], aad: vec![], group_state: MlsGroupState::Operational, - state_changed: InnerState::Changed, }; + provider + .storage() + .write_mls_join_config(mls_group.group_id(), &mls_group.mls_group_config) + .map_err(WelcomeError::StorageError)?; + provider + .storage() + .write_group_state(mls_group.group_id(), &MlsGroupState::Operational) + .map_err(WelcomeError::StorageError)?; + Ok(mls_group) } } diff --git a/openmls/src/group/mls_group/errors.rs b/openmls/src/group/mls_group/errors.rs index 01d4c82744..abf94bdbec 100644 --- a/openmls/src/group/mls_group/errors.rs +++ b/openmls/src/group/mls_group/errors.rs @@ -23,16 +23,16 @@ use crate::{ /// New group error #[derive(Error, Debug, PartialEq, Clone)] -pub enum NewGroupError { +pub enum NewGroupError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// No matching KeyPackage was found in the key store. #[error("No matching KeyPackage was found in the key store.")] NoMatchingKeyPackage, - /// Error accessing the key store. - #[error("Error accessing the key store.")] - KeyStoreError(KeyStoreError), + /// Error accessing the storage. + #[error("Error accessing the storage.")] + StorageError(StorageError), /// Unsupported proposal type in required capabilities. #[error("Unsupported proposal type in required capabilities.")] UnsupportedProposalType, @@ -57,7 +57,7 @@ pub enum EmptyInputError { /// Group state error #[derive(Error, Debug, PartialEq, Clone)] -pub enum MlsGroupStateError { +pub enum MlsGroupStateError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -76,22 +76,25 @@ pub enum MlsGroupStateError { /// Requested pending proposal hasn't been found in local pending proposals #[error("Requested pending proposal hasn't been found in local pending proposals.")] PendingProposalNotFound, + /// An error ocurred while writing to storage + #[error("An error ocurred while writing to storage")] + StorageError(StorageError), } /// Error merging pending commit #[derive(Error, Debug, PartialEq, Clone)] -pub enum MergePendingCommitError { +pub enum MergePendingCommitError { /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - MlsGroupStateError(#[from] MlsGroupStateError), + MlsGroupStateError(#[from] MlsGroupStateError), /// See [`MergeCommitError`] for more details. #[error(transparent)] - MergeCommitError(#[from] MergeCommitError), + MergeCommitError(#[from] MergeCommitError), } /// Process message error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ProcessMessageError { +pub enum ProcessMessageError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -103,7 +106,7 @@ pub enum ProcessMessageError { ValidationError(#[from] ValidationError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), /// The message's signature is invalid. #[error("The message's signature is invalid.")] InvalidSignature, @@ -120,18 +123,18 @@ pub enum ProcessMessageError { /// Create message error #[derive(Error, Debug, PartialEq, Clone)] -pub enum CreateMessageError { +pub enum CreateMessageError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), } /// Add members error #[derive(Error, Debug, PartialEq, Clone)] -pub enum AddMembersError { +pub enum AddMembersError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -140,15 +143,18 @@ pub enum AddMembersError { EmptyInput(#[from] EmptyInputError), /// See [`CreateCommitError`] for more details. #[error(transparent)] - CreateCommitError(#[from] CreateCommitError), + CreateCommitError(#[from] CreateCommitError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), + /// Error writing to storage. + #[error("Error writing to storage")] + StorageError(StorageError), } /// Propose add members error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ProposeAddMemberError { +pub enum ProposeAddMemberError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -157,29 +163,35 @@ pub enum ProposeAddMemberError { UnsupportedExtensions, /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), /// See [`LeafNodeValidationError`] for more details. #[error(transparent)] LeafNodeValidation(#[from] LeafNodeValidationError), + /// Error writing to storage + #[error("Error writing to storage: {0}")] + StorageError(StorageError), } /// Propose remove members error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ProposeRemoveMemberError { +pub enum ProposeRemoveMemberError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), /// The member that should be removed can not be found. #[error("The member that should be removed can not be found.")] UnknownMember, + /// Error writing to storage + #[error("Error writing to storage: {0}")] + StorageError(StorageError), } /// Remove members error #[derive(Error, Debug, PartialEq, Clone)] -pub enum RemoveMembersError { +pub enum RemoveMembersError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -188,56 +200,59 @@ pub enum RemoveMembersError { EmptyInput(#[from] EmptyInputError), /// See [`CreateCommitError`] for more details. #[error(transparent)] - CreateCommitError(#[from] CreateCommitError), + CreateCommitError(#[from] CreateCommitError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), /// The member that should be removed can not be found. #[error("The member that should be removed can not be found.")] UnknownMember, + /// Error writing to storage + #[error("Error writing to storage: {0}")] + StorageError(StorageError), } /// Leave group error #[derive(Error, Debug, PartialEq, Clone)] -pub enum LeaveGroupError { +pub enum LeaveGroupError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), } /// Self update error #[derive(Error, Debug, PartialEq, Clone)] -pub enum SelfUpdateError { +pub enum SelfUpdateError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`CreateCommitError`] for more details. #[error(transparent)] - CreateCommitError(#[from] CreateCommitError), + CreateCommitError(#[from] CreateCommitError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), - /// Error accessing the key store. - #[error("Error accessing the key store.")] - KeyStoreError, + GroupStateError(#[from] MlsGroupStateError), + /// Error accessing the storage. + #[error("Error accessing the storage.")] + StorageError(StorageError), } /// Propose self update error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ProposeSelfUpdateError { +pub enum ProposeSelfUpdateError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), - /// Error accessing the key store. - #[error("Error accessing the key store.")] - KeyStoreError(KeyStoreError), + GroupStateError(#[from] MlsGroupStateError), + /// Error accessing storage. + #[error("Error accessing storage.")] + StorageError(StorageError), /// See [`PublicTreeError`] for more details. #[error(transparent)] PublicTreeError(#[from] PublicTreeError), @@ -245,32 +260,35 @@ pub enum ProposeSelfUpdateError { /// Commit to pending proposals error #[derive(Error, Debug, PartialEq, Clone)] -pub enum CommitToPendingProposalsError { +pub enum CommitToPendingProposalsError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`CreateCommitError`] for more details. #[error(transparent)] - CreateCommitError(#[from] CreateCommitError), + CreateCommitError(#[from] CreateCommitError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), + /// Error writing to storage + #[error("Error writing to storage: {0}")] + StorageError(StorageError), } /// Errors that can happen when exporting a group info object. #[derive(Error, Debug, PartialEq, Clone)] -pub enum ExportGroupInfoError { +pub enum ExportGroupInfoError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), } /// Export secret error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ExportSecretError { +pub enum ExportSecretError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -279,18 +297,18 @@ pub enum ExportSecretError { KeyLengthTooLong, /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), } /// Propose PSK error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ProposePskError { +pub enum ProposePskError { /// See [`PskError`] for more details. #[error(transparent)] Psk(#[from] PskError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -298,29 +316,32 @@ pub enum ProposePskError { /// Export secret error #[derive(Error, Debug, PartialEq, Clone)] -pub enum ProposalError { +pub enum ProposalError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// See [`ProposeAddMemberError`] for more details. #[error(transparent)] - ProposeAddMemberError(#[from] ProposeAddMemberError), + ProposeAddMemberError(#[from] ProposeAddMemberError), /// See [`CreateAddProposalError`] for more details. #[error(transparent)] CreateAddProposalError(#[from] CreateAddProposalError), /// See [`ProposeSelfUpdateError`] for more details. #[error(transparent)] - ProposeSelfUpdateError(#[from] ProposeSelfUpdateError), + ProposeSelfUpdateError(#[from] ProposeSelfUpdateError), /// See [`ProposeRemoveMemberError`] for more details. #[error(transparent)] - ProposeRemoveMemberError(#[from] ProposeRemoveMemberError), + ProposeRemoveMemberError(#[from] ProposeRemoveMemberError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - GroupStateError(#[from] MlsGroupStateError), + GroupStateError(#[from] MlsGroupStateError), /// See [`ValidationError`] for more details. #[error(transparent)] ValidationError(#[from] ValidationError), /// See [`CreateGroupContextExtProposalError`] for more details. #[error(transparent)] - CreateGroupContextExtProposalError(#[from] CreateGroupContextExtProposalError), + CreateGroupContextExtProposalError(#[from] CreateGroupContextExtProposalError), + /// Error writing proposal to storage. + #[error("error writing proposal to storage")] + StorageError(StorageError), } diff --git a/openmls/src/group/mls_group/exporting.rs b/openmls/src/group/mls_group/exporting.rs index 19acb98904..36da09ea8d 100644 --- a/openmls/src/group/mls_group/exporting.rs +++ b/openmls/src/group/mls_group/exporting.rs @@ -1,7 +1,6 @@ -use openmls_traits::crypto::OpenMlsCrypto; use openmls_traits::signatures::Signer; -use crate::{group::errors::ExporterError, schedule::EpochAuthenticator}; +use crate::{group::errors::ExporterError, schedule::EpochAuthenticator, storage::OpenMlsProvider}; use super::*; @@ -13,13 +12,15 @@ impl MlsGroup { /// key length is too long. /// Returns [`ExportSecretError::GroupStateError(MlsGroupStateError::UseAfterEviction)`](MlsGroupStateError::UseAfterEviction) /// if the group is not active. - pub fn export_secret( + pub fn export_secret( &self, - crypto: &impl OpenMlsCrypto, + provider: &Provider, label: &str, context: &[u8], key_length: usize, - ) -> Result, ExportSecretError> { + ) -> Result, ExportSecretError> { + let crypto = provider.crypto(); + if self.is_active() { Ok(self .group @@ -52,15 +53,15 @@ impl MlsGroup { } /// Export a group info object for this group. - pub fn export_group_info( + pub fn export_group_info( &self, - crypto: &impl OpenMlsCrypto, + provider: &Provider, signer: &impl Signer, with_ratchet_tree: bool, - ) -> Result { + ) -> Result> { Ok(self .group - .export_group_info(crypto, signer, with_ratchet_tree)? + .export_group_info(provider.crypto(), signer, with_ratchet_tree)? .into()) } } diff --git a/openmls/src/group/mls_group/membership.rs b/openmls/src/group/mls_group/membership.rs index 043b545db3..b22c723ed4 100644 --- a/openmls/src/group/mls_group/membership.rs +++ b/openmls/src/group/mls_group/membership.rs @@ -3,7 +3,7 @@ //! This module contains membership-related operations and exposes [`RemoveOperation`]. use core_group::create_commit_params::CreateCommitParams; -use openmls_traits::signatures::Signer; +use openmls_traits::{signatures::Signer, storage::StorageProvider as _}; use super::{ errors::{AddMembersError, LeaveGroupError, RemoveMembersError}, @@ -11,7 +11,7 @@ use super::{ }; use crate::{ binary_tree::array_representation::LeafNodeIndex, messages::group_info::GroupInfo, - treesync::LeafNode, + storage::OpenMlsProvider, treesync::LeafNode, }; impl MlsGroup { @@ -31,13 +31,15 @@ impl MlsGroup { /// [`Welcome`]: crate::messages::Welcome // FIXME: #1217 #[allow(clippy::type_complexity)] - pub fn add_members( + pub fn add_members( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, key_packages: &[KeyPackage], - ) -> Result<(MlsMessageOut, MlsMessageOut, Option), AddMembersError> - { + ) -> Result< + (MlsMessageOut, MlsMessageOut, Option), + AddMembersError, + > { self.is_operational()?; if key_packages.is_empty() { @@ -80,8 +82,10 @@ impl MlsGroup { create_commit_result.staged_commit, ))); - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + provider + .storage() + .write_group_state(self.group_id(), &self.group_state) + .map_err(AddMembersError::StorageError)?; Ok(( mls_messages, @@ -111,14 +115,14 @@ impl MlsGroup { /// [`Welcome`]: crate::messages::Welcome // FIXME: #1217 #[allow(clippy::type_complexity)] - pub fn remove_members( + pub fn remove_members( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, members: &[LeafNodeIndex], ) -> Result< (MlsMessageOut, Option, Option), - RemoveMembersError, + RemoveMembersError, > { self.is_operational()?; @@ -153,8 +157,10 @@ impl MlsGroup { create_commit_result.staged_commit, ))); - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + provider + .storage() + .write_group_state(self.group_id(), &self.group_state) + .map_err(RemoveMembersError::StorageError)?; Ok(( mls_message, @@ -171,11 +177,11 @@ impl MlsGroup { /// The Remove Proposal is returned as a [`MlsMessageOut`]. /// /// Returns an error if there is a pending commit. - pub fn leave_group( + pub fn leave_group( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, - ) -> Result { + ) -> Result> { self.is_operational()?; let removed = self.group.own_leaf_index(); diff --git a/openmls/src/group/mls_group/mod.rs b/openmls/src/group/mls_group/mod.rs index ce5d88e966..501dc9c0d3 100644 --- a/openmls/src/group/mls_group/mod.rs +++ b/openmls/src/group/mls_group/mod.rs @@ -13,9 +13,10 @@ use crate::{ key_packages::{KeyPackage, KeyPackageBundle}, messages::proposals::*, schedule::ResumptionPskSecret, + storage::{OpenMlsProvider, StorageProvider}, treesync::{node::leaf_node::LeafNode, RatchetTree}, }; -use openmls_traits::{key_store::OpenMlsKeyStore, types::Ciphersuite, OpenMlsProvider}; +use openmls_traits::types::Ciphersuite; // Private mod application; @@ -25,7 +26,6 @@ mod exporting; mod updates; use config::*; -use errors::*; // Crate pub(crate) mod config; @@ -167,10 +167,6 @@ pub struct MlsGroup { // A variable that indicates the state of the group. See [`MlsGroupState`] // for more information. group_state: MlsGroupState, - // A flag that indicates if the group state has changed and needs to be persisted again. The value - // is set to `InnerState::Changed` whenever an the internal group state is change and is set to - // `InnerState::Persisted` once the state has been persisted. - state_changed: InnerState, } impl MlsGroup { @@ -182,11 +178,13 @@ impl MlsGroup { } /// Sets the configuration. - pub fn set_configuration(&mut self, mls_group_config: &MlsGroupJoinConfig) { + pub fn set_configuration( + &mut self, + storage: &Storage, + mls_group_config: &MlsGroupJoinConfig, + ) -> Result<(), Storage::Error> { self.mls_group_config = mls_group_config.clone(); - - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + storage.write_mls_join_config(self.group_id(), mls_group_config) } /// Returns the AAD used in the framing. @@ -195,11 +193,13 @@ impl MlsGroup { } /// Sets the AAD used in the framing. - pub fn set_aad(&mut self, aad: &[u8]) { + pub fn set_aad( + &mut self, + storage: &Storage, + aad: &[u8], + ) -> Result<(), Storage::Error> { self.aad = aad.to_vec(); - - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + storage.write_aad(self.group_id(), aad) } // === Advanced functions === @@ -217,7 +217,9 @@ impl MlsGroup { /// Returns own credential. If the group is inactive, it returns a /// `UseAfterEviction` error. - pub fn credential(&self) -> Result<&Credential, MlsGroupStateError> { + pub fn credential( + &self, + ) -> Result<&Credential, MlsGroupStateError> { if !self.is_active() { return Err(MlsGroupStateError::UseAfterEviction); } @@ -277,14 +279,20 @@ impl MlsGroup { /// the pending commit will not be used in the group. In particular, if a /// pending commit is later accepted by the group, this client will lack the /// key material to encrypt or decrypt group messages. - pub fn clear_pending_commit(&mut self) { + pub fn clear_pending_commit( + &mut self, + storage: &Storage, + ) -> Result<(), Storage::Error> { match self.group_state { MlsGroupState::PendingCommit(ref pending_commit_state) => { if let PendingCommitState::Member(_) = **pending_commit_state { - self.group_state = MlsGroupState::Operational + self.group_state = MlsGroupState::Operational; + storage.write_group_state(self.group_id(), &self.group_state) + } else { + Ok(()) } } - MlsGroupState::Operational | MlsGroupState::Inactive => (), + MlsGroupState::Operational | MlsGroupState::Inactive => Ok(()), } } @@ -294,14 +302,20 @@ impl MlsGroup { /// a Commit message that references those proposals. Only use this /// function as a last resort, e.g. when a call to /// `MlsGroup::commit_to_pending_proposals` fails. - pub fn clear_pending_proposals(&mut self) { + pub fn clear_pending_proposals( + &mut self, + storage: &Storage, + ) -> Result<(), Storage::Error> { // If the proposal store is not empty... if !self.proposal_store.is_empty() { // Empty the proposal store self.proposal_store.empty(); - // Since the state of the group is changed, arm the state flag - self.flag_state_change(); + + // Clear proposals in storage + storage.clear_proposal_queue::(self.group_id())?; } + + Ok(()) } /// Get a reference to the group context [`Extensions`] of this [`MlsGroup`]. @@ -309,30 +323,62 @@ impl MlsGroup { self.group.public_group().group_context().extensions() } - // === Load & save === + /// Returns the index of the sender of a staged, external commit. + pub fn ext_commit_sender_index( + &self, + commit: &StagedCommit, + ) -> Result { + self.group.public_group().ext_commit_sender_index(commit) + } + + // === Storage Methods === + + /// Loads the state of the group with given id from persisted state. + pub fn load( + storage: &Storage, + group_id: &GroupId, + ) -> Result, Storage::Error> { + let group_config = storage.mls_group_join_config(group_id)?; + let core_group = CoreGroup::load(storage, group_id)?; + let proposals: Vec<(ProposalRef, QueuedProposal)> = storage.queued_proposals(group_id)?; + let own_leaf_nodes = storage.own_leaf_nodes(group_id)?; + let aad = storage.aad(group_id)?; + let group_state = storage.group_state(group_id)?; + let mut proposal_store = ProposalStore::new(); + + for (_ref, proposal) in proposals { + proposal_store.add(proposal); + } - /// Loads the state from persisted state. - pub fn load(group_id: &GroupId, store: &impl OpenMlsKeyStore) -> Option { - store.read(group_id.as_slice()) + let build = || -> Option { + Some(Self { + mls_group_config: group_config?, + group: core_group?, + proposal_store, + own_leaf_nodes, + aad, + group_state: group_state?, + }) + }; + + Ok(build()) } - /// Persists the state. - pub fn save( + /// Remove the persisted state from storage + pub fn delete( &mut self, - store: &KeyStore, - ) -> Result<(), KeyStore::Error> { - store.store(self.group_id().as_slice(), &*self)?; + storage: &StorageProvider, + ) -> Result<(), StorageProvider::Error> { + self.group.delete(storage)?; + storage.delete_group_config(self.group_id())?; + storage.clear_proposal_queue::(self.group_id())?; + storage.delete_own_leaf_nodes(self.group_id())?; + storage.delete_aad(self.group_id())?; + storage.delete_group_state(self.group_id())?; - self.state_changed = InnerState::Persisted; Ok(()) } - /// Returns `true` if the internal state has changed and needs to be persisted and - /// `false` otherwise. Calling [`Self::save()`] resets the value to `false`. - pub fn state_changed(&self) -> InnerState { - self.state_changed - } - // === Extensions === /// Exports the Ratchet Tree. @@ -358,6 +404,7 @@ impl MlsGroup { if plaintext.sender().is_member() { plaintext.set_membership_tag( provider.crypto(), + self.ciphersuite(), self.group.message_secrets().membership_key(), self.group.message_secrets().serialized_context(), )?; @@ -380,11 +427,6 @@ impl MlsGroup { Ok(msg) } - /// Arm the state changed flag function - fn flag_state_change(&mut self) { - self.state_changed = InnerState::Changed; - } - /// Group framing parameters pub(crate) fn framing_parameters(&self) -> FramingParameters { FramingParameters::new( @@ -395,7 +437,7 @@ impl MlsGroup { /// Check if the group is operational. Throws an error if the group is /// inactive or if there is a pending commit. - fn is_operational(&self) -> Result<(), MlsGroupStateError> { + fn is_operational(&self) -> Result<(), MlsGroupStateError> { match self.group_state { MlsGroupState::PendingCommit(_) => Err(MlsGroupStateError::PendingCommit), MlsGroupState::Inactive => Err(MlsGroupStateError::UseAfterEviction), @@ -428,27 +470,20 @@ impl MlsGroup { } /// Removes a specific proposal from the store. - pub fn remove_pending_proposal( + pub fn remove_pending_proposal( &mut self, + storage: &Storage, proposal_ref: ProposalRef, - ) -> Result<(), MlsGroupStateError> { + ) -> Result<(), MlsGroupStateError> { + storage + .remove_proposal(self.group_id(), &proposal_ref) + .map_err(MlsGroupStateError::StorageError)?; self.proposal_store .remove(proposal_ref) .ok_or(MlsGroupStateError::PendingProposalNotFound) } } -/// `Enum` that indicates whether the inner group state has been modified since the last time it was persisted. -/// `InnerState::Changed` indicates that the state has changed and that [`.save()`] should be called. -/// `InnerState::Persisted` indicates that the state has not been modified and therefore doesn't need to be persisted. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum InnerState { - /// The inner group state has changed and needs to be persisted. - Changed, - /// The inner group state hasn't changed and doesn't need to be persisted. - Persisted, -} - /// A [`StagedWelcome`] can be inspected and then turned into a [`MlsGroup`]. /// This allows checking who authored the Welcome message. #[derive(Debug)] diff --git a/openmls/src/group/mls_group/processing.rs b/openmls/src/group/mls_group/processing.rs index 8875e8e239..bc61815a28 100644 --- a/openmls/src/group/mls_group/processing.rs +++ b/openmls/src/group/mls_group/processing.rs @@ -3,8 +3,9 @@ use std::mem; use core_group::staged_commit::StagedCommit; -use openmls_traits::signatures::Signer; +use openmls_traits::{signatures::Signer, storage::StorageProvider as _}; +use crate::storage::OpenMlsProvider; use crate::{ group::core_group::create_commit_params::CreateCommitParams, messages::group_info::GroupInfo, }; @@ -23,11 +24,11 @@ impl MlsGroup { /// # Errors: /// Returns an [`ProcessMessageError`] when the validation checks fail /// with the exact reason of the failure. - pub fn process_message( + pub fn process_message( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, message: impl Into, - ) -> Result { + ) -> Result> { // Make sure we are still a member of the group if !self.is_active() { return Err(ProcessMessageError::GroupStateError( @@ -48,9 +49,6 @@ impl MlsGroup { return Err(ProcessMessageError::IncompatibleWireFormat); } - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - // Parse the message let sender_ratchet_configuration = self.configuration().sender_ratchet_configuration().clone(); @@ -64,12 +62,16 @@ impl MlsGroup { } /// Stores a standalone proposal in the internal [ProposalStore] - pub fn store_pending_proposal(&mut self, proposal: QueuedProposal) { + pub fn store_pending_proposal( + &mut self, + storage: &Storage, + proposal: QueuedProposal, + ) -> Result<(), Storage::Error> { + storage.queue_proposal(self.group_id(), &proposal.proposal_reference(), &proposal)?; // Store the proposal in in the internal ProposalStore self.proposal_store.add(proposal); - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + Ok(()) } /// Creates a Commit message that covers the pending proposals that are @@ -83,13 +85,13 @@ impl MlsGroup { /// [`Welcome`]: crate::messages::Welcome // FIXME: #1217 #[allow(clippy::type_complexity)] - pub fn commit_to_pending_proposals( + pub fn commit_to_pending_proposals( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, ) -> Result< (MlsMessageOut, Option, Option), - CommitToPendingProposalsError, + CommitToPendingProposalsError, > { self.is_operational()?; @@ -110,9 +112,10 @@ impl MlsGroup { self.group_state = MlsGroupState::PendingCommit(Box::new(PendingCommitState::Member( create_commit_result.staged_commit, ))); - - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + provider + .storage() + .write_group_state(self.group_id(), &self.group_state) + .map_err(CommitToPendingProposalsError::StorageError)?; Ok(( mls_message, @@ -125,18 +128,19 @@ impl MlsGroup { /// Merge a [StagedCommit] into the group after inspection. As this advances /// the epoch of the group, it also clears any pending commits. - pub fn merge_staged_commit( + pub fn merge_staged_commit( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, staged_commit: StagedCommit, - ) -> Result<(), MergeCommitError> { + ) -> Result<(), MergeCommitError> { // Check if we were removed from the group if staged_commit.self_removed() { self.group_state = MlsGroupState::Inactive; } - - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + provider + .storage() + .write_group_state(self.group_id(), &self.group_state) + .map_err(MergeCommitError::StorageError)?; // Merge staged commit self.group @@ -150,19 +154,24 @@ impl MlsGroup { // Delete own KeyPackageBundles self.own_leaf_nodes.clear(); + provider + .storage() + .clear_own_leaf_nodes(self.group_id()) + .map_err(MergeCommitError::StorageError)?; // Delete a potential pending commit - self.clear_pending_commit(); + self.clear_pending_commit(provider.storage()) + .map_err(MergeCommitError::StorageError)?; Ok(()) } /// Merges the pending [`StagedCommit`] if there is one, and /// clears the field by setting it to `None`. - pub fn merge_pending_commit( + pub fn merge_pending_commit( &mut self, - provider: &impl OpenMlsProvider, - ) -> Result<(), MergePendingCommitError> { + provider: &Provider, + ) -> Result<(), MergePendingCommitError> { match &self.group_state { MlsGroupState::PendingCommit(_) => { let old_state = mem::replace(&mut self.group_state, MlsGroupState::Operational); diff --git a/openmls/src/group/mls_group/proposal.rs b/openmls/src/group/mls_group/proposal.rs index e5593d5b89..b76a46bd3c 100644 --- a/openmls/src/group/mls_group/proposal.rs +++ b/openmls/src/group/mls_group/proposal.rs @@ -1,12 +1,10 @@ -use openmls_traits::{ - key_store::OpenMlsKeyStore, signatures::Signer, types::Ciphersuite, OpenMlsProvider, -}; +use openmls_traits::{signatures::Signer, storage::StorageProvider, types::Ciphersuite}; use super::{ core_group::create_commit_params::CreateCommitParams, errors::{ProposalError, ProposeAddMemberError, ProposeRemoveMemberError}, - CreateGroupContextExtProposalError, GroupContextExtensionProposal, MlsGroup, MlsGroupState, - PendingCommitState, Proposal, + CreateGroupContextExtProposalError, CustomProposal, GroupContextExtensionProposal, MlsGroup, + MlsGroupState, PendingCommitState, Proposal, }; use crate::{ binary_tree::LeafNodeIndex, @@ -19,6 +17,7 @@ use crate::{ messages::{group_info::GroupInfo, proposals::ProposalOrRefType}, prelude::LibraryError, schedule::PreSharedKeyId, + storage::OpenMlsProvider, treesync::LeafNode, versions::ProtocolVersion, }; @@ -54,6 +53,9 @@ pub enum Propose { /// Propose adding new group context extensions. GroupContextExtensions(Extensions), + + /// A custom proposal with semantics to be implemented by the application. + Custom(CustomProposal), } macro_rules! impl_propose_fun { @@ -62,12 +64,12 @@ macro_rules! impl_propose_fun { /// Creates proposals to add an external PSK to the key schedule. /// /// Returns an error if there is a pending commit. - pub fn $name( + pub fn $name( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, value: $value_ty, - ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { self.is_operational()?; let proposal = self @@ -81,14 +83,16 @@ macro_rules! impl_propose_fun { $ref_or_value, )?; let proposal_ref = queued_proposal.proposal_reference(); + log::trace!("Storing proposal in queue {:?}", queued_proposal); + provider + .storage() + .queue_proposal(self.group.group_id(), &proposal_ref, &queued_proposal) + .map_err(ProposalError::StorageError)?; self.proposal_store.add(queued_proposal); let mls_message = self.content_to_mls_message(proposal, provider)?; - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok((mls_message, proposal_ref)) } }; @@ -123,14 +127,28 @@ impl MlsGroup { ProposalOrRefType::Proposal ); + impl_propose_fun!( + propose_custom_proposal_by_value, + CustomProposal, + create_custom_proposal, + ProposalOrRefType::Proposal + ); + + impl_propose_fun!( + propose_custom_proposal_by_reference, + CustomProposal, + create_custom_proposal, + ProposalOrRefType::Reference + ); + /// Generate a proposal - pub fn propose( + pub fn propose( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, propose: Propose, ref_or_value: ProposalOrRefType, - ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { match propose { Propose::Add(key_package) => match ref_or_value { ProposalOrRefType::Proposal => { @@ -196,18 +214,26 @@ impl MlsGroup { Propose::GroupContextExtensions(_) => Err(ProposalError::LibraryError( LibraryError::custom("Unsupported proposal type GroupContextExtensions"), )), + Propose::Custom(custom_proposal) => match ref_or_value { + ProposalOrRefType::Proposal => { + self.propose_custom_proposal_by_value(provider, signer, custom_proposal) + } + ProposalOrRefType::Reference => { + self.propose_custom_proposal_by_reference(provider, signer, custom_proposal) + } + }, } } /// Creates proposals to add members to the group. /// /// Returns an error if there is a pending commit. - pub fn propose_add_member( + pub fn propose_add_member( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, key_package: &KeyPackage, - ) -> Result<(MlsMessageOut, ProposalRef), ProposeAddMemberError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposeAddMemberError> { self.is_operational()?; let add_proposal = self @@ -226,13 +252,14 @@ impl MlsGroup { add_proposal.clone(), )?; let proposal_ref = proposal.proposal_reference(); + provider + .storage() + .queue_proposal(self.group_id(), &proposal_ref, &proposal) + .map_err(ProposeAddMemberError::StorageError)?; self.proposal_store.add(proposal); let mls_message = self.content_to_mls_message(add_proposal, provider)?; - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok((mls_message, proposal_ref)) } @@ -240,12 +267,13 @@ impl MlsGroup { /// The `member` has to be the member's leaf index. /// /// Returns an error if there is a pending commit. - pub fn propose_remove_member( + pub fn propose_remove_member( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, member: LeafNodeIndex, - ) -> Result<(MlsMessageOut, ProposalRef), ProposeRemoveMemberError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposeRemoveMemberError> + { self.is_operational()?; let remove_proposal = self @@ -259,13 +287,14 @@ impl MlsGroup { remove_proposal.clone(), )?; let proposal_ref = proposal.proposal_reference(); + provider + .storage() + .queue_proposal(self.group_id(), &proposal_ref, &proposal) + .map_err(ProposeRemoveMemberError::StorageError)?; self.proposal_store.add(proposal); let mls_message = self.content_to_mls_message(remove_proposal, provider)?; - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok((mls_message, proposal_ref)) } @@ -273,12 +302,13 @@ impl MlsGroup { /// The `member` has to be the member's credential. /// /// Returns an error if there is a pending commit. - pub fn propose_remove_member_by_credential( + pub fn propose_remove_member_by_credential( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, member: &Credential, - ) -> Result<(MlsMessageOut, ProposalRef), ProposeRemoveMemberError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposeRemoveMemberError> + { // Find the user for the credential first. let member_index = self .group @@ -298,12 +328,12 @@ impl MlsGroup { /// The `member` has to be the member's credential. /// /// Returns an error if there is a pending commit. - pub fn propose_remove_member_by_credential_by_value( + pub fn propose_remove_member_by_credential_by_value( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, member: &Credential, - ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { // Find the user for the credential first. let member_index = self .group @@ -325,15 +355,15 @@ impl MlsGroup { /// /// Returns an error when the group does not support all the required capabilities /// in the new `extensions`. - pub fn propose_group_context_extensions( + pub fn propose_group_context_extensions( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, extensions: Extensions, signer: &impl Signer, - ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposalError> { self.is_operational()?; - let proposal = self.group.create_group_context_ext_proposal::( + let proposal = self.group.create_group_context_ext_proposal::( self.framing_parameters(), extensions, signer, @@ -346,13 +376,14 @@ impl MlsGroup { )?; let proposal_ref = queued_proposal.proposal_reference(); + provider + .storage() + .queue_proposal(self.group_id(), &proposal_ref, &queued_proposal) + .map_err(ProposalError::StorageError)?; self.proposal_store.add(queued_proposal); let mls_message = self.content_to_mls_message(proposal, provider)?; - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok((mls_message, proposal_ref)) } @@ -361,14 +392,14 @@ impl MlsGroup { /// Returns an error when the group does not support all the required capabilities /// in the new `extensions`. #[allow(clippy::type_complexity)] - pub fn update_group_context_extensions( + pub fn update_group_context_extensions( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, extensions: Extensions, signer: &impl Signer, ) -> Result< (MlsMessageOut, Option, Option), - CreateGroupContextExtProposalError, + CreateGroupContextExtProposalError, > { self.is_operational()?; @@ -389,9 +420,6 @@ impl MlsGroup { create_commit_result.staged_commit, ))); - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok(( mls_messages, create_commit_result diff --git a/openmls/src/group/mls_group/ser.rs b/openmls/src/group/mls_group/ser.rs index 7977fad823..4449b3a0bb 100644 --- a/openmls/src/group/mls_group/ser.rs +++ b/openmls/src/group/mls_group/ser.rs @@ -4,7 +4,6 @@ use super::*; use crate::schedule::psk::store::ResumptionPskStore; -use openmls_traits::key_store::{MlsEntity, MlsEntityId}; use serde::{ ser::{SerializeStruct, Serializer}, Deserialize, Serialize, @@ -36,15 +35,10 @@ impl Into for SerializedMlsGroup { own_leaf_nodes: self.own_leaf_nodes, aad: self.aad, group_state: self.group_state, - state_changed: InnerState::Persisted, } } } -impl MlsEntity for MlsGroup { - const ID: MlsEntityId = MlsEntityId::GroupState; -} - impl Serialize for MlsGroup { fn serialize(&self, serializer: S) -> Result where diff --git a/openmls/src/group/mls_group/test_mls_group.rs b/openmls/src/group/mls_group/test_mls_group.rs index 466e972735..25fc1958f9 100644 --- a/openmls/src/group/mls_group/test_mls_group.rs +++ b/openmls/src/group/mls_group/test_mls_group.rs @@ -1,13 +1,13 @@ use core_group::test_core_group::setup_client; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{key_store::OpenMlsKeyStore, OpenMlsProvider}; +use openmls_test::openmls_test; +use openmls_traits::OpenMlsProvider as _; use tls_codec::{Deserialize, Serialize}; use crate::{ binary_tree::LeafNodeIndex, extensions::errors::InvalidExtensionError, framing::*, - group::{config::CryptoConfig, errors::*, *}, + group::{errors::*, *}, key_packages::*, messages::proposals::*, prelude::Capabilities, @@ -15,12 +15,11 @@ use crate::{ errors::ClientError, noop_authentication_service, ActionType::Commit, CodecUse, MlsGroupTestSetup, }, - test_utils::*, tree::sender_ratchet::SenderRatchetConfiguration, }; -#[apply(ciphersuites_and_providers)] -fn test_mls_group_persistence(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_mls_group_persistence() { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -30,7 +29,7 @@ fn test_mls_group_persistence(ciphersuite: Ciphersuite, provider: &impl OpenMlsP let mls_group_config = MlsGroupCreateConfig::test_default(ciphersuite); // === Alice creates a group === - let mut alice_group = MlsGroup::new_with_group_id( + let alice_group = MlsGroup::new_with_group_id( provider, &alice_signer, &mls_group_config, @@ -39,32 +38,30 @@ fn test_mls_group_persistence(ciphersuite: Ciphersuite, provider: &impl OpenMlsP ) .expect("An unexpected error occurred."); - // Check the internal state has changed - assert_eq!(alice_group.state_changed(), InnerState::Changed); - - alice_group - .save(provider.key_store()) - .expect("Could not write group state to file"); - - let alice_group_deserialized = - MlsGroup::load(&group_id, provider.key_store()).expect("Could not deserialize MlsGroup"); + let alice_group_deserialized = MlsGroup::load(provider.storage(), &group_id) + .expect("Could not deserialize MlsGroup: error") + .expect("Could not deserialize MlsGroup: doesn't exist"); assert_eq!( ( alice_group.export_ratchet_tree(), - alice_group.export_secret(provider.crypto(), "test", &[], 32) + alice_group + .export_secret(provider, "test", &[], 32) + .unwrap() ), ( alice_group_deserialized.export_ratchet_tree(), - alice_group_deserialized.export_secret(provider.crypto(), "test", &[], 32) + alice_group_deserialized + .export_secret(provider, "test", &[], 32) + .unwrap() ) ); } // This tests if the remover is correctly passed to the callback when one member // issues a RemoveProposal and another members issues the next Commit. -#[apply(ciphersuites_and_providers)] -fn remover(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn remover() { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -76,7 +73,7 @@ fn remover(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -174,7 +171,9 @@ fn remover(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Bob was removed assert_eq!(remove_proposal.removed(), LeafNodeIndex::new(1)); // Store proposal - charlie_group.store_pending_proposal(*staged_proposal.clone()); + charlie_group + .store_pending_proposal(provider.storage(), *staged_proposal.clone()) + .unwrap(); } else { unreachable!("Expected a Proposal."); } @@ -214,8 +213,8 @@ fn remover(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // TODO #524: Check that Alice removed Bob } -#[apply(ciphersuites_and_providers)] -fn export_secret(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn export_secret() { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -236,24 +235,24 @@ fn export_secret(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { assert!( alice_group - .export_secret(provider.crypto(), "test1", &[], ciphersuite.hash_length()) + .export_secret(provider, "test1", &[], ciphersuite.hash_length()) .expect("An unexpected error occurred.") != alice_group - .export_secret(provider.crypto(), "test2", &[], ciphersuite.hash_length()) + .export_secret(provider, "test2", &[], ciphersuite.hash_length()) .expect("An unexpected error occurred.") ); assert!( alice_group - .export_secret(provider.crypto(), "test", &[0u8], ciphersuite.hash_length()) + .export_secret(provider, "test", &[0u8], ciphersuite.hash_length()) .expect("An unexpected error occurred.") != alice_group - .export_secret(provider.crypto(), "test", &[1u8], ciphersuite.hash_length()) + .export_secret(provider, "test", &[1u8], ciphersuite.hash_length()) .expect("An unexpected error occurred.") ) } -#[apply(ciphersuites_and_providers)] -fn staged_join(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn staged_join() { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, alice_kpb, alice_signer, _alice_pk) = @@ -310,21 +309,21 @@ fn staged_join(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { assert_eq!( alice_group - .export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()) + .export_secret(provider, "test", &[], ciphersuite.hash_length()) .expect("An unexpected error occurred."), bob_group - .export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()) + .export_secret(provider, "test", &[], ciphersuite.hash_length()) .expect("An unexpected error occurred.") ); } -#[apply(ciphersuites_and_providers)] -fn test_invalid_plaintext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_invalid_plaintext() { // Some basic setup functions for the MlsGroup. let mls_group_create_config = MlsGroupCreateConfig::test_default(ciphersuite); let number_of_clients = 20; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( mls_group_create_config, number_of_clients, CodecUse::StructMessages, @@ -375,7 +374,8 @@ fn test_invalid_plaintext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi MlsMessageBodyOut::PublicMessage(pt) => { pt.set_sender(random_sender); pt.set_membership_tag( - provider.crypto(), + client.provider.crypto(), + ciphersuite, membership_key, client_group.group().message_secrets().serialized_context(), ) @@ -425,10 +425,10 @@ fn test_invalid_plaintext(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi ); } -#[apply(ciphersuites_and_providers)] +#[openmls_test] fn test_verify_staged_commit_credentials( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) { let group_id = GroupId::from_slice(b"Test Group"); @@ -468,7 +468,9 @@ fn test_verify_staged_commit_credentials( if let ProcessedMessageContent::ProposalMessage(staged_proposal) = alice_processed_message.into_content() { - alice_group.store_pending_proposal(*staged_proposal); + alice_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a StagedCommit."); } @@ -505,8 +507,12 @@ fn test_verify_staged_commit_credentials( alice_group.export_ratchet_tree() ); assert_eq!( - bob_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()), - alice_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()) + bob_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap(), + alice_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap() ); // Bob is added and the state aligns. @@ -582,8 +588,12 @@ fn test_verify_staged_commit_credentials( alice_group.export_ratchet_tree() ); assert_eq!( - bob_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()), - alice_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()) + bob_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap(), + alice_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap() ); } else { unreachable!() @@ -594,10 +604,10 @@ fn test_verify_staged_commit_credentials( assert!(alice_group.pending_commit().is_none()); } -#[apply(ciphersuites_and_providers)] +#[openmls_test] fn test_commit_with_update_path_leaf_node( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) { let group_id = GroupId::from_slice(b"Test Group"); @@ -637,7 +647,9 @@ fn test_commit_with_update_path_leaf_node( if let ProcessedMessageContent::ProposalMessage(staged_proposal) = alice_processed_message.into_content() { - alice_group.store_pending_proposal(*staged_proposal); + alice_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a StagedCommit."); } @@ -676,8 +688,12 @@ fn test_commit_with_update_path_leaf_node( alice_group.export_ratchet_tree() ); assert_eq!( - bob_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()), - alice_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()) + bob_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap(), + alice_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap() ); // Bob is added and the state aligns. @@ -765,8 +781,12 @@ fn test_commit_with_update_path_leaf_node( alice_group.export_ratchet_tree() ); assert_eq!( - bob_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()), - alice_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()) + bob_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap(), + alice_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap() ); } else { unreachable!() @@ -777,8 +797,11 @@ fn test_commit_with_update_path_leaf_node( assert!(alice_group.pending_commit().is_none()); } -#[apply(ciphersuites_and_providers)] -fn test_pending_commit_logic(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_pending_commit_logic( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -817,7 +840,9 @@ fn test_pending_commit_logic(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr if let ProcessedMessageContent::ProposalMessage(staged_proposal) = alice_processed_message.into_content() { - alice_group.store_pending_proposal(*staged_proposal); + alice_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a StagedCommit."); } @@ -839,55 +864,57 @@ fn test_pending_commit_logic(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr let error = alice_group .add_members(provider, &alice_signer, &[bob_key_package.clone()]) .expect_err("no error committing while a commit is pending"); - assert_eq!( + assert!(matches!( error, AddMembersError::GroupStateError(MlsGroupStateError::PendingCommit) - ); + )); let error = alice_group .propose_add_member(provider, &alice_signer, bob_key_package) .expect_err("no error creating a proposal while a commit is pending"); - assert_eq!( + assert!(matches!( error, ProposeAddMemberError::GroupStateError(MlsGroupStateError::PendingCommit) - ); + )); let error = alice_group .remove_members(provider, &alice_signer, &[LeafNodeIndex::new(1)]) .expect_err("no error committing while a commit is pending"); - assert_eq!( + assert!(matches!( error, RemoveMembersError::GroupStateError(MlsGroupStateError::PendingCommit) - ); + )); let error = alice_group .propose_remove_member(provider, &alice_signer, LeafNodeIndex::new(1)) .expect_err("no error creating a proposal while a commit is pending"); - assert_eq!( + assert!(matches!( error, ProposeRemoveMemberError::GroupStateError(MlsGroupStateError::PendingCommit) - ); + )); let error = alice_group .commit_to_pending_proposals(provider, &alice_signer) .expect_err("no error committing while a commit is pending"); - assert_eq!( + assert!(matches!( error, CommitToPendingProposalsError::GroupStateError(MlsGroupStateError::PendingCommit) - ); + )); let error = alice_group .self_update(provider, &alice_signer) .expect_err("no error committing while a commit is pending"); - assert_eq!( + assert!(matches!( error, SelfUpdateError::GroupStateError(MlsGroupStateError::PendingCommit) - ); + )); let error = alice_group .propose_self_update(provider, &alice_signer, None) .expect_err("no error creating a proposal while a commit is pending"); - assert_eq!( + assert!(matches!( error, ProposeSelfUpdateError::GroupStateError(MlsGroupStateError::PendingCommit) - ); + )); // Clearing the pending commit should actually clear it. - alice_group.clear_pending_commit(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); assert!(alice_group.pending_commit().is_none()); // Creating a new commit should commit the same proposals. @@ -922,8 +949,12 @@ fn test_pending_commit_logic(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr alice_group.export_ratchet_tree() ); assert_eq!( - bob_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()), - alice_group.export_secret(provider.crypto(), "test", &[], ciphersuite.hash_length()) + bob_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap(), + alice_group + .export_secret(provider, "test", &[], ciphersuite.hash_length()) + .unwrap() ); // While a commit is pending, merging Bob's commit should clear the pending commit. @@ -954,8 +985,8 @@ fn test_pending_commit_logic(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr // Test that the key package and the corresponding private key are deleted when // creating a new group for a welcome message. -#[apply(ciphersuites_and_providers)] -fn key_package_deletion(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn key_package_deletion() { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -966,7 +997,7 @@ fn key_package_deletion(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -1002,30 +1033,24 @@ fn key_package_deletion(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide .into_group(provider) .expect("Error creating group from staged join"); - // TEST: The private key must be gone from the key store. - assert!(provider - .key_store() - .read::(bob_key_package.hpke_init_key().as_slice()) - .is_none(), - "The HPKE private key is still in the key store after creating a new group from the key package."); + use openmls_traits::storage::StorageProvider; // TEST: The key package must be gone from the key store. + let result: Option = provider + .storage() + .key_package(&bob_key_package.hash_ref(provider.crypto()).unwrap()) + .unwrap(); assert!( - provider - .key_store() - .read::( - bob_key_package - .hash_ref(provider.crypto()) - .unwrap() - .as_slice() - ) - .is_none(), + result.is_none(), "The key package is still in the key store after creating a new group from it." ); } -#[apply(ciphersuites_and_providers)] -fn remove_prosposal_by_ref(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn remove_prosposal_by_ref( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -1039,7 +1064,7 @@ fn remove_prosposal_by_ref(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -1080,13 +1105,15 @@ fn remove_prosposal_by_ref(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv assert_eq!(alice_group.proposal_store.proposals().count(), 1); // clearing the proposal by reference alice_group - .remove_pending_proposal(reference.clone()) + .remove_pending_proposal(provider.storage(), reference.clone()) .unwrap(); assert!(alice_group.proposal_store.is_empty()); // the proposal should not be stored anymore - let err = alice_group.remove_pending_proposal(reference).unwrap_err(); - assert_eq!(err, MlsGroupStateError::PendingProposalNotFound); + let err = alice_group + .remove_pending_proposal(provider.storage(), reference) + .unwrap_err(); + assert!(matches!(err, MlsGroupStateError::PendingProposalNotFound)); // the commit should have no proposal let (commit, _, _) = alice_group @@ -1114,150 +1141,14 @@ fn remove_prosposal_by_ref(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv } // Test that the builder pattern accurately configures the new group. -#[apply(ciphersuites_and_providers)] -fn immutable_metadata(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = - setup_client("Alice", ciphersuite, provider); - - let metadata = Metadata::new(b"this is a test group".to_vec()); - - let required_capabilities_extension = RequiredCapabilitiesExtension::new( - &[ - ExtensionType::RequiredCapabilities, - ExtensionType::ImmutableMetadata, - ], - &[], - &[], - ); - - let extensions_with_metadata = Extensions::from_vec(vec![ - Extension::ImmutableMetadata(metadata.clone()), - Extension::RequiredCapabilities(required_capabilities_extension), - ]) - .unwrap(); - - // === Create a Group with Metadata === - let capabilities = Capabilities::new( - None, - None, - Some(&[ExtensionType::ImmutableMetadata]), - None, - None, - ); - let mut group_with_metadata = MlsGroup::builder() - .with_group_context_extensions(extensions_with_metadata.clone()) - .expect("error when setting initial metadata group extension") - .with_capabilities(capabilities) - .build(provider, &alice_signer, alice_credential_with_key.clone()) - .expect("error creating group using builder"); - - let got_metadata = group_with_metadata - .group() - .group_context_extensions() - .immutable_metadata() - .expect("error getting metadata from group"); - - // check we get back the same metadata we initially set - assert_eq!(got_metadata, &metadata); - - // changing the metadata should fail - let new_metadata = Metadata::new(b"this is a not test group".to_vec()); - - let new_extensions_with_metadata = - Extensions::single(Extension::ImmutableMetadata(new_metadata.clone())); - - group_with_metadata - .propose_group_context_extensions(provider, new_extensions_with_metadata, &alice_signer) - .expect("error proposing GCE proposal with same metadata"); - - assert_eq!(group_with_metadata.pending_proposals().count(), 1); - - group_with_metadata - .commit_to_pending_proposals(provider, &alice_signer) - .expect_err("should not have been able to commit to proposal that changes metadata"); - - group_with_metadata.clear_pending_proposals(); - - assert_eq!(group_with_metadata.pending_proposals().count(), 0); - - // using the same metadata should succeed - group_with_metadata - .propose_group_context_extensions(provider, extensions_with_metadata.clone(), &alice_signer) - .expect("error proposing GCE proposal with same metadata"); - - assert_eq!(group_with_metadata.pending_proposals().count(), 1); - - group_with_metadata - .commit_to_pending_proposals(provider, &alice_signer) - .expect("failed to commit to pending proposals"); - - group_with_metadata - .merge_pending_commit(provider) - .expect("error merging pending commit"); - - let got_metadata = group_with_metadata - .group() - .group_context_extensions() - .immutable_metadata() - .expect("couldn't get immutable_metadata"); - - // check we get back the same metadata we initially set - assert_eq!(got_metadata, &metadata); - - // === Create a Group without Metadata === - let mut group_without_metadata = MlsGroup::builder() - .build(provider, &alice_signer, alice_credential_with_key) - .expect("error creating group using builder"); - - assert!(group_without_metadata - .group() - .group_context_extensions() - .immutable_metadata() - .is_none()); - - // changing the metadata should fail - let new_metadata = Metadata::new(b"this is a new metadata".to_vec()); - - let new_extensions_with_metadata = - Extensions::single(Extension::ImmutableMetadata(new_metadata.clone())); - - group_without_metadata - .propose_group_context_extensions(provider, new_extensions_with_metadata, &alice_signer) - .expect("error proposing GCE proposal with metadata"); - - // since the GCEs are the same as befroe, the proposal handling logic "deduplicates" the proposal - // (with the implicit "identity proposal" I guess), so there isn't actually a proposal here - assert_eq!(group_with_metadata.pending_proposals().count(), 0); - - // using the same metadata should succeed - group_without_metadata - .propose_group_context_extensions(provider, Extensions::empty(), &alice_signer) - .expect("error proposing GCE proposal with no metadata"); - - // since the GCEs are empty, the proposal handling logic "deduplicates" the proposal - // (with the implicit "identity proposal" I guess), so there isn't actually a proposal here - assert_eq!(group_with_metadata.pending_proposals().count(), 0); - - // check we still get no metadata - assert!(group_without_metadata - .group() - .group_context_extensions() - .immutable_metadata() - .is_none()); - - // TODO: we need to test that processing an invalid commit also fails. - // however, we can't generate this commit, because our functions for - // constructing commits does not permit it. See #1476 -} - -// Test that the builder pattern accurately configures the new group. -#[apply(ciphersuites_and_providers)] -fn group_context_extensions_proposal(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn group_context_extensions_proposal() { let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = setup_client("Alice", ciphersuite, provider); // === Alice creates a group === let mut alice_group = MlsGroup::builder() + .ciphersuite(ciphersuite) .build(provider, &alice_signer, alice_credential_with_key) .expect("error creating group using builder"); @@ -1338,8 +1229,8 @@ fn group_context_extensions_proposal(ciphersuite: Ciphersuite, provider: &impl O } // Test that the builder pattern accurately configures the new group. -#[apply(ciphersuites_and_providers)] -fn builder_pattern(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn builder_pattern() { let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = setup_client("Alice", ciphersuite, provider); @@ -1361,7 +1252,7 @@ fn builder_pattern(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ]) .expect("error creating group context extensions"); - let test_crypto_config = CryptoConfig::with_default_version(ciphersuite); + let test_ciphersuite = ciphersuite; let test_sender_ratchet_config = SenderRatchetConfiguration::new(10, 2000); let test_max_past_epochs = 10; let test_number_of_resumption_psks = 5; @@ -1384,7 +1275,7 @@ fn builder_pattern(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .sender_ratchet_configuration(test_sender_ratchet_config.clone()) .with_group_context_extensions(test_gc_extensions.clone()) .expect("error adding group context extension to builder") - .crypto_config(test_crypto_config) + .ciphersuite(test_ciphersuite) .with_wire_format_policy(test_wire_format_policy) .lifetime(test_lifetime) .use_ratchet_tree_extension(true) @@ -1425,11 +1316,7 @@ fn builder_pattern(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { Extension::ExternalSenders(external_senders), test_external_senders ); - let crypto_config = CryptoConfig { - ciphersuite, - version: group_context.protocol_version(), - }; - assert_eq!(crypto_config, test_crypto_config); + assert_eq!(ciphersuite, test_ciphersuite); let extensions = group_context.extensions(); assert_eq!(extensions, &test_gc_extensions); let lifetime = alice_group @@ -1456,107 +1343,21 @@ fn builder_pattern(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { assert_eq!(builder_err, InvalidExtensionError::IllegalInLeafNodes); } -// Test that unknown group context and leaf node extensions can be used in groups -#[apply(ciphersuites_and_providers)] -fn unknown_extensions(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = - setup_client("Alice", ciphersuite, provider); - - let unknown_gc_extension = Extension::Unknown(0xff00, UnknownExtension(vec![0, 1, 2, 3])); - let unknown_leaf_extension = Extension::Unknown(0xff01, UnknownExtension(vec![4, 5, 6, 7])); - let unknown_kp_extension = Extension::Unknown(0xff02, UnknownExtension(vec![8, 9, 10, 11])); - let required_extensions = &[ - ExtensionType::Unknown(0xff00), - ExtensionType::Unknown(0xff01), - ]; - let required_capabilities = - Extension::RequiredCapabilities(RequiredCapabilitiesExtension::new(&[], &[], &[])); - let capabilities = Capabilities::new(None, None, Some(required_extensions), None, None); - let test_gc_extensions = Extensions::from_vec(vec![ - unknown_gc_extension.clone(), - required_capabilities.clone(), - ]) - .expect("error creating group context extensions"); - let test_kp_extensions = Extensions::single(unknown_kp_extension.clone()); - - // === Alice creates a group === - let config = CryptoConfig { - ciphersuite, - version: crate::versions::ProtocolVersion::default(), - }; - let mut alice_group = MlsGroup::builder() - .crypto_config(config) - .with_capabilities(capabilities.clone()) - .with_leaf_node_extensions(Extensions::single(unknown_leaf_extension.clone())) - .expect("error adding unknown leaf extension to builder") - .with_group_context_extensions(test_gc_extensions.clone()) - .expect("error adding unknown extension to builder") - .build(provider, &alice_signer, alice_credential_with_key) - .expect("error creating group using builder"); - - // Check that everything was added successfully - let group_context = alice_group.export_group_context(); - assert_eq!(group_context.extensions(), &test_gc_extensions); - let leaf_node = alice_group.own_leaf().expect("error getting own leaf"); - assert_eq!( - leaf_node.extensions(), - &Extensions::single(unknown_leaf_extension) - ); - - // Now let's add Bob to the group and make sure that he joins the group successfully - - // === Alice adds Bob === - let (bob_credential_with_key, _bob_kpb, bob_signer, _bob_pk) = - setup_client("Bob", ciphersuite, provider); - - // Generate a KP that supports the unknown extensions - let bob_key_package = KeyPackage::builder() - .leaf_node_capabilities(capabilities) - .key_package_extensions(test_kp_extensions.clone()) - .build(config, provider, &bob_signer, bob_credential_with_key) - .expect("error building key package"); - - assert_eq!( - bob_key_package.extensions(), - &Extensions::single(unknown_kp_extension) - ); - - // alice adds bob and bob processes the welcome to ensure that the unknown - // extensions are processed correctly - let (_, welcome, _) = alice_group - .add_members(provider, &alice_signer, &[bob_key_package.clone()]) - .unwrap(); - alice_group.merge_pending_commit(provider).unwrap(); - - let welcome: MlsMessageIn = welcome.into(); - let welcome = welcome - .into_welcome() - .expect("expected message to be a welcome"); - - let _bob_group = StagedWelcome::new_from_welcome( - provider, - &MlsGroupJoinConfig::default(), - welcome, - Some(alice_group.export_ratchet_tree().into()), - ) - .expect("Error creating staged join from Welcome") - .into_group(provider) - .expect("Error creating group from staged join"); -} - // Test the successful update of Group Context Extension with type Extension::Unknown(0xff11) -#[apply(ciphersuites_and_providers)] -fn update_group_context_with_unknown_extension( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test] +fn update_group_context_with_unknown_extension() { + let alice_provider = Provider::default(); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = - setup_client("Alice", ciphersuite, provider); + setup_client("Alice", ciphersuite, &alice_provider); // === Define the unknown group context extension and initial data === + const UNKNOWN_EXTENSION_TYPE: u16 = 0xff11; let unknown_extension_data = vec![1, 2]; - let unknown_gc_extension = Extension::Unknown(0xff11, UnknownExtension(unknown_extension_data)); - let required_extension_types = &[ExtensionType::Unknown(0xff11)]; + let unknown_gc_extension = Extension::Unknown( + UNKNOWN_EXTENSION_TYPE, + UnknownExtension(unknown_extension_data), + ); + let required_extension_types = &[ExtensionType::Unknown(UNKNOWN_EXTENSION_TYPE)]; let required_capabilities = Extension::RequiredCapabilities( RequiredCapabilitiesExtension::new(required_extension_types, &[], &[]), ); @@ -1570,12 +1371,12 @@ fn update_group_context_with_unknown_extension( .with_group_context_extensions(test_gc_extensions.clone()) .expect("error adding unknown extension to config") .capabilities(capabilities.clone()) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === let mut alice_group = MlsGroup::new( - provider, + &alice_provider, &alice_signer, &mls_group_create_config, alice_credential_with_key, @@ -1586,7 +1387,7 @@ fn update_group_context_with_unknown_extension( let group_context_extensions = alice_group.group().context().extensions(); let mut extracted_data = None; for extension in group_context_extensions.iter() { - if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension { + if let Extension::Unknown(UNKNOWN_EXTENSION_TYPE, UnknownExtension(data)) = extension { extracted_data = Some(data.clone()); } } @@ -1597,44 +1398,49 @@ fn update_group_context_with_unknown_extension( ); // === Alice adds Bob === + let bob_provider: Provider = Default::default(); let (bob_credential_with_key, _bob_kpb, bob_signer, _bob_pk) = - setup_client("Bob", ciphersuite, provider); + setup_client("Bob", ciphersuite, &bob_provider); let bob_key_package = KeyPackage::builder() .leaf_node_capabilities(capabilities) .build( - CryptoConfig::with_default_version(ciphersuite), - provider, + ciphersuite, + &bob_provider, &bob_signer, bob_credential_with_key, ) .expect("error building key package"); let (_, welcome, _) = alice_group - .add_members(provider, &alice_signer, &[bob_key_package.clone()]) + .add_members( + &alice_provider, + &alice_signer, + &[bob_key_package.key_package().clone()], + ) .unwrap(); - alice_group.merge_pending_commit(provider).unwrap(); + alice_group.merge_pending_commit(&alice_provider).unwrap(); let welcome: MlsMessageIn = welcome.into(); let welcome = welcome .into_welcome() .expect("expected message to be a welcome"); - let bob_group = StagedWelcome::new_from_welcome( - provider, + let mut bob_group = StagedWelcome::new_from_welcome( + &bob_provider, &MlsGroupJoinConfig::default(), welcome, Some(alice_group.export_ratchet_tree().into()), ) .expect("Error creating staged join from Welcome") - .into_group(provider) + .into_group(&bob_provider) .expect("Error creating group from staged join"); // === Verify Bob's initial group context extension data is correct === let group_context_extensions = bob_group.group().context().extensions(); let mut extracted_data_2 = None; for extension in group_context_extensions.iter() { - if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension { + if let Extension::Unknown(UNKNOWN_EXTENSION_TYPE, UnknownExtension(data)) = extension { extracted_data_2 = Some(data.clone()); } } @@ -1647,13 +1453,13 @@ fn update_group_context_with_unknown_extension( // === Propose the new group context extension === let updated_unknown_extension_data = vec![3, 4]; // Sample data for the extension let updated_unknown_gc_extension = Extension::Unknown( - 0xff11, + UNKNOWN_EXTENSION_TYPE, UnknownExtension(updated_unknown_extension_data.clone()), ); let mut updated_extensions = test_gc_extensions.clone(); updated_extensions.add_or_replace(updated_unknown_gc_extension); - alice_group + let (update_proposal, _) = alice_group .propose_group_context_extensions(provider, updated_extensions, &alice_signer) .expect("failed to propose group context extensions with unknown extension"); @@ -1664,7 +1470,7 @@ fn update_group_context_with_unknown_extension( ); // === Commit to the proposed group context extension === - alice_group + let (update_commit, _, _) = alice_group .commit_to_pending_proposals(provider, &alice_signer) .expect("failed to commit to pending group context extensions"); @@ -1672,15 +1478,54 @@ fn update_group_context_with_unknown_extension( .merge_pending_commit(provider) .expect("error merging pending commit"); - alice_group - .save(provider.key_store()) - .expect("error saving group"); + // === let bob process the updates === + assert_eq!( + bob_group.pending_proposals().count(), + 0, + "Expected no pending proposals" + ); + + let processed_update_message = bob_group + .process_message( + &bob_provider, + update_proposal.into_protocol_message().unwrap(), + ) + .expect("bob failed processing the update"); + + match processed_update_message.into_content() { + ProcessedMessageContent::ProposalMessage(msg) => { + bob_group + .store_pending_proposal(bob_provider.storage(), *msg) + .unwrap(); + } + other => panic!("expected proposal, got {other:?}"), + } + + assert_eq!( + bob_group.pending_proposals().count(), + 1, + "Expected one pending proposal" + ); + + let processed_commit_message = bob_group + .process_message( + &bob_provider, + update_commit.into_protocol_message().unwrap(), + ) + .expect("bob failed processing the update"); + + match processed_commit_message.into_content() { + ProcessedMessageContent::StagedCommitMessage(staged_commit) => bob_group + .merge_staged_commit(&bob_provider, *staged_commit) + .expect("error merging group context update commit"), + other => panic!("expected commit, got {other:?}"), + }; // === Verify the group context extension was updated === let group_context_extensions = alice_group.group().context().extensions(); let mut extracted_data_updated = None; for extension in group_context_extensions.iter() { - if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension { + if let Extension::Unknown(UNKNOWN_EXTENSION_TYPE, UnknownExtension(data)) = extension { extracted_data_updated = Some(data.clone()); } } @@ -1691,12 +1536,13 @@ fn update_group_context_with_unknown_extension( ); // === Verify Bob sees the group context extension updated === - let bob_group_loaded = MlsGroup::load(bob_group.group().group_id(), provider.key_store()) - .expect("error loading group"); + let bob_group_loaded = MlsGroup::load(bob_provider.storage(), bob_group.group().group_id()) + .expect("error loading group") + .expect("no such group"); let group_context_extensions_2 = bob_group_loaded.export_group_context().extensions(); let mut extracted_data_2 = None; for extension in group_context_extensions_2.iter() { - if let Extension::Unknown(0xff11, UnknownExtension(data)) = extension { + if let Extension::Unknown(UNKNOWN_EXTENSION_TYPE, UnknownExtension(data)) = extension { extracted_data_2 = Some(data.clone()); } } @@ -1707,10 +1553,98 @@ fn update_group_context_with_unknown_extension( ); } -#[apply(ciphersuites_and_providers)] +// Test that unknown group context and leaf node extensions can be used in groups +#[openmls_test] +fn unknown_extensions() { + let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = + setup_client("Alice", ciphersuite, provider); + + let unknown_gc_extension = Extension::Unknown(0xff00, UnknownExtension(vec![0, 1, 2, 3])); + let unknown_leaf_extension = Extension::Unknown(0xff01, UnknownExtension(vec![4, 5, 6, 7])); + let unknown_kp_extension = Extension::Unknown(0xff02, UnknownExtension(vec![8, 9, 10, 11])); + let required_extensions = &[ + ExtensionType::Unknown(0xff00), + ExtensionType::Unknown(0xff01), + ]; + let required_capabilities = + Extension::RequiredCapabilities(RequiredCapabilitiesExtension::new(&[], &[], &[])); + let capabilities = Capabilities::new(None, None, Some(required_extensions), None, None); + let test_gc_extensions = Extensions::from_vec(vec![ + unknown_gc_extension.clone(), + required_capabilities.clone(), + ]) + .expect("error creating group context extensions"); + let test_kp_extensions = Extensions::single(unknown_kp_extension.clone()); + + // === Alice creates a group === + let mut alice_group = MlsGroup::builder() + .ciphersuite(ciphersuite) + .with_capabilities(capabilities.clone()) + .with_leaf_node_extensions(Extensions::single(unknown_leaf_extension.clone())) + .expect("error adding unknown leaf extension to builder") + .with_group_context_extensions(test_gc_extensions.clone()) + .expect("error adding unknown extension to builder") + .build(provider, &alice_signer, alice_credential_with_key) + .expect("error creating group using builder"); + + // Check that everything was added successfully + let group_context = alice_group.export_group_context(); + assert_eq!(group_context.extensions(), &test_gc_extensions); + let leaf_node = alice_group.own_leaf().expect("error getting own leaf"); + assert_eq!( + leaf_node.extensions(), + &Extensions::single(unknown_leaf_extension) + ); + + // Now let's add Bob to the group and make sure that he joins the group successfully + + // === Alice adds Bob === + let (bob_credential_with_key, _bob_kpb, bob_signer, _bob_pk) = + setup_client("Bob", ciphersuite, provider); + + // Generate a KP that supports the unknown extensions + let bob_key_package = KeyPackage::builder() + .leaf_node_capabilities(capabilities) + .key_package_extensions(test_kp_extensions.clone()) + .build(ciphersuite, provider, &bob_signer, bob_credential_with_key) + .expect("error building key package"); + + assert_eq!( + bob_key_package.key_package().extensions(), + &Extensions::single(unknown_kp_extension) + ); + + // alice adds bob and bob processes the welcome to ensure that the unknown + // extensions are processed correctly + let (_, welcome, _) = alice_group + .add_members( + provider, + &alice_signer, + &[bob_key_package.key_package().clone()], + ) + .unwrap(); + alice_group.merge_pending_commit(provider).unwrap(); + + let welcome: MlsMessageIn = welcome.into(); + let welcome = welcome + .into_welcome() + .expect("expected message to be a welcome"); + + let _bob_group = StagedWelcome::new_from_welcome( + provider, + &MlsGroupJoinConfig::default(), + welcome, + Some(alice_group.export_ratchet_tree().into()), + ) + .expect("Error creating staged join from Welcome") + .into_group(provider) + .expect("Error creating group from staged join"); +} + +#[openmls_test] fn join_multiple_groups_last_resort_extension( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) { // start with alice, bob, charlie, common config items let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -1719,22 +1653,18 @@ fn join_multiple_groups_last_resort_extension( setup_client("bob", ciphersuite, provider); let (charlie_credential_with_key, _charlie_kpb, charlie_signer, _charlie_pk) = setup_client("charlie", ciphersuite, provider); - let config = CryptoConfig { - ciphersuite, - version: crate::versions::ProtocolVersion::default(), - }; let leaf_capabilities = Capabilities::new(None, None, Some(&[ExtensionType::LastResort]), None, None); let keypkg_extensions = Extensions::single(Extension::LastResort(LastResortExtension::new())); // alice creates MlsGroup let mut alice_group = MlsGroup::builder() - .crypto_config(config) + .ciphersuite(ciphersuite) .use_ratchet_tree_extension(true) .build(provider, &alice_signer, alice_credential_with_key) .expect("error creating group for alice using builder"); // bob creates MlsGroup let mut bob_group = MlsGroup::builder() - .crypto_config(config) + .ciphersuite(ciphersuite) .use_ratchet_tree_extension(true) .build(provider, &bob_signer, bob_credential_with_key) .expect("error creating group for bob using builder"); @@ -1743,7 +1673,7 @@ fn join_multiple_groups_last_resort_extension( .leaf_node_capabilities(leaf_capabilities) .key_package_extensions(keypkg_extensions.clone()) .build( - config, + ciphersuite, provider, &charlie_signer, charlie_credential_with_key, @@ -1751,7 +1681,11 @@ fn join_multiple_groups_last_resort_extension( .expect("error building key package for charlie"); // alice calls add_members(...) with charlie's KeyPackage; produces Commit and Welcome messages let (_, alice_welcome, _) = alice_group - .add_members(provider, &alice_signer, &[charlie_keypkg.clone()]) + .add_members( + provider, + &alice_signer, + &[charlie_keypkg.key_package().clone()], + ) .expect("error adding charlie to alice's group"); alice_group .merge_pending_commit(provider) @@ -1775,7 +1709,11 @@ fn join_multiple_groups_last_resort_extension( // bob calls add_members(...) with charlie's KeyPackage; produces Commit and Welcome messages let (_, bob_welcome, _) = bob_group - .add_members(provider, &bob_signer, &[charlie_keypkg.clone()]) + .add_members( + provider, + &bob_signer, + &[charlie_keypkg.key_package().clone()], + ) .expect("error adding charlie to bob's group"); bob_group .merge_pending_commit(provider) diff --git a/openmls/src/group/mls_group/updates.rs b/openmls/src/group/mls_group/updates.rs index 0e3c0e31cc..93508f984d 100644 --- a/openmls/src/group/mls_group/updates.rs +++ b/openmls/src/group/mls_group/updates.rs @@ -1,7 +1,7 @@ use core_group::create_commit_params::CreateCommitParams; -use openmls_traits::signatures::Signer; +use openmls_traits::{signatures::Signer, storage::StorageProvider as _}; -use crate::{messages::group_info::GroupInfo, treesync::LeafNode, versions::ProtocolVersion}; +use crate::{messages::group_info::GroupInfo, storage::OpenMlsProvider, treesync::LeafNode}; use super::*; @@ -23,13 +23,13 @@ impl MlsGroup { /// [`Welcome`]: crate::messages::Welcome // FIXME: #1217 #[allow(clippy::type_complexity)] - pub fn self_update( + pub fn self_update( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, ) -> Result< (MlsMessageOut, Option, Option), - SelfUpdateError, + SelfUpdateError, > { self.is_operational()?; @@ -51,8 +51,13 @@ impl MlsGroup { create_commit_result.staged_commit, ))); - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); + provider + .storage() + .write_group_state(self.group_id(), &self.group_state) + .map_err(SelfUpdateError::StorageError)?; + self.group + .store(provider.storage()) + .map_err(SelfUpdateError::StorageError)?; Ok(( mls_message, @@ -66,12 +71,12 @@ impl MlsGroup { /// Creates a proposal to update the own leaf node. Optionally, a /// [`LeafNode`] can be provided to update the leaf node. Note that its /// private key must be manually added to the key store. - fn _propose_self_update( + fn _propose_self_update( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, leaf_node: Option, - ) -> Result> { + ) -> Result> { self.is_operational()?; // Here we clone our own leaf to rekey it such that we don't change the @@ -97,14 +102,13 @@ impl MlsGroup { self.group_id(), self.own_leaf_index(), self.ciphersuite(), - ProtocolVersion::default(), // XXX: openmls/openmls#1065 provider, signer, )?; // TODO #1207: Move to the top of the function. keypair - .write_to_key_store(provider.key_store()) - .map_err(ProposeSelfUpdateError::KeyStoreError)?; + .write(provider.storage()) + .map_err(ProposeSelfUpdateError::StorageError)?; }; let update_proposal = self.group.create_update_proposal( @@ -113,18 +117,22 @@ impl MlsGroup { signer, )?; + provider + .storage() + .append_own_leaf_node(self.group_id(), &own_leaf) + .map_err(ProposeSelfUpdateError::StorageError)?; self.own_leaf_nodes.push(own_leaf); Ok(update_proposal) } /// Creates a proposal to update the own leaf node. - pub fn propose_self_update( + pub fn propose_self_update( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, leaf_node: Option, - ) -> Result<(MlsMessageOut, ProposalRef), ProposeSelfUpdateError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposeSelfUpdateError> { let update_proposal = self._propose_self_update(provider, signer, leaf_node)?; let proposal = QueuedProposal::from_authenticated_content_by_ref( self.ciphersuite(), @@ -132,23 +140,24 @@ impl MlsGroup { update_proposal.clone(), )?; let proposal_ref = proposal.proposal_reference(); + provider + .storage() + .queue_proposal(self.group_id(), &proposal_ref, &proposal) + .map_err(ProposeSelfUpdateError::StorageError)?; self.proposal_store.add(proposal); let mls_message = self.content_to_mls_message(update_proposal, provider)?; - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok((mls_message, proposal_ref)) } /// Creates a proposal to update the own leaf node. - pub fn propose_self_update_by_value( + pub fn propose_self_update_by_value( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, leaf_node: Option, - ) -> Result<(MlsMessageOut, ProposalRef), ProposeSelfUpdateError> { + ) -> Result<(MlsMessageOut, ProposalRef), ProposeSelfUpdateError> { let update_proposal = self._propose_self_update(provider, signer, leaf_node)?; let proposal = QueuedProposal::from_authenticated_content_by_value( self.ciphersuite(), @@ -156,13 +165,14 @@ impl MlsGroup { update_proposal.clone(), )?; let proposal_ref = proposal.proposal_reference(); + provider + .storage() + .queue_proposal(self.group_id(), &proposal_ref, &proposal) + .map_err(ProposeSelfUpdateError::StorageError)?; self.proposal_store.add(proposal); let mls_message = self.content_to_mls_message(update_proposal, provider)?; - // Since the state of the group might be changed, arm the state flag - self.flag_state_change(); - Ok((mls_message, proposal_ref)) } } diff --git a/openmls/src/group/mod.rs b/openmls/src/group/mod.rs index 8dd6e77dfd..c905da9818 100644 --- a/openmls/src/group/mod.rs +++ b/openmls/src/group/mod.rs @@ -12,7 +12,6 @@ use crate::extensions::*; #[cfg(test)] use crate::utils::*; -use openmls_traits::OpenMlsProvider; use serde::{Deserialize, Serialize}; use tls_codec::*; @@ -20,12 +19,10 @@ use tls_codec::*; pub(crate) mod core_group; pub(crate) mod public_group; pub(crate) use core_group::*; +pub(crate) mod errors; pub(crate) mod mls_group; // Public -pub mod config; -pub(crate) mod errors; - pub use core_group::proposals::*; pub use core_group::staged_commit::StagedCommit; pub use errors::*; diff --git a/openmls/src/group/public_group/builder.rs b/openmls/src/group/public_group/builder.rs index 0e02a525fc..ea0269c536 100644 --- a/openmls/src/group/public_group/builder.rs +++ b/openmls/src/group/public_group/builder.rs @@ -1,14 +1,13 @@ -use openmls_traits::{crypto::OpenMlsCrypto, signatures::Signer, OpenMlsProvider}; +use openmls_traits::{ + crypto::OpenMlsCrypto, signatures::Signer, types::Ciphersuite, OpenMlsProvider, +}; use super::{errors::PublicGroupBuildError, PublicGroup}; use crate::{ credentials::CredentialWithKey, error::LibraryError, - extensions::{ - errors::{ExtensionError, InvalidExtensionError}, - Extensions, - }, - group::{config::CryptoConfig, ExtensionType, GroupContext, GroupId}, + extensions::{errors::InvalidExtensionError, Extensions}, + group::{ExtensionType, GroupContext, GroupId}, key_packages::Lifetime, messages::ConfirmationTag, schedule::CommitSecret, @@ -16,12 +15,13 @@ use crate::{ node::{encryption_keys::EncryptionKeyPair, leaf_node::Capabilities}, TreeSync, }, + versions::ProtocolVersion, }; #[derive(Debug)] pub(crate) struct TempBuilderPG1 { group_id: GroupId, - crypto_config: CryptoConfig, + ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, lifetime: Option, capabilities: Option, @@ -81,17 +81,6 @@ impl TempBuilderPG1 { if let Some(required_capabilities) = self.group_context_extensions.required_capabilities() { - // Also, while we're at it, check if we support all required - // capabilities ourselves. - required_capabilities.check_support().map_err(|e| match e { - ExtensionError::UnsupportedProposalType => { - PublicGroupBuildError::UnsupportedProposalType - } - ExtensionError::UnsupportedExtensionType => { - PublicGroupBuildError::UnsupportedExtensionType - } - _ => LibraryError::custom("Unexpected ExtensionError").into(), - })?; ( Some(required_capabilities.extension_types()), Some(required_capabilities.proposal_types()), @@ -101,8 +90,8 @@ impl TempBuilderPG1 { (None, None, None) }; let capabilities = self.capabilities.unwrap_or(Capabilities::new( - Some(&[self.crypto_config.version]), - Some(&[self.crypto_config.ciphersuite]), + Some(&[ProtocolVersion::default()]), + Some(&[self.ciphersuite]), required_extensions, required_proposals, required_credentials, @@ -110,7 +99,7 @@ impl TempBuilderPG1 { let (treesync, commit_secret, leaf_keypair) = TreeSync::new( provider, signer, - self.crypto_config, + self.ciphersuite, self.credential_with_key, self.lifetime.unwrap_or_default(), capabilities, @@ -118,7 +107,7 @@ impl TempBuilderPG1 { )?; let group_context = GroupContext::create_initial_group_context( - self.crypto_config.ciphersuite, + self.ciphersuite, self.group_id, treesync.tree_hash().to_vec(), self.group_context_extensions, @@ -148,13 +137,6 @@ impl TempBuilderPG2 { } } - pub(crate) fn crypto_config(&self) -> CryptoConfig { - CryptoConfig { - ciphersuite: self.group_context.ciphersuite(), - version: self.group_context.protocol_version(), - } - } - pub(crate) fn group_context(&self) -> &GroupContext { &self.group_context } @@ -185,12 +167,12 @@ impl PublicGroup { /// Create a new [`PublicGroupBuilder`]. pub(crate) fn builder( group_id: GroupId, - crypto_config: CryptoConfig, + ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, ) -> TempBuilderPG1 { TempBuilderPG1 { group_id, - crypto_config, + ciphersuite, credential_with_key, lifetime: None, capabilities: None, diff --git a/openmls/src/group/public_group/diff.rs b/openmls/src/group/public_group/diff.rs index e2b3eafb75..685b7760c1 100644 --- a/openmls/src/group/public_group/diff.rs +++ b/openmls/src/group/public_group/diff.rs @@ -131,7 +131,6 @@ impl<'a> PublicGroupDiff<'a> { exclusion_list: &HashSet<&LeafNodeIndex>, ) -> Result<(Vec, CommitSecret), ApplyUpdatePathError> { let params = DecryptPathParams { - version: self.group_context().protocol_version(), update_path, sender_leaf_index, exclusion_list, diff --git a/openmls/src/group/public_group/diff/compute_path.rs b/openmls/src/group/public_group/diff/compute_path.rs index b41d8784b9..6caa719c2a 100644 --- a/openmls/src/group/public_group/diff/compute_path.rs +++ b/openmls/src/group/public_group/diff/compute_path.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use openmls_traits::{key_store::OpenMlsKeyStore, signatures::Signer, OpenMlsProvider}; +use openmls_traits::signatures::Signer; use tls_codec::Serialize; use crate::{ @@ -8,12 +8,10 @@ use crate::{ credentials::CredentialWithKey, error::LibraryError, extensions::Extensions, - group::{ - config::CryptoConfig, core_group::create_commit_params::CommitType, - errors::CreateCommitError, - }, + group::{core_group::create_commit_params::CommitType, errors::CreateCommitError}, key_packages::{KeyPackage, KeyPackageCreationResult}, schedule::CommitSecret, + storage::OpenMlsProvider, treesync::{ node::{ encryption_keys::EncryptionKeyPair, leaf_node::LeafNode, @@ -37,17 +35,16 @@ pub(crate) struct PathComputationResult { impl<'a> PublicGroupDiff<'a> { #[allow(clippy::too_many_arguments)] - pub(crate) fn compute_path( + pub(crate) fn compute_path( &mut self, - provider: &impl OpenMlsProvider, + provider: &Provider, leaf_index: LeafNodeIndex, exclusion_list: HashSet<&LeafNodeIndex>, commit_type: CommitType, signer: &impl Signer, credential_with_key: Option, extensions: Option, - ) -> Result> { - let version = self.group_context().protocol_version(); + ) -> Result> { let ciphersuite = self.group_context().ciphersuite(); let group_id = self.group_context().group_id().clone(); @@ -61,11 +58,8 @@ impl<'a> PublicGroupDiff<'a> { // The KeyPackage is immediately put into the group. No need for // the init key. init_private_key: _, - } = KeyPackage::builder().build_without_key_storage( - CryptoConfig { - ciphersuite, - version, - }, + } = KeyPackage::builder().build_without_storage( + ciphersuite, provider, signer, credential_with_key.ok_or(CreateCommitError::MissingCredential)?, @@ -82,14 +76,8 @@ impl<'a> PublicGroupDiff<'a> { .diff .leaf_mut(leaf_index) .ok_or_else(|| LibraryError::custom("Unable to get own leaf from diff"))?; - let encryption_keypair = own_diff_leaf.rekey( - &group_id, - leaf_index, - ciphersuite, - version, - provider, - signer, - )?; + let encryption_keypair = + own_diff_leaf.rekey(&group_id, leaf_index, ciphersuite, provider, signer)?; vec![encryption_keypair] }; diff --git a/openmls/src/group/public_group/errors.rs b/openmls/src/group/public_group/errors.rs index a819f68bba..f9dad67fd5 100644 --- a/openmls/src/group/public_group/errors.rs +++ b/openmls/src/group/public_group/errors.rs @@ -7,7 +7,7 @@ use crate::{ /// Public group creation from external error. #[derive(Error, Debug, PartialEq, Clone)] -pub enum CreationFromExternalError { +pub enum CreationFromExternalError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -26,6 +26,9 @@ pub enum CreationFromExternalError { /// We don't support the version of the group we are trying to join. #[error("We don't support the version of the group we are trying to join.")] UnsupportedMlsVersion, + /// Error writing to storage + #[error("Error writing to storage: {0}")] + WriteToStorageError(StorageError), } /// Public group builder error. @@ -34,12 +37,6 @@ pub enum PublicGroupBuildError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), - /// Unsupported proposal type in required capabilities. - #[error("Unsupported proposal type in required capabilities.")] - UnsupportedProposalType, - /// Unsupported extension type in required capabilities. - #[error("Unsupported extension type in required capabilities.")] - UnsupportedExtensionType, /// Invalid extensions set in configuration #[error("Invalid extensions set in configuration")] InvalidExtensions(#[from] InvalidExtensionError), diff --git a/openmls/src/group/public_group/mod.rs b/openmls/src/group/public_group/mod.rs index 8f394c5990..b4ffff0810 100644 --- a/openmls/src/group/public_group/mod.rs +++ b/openmls/src/group/public_group/mod.rs @@ -11,8 +11,6 @@ //! To avoid duplication of code and functionality, [`CoreGroup`] internally //! relies on a [`PublicGroup`] as well. -#[cfg(test)] -use crate::prelude::OpenMlsProvider; #[cfg(test)] use std::collections::HashSet; @@ -38,6 +36,7 @@ use crate::{ ConfirmationTag, PathSecret, }, schedule::CommitSecret, + storage::{OpenMlsProvider, StorageProvider}, treesync::{ errors::{DerivePathError, TreeSyncFromNodesError}, node::{ @@ -72,6 +71,10 @@ pub struct PublicGroup { confirmation_tag: ConfirmationTag, } +/// This is a wrapper type, because we can't implement the storage traits on `Vec`. +#[derive(Debug, Serialize, Deserialize)] +pub struct InterimTranscriptHash(pub Vec); + impl PublicGroup { /// Create a new PublicGroup from a [`TreeSync`] instance and a /// [`GroupInfo`]. @@ -105,13 +108,14 @@ impl PublicGroup { /// This function performs basic validation checks and returns an error if /// one of the checks fails. See [`CreationFromExternalError`] for more /// details. - pub fn from_external( - crypto: &impl OpenMlsCrypto, + pub fn from_external( + provider: &Provider, ratchet_tree: RatchetTreeIn, verifiable_group_info: VerifiableGroupInfo, proposal_store: ProposalStore, - ) -> Result<(Self, GroupInfo), CreationFromExternalError> { + ) -> Result<(Self, GroupInfo), CreationFromExternalError> { let ciphersuite = verifiable_group_info.ciphersuite(); + let crypto = provider.crypto(); let group_id = verifiable_group_info.group_id(); let ratchet_tree = ratchet_tree @@ -160,16 +164,19 @@ impl PublicGroup { )? }; - Ok(( - Self { - treesync, - group_context, - interim_transcript_hash, - confirmation_tag: group_info.confirmation_tag().clone(), - proposal_store, - }, - group_info, - )) + let public_group = Self { + treesync, + group_context, + interim_transcript_hash, + confirmation_tag: group_info.confirmation_tag().clone(), + proposal_store, + }; + + public_group + .store(provider.storage()) + .map_err(CreationFromExternalError::WriteToStorageError)?; + + Ok((public_group, group_info)) } /// Returns the index of the sender of a staged, external commit. @@ -343,6 +350,64 @@ impl PublicGroup { pub(crate) fn owned_encryption_keys(&self, leaf_index: LeafNodeIndex) -> Vec { self.treesync().owned_encryption_keys(leaf_index) } + + /// Stores the [`PublicGroup`] to storage. Called from methods creating a new group and mutating an + /// existing group, both inside [`PublicGroup`] and in [`CoreGroup`]. + /// + /// [`CoreGroup`]: crate::group::core_group::CoreGroup + pub(crate) fn store( + &self, + storage: &Storage, + ) -> Result<(), Storage::Error> { + let group_id = self.group_context.group_id(); + storage.write_tree(group_id, self.treesync())?; + storage.write_confirmation_tag(group_id, self.confirmation_tag())?; + storage.write_context(group_id, self.group_context())?; + storage.write_interim_transcript_hash( + group_id, + &InterimTranscriptHash(self.interim_transcript_hash.clone()), + )?; + Ok(()) + } + + /// Deletes the [`PublicGroup`] from storage. + pub(crate) fn delete( + &self, + storage: &Storage, + ) -> Result<(), Storage::Error> { + storage.delete_tree(self.group_id())?; + storage.delete_confirmation_tag(self.group_id())?; + storage.delete_context(self.group_id())?; + storage.delete_interim_transcript_hash(self.group_id())?; + + Ok(()) + } + + /// Loads the [`PublicGroup`] from storage. Called from [`CoreGroup::load`]. + /// + /// [`CoreGroup::load`]: crate::group::core_group::CoreGroup::load + pub(crate) fn load( + storage: &Storage, + group_id: &GroupId, + ) -> Result, Storage::Error> { + let treesync = storage.treesync(group_id)?; + let group_context = storage.group_context(group_id)?; + let interim_transcript_hash: Option = + storage.interim_transcript_hash(group_id)?; + let confirmation_tag = storage.confirmation_tag(group_id)?; + + let build = || -> Option { + Some(Self { + treesync: treesync?, + proposal_store: ProposalStore::new(), + group_context: group_context?, + interim_transcript_hash: interim_transcript_hash?.0, + confirmation_tag: confirmation_tag?, + }) + }; + + Ok(build()) + } } // Test functions diff --git a/openmls/src/group/public_group/process.rs b/openmls/src/group/public_group/process.rs index 5812a55702..fe6daade9c 100644 --- a/openmls/src/group/public_group/process.rs +++ b/openmls/src/group/public_group/process.rs @@ -1,4 +1,3 @@ -use openmls_traits::crypto::OpenMlsCrypto; use tls_codec::Serialize; use crate::{ @@ -16,6 +15,7 @@ use crate::{ past_secrets::MessageSecretsStore, }, messages::proposals::Proposal, + storage::OpenMlsProvider, }; use super::PublicGroup; @@ -129,11 +129,12 @@ impl PublicGroup { /// - ValSem244 /// - ValSem245 /// - ValSem246 (as part of ValSem010) - pub fn process_message( + pub fn process_message( &self, - crypto: &impl OpenMlsCrypto, + provider: &Provider, message: impl Into, - ) -> Result { + ) -> Result> { + let crypto = provider.crypto(); let protocol_message = message.into(); // Checks the following semantic validation: // - ValSem002 @@ -152,6 +153,7 @@ impl PublicGroup { .tls_serialize_detached() .map_err(LibraryError::missing_bound_check)?, crypto, + self.ciphersuite(), )? } }; @@ -159,7 +161,7 @@ impl PublicGroup { let unverified_message = self .parse_message(decrypted_message, None) .map_err(ProcessMessageError::from)?; - self.process_unverified_message(crypto, unverified_message, &self.proposal_store) + self.process_unverified_message(provider, unverified_message, &self.proposal_store) } } @@ -190,17 +192,18 @@ impl PublicGroup { /// - ValSem242 /// - ValSem244 /// - ValSem246 (as part of ValSem010) - pub(crate) fn process_unverified_message( + pub(crate) fn process_unverified_message( &self, - crypto: &impl OpenMlsCrypto, + provider: &Provider, unverified_message: UnverifiedMessage, proposal_store: &ProposalStore, - ) -> Result { + ) -> Result> { + let crypto = provider.crypto(); // Checks the following semantic validation: // - ValSem010 // - ValSem246 (as part of ValSem010) let (content, credential) = - unverified_message.verify(self.ciphersuite(), crypto, self.version())?; + unverified_message.verify(self.ciphersuite(), provider, self.version())?; match content.sender() { Sender::Member(_) | Sender::NewMemberCommit | Sender::NewMemberProposal => { diff --git a/openmls/src/group/public_group/staged_commit.rs b/openmls/src/group/public_group/staged_commit.rs index c627248e71..342dbcf77f 100644 --- a/openmls/src/group/public_group/staged_commit.rs +++ b/openmls/src/group/public_group/staged_commit.rs @@ -9,8 +9,39 @@ use crate::{ StagedCommit, }, messages::{proposals::ProposalOrRef, Commit}, + storage::StorageProvider, }; +#[derive(Debug, Serialize, Deserialize)] +pub struct PublicStagedCommitState { + pub(super) staged_diff: StagedPublicGroupDiff, + pub(super) update_path_leaf_node: Option, +} + +impl PublicStagedCommitState { + pub fn new( + staged_diff: StagedPublicGroupDiff, + update_path_leaf_node: Option, + ) -> Self { + Self { + staged_diff, + update_path_leaf_node, + } + } + + pub(crate) fn into_staged_diff(self) -> StagedPublicGroupDiff { + self.staged_diff + } + + pub fn update_path_leaf_node(&self) -> Option<&LeafNode> { + self.update_path_leaf_node.as_ref() + } + + pub fn staged_diff(&self) -> &StagedPublicGroupDiff { + &self.staged_diff + } +} + impl PublicGroup { pub(crate) fn validate_commit<'a>( &self, @@ -69,6 +100,7 @@ impl PublicGroup { })?; // Validate the staged proposals by doing the following checks: + // ValSem101 // ValSem102 // ValSem103 @@ -82,6 +114,9 @@ impl PublicGroup { // ValSem107 // ValSem108 self.validate_remove_proposals(&proposal_queue)?; + // ValSem113: All Proposals: The proposal type must be supported by all + // members of the group + self.validate_proposal_type_support(&proposal_queue)?; // ValSem208 // ValSem209 self.validate_group_context_extensions_proposal(&proposal_queue)?; @@ -172,23 +207,22 @@ impl PublicGroup { /// - ValSem244 /// Returns an error if the given commit was sent by the owner of this /// group. - /// TODO #1255: This will be used by the `process_message` function of the - /// `PublicGroup` later on. - #[allow(unused)] pub(crate) fn stage_commit( &self, mls_content: &AuthenticatedContent, proposal_store: &ProposalStore, crypto: &impl OpenMlsCrypto, ) -> Result { - let ciphersuite = self.ciphersuite(); - let (commit, proposal_queue, sender_index) = self.validate_commit(mls_content, proposal_store, crypto)?; let staged_diff = self.stage_diff(mls_content, &proposal_queue, sender_index, crypto)?; + let staged_state = PublicStagedCommitState { + staged_diff, + update_path_leaf_node: commit.path.as_ref().map(|p| p.leaf_node().clone()), + }; - let staged_commit_state = StagedCommitState::PublicState(Box::new(staged_diff)); + let staged_commit_state = StagedCommitState::PublicState(Box::new(staged_state)); Ok(StagedCommit::new(proposal_queue, staged_commit_state)) } @@ -244,14 +278,19 @@ impl PublicGroup { } /// Merges a [StagedCommit] into the public group state. - /// - /// This function should not fail and only returns a [`Result`], because it - /// might throw a `LibraryError`. - pub fn merge_commit(&mut self, staged_commit: StagedCommit) { + pub fn merge_commit( + &mut self, + storage: &Storage, + staged_commit: StagedCommit, + ) -> Result<(), MergeCommitError> { match staged_commit.into_state() { - StagedCommitState::PublicState(staged_diff) => self.merge_diff(*staged_diff), + StagedCommitState::PublicState(staged_state) => { + self.merge_diff(staged_state.staged_diff); + } StagedCommitState::GroupMember(_) => (), } - self.proposal_store.empty() + + self.proposal_store.empty(); + self.store(storage).map_err(MergeCommitError::StorageError) } } diff --git a/openmls/src/group/public_group/tests.rs b/openmls/src/group/public_group/tests.rs index 15a6eda321..6f39076b33 100644 --- a/openmls/src/group/public_group/tests.rs +++ b/openmls/src/group/public_group/tests.rs @@ -1,7 +1,4 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; -use rstest::*; -use rstest_reuse::{self, *}; +use openmls_traits::prelude::*; use crate::{ binary_tree::LeafNodeIndex, @@ -10,16 +7,16 @@ use crate::{ ProcessedMessageContent, ProtocolMessage, Sender, }, group::{ - config::CryptoConfig, test_core_group::setup_client, GroupId, MlsGroup, - MlsGroupCreateConfig, ProposalStore, StagedCommit, PURE_PLAINTEXT_WIRE_FORMAT_POLICY, + test_core_group::setup_client, GroupId, MlsGroup, MlsGroupCreateConfig, ProposalStore, + StagedCommit, PURE_PLAINTEXT_WIRE_FORMAT_POLICY, }, messages::proposals::Proposal, }; use super::{super::mls_group::StagedWelcome, PublicGroup}; -#[apply(ciphersuites_and_providers)] -fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn public_group(ciphersuite: Ciphersuite, provider: &Provider) { let group_id = GroupId::from_slice(b"Test Group"); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_pk) = @@ -33,7 +30,7 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Set plaintext wire format policy s.t. the public group can track changes. let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(PURE_PLAINTEXT_WIRE_FORMAT_POLICY) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -48,13 +45,13 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // === Create a public group that tracks the changes throughout this test === let verifiable_group_info = alice_group - .export_group_info(provider.crypto(), &alice_signer, false) + .export_group_info(provider, &alice_signer, false) .unwrap() .into_verifiable_group_info() .unwrap(); let ratchet_tree = alice_group.export_ratchet_tree(); let (mut public_group, _extensions) = PublicGroup::from_external( - provider.crypto(), + provider, ratchet_tree.into(), verifiable_group_info, ProposalStore::new(), @@ -75,7 +72,7 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ProtocolMessage::PublicMessage(public_message) => public_message, }; let processed_message = public_group - .process_message(provider.crypto(), public_message) + .process_message(provider, public_message) .unwrap(); // Further inspection of the message can take place here ... @@ -87,7 +84,9 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } ProcessedMessageContent::StagedCommitMessage(staged_commit) => { // Merge the diff - public_group.merge_commit(*staged_commit) + public_group + .merge_commit(provider.storage(), *staged_commit) + .unwrap() } }; @@ -135,9 +134,11 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // The public group processes let ppm = public_group - .process_message(provider.crypto(), into_public_message(queued_messages)) + .process_message(provider, into_public_message(queued_messages)) + .unwrap(); + public_group + .merge_commit(provider.storage(), extract_staged_commit(ppm)) .unwrap(); - public_group.merge_commit(extract_staged_commit(ppm)); // Bob merges bob_group @@ -177,7 +178,7 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // The public group processes let ppm = public_group - .process_message(provider.crypto(), into_public_message(queued_messages)) + .process_message(provider, into_public_message(queued_messages)) .unwrap(); // We have to add the proposal to the public group's proposal store. match ppm.into_content() { @@ -201,7 +202,9 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Bob was removed assert_eq!(remove_proposal.removed(), LeafNodeIndex::new(1)); // Store proposal - charlie_group.store_pending_proposal(*staged_proposal.clone()); + charlie_group + .store_pending_proposal(provider.storage(), *staged_proposal.clone()) + .expect("error writing to storage"); } else { unreachable!("Expected a Proposal."); } @@ -222,12 +225,11 @@ fn public_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // The public group processes let ppm = public_group - .process_message( - provider.crypto(), - into_public_message(queued_messages.clone()), - ) + .process_message(provider, into_public_message(queued_messages.clone())) + .unwrap(); + public_group + .merge_commit(provider.storage(), extract_staged_commit(ppm)) .unwrap(); - public_group.merge_commit(extract_staged_commit(ppm)); // Check that we receive the correct proposal if let Some(staged_commit) = charlie_group.pending_commit() { diff --git a/openmls/src/group/public_group/validation.rs b/openmls/src/group/public_group/validation.rs index 039ded3d28..b7969e0f42 100644 --- a/openmls/src/group/public_group/validation.rs +++ b/openmls/src/group/public_group/validation.rs @@ -121,6 +121,43 @@ impl PublicGroup { // === Proposals === + /// Validate that all group members support the types of all proposals. + pub(crate) fn validate_proposal_type_support( + &self, + proposal_queue: &ProposalQueue, + ) -> Result<(), ProposalValidationError> { + let mut leaves = self.treesync().full_leaves(); + let Some(first_leaf) = leaves.next() else { + return Ok(()); + }; + // Initialize the capabilities intersection with the capabilities of the + // first leaf node. + let mut capabilities_intersection = first_leaf + .capabilities() + .proposals() + .iter() + .collect::>(); + // Iterate over the remaining leaf nodes and intersect their capabilities + for leaf_node in leaves { + let leaf_capabilities_set = leaf_node.capabilities().proposals().iter().collect(); + capabilities_intersection = capabilities_intersection + .intersection(&leaf_capabilities_set) + .cloned() + .collect(); + } + + // Check that the types of all proposals are supported by all members + for proposal in proposal_queue.queued_proposals() { + let proposal_type = proposal.proposal().proposal_type(); + if matches!(proposal_type, ProposalType::Custom(_)) + && !capabilities_intersection.contains(&proposal_type) + { + return Err(ProposalValidationError::UnsupportedProposalType); + } + } + Ok(()) + } + /// Validate key uniqueness. This function implements the following checks: /// - ValSem101: Add Proposal: Signature public key in proposals must be unique among proposals & members /// - ValSem102: Add Proposal: Init key in proposals must be unique among proposals @@ -477,7 +514,10 @@ impl PublicGroup { let is_inline = p.proposal_or_ref_type() == ProposalOrRefType::Proposal; let is_allowed_type = matches!( p.proposal(), - Proposal::ExternalInit(_) | Proposal::Remove(_) | Proposal::PreSharedKey(_) + Proposal::ExternalInit(_) + | Proposal::Remove(_) + | Proposal::PreSharedKey(_) + | Proposal::Custom(_) ); is_inline && !is_allowed_type }); diff --git a/openmls/src/group/tests/external_add_proposal.rs b/openmls/src/group/tests/external_add_proposal.rs index 277e8b40a7..a284ec20b1 100644 --- a/openmls/src/group/tests/external_add_proposal.rs +++ b/openmls/src/group/tests/external_add_proposal.rs @@ -1,19 +1,17 @@ use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; -use rstest::*; -use rstest_reuse::{self, *}; +use openmls_test::openmls_test; use crate::{ binary_tree::LeafNodeIndex, framing::*, - group::{config::CryptoConfig, *}, + group::*, messages::{ external_proposals::*, proposals::{AddProposal, Proposal, ProposalType}, }, }; -use openmls_traits::types::Ciphersuite; +use openmls_traits::{types::Ciphersuite, OpenMlsProvider as _}; use super::utils::*; @@ -27,7 +25,7 @@ fn new_test_group( identity: &str, wire_format_policy: WireFormatPolicy, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> (MlsGroup, CredentialWithKeyAndSigner) { let group_id = GroupId::from_slice(b"Test Group"); @@ -38,7 +36,7 @@ fn new_test_group( // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); ( @@ -58,7 +56,7 @@ fn new_test_group( fn validation_test_setup( wire_format_policy: WireFormatPolicy, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> ProposalValidationTestSetup { // === Alice creates a group === let (mut alice_group, alice_signer_with_keys) = @@ -75,7 +73,11 @@ fn validation_test_setup( ); let (_message, welcome, _group_info) = alice_group - .add_members(provider, &alice_signer_with_keys.signer, &[bob_key_package]) + .add_members( + provider, + &alice_signer_with_keys.signer, + &[bob_key_package.key_package().clone()], + ) .expect("error adding Bob to group"); alice_group @@ -108,8 +110,8 @@ fn validation_test_setup( } } -#[apply(ciphersuites_and_providers)] -fn external_add_proposal_should_succeed(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn external_add_proposal_should_succeed() { for policy in WIRE_FORMAT_POLICIES { let ProposalValidationTestSetup { alice_group, @@ -135,13 +137,14 @@ fn external_add_proposal_should_succeed(ciphersuite: Ciphersuite, provider: &imp charlie_credential.clone(), ); - let proposal = JoinProposal::new( - charlie_kp.clone(), - alice_group.group_id().clone(), - alice_group.epoch(), - &charlie_credential.signer, - ) - .unwrap(); + let proposal = + JoinProposal::new::<::StorageProvider>( + charlie_kp.key_package().clone(), + alice_group.group_id().clone(), + alice_group.epoch(), + &charlie_credential.signer, + ) + .unwrap(); // an external proposal is always plaintext and has sender type 'new_member_proposal' let verify_proposal = |msg: &PublicMessage| { @@ -162,9 +165,11 @@ fn external_add_proposal_should_succeed(ciphersuite: Ciphersuite, provider: &imp assert!(matches!(proposal.sender(), Sender::NewMemberProposal)); assert!(matches!( proposal.proposal(), - Proposal::Add(AddProposal { key_package }) if key_package == &charlie_kp + Proposal::Add(AddProposal { key_package }) if key_package == charlie_kp.key_package() )); - alice_group.store_pending_proposal(*proposal) + alice_group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap() } _ => unreachable!(), } @@ -174,9 +179,9 @@ fn external_add_proposal_should_succeed(ciphersuite: Ciphersuite, provider: &imp .unwrap(); match msg.into_content() { - ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => { - bob_group.store_pending_proposal(*proposal) - } + ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => bob_group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap(), _ => unreachable!(), } @@ -221,11 +226,10 @@ fn external_add_proposal_should_succeed(ciphersuite: Ciphersuite, provider: &imp } } -#[apply(ciphersuites_and_providers)] -fn external_add_proposal_should_be_signed_by_key_package_it_references( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test] +fn external_add_proposal_should_be_signed_by_key_package_it_references< + Provider: OpenMlsProvider, +>() { let ProposalValidationTestSetup { alice_group, .. } = validation_test_setup(PURE_PLAINTEXT_WIRE_FORMAT_POLICY, ciphersuite, provider); let (mut alice_group, _alice_signer) = alice_group; @@ -250,13 +254,14 @@ fn external_add_proposal_should_be_signed_by_key_package_it_references( attacker_credential, ); - let invalid_proposal = JoinProposal::new( - charlie_kp, - alice_group.group_id().clone(), - alice_group.epoch(), - &charlie_credential.signer, - ) - .unwrap(); + let invalid_proposal = + JoinProposal::new::<::StorageProvider>( + charlie_kp.key_package().clone(), + alice_group.group_id().clone(), + alice_group.epoch(), + &charlie_credential.signer, + ) + .unwrap(); // fails because the message was not signed by the same credential as the one in the Add proposal assert!(matches!( @@ -268,11 +273,8 @@ fn external_add_proposal_should_be_signed_by_key_package_it_references( } // TODO #1093: move this test to a dedicated external proposal ValSem test module once all external proposals implemented -#[apply(ciphersuites_and_providers)] -fn new_member_proposal_sender_should_be_reserved_for_join_proposals( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test] +fn new_member_proposal_sender_should_be_reserved_for_join_proposals() { let ProposalValidationTestSetup { alice_group, bob_group, @@ -291,13 +293,14 @@ fn new_member_proposal_sender_should_be_reserved_for_join_proposals( any_credential.clone(), ); - let join_proposal = JoinProposal::new( - any_kp, - alice_group.group_id().clone(), - alice_group.epoch(), - &any_credential.signer, - ) - .unwrap(); + let join_proposal = + JoinProposal::new::<::StorageProvider>( + any_kp.key_package().clone(), + alice_group.group_id().clone(), + alice_group.epoch(), + &any_credential.signer, + ) + .unwrap(); if let MlsMessageBodyOut::PublicMessage(plaintext) = &join_proposal.body { // Make sure it's an add proposal... @@ -316,7 +319,9 @@ fn new_member_proposal_sender_should_be_reserved_for_join_proposals( } else { panic!() }; - alice_group.clear_pending_proposals(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); // Remove proposal cannot have a 'new_member_proposal' sender let remove_proposal = alice_group @@ -332,7 +337,9 @@ fn new_member_proposal_sender_should_be_reserved_for_join_proposals( } else { panic!() }; - alice_group.clear_pending_proposals(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); // Update proposal cannot have a 'new_member_proposal' sender let update_proposal = alice_group @@ -348,5 +355,7 @@ fn new_member_proposal_sender_should_be_reserved_for_join_proposals( } else { panic!() }; - alice_group.clear_pending_proposals(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); } diff --git a/openmls/src/group/tests/external_remove_proposal.rs b/openmls/src/group/tests/external_remove_proposal.rs index ce8f0d1268..c6393d3d46 100644 --- a/openmls/src/group/tests/external_remove_proposal.rs +++ b/openmls/src/group/tests/external_remove_proposal.rs @@ -1,16 +1,8 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; -use rstest::*; -use rstest_reuse::{self, *}; -use tls_codec::Deserialize; +use openmls_test::openmls_test; -use crate::{ - credentials::BasicCredential, - framing::*, - group::{config::CryptoConfig, *}, - messages::external_proposals::*, -}; +use crate::{credentials::BasicCredential, framing::*, group::*, messages::external_proposals::*}; -use openmls_traits::types::Ciphersuite; +use openmls_traits::{types::Ciphersuite, OpenMlsProvider as _}; use super::utils::*; @@ -19,7 +11,7 @@ fn new_test_group( identity: &str, wire_format_policy: WireFormatPolicy, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, external_senders: ExternalSendersExtension, ) -> (MlsGroup, CredentialWithKeyAndSigner) { let group_id = GroupId::from_slice(b"Test Group"); @@ -31,7 +23,7 @@ fn new_test_group( // Define the MlsGroup configuration let mls_group_config = MlsGroupCreateConfig::builder() .wire_format_policy(wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .with_group_context_extensions(Extensions::single(Extension::ExternalSenders( external_senders, ))) @@ -55,7 +47,7 @@ fn new_test_group( fn validation_test_setup( wire_format_policy: WireFormatPolicy, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, external_senders: ExternalSendersExtension, ) -> (MlsGroup, CredentialWithKeyAndSigner) { // === Alice creates a group === @@ -78,7 +70,11 @@ fn validation_test_setup( ); alice_group - .add_members(provider, &alice_signer_when_keys.signer, &[bob_key_package]) + .add_members( + provider, + &alice_signer_when_keys.signer, + &[bob_key_package.key_package().clone()], + ) .expect("error adding Bob to group"); alice_group @@ -89,11 +85,8 @@ fn validation_test_setup( (alice_group, alice_signer_when_keys) } -#[apply(ciphersuites_and_providers)] -fn external_remove_proposal_should_remove_member( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test] +fn external_remove_proposal_should_remove_member() { // delivery service credentials. DS will craft an external remove proposal let ds_credential_with_key = generate_credential_with_key( "delivery-service".into(), @@ -129,14 +122,14 @@ fn external_remove_proposal_should_remove_member( let bob_index = alice_group .members() .find(|member| { - let credential = BasicCredential::try_from(&member.credential).unwrap(); + let credential = BasicCredential::try_from(member.credential.clone()).unwrap(); let identity = credential.identity(); identity == b"Bob" }) .map(|member| member.index) .unwrap(); // Now Delivery Service wants to (already) remove Bob - let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove( + let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove::( bob_index, alice_group.group_id().clone(), alice_group.epoch(), @@ -161,23 +154,26 @@ fn external_remove_proposal_should_remove_member( else { panic!("Not a remove proposal"); }; - alice_group.store_pending_proposal(*remove_proposal); + alice_group + .store_pending_proposal(provider.storage(), *remove_proposal) + .unwrap(); alice_group .commit_to_pending_proposals(provider, &alice_credential.signer) .unwrap(); alice_group.merge_pending_commit(provider).unwrap(); // Trying to do an external remove proposal of Bob now should fail as he no longer is in the group - let invalid_bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove( - // Bob is no longer in the group - bob_index, - alice_group.group_id().clone(), - alice_group.epoch(), - &ds_credential_with_key.signer, - SenderExtensionIndex::new(0), - ) - .unwrap() - .into(); + let invalid_bob_external_remove_proposal: MlsMessageIn = + ExternalProposal::new_remove::( + // Bob is no longer in the group + bob_index, + alice_group.group_id().clone(), + alice_group.epoch(), + &ds_credential_with_key.signer, + SenderExtensionIndex::new(0), + ) + .unwrap() + .into(); let processed_message = alice_group .process_message( provider, @@ -192,8 +188,10 @@ fn external_remove_proposal_should_remove_member( else { panic!("Not a remove proposal"); }; - alice_group.store_pending_proposal(*remove_proposal); - assert_eq!( + alice_group + .store_pending_proposal(provider.storage(), *remove_proposal) + .unwrap(); + assert!(matches!( alice_group .commit_to_pending_proposals(provider, &alice_credential.signer) .unwrap_err(), @@ -202,14 +200,13 @@ fn external_remove_proposal_should_remove_member( ProposalValidationError::UnknownMemberRemoval ) ) - ); + )); } -#[apply(ciphersuites_and_providers)] -fn external_remove_proposal_should_fail_when_invalid_external_senders_index( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test] +fn external_remove_proposal_should_fail_when_invalid_external_senders_index< + Provider: OpenMlsProvider, +>() { // delivery service credentials. DS will craft an external remove proposal let ds_credential_with_key = generate_credential_with_key( "delivery-service".into(), @@ -236,15 +233,11 @@ fn external_remove_proposal_should_fail_when_invalid_external_senders_index( // get Bob's index let bob_index = alice_group .members() - .find(|member| { - let identity = - VLBytes::tls_deserialize_exact(member.credential.serialized_content()).unwrap(); - identity.as_slice() == b"Bob" - }) + .find(|member| member.credential.serialized_content() == b"Bob") .map(|member| member.index) .unwrap(); // Now Delivery Service wants to (already) remove Bob with invalid sender index - let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove( + let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove::( bob_index, alice_group.group_id().clone(), alice_group.epoch(), @@ -263,17 +256,14 @@ fn external_remove_proposal_should_fail_when_invalid_external_senders_index( .unwrap(), ) .unwrap_err(); - assert_eq!( + assert!(matches!( error, ProcessMessageError::ValidationError(ValidationError::UnauthorizedExternalSender) - ); + )); } -#[apply(ciphersuites_and_providers)] -fn external_remove_proposal_should_fail_when_invalid_signature( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test] +fn external_remove_proposal_should_fail_when_invalid_signature() { // delivery service credentials. DS will craft an external remove proposal let ds_credential_with_key = generate_credential_with_key( "delivery-service".into(), @@ -303,15 +293,11 @@ fn external_remove_proposal_should_fail_when_invalid_signature( // get Bob's index let bob_index = alice_group .members() - .find(|member| { - let identity = - VLBytes::tls_deserialize_exact(member.credential.serialized_content()).unwrap(); - identity.as_slice() == b"Bob" - }) + .find(|member| member.credential.serialized_content() == b"Bob") .map(|member| member.index) .unwrap(); // Now Delivery Service wants to (already) remove Bob with invalid sender index - let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove( + let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove::( bob_index, alice_group.group_id().clone(), alice_group.epoch(), @@ -330,14 +316,11 @@ fn external_remove_proposal_should_fail_when_invalid_signature( .unwrap(), ) .unwrap_err(); - assert_eq!(error, ProcessMessageError::InvalidSignature); + assert!(matches!(error, ProcessMessageError::InvalidSignature)); } -#[apply(ciphersuites_and_providers)] -fn external_remove_proposal_should_fail_when_no_external_senders( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { +#[openmls_test] +fn external_remove_proposal_should_fail_when_no_external_senders() { let (mut alice_group, _) = validation_test_setup( PURE_PLAINTEXT_WIRE_FORMAT_POLICY, ciphersuite, @@ -354,15 +337,11 @@ fn external_remove_proposal_should_fail_when_no_external_senders( // get Bob's index let bob_index = alice_group .members() - .find(|member| { - let identity = - VLBytes::tls_deserialize_exact(member.credential.serialized_content()).unwrap(); - identity.as_slice() == b"Bob" - }) + .find(|member| member.credential.serialized_content() == b"Bob") .map(|member| member.index) .unwrap(); // Now Delivery Service wants to remove Bob with invalid sender index but there's no extension - let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove( + let bob_external_remove_proposal: MlsMessageIn = ExternalProposal::new_remove::( bob_index, alice_group.group_id().clone(), alice_group.epoch(), @@ -381,8 +360,8 @@ fn external_remove_proposal_should_fail_when_no_external_senders( .unwrap(), ) .unwrap_err(); - assert_eq!( + assert!(matches!( error, ProcessMessageError::ValidationError(ValidationError::UnauthorizedExternalSender) - ); + )); } diff --git a/openmls/src/group/tests/kat_messages.rs b/openmls/src/group/tests/kat_messages.rs index b67cc3db60..19aee1a73a 100644 --- a/openmls/src/group/tests/kat_messages.rs +++ b/openmls/src/group/tests/kat_messages.rs @@ -4,7 +4,6 @@ //! See //! for more description on the test vectors. -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{random::OpenMlsRand, types::SignatureScheme, OpenMlsProvider}; use rand::{rngs::OsRng, RngCore}; use serde::{self, Deserialize, Serialize}; @@ -16,7 +15,6 @@ use crate::{ ciphersuite::Mac, framing::*, group::{ - config::CryptoConfig, tests::utils::{generate_credential_with_key, generate_key_package, randombytes}, *, }, @@ -139,7 +137,7 @@ pub fn generate_test_vector(ciphersuite: Ciphersuite) -> MessagesTestVector { // Let's create a group let mut alice_group = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_key_and_signer .credential_with_key .clone(), @@ -172,10 +170,7 @@ pub fn generate_test_vector(ciphersuite: Ciphersuite) -> MessagesTestVector { ); LeafNode::generate_update( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::Mls10, - }, + ciphersuite, alice_credential_with_key_and_signer .credential_with_key .clone(), @@ -199,7 +194,7 @@ pub fn generate_test_vector(ciphersuite: Ciphersuite) -> MessagesTestVector { &provider, ); - let bob_key_package_bundle = KeyPackageBundle::new( + let bob_key_package_bundle = KeyPackageBundle::generate( &provider, &bob_credential_with_key_and_signer.signer, ciphersuite, @@ -361,12 +356,7 @@ pub fn generate_test_vector(ciphersuite: Ciphersuite) -> MessagesTestVector { .tls_serialize_detached() .unwrap(), - group_secrets: GroupSecrets::random_encoded( - ciphersuite, - provider.rand(), - ProtocolVersion::default(), - ) - .unwrap(), + group_secrets: GroupSecrets::random_encoded(ciphersuite, provider.rand()).unwrap(), ratchet_tree: alice_ratchet_tree.tls_serialize_detached().unwrap(), add_proposal: add_proposal.tls_serialize_detached().unwrap(), diff --git a/openmls/src/group/tests/kat_transcript_hashes.rs b/openmls/src/group/tests/kat_transcript_hashes.rs index 3925454cc5..5d7ec46c94 100644 --- a/openmls/src/group/tests/kat_transcript_hashes.rs +++ b/openmls/src/group/tests/kat_transcript_hashes.rs @@ -3,7 +3,6 @@ //! See //! for more description on the test vectors. -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{crypto::OpenMlsCrypto, random::OpenMlsRand, OpenMlsProvider}; use serde::{self, Deserialize, Serialize}; use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait}; @@ -20,7 +19,6 @@ use crate::{ messages::*, schedule::*, test_utils::*, - versions::ProtocolVersion, }; const TEST_VECTOR_PATH_READ: &str = "test_vectors/transcript-hashes.json"; @@ -90,14 +88,12 @@ pub fn run_test_vector(test_vector: TranscriptTestVector) { )); // ... and `authenticated_content.auth.confirmation_tag` is a valid MAC for `authenticated_content` with key `confirmation_key` and input `confirmed_transcript_hash_after`. - let confirmation_key = ConfirmationKey::from_secret(Secret::from_slice( - &test_vector.confirmation_key, - ProtocolVersion::default(), - ciphersuite, - )); + let confirmation_key = + ConfirmationKey::from_secret(Secret::from_slice(&test_vector.confirmation_key)); let got_confirmation_tag = confirmation_key .tag( provider.crypto(), + ciphersuite, &test_vector.confirmed_transcript_hash_after, ) .unwrap(); @@ -233,7 +229,11 @@ pub fn generate_test_vector(ciphersuite: Ciphersuite) -> TranscriptTestVector { // ... and the `confirmation_tag` ... let confirmation_tag = { confirmation_key - .tag(provider.crypto(), &confirmed_transcript_hash_after) + .tag( + provider.crypto(), + ciphersuite, + &confirmed_transcript_hash_after, + ) .unwrap() }; diff --git a/openmls/src/group/tests/test_commit_validation.rs b/openmls/src/group/tests/test_commit_validation.rs index 25b20d5e6a..186f116668 100644 --- a/openmls/src/group/tests/test_commit_validation.rs +++ b/openmls/src/group/tests/test_commit_validation.rs @@ -1,10 +1,7 @@ //! This module tests the validation of commits as defined in //! https://openmls.tech/book/message_validation.html#commit-message-validation -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{signatures::Signer, types::Ciphersuite}; -use rstest::*; -use rstest_reuse::{self, *}; +use openmls_traits::{prelude::*, signatures::Signer, types::Ciphersuite}; use tls_codec::{Deserialize, Serialize}; use super::utils::{ @@ -14,13 +11,12 @@ use crate::{ binary_tree::LeafNodeIndex, ciphersuite::signable::Signable, framing::*, - group::{config::CryptoConfig, *}, + group::*, messages::proposals::*, schedule::{ExternalPsk, PreSharedKeyId, Psk}, treesync::{ errors::ApplyUpdatePathError, node::parent_node::PlainUpdatePathNode, treekem::UpdatePath, }, - versions::ProtocolVersion, }; struct CommitValidationTestSetup { @@ -34,7 +30,7 @@ struct CommitValidationTestSetup { fn validation_test_setup( wire_format_policy: WireFormatPolicy, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> CommitValidationTestSetup { let group_id = GroupId::from_slice(b"Test Group"); @@ -66,7 +62,7 @@ fn validation_test_setup( let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -83,7 +79,10 @@ fn validation_test_setup( .add_members( provider, &alice_credential.signer, - &[bob_key_package, charlie_key_package], + &[ + bob_key_package.key_package().clone(), + charlie_key_package.key_package().clone(), + ], ) .expect("error adding Bob to group"); @@ -125,8 +124,8 @@ fn validation_test_setup( } // ValSem200: Commit must not cover inline self Remove proposal -#[apply(ciphersuites_and_providers)] -fn test_valsem200(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem200() { // Test with PublicMessage let CommitValidationTestSetup { mut alice_group, @@ -163,7 +162,9 @@ fn test_valsem200(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // We have to clear the pending proposals so Alice doesn't try to commit to // her own remove. - alice_group.clear_pending_proposals(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); // Now let's stick it in the commit. let serialized_message = alice_group @@ -219,6 +220,7 @@ fn test_valsem200(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { signed_plaintext .set_membership_tag( provider.crypto(), + ciphersuite, membership_key, alice_group.group().message_secrets().serialized_context(), ) @@ -231,10 +233,10 @@ fn test_valsem200(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could process unverified message despite self remove."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::AttemptedSelfRemoval) - ); + )); // Positive case bob_group @@ -243,8 +245,8 @@ fn test_valsem200(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem201: Path must be present, if at least one proposal requires a path -#[apply(ciphersuites_and_providers)] -fn test_valsem201(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem201() { let wire_format_policy = PURE_PLAINTEXT_WIRE_FORMAT_POLICY; // Test with PublicMessage let CommitValidationTestSetup { @@ -275,12 +277,12 @@ fn test_valsem201(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { generate_key_package(ciphersuite, Extensions::empty(), provider, dave_credential); queued(Proposal::Add(AddProposal { - key_package: dave_key_package, + key_package: dave_key_package.key_package().clone(), })) }; let psk_proposal = || { - let secret = Secret::random(ciphersuite, provider.rand(), None).unwrap(); + let secret = Secret::random(ciphersuite, provider.rand()).unwrap(); let rand = provider .rand() .random_vec(ciphersuite.hash_length()) @@ -291,9 +293,7 @@ fn test_valsem201(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { Psk::External(ExternalPsk::new(rand)), ) .unwrap(); - psk_id - .write_to_key_store(provider, ciphersuite, secret.as_slice()) - .unwrap(); + psk_id.store(provider, secret.as_slice()).unwrap(); queued(Proposal::PreSharedKey(PreSharedKeyProposal::new(psk_id))) }; @@ -340,9 +340,11 @@ fn test_valsem201(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { for (proposal, is_path_required) in cases { // create a commit containing the proposals - proposal - .into_iter() - .for_each(|p| alice_group.store_pending_proposal(p)); + proposal.into_iter().for_each(|p| { + alice_group + .store_pending_proposal(provider.storage(), p) + .unwrap() + }); let params = CreateCommitParams::builder() .framing_parameters(alice_group.framing_parameters()) @@ -368,6 +370,7 @@ fn test_valsem201(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { commit .set_membership_tag( provider.crypto(), + ciphersuite, membership_key, alice_group.group().message_secrets().serialized_context(), ) @@ -376,15 +379,16 @@ fn test_valsem201(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { if is_path_required { let commit_wo_path = erase_path( provider, + ciphersuite, commit.clone(), &alice_group, &alice_credential.signer, ); let processed_msg = bob_group.process_message(provider, commit_wo_path); - assert_eq!( + assert!(matches!( processed_msg.unwrap_err(), ProcessMessageError::InvalidCommit(StageCommitError::RequiredPathNotFound) - ); + )); } // Positive case @@ -392,14 +396,19 @@ fn test_valsem201(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { assert!(process_message_result.is_ok(), "{process_message_result:?}"); // cleanup & restore for next iteration - alice_group.clear_pending_proposals(); - alice_group.clear_pending_commit(); - bob_group.clear_pending_commit(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); + bob_group.clear_pending_commit(provider.storage()).unwrap(); } } fn erase_path( - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, + ciphersuite: Ciphersuite, mut plaintext: PublicMessage, alice_group: &MlsGroup, alice_signer: &impl Signer, @@ -422,14 +431,15 @@ fn erase_path( &original_plaintext, provider, alice_signer, + ciphersuite, ); plaintext.into() } // ValSem202: Path must be the right length -#[apply(ciphersuites_and_providers)] -fn test_valsem202(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem202() { // Test with PublicMessage let CommitValidationTestSetup { mut alice_group, @@ -473,6 +483,7 @@ fn test_valsem202(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { &original_plaintext, provider, &alice_credential.signer, + ciphersuite, ); let update_message_in = ProtocolMessage::from(plaintext); @@ -481,12 +492,12 @@ fn test_valsem202(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process unverified message despite path length mismatch."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::UpdatePathError( ApplyUpdatePathError::PathLengthMismatch )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -504,8 +515,8 @@ fn test_valsem202(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem203: Path secrets must decrypt correctly -#[apply(ciphersuites_and_providers)] -fn test_valsem203(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem203() { // Test with PublicMessage let CommitValidationTestSetup { mut alice_group, @@ -551,6 +562,7 @@ fn test_valsem203(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { &original_plaintext, provider, &alice_credential.signer, + ciphersuite, ); let update_message_in = ProtocolMessage::from(plaintext); @@ -559,12 +571,12 @@ fn test_valsem203(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process unverified message despite scrambled ciphertexts."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::UpdatePathError( ApplyUpdatePathError::UnableToDecrypt )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -582,8 +594,8 @@ fn test_valsem203(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem204: Public keys from Path must be verified and match the private keys from the direct path -#[apply(ciphersuites_and_providers)] -fn test_valsem204(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem204() { // Test with PublicMessage let CommitValidationTestSetup { mut alice_group, @@ -645,9 +657,7 @@ fn test_valsem204(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .map(|upn| { PlainUpdatePathNode::new( upn.encryption_key().clone(), - Secret::random(ciphersuite, provider.rand(), ProtocolVersion::default()) - .unwrap() - .into(), + Secret::random(ciphersuite, provider.rand()).unwrap().into(), ) }) .collect(); @@ -675,6 +685,7 @@ fn test_valsem204(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { &original_plaintext, provider, &alice_credential.signer, + ciphersuite, ); let update_message_in = ProtocolMessage::from(plaintext); @@ -683,12 +694,12 @@ fn test_valsem204(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process unverified message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::UpdatePathError( ApplyUpdatePathError::PathMismatch )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -706,8 +717,8 @@ fn test_valsem204(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem205: Confirmation tag must be successfully verified -#[apply(ciphersuites_and_providers)] -fn test_valsem205(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem205() { // Test with PublicMessage let CommitValidationTestSetup { mut alice_group, @@ -749,6 +760,7 @@ fn test_valsem205(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { plaintext .set_membership_tag( provider.crypto(), + ciphersuite, membership_key, alice_group.group().message_secrets().serialized_context(), ) @@ -760,10 +772,10 @@ fn test_valsem205(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process unverified message despite confirmation tag mismatch."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ConfirmationTagMismatch) - ); + )); // Positive case bob_group @@ -772,8 +784,11 @@ fn test_valsem205(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // this ensures that a member can process commits not containing all the stored proposals -#[apply(ciphersuites_and_providers)] -fn test_partial_proposal_commit(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_partial_proposal_commit( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Test with PublicMessage let CommitValidationTestSetup { mut alice_group, @@ -784,11 +799,7 @@ fn test_partial_proposal_commit(ciphersuite: Ciphersuite, provider: &impl OpenMl let charlie_index = alice_group .members() - .find(|m| { - let identity = - VLBytes::tls_deserialize_exact(m.credential.serialized_content()).unwrap(); - identity.as_slice() == b"Charlie" - }) + .find(|m| m.credential.serialized_content() == b"Charlie") .unwrap() .index; @@ -801,7 +812,9 @@ fn test_partial_proposal_commit(ciphersuite: Ciphersuite, provider: &impl OpenMl .process_message(provider, proposal_1.try_into_protocol_message().unwrap()) .unwrap(); match proposal_1.into_content() { - ProcessedMessageContent::ProposalMessage(p) => bob_group.store_pending_proposal(*p), + ProcessedMessageContent::ProposalMessage(p) => bob_group + .store_pending_proposal(provider.storage(), *p) + .unwrap(), _ => unreachable!(), } @@ -814,7 +827,9 @@ fn test_partial_proposal_commit(ciphersuite: Ciphersuite, provider: &impl OpenMl .process_message(provider, proposal_2.try_into_protocol_message().unwrap()) .unwrap(); match proposal_2.into_content() { - ProcessedMessageContent::ProposalMessage(p) => bob_group.store_pending_proposal(*p), + ProcessedMessageContent::ProposalMessage(p) => bob_group + .store_pending_proposal(provider.storage(), *p) + .unwrap(), _ => unreachable!(), } diff --git a/openmls/src/group/tests/test_encoding.rs b/openmls/src/group/tests/test_encoding.rs index 4495448c13..c48b89705d 100644 --- a/openmls/src/group/tests/test_encoding.rs +++ b/openmls/src/group/tests/test_encoding.rs @@ -1,15 +1,14 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::crypto::OpenMlsCrypto; use tls_codec::{Deserialize, Serialize}; use super::utils::*; use crate::{ binary_tree::LeafNodeIndex, framing::*, group::*, key_packages::*, messages::*, - schedule::psk::store::ResumptionPskStore, test_utils::*, *, + schedule::psk::store::ResumptionPskStore, test_utils::*, }; /// Creates a simple test setup for various encoding tests. -fn create_encoding_test_setup(provider: &impl OpenMlsProvider) -> TestSetup { +fn create_encoding_test_setup(provider: &impl crate::storage::OpenMlsProvider) -> TestSetup { // Create a test config for a single client supporting all possible // ciphersuites. let alice_config = TestClientConfig { @@ -51,8 +50,8 @@ fn create_encoding_test_setup(provider: &impl OpenMlsProvider) -> TestSetup { } /// This test tests encoding and decoding of application messages. -#[apply(providers)] -fn test_application_message_encoding(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_application_message_encoding(provider: &impl crate::storage::OpenMlsProvider) { let test_setup = create_encoding_test_setup(provider); let test_clients = test_setup.clients.borrow(); let alice = test_clients @@ -93,8 +92,8 @@ fn test_application_message_encoding(provider: &impl OpenMlsProvider) { } /// This test tests encoding and decoding of update proposals. -#[apply(providers)] -fn test_update_proposal_encoding(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_update_proposal_encoding(provider: &impl crate::storage::OpenMlsProvider) { let test_setup = create_encoding_test_setup(provider); let test_clients = test_setup.clients.borrow(); let alice = test_clients @@ -110,7 +109,7 @@ fn test_update_proposal_encoding(provider: &impl OpenMlsProvider) { .get(&group_state.ciphersuite()) .expect("An unexpected error occurred."); - let key_package_bundle = KeyPackageBundle::new( + let key_package_bundle = KeyPackageBundle::generate( provider, &credential_with_key_and_signer.signer, group_state.ciphersuite(), @@ -128,6 +127,7 @@ fn test_update_proposal_encoding(provider: &impl OpenMlsProvider) { update .set_membership_tag( provider.crypto(), + group_state.ciphersuite(), group_state.message_secrets().membership_key(), group_state.message_secrets().serialized_context(), ) @@ -146,8 +146,8 @@ fn test_update_proposal_encoding(provider: &impl OpenMlsProvider) { } /// This test tests encoding and decoding of add proposals. -#[apply(providers)] -fn test_add_proposal_encoding(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_add_proposal_encoding(provider: &impl crate::storage::OpenMlsProvider) { let test_setup = create_encoding_test_setup(provider); let test_clients = test_setup.clients.borrow(); let alice = test_clients @@ -163,7 +163,7 @@ fn test_add_proposal_encoding(provider: &impl OpenMlsProvider) { .get(&group_state.ciphersuite()) .expect("An unexpected error occurred."); - let key_package_bundle = KeyPackageBundle::new( + let key_package_bundle = KeyPackageBundle::generate( provider, &credential_with_key_and_signer.signer, group_state.ciphersuite(), @@ -181,6 +181,7 @@ fn test_add_proposal_encoding(provider: &impl OpenMlsProvider) { .into(); add.set_membership_tag( provider.crypto(), + group_state.ciphersuite(), group_state.message_secrets().membership_key(), group_state.message_secrets().serialized_context(), ) @@ -196,8 +197,8 @@ fn test_add_proposal_encoding(provider: &impl OpenMlsProvider) { } /// This test tests encoding and decoding of remove proposals. -#[apply(providers)] -fn test_remove_proposal_encoding(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_remove_proposal_encoding(provider: &impl crate::storage::OpenMlsProvider) { let test_setup = create_encoding_test_setup(provider); let test_clients = test_setup.clients.borrow(); let alice = test_clients @@ -224,6 +225,7 @@ fn test_remove_proposal_encoding(provider: &impl OpenMlsProvider) { remove .set_membership_tag( provider.crypto(), + group_state.ciphersuite(), group_state.message_secrets().membership_key(), group_state.message_secrets().serialized_context(), ) @@ -239,8 +241,8 @@ fn test_remove_proposal_encoding(provider: &impl OpenMlsProvider) { } /// This test tests encoding and decoding of commit messages. -#[apply(providers)] -fn test_commit_encoding(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_commit_encoding(provider: &impl crate::storage::OpenMlsProvider) { let test_setup = create_encoding_test_setup(provider); let test_clients = test_setup.clients.borrow(); let alice = test_clients @@ -256,7 +258,7 @@ fn test_commit_encoding(provider: &impl OpenMlsProvider) { .get(&group_state.ciphersuite()) .expect("An unexpected error occurred."); - let alice_key_package_bundle = KeyPackageBundle::new( + let alice_key_package_bundle = KeyPackageBundle::generate( provider, &alice_credential_with_key_and_signer.signer, group_state.ciphersuite(), @@ -324,6 +326,7 @@ fn test_commit_encoding(provider: &impl OpenMlsProvider) { commit .set_membership_tag( provider.crypto(), + group_state.ciphersuite(), group_state.message_secrets().membership_key(), group_state.message_secrets().serialized_context(), ) @@ -339,8 +342,8 @@ fn test_commit_encoding(provider: &impl OpenMlsProvider) { } } -#[apply(providers)] -fn test_welcome_message_encoding(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_welcome_message_encoding(provider: &impl crate::storage::OpenMlsProvider) { let test_setup = create_encoding_test_setup(provider); let test_clients = test_setup.clients.borrow(); let alice = test_clients diff --git a/openmls/src/group/tests/test_external_commit_validation.rs b/openmls/src/group/tests/test_external_commit_validation.rs index e86bf86b4d..c1b1e66634 100644 --- a/openmls/src/group/tests/test_external_commit_validation.rs +++ b/openmls/src/group/tests/test_external_commit_validation.rs @@ -2,10 +2,7 @@ //! commit messages as defined in //! https://github.com/openmls/openmls/wiki/Message-validation -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; -use rstest::rstest; -use rstest_reuse::apply; +use openmls_traits::prelude::*; use tls_codec::{Deserialize, Serialize}; use self::utils::*; @@ -32,8 +29,8 @@ use crate::{ }; // ValSem240: External Commit, inline Proposals: There MUST be at least one ExternalInit proposal. -#[apply(ciphersuites_and_providers)] -fn test_valsem240(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem240() { let ECValidationTestSetup { mut alice_group, bob_credential, @@ -85,12 +82,12 @@ fn test_valsem240(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, ProtocolMessage::from(public_message_commit_bad)) .expect_err("Could process message despite missing external init proposal."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ExternalCommitValidation( ExternalCommitValidationError::NoExternalInitProposals )) - ); + )); // Positive case alice_group @@ -99,8 +96,8 @@ fn test_valsem240(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem241: External Commit, inline Proposals: There MUST be at most one ExternalInit proposal. -#[apply(ciphersuites_and_providers)] -fn test_valsem241(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem241() { // Test with PublicMessage let ECValidationTestSetup { mut alice_group, @@ -148,12 +145,12 @@ fn test_valsem241(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, ProtocolMessage::from(public_message_commit_bad)) .expect_err("Could process message despite second ext. init proposal in commit."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ExternalCommitValidation( ExternalCommitValidationError::MultipleExternalInitProposals )) - ); + )); // Positive case alice_group @@ -162,8 +159,8 @@ fn test_valsem241(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem242: External Commit must only cover inline proposal in allowlist (ExternalInit, Remove, PreSharedKey) -#[apply(ciphersuites_and_providers)] -fn test_valsem242(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem242() { // Test with PublicMessage let ECValidationTestSetup { mut alice_group, @@ -184,12 +181,16 @@ fn test_valsem242(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); alice_group - .add_members(provider, &alice_credential.signer, &[bob_key_package]) + .add_members( + provider, + &alice_credential.signer, + &[bob_key_package.key_package().clone()], + ) .unwrap(); alice_group.merge_pending_commit(provider).unwrap(); let verifiable_group_info = alice_group - .export_group_info(provider.crypto(), &alice_credential.signer, true) + .export_group_info(provider, &alice_credential.signer, true) .unwrap() .into_verifiable_group_info() .unwrap(); @@ -237,7 +238,7 @@ fn test_valsem242(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); ProposalOrRef::Proposal(Proposal::Add(AddProposal { - key_package: charlie_key_package, + key_package: charlie_key_package.key_package().clone(), })) }; @@ -295,12 +296,12 @@ fn test_valsem242(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, public_message_commit_bad) .unwrap_err(); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ExternalCommitValidation( ExternalCommitValidationError::InvalidInlineProposals )) - ); + )); // Positive case alice_group @@ -310,8 +311,8 @@ fn test_valsem242(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem244: External Commit must not include any proposals by reference -#[apply(ciphersuites_and_providers)] -fn test_valsem244(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem244() { // Test with PublicMessage let ECValidationTestSetup { mut alice_group, @@ -338,7 +339,7 @@ fn test_valsem244(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); let add_proposal = Proposal::Add(AddProposal { - key_package: bob_key_package, + key_package: bob_key_package.key_package().clone(), }); let proposal_ref = @@ -373,12 +374,12 @@ fn test_valsem244(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, ProtocolMessage::from(public_message_commit_bad)) .unwrap_err(); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ExternalCommitValidation( ExternalCommitValidationError::ReferencedProposal )) - ); + )); // Positive case alice_group @@ -387,8 +388,8 @@ fn test_valsem244(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem245: External Commit: MUST contain a path. -#[apply(ciphersuites_and_providers)] -fn test_valsem245(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem245() { // Test with PublicMessage let ECValidationTestSetup { mut alice_group, @@ -431,10 +432,10 @@ fn test_valsem245(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, ProtocolMessage::from(public_message_commit_bad)) .expect_err("Could process message despite missing path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::NoPath) - ); + )); // Positive case alice_group @@ -443,8 +444,8 @@ fn test_valsem245(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem246: External Commit: The signature of the PublicMessage MUST be verified with the credential of the KeyPackage in the included `path`. -#[apply(ciphersuites_and_providers)] -fn test_valsem246(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem246() { // Test with PublicMessage let ECValidationTestSetup { mut alice_group, @@ -477,7 +478,7 @@ fn test_valsem246(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); if let Some(ref mut path) = commit_bad.path { - path.set_leaf_node(bob_new_key_package.leaf_node().clone()) + path.set_leaf_node(bob_new_key_package.key_package().leaf_node().clone()) } let mut public_message_commit_bad = public_message_commit.clone(); @@ -504,7 +505,7 @@ fn test_valsem246(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // This shows that signature verification fails if the signature is not done // using the credential in the path. - assert_eq!(err, ProcessMessageError::InvalidSignature); + assert!(matches!(err, ProcessMessageError::InvalidSignature)); // This shows that the credential in the original path key package is actually bob's credential. let commit = if let FramedContentBody::Commit(commit) = public_message_commit.content() { @@ -534,6 +535,7 @@ fn test_valsem246(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .serialized_context() .to_vec(), provider.crypto(), + ciphersuite, ) .unwrap(); let verification_result: Result = @@ -555,8 +557,8 @@ fn test_valsem246(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // External Commit should work when group use ciphertext WireFormat -#[apply(ciphersuites_and_providers)] -fn test_pure_ciphertest(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_pure_ciphertest() { // Test with PrivateMessage let ECValidationTestSetup { mut alice_group, @@ -569,7 +571,7 @@ fn test_pure_ciphertest(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Have Alice export everything that bob needs. let verifiable_group_info = alice_group - .export_group_info(provider.crypto(), &alice_credential.signer, true) + .export_group_info(provider, &alice_credential.signer, true) .unwrap() .into_verifiable_group_info() .unwrap(); @@ -598,13 +600,12 @@ fn test_pure_ciphertest(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide } mod utils { - use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; + use openmls_traits::types::Ciphersuite; use tls_codec::{Deserialize, Serialize}; use crate::{ framing::{MlsMessageIn, PublicMessage, Sender}, group::{ - config::CryptoConfig, tests::utils::{generate_credential_with_key, CredentialWithKeyAndSigner}, MlsGroup, MlsGroupCreateConfig, WireFormatPolicy, }, @@ -623,7 +624,7 @@ mod utils { pub(super) fn validation_test_setup( wire_format_policy: WireFormatPolicy, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> ECValidationTestSetup { // Generate credentials with keys let alice_credential = generate_credential_with_key( @@ -638,7 +639,7 @@ mod utils { // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // Alice creates a group @@ -654,7 +655,7 @@ mod utils { // Have Alice export everything that bob needs. let verifiable_group_info = alice_group - .export_group_info(provider.crypto(), &alice_credential.signer, false) + .export_group_info(provider, &alice_credential.signer, false) .unwrap() .into_verifiable_group_info() .unwrap(); diff --git a/openmls/src/group/tests/test_framing.rs b/openmls/src/group/tests/test_framing.rs index 3803157d87..9dd1e8b43a 100644 --- a/openmls/src/group/tests/test_framing.rs +++ b/openmls/src/group/tests/test_framing.rs @@ -1,11 +1,8 @@ use std::io::Write; use itertools::iproduct; -use openmls_traits::{ - crypto::OpenMlsCrypto, random::OpenMlsRand, types::Ciphersuite, OpenMlsProvider, -}; -use rstest::*; -use rstest_reuse::{self, *}; +use openmls_traits::random::OpenMlsRand; + use tls_codec::Serialize; use super::utils::*; @@ -21,12 +18,10 @@ use crate::{ secret_tree::SecretTree, secret_tree::SecretType, sender_ratchet::SenderRatchetConfiguration, }, - versions::ProtocolVersion, - *, }; -#[apply(providers)] -fn padding(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn padding(provider: &impl crate::storage::OpenMlsProvider) { // Create a test config for a single client supporting all possible // ciphersuites. let alice_config = TestClientConfig { @@ -98,8 +93,8 @@ fn padding(provider: &impl OpenMlsProvider) { } /// Check that PrivateMessageContent's padding field is verified to be all-zero. -#[apply(ciphersuites_and_providers)] -fn bad_padding(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn bad_padding() { let tests = { // { 2^i } ∪ { 2^i +- 1 } let padding_sizes = [ @@ -159,11 +154,8 @@ fn bad_padding(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .unwrap(); let sender_secret_tree = { - let sender_encryption_secret = EncryptionSecret::from_slice( - &encryption_secret_bytes[..], - ProtocolVersion::default(), - ciphersuite, - ); + let sender_encryption_secret = + EncryptionSecret::from_slice(&encryption_secret_bytes[..]); SecretTree::new( sender_encryption_secret, @@ -173,11 +165,8 @@ fn bad_padding(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { }; let receiver_secret_tree = { - let receiver_encryption_secret = EncryptionSecret::from_slice( - &encryption_secret_bytes[..], - ProtocolVersion::default(), - ciphersuite, - ); + let receiver_encryption_secret = + EncryptionSecret::from_slice(&encryption_secret_bytes[..]); SecretTree::new( receiver_encryption_secret, @@ -285,7 +274,7 @@ fn bad_padding(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Derive the sender data key from the key schedule using the ciphertext. let sender_data_key = message_secrets .sender_data_secret() - .derive_aead_key(provider.crypto(), &ciphertext) + .derive_aead_key(provider.crypto(), ciphersuite, &ciphertext) .unwrap(); // Derive initial nonce from the key schedule using the ciphertext. let sender_data_nonce = message_secrets diff --git a/openmls/src/group/tests/test_framing_validation.rs b/openmls/src/group/tests/test_framing_validation.rs index 80db1821a7..39e3d0f815 100644 --- a/openmls/src/group/tests/test_framing_validation.rs +++ b/openmls/src/group/tests/test_framing_validation.rs @@ -1,19 +1,10 @@ //! This module tests the validation of message framing as defined in //! https://openmls.tech/book/message_validation.html#semantic-validation-of-message-framing -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; +use openmls_traits::prelude::{openmls_types::Ciphersuite, *}; use tls_codec::{Deserialize, Serialize}; -use rstest::*; -use rstest_reuse::{self, *}; - -use crate::{ - binary_tree::LeafNodeIndex, - framing::*, - group::{config::CryptoConfig, *}, - key_packages::*, -}; +use crate::{binary_tree::LeafNodeIndex, framing::*, group::*, key_packages::*}; use super::utils::{ generate_credential_with_key, generate_key_package, CredentialWithKeyAndSigner, @@ -33,7 +24,7 @@ struct ValidationTestSetup { fn validation_test_setup( wire_format_policy: WireFormatPolicy, ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> ValidationTestSetup { let group_id = GroupId::from_slice(b"Test Group"); @@ -62,7 +53,7 @@ fn validation_test_setup( // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -80,7 +71,7 @@ fn validation_test_setup( .add_members( provider, &alice_credential.signer, - &[bob_key_package.clone()], + &[bob_key_package.key_package().clone()], ) .expect("Could not add member."); @@ -108,14 +99,14 @@ fn validation_test_setup( bob_group, _alice_credential: alice_credential, _bob_credential: bob_credential, - _alice_key_package: alice_key_package, - _bob_key_package: bob_key_package, + _alice_key_package: alice_key_package.key_package().clone(), + _bob_key_package: bob_key_package.key_package().clone(), } } // ValSem002 Group id -#[apply(ciphersuites_and_providers)] -fn test_valsem002(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem002() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -148,10 +139,10 @@ fn test_valsem002(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could parse message despite wrong group ID."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::WrongGroupId) - ); + )); // Positive case bob_group @@ -160,8 +151,8 @@ fn test_valsem002(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem003 Epoch -#[apply(ciphersuites_and_providers)] -fn test_valsem003(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem003() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -215,20 +206,20 @@ fn test_valsem003(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let err = bob_group .process_message(provider, plaintext.clone()) .expect_err("Could parse message despite wrong epoch."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::WrongEpoch) - ); + )); // Set the epoch too low plaintext.set_epoch(current_epoch.as_u64() - 1); let err = bob_group .process_message(provider, plaintext) .expect_err("Could parse message despite wrong epoch."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::WrongEpoch) - ); + )); // Positive case let processed_msg = bob_group @@ -247,15 +238,15 @@ fn test_valsem003(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Processing a commit twice should fail i.e. an epoch can only be used once in a commit message let process_twice = bob_group.process_message(provider, original_message); - assert_eq!( + assert!(matches!( process_twice.unwrap_err(), ProcessMessageError::ValidationError(ValidationError::WrongEpoch) - ); + )); } // ValSem004 Sender: Member: check the member exists -#[apply(ciphersuites_and_providers)] -fn test_valsem004(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem004() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -287,6 +278,7 @@ fn test_valsem004(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { plaintext .set_membership_tag( provider.crypto(), + ciphersuite, alice_group.group().message_secrets().membership_key(), alice_group.group().message_secrets().serialized_context(), ) @@ -298,10 +290,10 @@ fn test_valsem004(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could parse message despite wrong sender."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::UnknownMember) - ); + )); // Positive case bob_group @@ -310,8 +302,8 @@ fn test_valsem004(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem005 Application messages must use ciphertext -#[apply(ciphersuites_and_providers)] -fn test_valsem005(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem005() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -342,6 +334,7 @@ fn test_valsem005(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { plaintext .set_membership_tag( provider.crypto(), + ciphersuite, alice_group.group().message_secrets().membership_key(), alice_group.group().message_secrets().serialized_context(), ) @@ -353,10 +346,10 @@ fn test_valsem005(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could parse message despite unencrypted application message."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::UnencryptedApplicationMessage) - ); + )); // Positive case bob_group @@ -365,8 +358,8 @@ fn test_valsem005(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem006 Ciphertext: decryption needs to work -#[apply(ciphersuites_and_providers)] -fn test_valsem006(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem006() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -399,12 +392,12 @@ fn test_valsem006(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could parse message despite garbled ciphertext."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( MessageDecryptionError::AeadError )) - ); + )); // Positive case bob_group @@ -413,8 +406,8 @@ fn test_valsem006(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem007 Membership tag presence -#[apply(ciphersuites_and_providers)] -fn test_valsem007(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem007() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -447,10 +440,10 @@ fn test_valsem007(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could parse message despite missing membership tag."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::MissingMembershipTag) - ); + )); // Positive case bob_group @@ -459,8 +452,8 @@ fn test_valsem007(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem008 Membership tag verification -#[apply(ciphersuites_and_providers)] -fn test_valsem008(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem008() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -487,8 +480,13 @@ fn test_valsem008(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let original_message = plaintext.clone(); plaintext.set_membership_tag_test(MembershipTag( - Mac::new(provider.crypto(), &Secret::default(), &[1, 2, 3]) - .expect("Could not compute membership tag."), + Mac::new( + provider.crypto(), + ciphersuite, + &Secret::default(), + &[1, 2, 3], + ) + .expect("Could not compute membership tag."), )); let message_in = ProtocolMessage::from(plaintext); @@ -497,10 +495,10 @@ fn test_valsem008(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could process message despite wrong membership tag."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::InvalidMembershipTag) - ); + )); // Positive case bob_group @@ -509,8 +507,8 @@ fn test_valsem008(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem009 Confirmation tag presence -#[apply(ciphersuites_and_providers)] -fn test_valsem009(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem009() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -541,6 +539,7 @@ fn test_valsem009(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { plaintext .set_membership_tag( provider.crypto(), + ciphersuite, alice_group.group().message_secrets().membership_key(), alice_group.group().message_secrets().serialized_context(), ) @@ -552,10 +551,10 @@ fn test_valsem009(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could parse message despite missing confirmation tag."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::MissingConfirmationTag) - ); + )); // Positive case bob_group @@ -564,8 +563,8 @@ fn test_valsem009(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } // ValSem010 Signature verification -#[apply(ciphersuites_and_providers)] -fn test_valsem010(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem010() { let ValidationTestSetup { mut alice_group, mut bob_group, @@ -598,6 +597,7 @@ fn test_valsem010(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { plaintext .set_membership_tag( provider.crypto(), + ciphersuite, alice_group.group().message_secrets().membership_key(), alice_group.group().message_secrets().serialized_context(), ) @@ -609,7 +609,7 @@ fn test_valsem010(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, message_in) .expect_err("Could process message despite wrong signature."); - assert_eq!(err, ProcessMessageError::InvalidSignature); + assert!(matches!(err, ProcessMessageError::InvalidSignature)); // Positive case bob_group diff --git a/openmls/src/group/tests/test_group.rs b/openmls/src/group/tests/test_group.rs index 2c9983801e..61984eb7bd 100644 --- a/openmls/src/group/tests/test_group.rs +++ b/openmls/src/group/tests/test_group.rs @@ -1,22 +1,17 @@ use framing::mls_content_in::FramedContentBodyIn; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::key_store::OpenMlsKeyStore; use tests::utils::{generate_credential_with_key, generate_key_package}; use crate::{ - ciphersuite::signable::Verifiable, - framing::*, - group::{config::CryptoConfig, *}, - key_packages::*, - schedule::psk::store::ResumptionPskStore, - test_utils::*, - tree::sender_ratchet::SenderRatchetConfiguration, - treesync::node::leaf_node::TreeInfoTbs, - *, + ciphersuite::signable::Verifiable, framing::*, group::*, key_packages::*, + schedule::psk::store::ResumptionPskStore, test_utils::*, + tree::sender_ratchet::SenderRatchetConfiguration, treesync::node::leaf_node::TreeInfoTbs, *, }; -#[apply(ciphersuites_and_providers)] -fn create_commit_optional_path(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn create_commit_optional_path( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { let group_aad = b"Alice's test group"; // Framing parameters let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -41,7 +36,7 @@ fn create_commit_optional_path(ciphersuite: Ciphersuite, provider: &impl OpenMls // Alice creates a group let mut group_alice = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_keys.credential_with_key, ) .build(provider, &alice_credential_with_keys.signer) @@ -53,7 +48,7 @@ fn create_commit_optional_path(ciphersuite: Ciphersuite, provider: &impl OpenMls let bob_add_proposal = group_alice .create_add_proposal( framing_parameters, - bob_key_package.clone(), + bob_key_package.key_package().clone(), &alice_credential_with_keys.signer, ) .expect("Could not create proposal."); @@ -92,7 +87,7 @@ fn create_commit_optional_path(ciphersuite: Ciphersuite, provider: &impl OpenMls let bob_add_proposal = group_alice .create_add_proposal( framing_parameters, - bob_key_package.clone(), + bob_key_package.key_package().clone(), &alice_credential_with_keys.signer, ) .expect("Could not create proposal."); @@ -129,22 +124,13 @@ fn create_commit_optional_path(ciphersuite: Ciphersuite, provider: &impl OpenMls .expect("error merging pending commit"); let ratchet_tree = group_alice.public_group().export_ratchet_tree(); - let bob_private_key = provider - .key_store() - .read::(bob_key_package.hpke_init_key().as_slice()) - .unwrap(); - let bob_key_package_bundle = KeyPackageBundle { - key_package: bob_key_package, - private_key: bob_private_key, - }; - // Bob creates group from Welcome let group_bob = StagedCoreWelcome::new_from_welcome( create_commit_result .welcome_option .expect("An unexpected error occurred."), Some(ratchet_tree.into()), - bob_key_package_bundle, + bob_key_package, provider, ResumptionPskStore::new(1024), ) @@ -161,7 +147,7 @@ fn create_commit_optional_path(ciphersuite: Ciphersuite, provider: &impl OpenMls .own_leaf_node() .unwrap() .updated( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, TreeInfoTbs::Update(group_alice.own_tree_position()), provider, &alice_credential_with_keys.signer, @@ -208,8 +194,8 @@ fn create_commit_optional_path(ciphersuite: Ciphersuite, provider: &impl OpenMls .expect("error merging pending commit"); } -#[apply(ciphersuites_and_providers)] -fn basic_group_setup(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn basic_group_setup() { let group_aad = b"Alice's test group"; // Framing parameters let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -234,7 +220,7 @@ fn basic_group_setup(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) // Alice creates a group let group_alice = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_keys.credential_with_key, ) .build(provider, &alice_credential_with_keys.signer) @@ -244,7 +230,7 @@ fn basic_group_setup(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) let bob_add_proposal = group_alice .create_add_proposal( framing_parameters, - bob_key_package, + bob_key_package.key_package().clone(), &alice_credential_with_keys.signer, ) .expect("Could not create proposal."); @@ -284,8 +270,8 @@ fn basic_group_setup(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) /// - Charlie sends a message to the group /// - Charlie updates and commits /// - Charlie removes Bob -#[apply(ciphersuites_and_providers)] -fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn group_operations() { let group_aad = b"Alice's test group"; // Framing parameters let framing_parameters = FramingParameters::new(group_aad, WireFormat::PublicMessage); @@ -301,7 +287,7 @@ fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { generate_credential_with_key(b"Bob".to_vec(), ciphersuite.signature_algorithm(), provider); // Generate KeyPackages - let bob_key_package_bundle = KeyPackageBundle::new( + let bob_key_package_bundle = KeyPackageBundle::generate( provider, &bob_credential_with_keys.signer, ciphersuite, @@ -312,7 +298,7 @@ fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // === Alice creates a group === let mut group_alice = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, alice_credential_with_keys.credential_with_key.clone(), ) .build(provider, &alice_credential_with_keys.signer) @@ -426,7 +412,7 @@ fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .own_leaf_node() .unwrap() .updated( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, TreeInfoTbs::Update(group_bob.own_tree_position()), provider, &alice_credential_with_keys.signer, @@ -493,7 +479,7 @@ fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .own_leaf_node() .unwrap() .updated( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, TreeInfoTbs::Update(group_alice.own_tree_position()), provider, &alice_credential_with_keys.signer, @@ -556,7 +542,7 @@ fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .own_leaf_node() .unwrap() .updated( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, TreeInfoTbs::Update(group_bob.own_tree_position()), provider, &bob_credential_with_keys.signer, @@ -633,7 +619,7 @@ fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { provider, ); - let charlie_key_package_bundle = KeyPackageBundle::new( + let charlie_key_package_bundle = KeyPackageBundle::generate( provider, &charlie_credential_with_keys.signer, ciphersuite, @@ -786,7 +772,7 @@ fn group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .own_leaf_node() .unwrap() .updated( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, TreeInfoTbs::Update(group_charlie.own_tree_position()), provider, &charlie_credential_with_keys.signer, diff --git a/openmls/src/group/tests/test_past_secrets.rs b/openmls/src/group/tests/test_past_secrets.rs index 0eb9297ca0..f065919d83 100644 --- a/openmls/src/group/tests/test_past_secrets.rs +++ b/openmls/src/group/tests/test_past_secrets.rs @@ -1,19 +1,16 @@ //! This module contains tests regarding the use of [`MessageSecretsStore`] in [`MlsGroup`] -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; - -use rstest::*; -use rstest_reuse::{self, *}; - use super::utils::{generate_credential_with_key, generate_key_package}; use crate::{ framing::{MessageDecryptionError, MlsMessageIn, ProcessedMessageContent}, - group::{config::CryptoConfig, *}, + group::*, }; -#[apply(ciphersuites_and_providers)] -fn test_past_secrets_in_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_past_secrets_in_group( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Test this for different parameters for max_epochs in (0..10usize).step_by(2) { let group_id = GroupId::from_slice(b"Test Group"); @@ -42,7 +39,7 @@ fn test_past_secrets_in_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsP let mls_group_create_config = MlsGroupCreateConfig::builder() .max_past_epochs(max_epochs / 2) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -60,7 +57,7 @@ fn test_past_secrets_in_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsP .add_members( provider, &alice_credential_with_keys.signer, - &[bob_key_package], + &[bob_key_package.key_package().clone()], ) .expect("An unexpected error occurred."); @@ -131,12 +128,12 @@ fn test_past_secrets_in_group(ciphersuite: Ciphersuite, provider: &impl OpenMlsP let err = bob_group .process_message(provider, application_message.clone()) .expect_err("An unexpected error occurred."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::UnableToDecrypt( MessageDecryptionError::AeadError ),) - ); + )); } // The last messages should not fail diff --git a/openmls/src/group/tests/test_proposal_validation.rs b/openmls/src/group/tests/test_proposal_validation.rs index 42007e5d3c..904c886fcb 100644 --- a/openmls/src/group/tests/test_proposal_validation.rs +++ b/openmls/src/group/tests/test_proposal_validation.rs @@ -1,12 +1,11 @@ //! This module tests the validation of proposals as defined in //! https://openmls.tech/book/message_validation.html#semantic-validation-of-proposals-covered-by-a-commit -use openmls_rust_crypto::OpenMlsRustCrypto; +use crate::storage::OpenMlsProvider; use openmls_traits::{ - key_store::OpenMlsKeyStore, signatures::Signer, types::Ciphersuite, OpenMlsProvider, + prelude::{openmls_types::*, *}, + signatures::Signer, }; -use rstest::*; -use rstest_reuse::{self, *}; use tls_codec::{Deserialize, Serialize}; use super::utils::{ @@ -20,15 +19,19 @@ use crate::{ mls_content::FramedContentBody, validation::ProcessedMessageContent, AuthenticatedContent, FramedContent, MlsMessageIn, MlsMessageOut, ProtocolMessage, PublicMessage, Sender, }, - group::{config::CryptoConfig, *}, + group::*, key_packages::{errors::*, *}, messages::{ - proposals::{AddProposal, Proposal, ProposalOrRef, RemoveProposal, UpdateProposal}, + proposals::{ + AddProposal, CustomProposal, Proposal, ProposalOrRef, ProposalType, RemoveProposal, + UpdateProposal, + }, Commit, Welcome, }, - prelude::MlsMessageBodyIn, + prelude::{Capabilities, MlsMessageBodyIn}, schedule::PreSharedKeyId, - treesync::{errors::ApplyUpdatePathError, node::leaf_node::Capabilities}, + test_utils::frankenstein::FrankenKeyPackage, + treesync::errors::ApplyUpdatePathError, versions::ProtocolVersion, }; @@ -37,7 +40,7 @@ fn generate_credential_with_key_and_key_package( identity: Vec, ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider, -) -> (CredentialWithKeyAndSigner, KeyPackage) { +) -> (CredentialWithKeyAndSigner, KeyPackageBundle) { let credential_with_key_and_signer = generate_credential_with_key(identity, ciphersuite.signature_algorithm(), provider); @@ -52,17 +55,17 @@ fn generate_credential_with_key_and_key_package( } /// Helper function to create a group and try to add `members` to it. -fn create_group_with_members( +fn create_group_with_members( ciphersuite: Ciphersuite, alice_credential_with_key_and_signer: &CredentialWithKeyAndSigner, member_key_packages: &[KeyPackage], - provider: &impl OpenMlsProvider, -) -> Result<(MlsMessageIn, Welcome), AddMembersError> { + provider: &Provider, +) -> Result<(MlsMessageIn, Welcome), AddMembersError<::StorageError>> { let mut alice_group = MlsGroup::new_with_group_id( provider, &alice_credential_with_key_and_signer.signer, &MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(), GroupId::from_slice(b"Alice's Friends"), alice_credential_with_key_and_signer @@ -108,7 +111,7 @@ fn new_test_group( // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); ( @@ -148,7 +151,7 @@ fn validation_test_setup( .add_members( provider, &alice_credential_with_key_and_signer.signer, - &[bob_key_package], + &[bob_key_package.key_package().clone()], ) .unwrap(); @@ -184,6 +187,7 @@ fn validation_test_setup( fn insert_proposal_and_resign( provider: &impl OpenMlsProvider, + ciphersuite: Ciphersuite, mut proposal_or_ref: Vec, mut plaintext: PublicMessage, original_plaintext: &PublicMessage, @@ -206,6 +210,7 @@ fn insert_proposal_and_resign( original_plaintext, provider, signer, + ciphersuite, ); let membership_key = committer_group.group().message_secrets().membership_key(); @@ -213,6 +218,7 @@ fn insert_proposal_and_resign( signed_plaintext .set_membership_tag( provider.crypto(), + ciphersuite, membership_key, committer_group .group() @@ -236,8 +242,8 @@ enum KeyUniqueness { /// ValSem101: /// Add Proposal: /// Signature public key in proposals must be unique among proposals -#[apply(ciphersuites_and_providers)] -fn test_valsem101a(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem101a() { for bob_and_charlie_share_keys in [ KeyUniqueness::NegativeSameKey, KeyUniqueness::PositiveDifferentKey, @@ -293,19 +299,22 @@ fn test_valsem101a(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let res = create_group_with_members( ciphersuite, &alice_credential_with_keys, - &[bob_key_package, charlie_key_package], + &[ + bob_key_package.key_package().clone(), + charlie_key_package.key_package().clone(), + ], provider, ); match bob_and_charlie_share_keys { KeyUniqueness::NegativeSameKey => { let err = res.expect_err("was able to add users with the same signature key!"); - assert_eq!( + assert!(matches!( err, AddMembersError::CreateCommitError(CreateCommitError::ProposalValidationError( ProposalValidationError::DuplicateSignatureKey )) - ); + )); } KeyUniqueness::PositiveDifferentKey => { let _ = res.expect("failed to add users with different signature keypairs!"); @@ -335,7 +344,7 @@ fn test_valsem101a(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .add_members( provider, &alice_credential_with_key_and_signer.signer, - &[charlie_key_package], + &[charlie_key_package.key_package().clone()], ) .expect("Error creating self-update") .tls_serialize_detached() @@ -353,14 +362,11 @@ fn test_valsem101a(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // a different hpke key, different identity, but the same signature key. let dave_key_package = KeyPackage::builder() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, provider, &charlie_credential_with_key.signer, CredentialWithKey { - credential: BasicCredential::new(b"Dave".to_vec()).unwrap().into(), + credential: BasicCredential::new(b"Dave".to_vec()).into(), signature_key: charlie_credential_with_key .credential_with_key .signature_key, @@ -369,11 +375,12 @@ fn test_valsem101a(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .unwrap(); let second_add_proposal = Proposal::Add(AddProposal { - key_package: dave_key_package, + key_package: dave_key_package.key_package().clone(), }); let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![ProposalOrRef::Proposal(second_add_proposal)], plaintext, &original_plaintext, @@ -388,12 +395,12 @@ fn test_valsem101a(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ProposalValidationError( ProposalValidationError::DuplicateSignatureKey )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -413,8 +420,8 @@ fn test_valsem101a(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { /// ValSem102: /// Add Proposal: /// HPKE init key in proposals must be unique among proposals -#[apply(ciphersuites_and_providers)] -fn test_valsem102(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem102() { for bob_and_charlie_share_keys in [ KeyUniqueness::NegativeSameKey, KeyUniqueness::PositiveDifferentKey, @@ -430,20 +437,23 @@ fn test_valsem102(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { match bob_and_charlie_share_keys { KeyUniqueness::NegativeSameKey => { // Create a new key package for bob with the init key from Charlie. - bob_key_package = KeyPackage::new_from_init_key( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - provider, - &bob_credential_with_key.signer, - bob_credential_with_key.credential_with_key.clone(), - Extensions::empty(), - Capabilities::default(), - Extensions::empty(), - charlie_key_package.hpke_init_key().to_owned(), - ) - .unwrap(); + let encryption_private_key = bob_key_package.encryption_private_key().clone(); + let mut franken_key_package = FrankenKeyPackage::from(bob_key_package); + franken_key_package.init_key = charlie_key_package + .key_package() + .hpke_init_key() + .as_slice() + .to_owned() + .into(); + franken_key_package.resign(&bob_credential_with_key.signer); + bob_key_package = { + let kp = KeyPackage::from(franken_key_package.clone()); + KeyPackageBundle::new( + kp, + charlie_key_package.init_private_key().clone(), + encryption_private_key.into(), + ) + }; } KeyUniqueness::PositiveDifferentKey => { // don't need to do anything since the keys are already @@ -456,19 +466,22 @@ fn test_valsem102(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let res = create_group_with_members( ciphersuite, &alice_credential_with_key, - &[bob_key_package, charlie_key_package], + &[ + bob_key_package.key_package().clone(), + charlie_key_package.key_package().clone(), + ], provider, ); match bob_and_charlie_share_keys { KeyUniqueness::NegativeSameKey => { let err = res.expect_err("was able to add users with the same HPKE init key!"); - assert_eq!( + assert!(matches!( err, AddMembersError::CreateCommitError(CreateCommitError::ProposalValidationError( ProposalValidationError::DuplicateInitKey )) - ); + )); } KeyUniqueness::PositiveDifferentKey => { let _ = res.expect("failed to add users with different HPKE init keys!"); @@ -498,7 +511,7 @@ fn test_valsem102(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .add_members( provider, &alice_credential_with_key_and_signer.signer, - &[charlie_key_package.clone()], + &[charlie_key_package.key_package().clone()], ) .expect("Error creating self-update") .tls_serialize_detached() @@ -515,22 +528,26 @@ fn test_valsem102(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Now let's create a second proposal and insert it into the commit. We want // a different signature key, different identity, but the same hpke init // key. - let (dave_credential_with_key_and_signer, mut dave_key_package) = + let (dave_credential_with_key_and_signer, dave_key_package) = generate_credential_with_key_and_key_package("Dave".into(), ciphersuite, provider); // Change the init key and re-sign. - dave_key_package.set_init_key(charlie_key_package.hpke_init_key().clone()); - let dave_key_package = dave_key_package.resign( - &dave_credential_with_key_and_signer.signer, - dave_credential_with_key_and_signer - .credential_with_key - .clone(), - ); + let mut franken_key_package = FrankenKeyPackage::from(dave_key_package); + franken_key_package.init_key = charlie_key_package + .key_package() + .hpke_init_key() + .as_slice() + .to_owned() + .into(); + franken_key_package.resign(&dave_credential_with_key_and_signer.signer); + let dave_key_package = KeyPackage::from(franken_key_package); + let second_add_proposal = Proposal::Add(AddProposal { key_package: dave_key_package, }); let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![ProposalOrRef::Proposal(second_add_proposal)], plaintext, &original_plaintext, @@ -545,12 +562,12 @@ fn test_valsem102(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process message despite modified encryption key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ProposalValidationError( ProposalValidationError::DuplicateInitKey )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -571,8 +588,8 @@ fn test_valsem102(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { /// Add Proposal: /// Signature public key in proposals must be unique among existing group /// members -#[apply(ciphersuites_and_providers)] -fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem101b() { for alice_and_bob_share_keys in [ KeyUniqueness::NegativeSameKey, KeyUniqueness::PositiveDifferentKey, @@ -604,7 +621,7 @@ fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } .map(|(name, keypair)| CredentialWithKeyAndSigner { credential_with_key: CredentialWithKey { - credential: BasicCredential::new(name.into()).unwrap().into(), + credential: BasicCredential::new(name.into()).into(), signature_key: keypair.to_public_vec().into(), }, signer: keypair, @@ -628,7 +645,7 @@ fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { provider, &alice_credential_with_key.signer, &MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(), GroupId::from_slice(b"Alice's Friends"), alice_credential_with_key.credential_with_key.clone(), @@ -641,22 +658,28 @@ fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .add_members( provider, &alice_credential_with_key.signer, - &[bob_key_package, target_key_package], + &[ + bob_key_package.key_package().clone(), + target_key_package.key_package().clone(), + ], ) .expect_err("was able to add user with same signature key as a group member!"); - assert_eq!( + assert!(matches!( err, AddMembersError::CreateCommitError(CreateCommitError::ProposalValidationError( ProposalValidationError::DuplicateSignatureKey )) - ); + )); } KeyUniqueness::PositiveDifferentKey => { alice_group .add_members( provider, &alice_credential_with_key.signer, - &[bob_key_package, target_key_package], + &[ + bob_key_package.key_package().clone(), + target_key_package.key_package().clone(), + ], ) .expect("failed to add user with different signature keypair!"); } @@ -665,17 +688,14 @@ fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .add_members( provider, &alice_credential_with_key.signer, - &[bob_key_package.clone()], + &[bob_key_package.key_package().clone()], ) .unwrap(); alice_group.merge_pending_commit(provider).unwrap(); let bob_index = alice_group .members() .find_map(|member| { - let identity = - VLBytes::tls_deserialize_exact(member.credential.serialized_content()) - .unwrap(); - if identity.as_slice() == b"Bob" { + if member.credential.serialized_content() == b"Bob" { Some(member.index) } else { None @@ -686,7 +706,7 @@ fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .propose_remove_member(provider, &alice_credential_with_key.signer, bob_index) .unwrap(); alice_group - .add_members(provider, &alice_credential_with_key.signer, &[target_key_package]) + .add_members(provider, &alice_credential_with_key.signer, &[target_key_package.key_package().clone()]) .expect( "failed to add a user with the same identity as someone in the group (with a remove proposal)!", ); @@ -742,10 +762,7 @@ fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let dave_key_package = KeyPackage::builder() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite provider, &dave_credential_bundle, ) @@ -829,7 +846,7 @@ fn test_valsem101b(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { /// Add Proposal: Encryption key must be unique in the tree /// ValSem104: /// Add Proposal: Init key and encryption key must be different -#[apply(ciphersuites_and_providers)] +#[openmls_test::openmls_test] fn test_valsem103_valsem104(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { for alice_and_bob_share_keys in [ KeyUniqueness::NegativeSameKey, @@ -838,33 +855,24 @@ fn test_valsem103_valsem104(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro // 0. Initialize Alice and Bob let (alice_credential_with_key, _) = generate_credential_with_key_and_key_package("Alice".into(), ciphersuite, provider); - let (bob_credential_with_key, mut bob_key_package) = + let (bob_credential_with_key, bob_key_package) = generate_credential_with_key_and_key_package("Bob".into(), ciphersuite, provider); - match alice_and_bob_share_keys { + let bob_key_package = match alice_and_bob_share_keys { KeyUniqueness::NegativeSameKey => { // Create a new key package for bob using the encryption key as init key. - bob_key_package = bob_key_package - .clone() - .into_with_init_key( - CryptoConfig::with_default_version(ciphersuite), - &bob_credential_with_key.signer, - InitKey::from( - bob_key_package - .leaf_node() - .encryption_key() - .as_slice() - .to_vec(), - ), - ) - .unwrap(); + let mut franken_key_package = FrankenKeyPackage::from(bob_key_package); + franken_key_package.init_key = franken_key_package.leaf_node.encryption_key.clone(); + franken_key_package.resign(&bob_credential_with_key.signer); + KeyPackage::from(franken_key_package) } KeyUniqueness::PositiveDifferentKey => { // don't need to do anything since all keys are already // different. + bob_key_package.key_package().clone() } KeyUniqueness::PositiveSameKeyWithRemove => unreachable!(), - } + }; // 1. Alice creates a group and tries to add Bob to it let res = create_group_with_members( @@ -878,12 +886,12 @@ fn test_valsem103_valsem104(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro KeyUniqueness::NegativeSameKey => { let err = res.expect_err("was able to add user with colliding init and encryption keys!"); - assert_eq!( + assert!(matches!( err, AddMembersError::CreateCommitError(CreateCommitError::ProposalValidationError( ProposalValidationError::InitEncryptionKeyCollision )) - ); + )); } KeyUniqueness::PositiveDifferentKey => { let _ = res.expect("failed to add user with different HPKE init key!"); @@ -928,24 +936,14 @@ fn test_valsem103_valsem104(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro .clone(); // Generate fresh key material for Dave. - let (dave_credential_with_key, _) = + let (dave_credential_with_key, dave_key_package) = generate_credential_with_key_and_key_package("Dave".into(), ciphersuite, provider); // Insert Bob's public key into Dave's KPB and resign. - let dave_key_package = KeyPackage::new_from_encryption_key( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - provider, - &dave_credential_with_key.signer, - dave_credential_with_key.credential_with_key.clone(), - Extensions::empty(), - Capabilities::default(), - Extensions::empty(), - bob_encryption_key, - ) - .unwrap(); + let mut franken_key_package = FrankenKeyPackage::from(dave_key_package); + franken_key_package.leaf_node.encryption_key = bob_encryption_key.as_slice().to_owned().into(); + franken_key_package.resign(&dave_credential_with_key.signer); + let dave_key_package = KeyPackage::from(franken_key_package); // Use the resulting KP to create an Add proposal. let add_proposal = Proposal::Add(AddProposal { @@ -956,6 +954,7 @@ fn test_valsem103_valsem104(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro // encryption key. let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![ProposalOrRef::Proposal(add_proposal)], plaintext, &original_plaintext, @@ -970,12 +969,12 @@ fn test_valsem103_valsem104(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro .process_message(provider, update_message_in) .expect_err("Could process message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ProposalValidationError( ProposalValidationError::DuplicateEncryptionKey )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -1014,8 +1013,8 @@ enum ProposalInclusion { /// ValSem105: /// Add Proposal: /// Ciphersuite & protocol version must match the group -#[apply(ciphersuites_and_providers)] -fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem105() { let _ = pretty_env_logger::try_init(); // Ciphersuite & protocol version validation includes checking the @@ -1054,9 +1053,11 @@ fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .. } = validation_test_setup(PURE_PLAINTEXT_WIRE_FORMAT_POLICY, ciphersuite, provider); - let (charlie_credential_with_key, mut charlie_key_package) = + let (charlie_credential_with_key, charlie_key_package) = generate_credential_with_key_and_key_package("Charlie".into(), ciphersuite, provider); + let mut franken_key_package = FrankenKeyPackage::from(charlie_key_package.clone()); + let kpi = KeyPackageIn::from(charlie_key_package.clone()); kpi.validate(provider.crypto(), ProtocolVersion::Mls10) .unwrap(); @@ -1067,79 +1068,63 @@ fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 } _ => Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, - }; + } as u16; match key_package_version { KeyPackageTestVersion::WrongCiphersuite => { - charlie_key_package.set_ciphersuite(wrong_ciphersuite) + franken_key_package.ciphersuite = wrong_ciphersuite; } KeyPackageTestVersion::WrongVersion => { - charlie_key_package.set_version(ProtocolVersion::Mls10Draft11); + franken_key_package.protocol_version = 999; } KeyPackageTestVersion::UnsupportedVersion => { - let mut new_leaf_node = charlie_key_package.leaf_node().clone(); - new_leaf_node - .capabilities_mut() - .set_versions(vec![ProtocolVersion::Mls10Draft11]); - charlie_key_package.set_leaf_node(new_leaf_node); + franken_key_package.leaf_node.capabilities.versions = vec![999]; } KeyPackageTestVersion::UnsupportedCiphersuite => { - let mut new_leaf_node = charlie_key_package.leaf_node().clone(); - new_leaf_node.capabilities_mut().set_ciphersuites(vec![ - Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(), - ]); - charlie_key_package.set_leaf_node(new_leaf_node); + franken_key_package.leaf_node.capabilities.ciphersuites = + vec![Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into()]; } KeyPackageTestVersion::ValidTestCase => (), }; - let test_kp = charlie_key_package.resign( - &charlie_credential_with_key.signer, - charlie_credential_with_key.credential_with_key.clone(), - ); + franken_key_package.resign(&charlie_credential_with_key.signer); + let test_kp = KeyPackage::from(franken_key_package); let test_kp_2 = { - let (charlie_credential_with_key, mut charlie_key_package) = + let (charlie_credential_with_key, charlie_key_package) = generate_credential_with_key_and_key_package( "Charlie".into(), ciphersuite, provider, ); + let mut franken_key_package = FrankenKeyPackage::from(charlie_key_package.clone()); + // Let's just pick a ciphersuite that's not the one we're testing right now. let wrong_ciphersuite = match ciphersuite { Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 => { Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 } _ => Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, - }; + } as u16; match key_package_version { KeyPackageTestVersion::WrongCiphersuite => { - charlie_key_package.set_ciphersuite(wrong_ciphersuite) + franken_key_package.ciphersuite = wrong_ciphersuite; } KeyPackageTestVersion::WrongVersion => { - charlie_key_package.set_version(ProtocolVersion::Mls10Draft11); + franken_key_package.protocol_version = 999; } KeyPackageTestVersion::UnsupportedVersion => { - let mut new_leaf_node = charlie_key_package.leaf_node().clone(); - new_leaf_node - .capabilities_mut() - .set_versions(vec![ProtocolVersion::Mls10Draft11]); - charlie_key_package.set_leaf_node(new_leaf_node); + franken_key_package.leaf_node.capabilities.versions = vec![999]; } KeyPackageTestVersion::UnsupportedCiphersuite => { - let mut new_leaf_node = charlie_key_package.leaf_node().clone(); - new_leaf_node.capabilities_mut().set_ciphersuites(vec![ - Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into(), - ]); - charlie_key_package.set_leaf_node(new_leaf_node); + franken_key_package.leaf_node.capabilities.ciphersuites = + vec![Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448.into()]; } KeyPackageTestVersion::ValidTestCase => (), }; - charlie_key_package.resign( - &charlie_credential_with_key.signer, - charlie_credential_with_key.credential_with_key.clone(), - ) + franken_key_package.resign(&charlie_credential_with_key.signer); + KeyPackage::from(franken_key_package) }; // Try to have Alice commit an Add with the test KeyPackage. @@ -1190,10 +1175,14 @@ fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { } }; // Reset alice's group state for the next test case. - alice_group.clear_pending_commit(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); } // Now we create a valid commit and add the proposal afterwards. Once by value, once by reference. - alice_group.clear_pending_proposals(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); // Create the Commit. let serialized_update = alice_group @@ -1226,6 +1215,7 @@ fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Artificially add the proposal. let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![proposal_or_ref], plaintext.clone(), &original_plaintext, @@ -1238,15 +1228,18 @@ fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // If we're including by reference, we have to sneak the proposal // into Bob's queue. if matches!(proposal_inclusion, ProposalInclusion::ByReference) { - bob_group.store_pending_proposal( - QueuedProposal::from_proposal_and_sender( - ciphersuite, - provider.crypto(), - add_proposal.clone(), - &Sender::build_member(alice_group.own_leaf_index()), + bob_group + .store_pending_proposal( + provider.storage(), + QueuedProposal::from_proposal_and_sender( + ciphersuite, + provider.crypto(), + add_proposal.clone(), + &Sender::build_member(alice_group.own_leaf_index()), + ) + .unwrap(), ) - .unwrap(), - ) + .unwrap() } // Have bob process the resulting plaintext @@ -1260,71 +1253,94 @@ fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // be longer due to the included Add proposal. Since we added // the Add artificially, we thus have a path length mismatch. KeyPackageTestVersion::ValidTestCase => { - let expected_error = ProcessMessageError::InvalidCommit( - StageCommitError::UpdatePathError(ApplyUpdatePathError::PathLengthMismatch), - ); - assert_eq!(err, expected_error); + assert!(matches!( + err, + ProcessMessageError::<::StorageError>::InvalidCommit( + StageCommitError::UpdatePathError( + ApplyUpdatePathError::PathLengthMismatch, + ), + ) + )); } KeyPackageTestVersion::WrongCiphersuite => { // In this case we need to differentiate, since we // manipulated the ciphersuite. The signature algorithm can // also have a mismatch and therefore invalidate the // signature, and/or the ciphersuite doesn't match. - let expected_error_1 = ProcessMessageError::InvalidCommit( - StageCommitError::ProposalValidationError( - ProposalValidationError::InvalidAddProposalCiphersuiteOrVersion, - ), - ); - let expected_error_2 = ProcessMessageError::ValidationError( - ValidationError::KeyPackageVerifyError( - KeyPackageVerifyError::InvalidLeafNodeSignature, - ), - ); - let expected_error_3 = ProcessMessageError::ValidationError( - ValidationError::InvalidAddProposalCiphersuite, - ); + assert!( - err == expected_error_1 - || err == expected_error_2 - || err == expected_error_3 + matches!( + err, + ProcessMessageError::<::StorageError>::InvalidCommit( + StageCommitError::ProposalValidationError( + ProposalValidationError::InvalidAddProposalCiphersuiteOrVersion, + ), + ) + ) || matches!( + err, + ProcessMessageError::<::StorageError>::ValidationError( + ValidationError::KeyPackageVerifyError( + KeyPackageVerifyError::InvalidLeafNodeSignature, + ), + ) + ) || matches!( + err, + ProcessMessageError::<::StorageError>::ValidationError( + ValidationError::InvalidAddProposalCiphersuite, + ) + ) ); } KeyPackageTestVersion::WrongVersion => { // We need to distinguish between the two cases where the // version is wrong, depending on whether it's a proposal by // value or by reference. - let expected_error_1 = ProcessMessageError::InvalidCommit( - StageCommitError::ProposalValidationError( - ProposalValidationError::InvalidAddProposalCiphersuiteOrVersion, - ), - ); - let expected_error_2 = ProcessMessageError::ValidationError( - ValidationError::KeyPackageVerifyError( - KeyPackageVerifyError::InvalidProtocolVersion, - ), + assert!( + matches!( + err, + ProcessMessageError::<::StorageError>::InvalidCommit( + StageCommitError::ProposalValidationError( + ProposalValidationError::InvalidAddProposalCiphersuiteOrVersion, + ), + ) + ) || matches!( + err, + ProcessMessageError::<::StorageError>::ValidationError( + ValidationError::KeyPackageVerifyError( + KeyPackageVerifyError::InvalidProtocolVersion, + ), + ) + ) ); - assert!(err == expected_error_1 || err == expected_error_2); } KeyPackageTestVersion::UnsupportedVersion => { - let expected_error_1 = ProcessMessageError::ValidationError( - ValidationError::KeyPackageVerifyError( - KeyPackageVerifyError::InvalidProtocolVersion, - ), - ); - let expected_error_2 = ProcessMessageError::InvalidCommit( - StageCommitError::ProposalValidationError( - ProposalValidationError::InsufficientCapabilities, - ), + assert!( + matches!( + err, + ProcessMessageError::<::StorageError>::ValidationError( + ValidationError::KeyPackageVerifyError( + KeyPackageVerifyError::InvalidProtocolVersion, + ), + ) + ) || matches!( + err, + ProcessMessageError::<::StorageError>::InvalidCommit( + StageCommitError::ProposalValidationError( + ProposalValidationError::InsufficientCapabilities, + ), + ) + ) ); - assert!(err == expected_error_1 || err == expected_error_2); } KeyPackageTestVersion::UnsupportedCiphersuite => { - let expected_error = ProcessMessageError::InvalidCommit( - StageCommitError::ProposalValidationError( - ProposalValidationError::InsufficientCapabilities, - ), - ); - assert_eq!(err, expected_error); + assert!(matches!( + err, + ProcessMessageError::<::StorageError>::InvalidCommit( + StageCommitError::ProposalValidationError( + ProposalValidationError::InsufficientCapabilities, + ), + ) + )); } }; @@ -1343,15 +1359,17 @@ fn test_valsem105(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .unwrap(); } - alice_group.clear_pending_commit(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); } } /// ValSem107: /// Remove Proposal: /// Removed member must be unique among proposals -#[apply(ciphersuites_and_providers)] -fn test_valsem107(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem107() { // Helper function to unwrap a commit with a single proposal from an mls message. fn unwrap_specific_commit(commit_ref_remove: MlsMessageOut) -> Commit { let serialized_message = commit_ref_remove.tls_serialize_detached().unwrap(); @@ -1421,7 +1439,9 @@ fn test_valsem107(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("error while trying to commit to colliding remove proposals"); // Clear commit to try another way of committing two identical removes. - alice_group.clear_pending_commit(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); // Now let's verify that both commits only contain one proposal. let (commit_inline_remove, _welcome, _group_info) = alice_group @@ -1497,8 +1517,8 @@ fn test_valsem107(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { /// ValSem108 /// Remove Proposal: /// Removed member must be an existing group member -#[apply(ciphersuites_and_providers)] -fn test_valsem108(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem108() { // Before we can test creation or reception of (invalid) proposals, we set // up a new group with Alice and Bob. let ProposalValidationTestSetup { @@ -1528,7 +1548,9 @@ fn test_valsem108(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .commit_to_pending_proposals(provider, &alice_credential_with_key_and_signer.signer) .expect("No error while committing empty proposals"); // FIXME: #1098 This shouldn't be necessary. Something is broken in the state logic. - alice_group.clear_pending_commit(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); // Creating the proposal should fail already because the member is not known. let err = alice_group @@ -1539,11 +1561,15 @@ fn test_valsem108(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ) .expect_err("Successfully created remove proposal for unknown member"); - assert_eq!(err, ProposeRemoveMemberError::UnknownMember); + assert!(matches!(err, ProposeRemoveMemberError::UnknownMember)); // Clear commit to try another way of committing a remove of a non-member. - alice_group.clear_pending_commit(); - alice_group.clear_pending_proposals(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); let err = alice_group .remove_members( @@ -1553,12 +1579,12 @@ fn test_valsem108(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ) .expect_err("no error while trying to remove non-group-member"); - assert_eq!( + assert!(matches!( err, RemoveMembersError::CreateCommitError(CreateCommitError::ProposalValidationError( ProposalValidationError::UnknownMemberRemoval )) - ); + )); // We now have alice create a commit. Then we artificially add an invalid // remove proposal targeting a member that is not part of the group. @@ -1587,6 +1613,7 @@ fn test_valsem108(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // group. let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![ProposalOrRef::Proposal(remove_proposal)], plaintext, &original_plaintext, @@ -1601,12 +1628,12 @@ fn test_valsem108(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ProposalValidationError( ProposalValidationError::UnknownMemberRemoval )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -1626,8 +1653,8 @@ fn test_valsem108(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { /// ValSem110 /// Update Proposal: /// Encryption key must be unique among existing members -#[apply(ciphersuites_and_providers)] -fn test_valsem110(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem110() { // Before we can test creation or reception of (invalid) proposals, we set // up a new group with Alice and Bob. let ProposalValidationTestSetup { @@ -1688,7 +1715,9 @@ fn test_valsem110(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("error processing proposal") .into_content() { - alice_group.store_pending_proposal(*proposal) + alice_group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap() } else { panic!("Unexpected message type"); }; @@ -1698,19 +1727,23 @@ fn test_valsem110(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .commit_to_pending_proposals(provider, &alice_credential_with_key_and_signer.signer) .expect_err("no error while trying to commit to update proposal with differing identity"); - assert_eq!( + assert!(matches!( err, CommitToPendingProposalsError::CreateCommitError( CreateCommitError::ProposalValidationError( ProposalValidationError::DuplicateEncryptionKey ) ) - ); + )); // Clear commit to see if Bob will process a commit containing two colliding // keys. - alice_group.clear_pending_commit(); - alice_group.clear_pending_proposals(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); // We now have Alice create a commit. Then we artificially add an // update proposal with a colliding encryption key. @@ -1736,6 +1769,7 @@ fn test_valsem110(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Artificially add the proposal. let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![ProposalOrRef::Proposal(update_proposal)], plaintext, &original_plaintext, @@ -1749,30 +1783,28 @@ fn test_valsem110(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // process the commit. let leaf_keypair = alice_group .group() - .read_epoch_keypairs(provider.key_store()) + .read_epoch_keypairs(provider.storage()) .into_iter() .find(|keypair| keypair.public_key() == &alice_encryption_key) .unwrap(); - leaf_keypair - .write_to_key_store(provider.key_store()) - .unwrap(); + leaf_keypair.write(provider.storage()).unwrap(); // Have bob process the resulting plaintext let err = bob_group .process_message(provider, update_message_in) .expect_err("Could process message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::CommitterIncludedOwnUpdate) - ); + )); } /// ValSem111 /// Update Proposal: /// The sender of a full Commit must not include own update proposals -#[apply(ciphersuites_and_providers)] -fn test_valsem111(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem111() { // Before we can test creation or reception of (invalid) proposals, we set // up a new group with Alice and Bob. let ProposalValidationTestSetup { @@ -1799,7 +1831,7 @@ fn test_valsem111(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); let update_proposal = Proposal::Update(UpdateProposal { - leaf_node: update_kp.leaf_node().clone(), + leaf_node: update_kp.key_package().leaf_node().clone(), }); // We now have Alice create a commit. That commit should not contain any @@ -1842,6 +1874,7 @@ fn test_valsem111(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Let's insert the proposal into the commit. let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![ProposalOrRef::Proposal(update_proposal.clone())], plaintext, &original_plaintext, @@ -1856,28 +1889,33 @@ fn test_valsem111(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::CommitterIncludedOwnUpdate) - ); + )); // Now we insert the proposal into Bob's proposal store so we can include it // in the commit by reference. - bob_group.store_pending_proposal( - QueuedProposal::from_proposal_and_sender( - ciphersuite, - provider.crypto(), - update_proposal.clone(), - &Sender::build_member(alice_group.own_leaf_index()), + bob_group + .store_pending_proposal( + provider.storage(), + QueuedProposal::from_proposal_and_sender( + ciphersuite, + provider.crypto(), + update_proposal.clone(), + &Sender::build_member(alice_group.own_leaf_index()), + ) + .expect("error creating queued proposal"), ) - .expect("error creating queued proposal"), - ); + .expect("error writing to storage"); // Now we can have Alice create a new commit and insert the proposal by // reference. // Wipe any pending commit first. - alice_group.clear_pending_commit(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); let commit = alice_group .self_update(provider, &alice_credential_with_key_and_signer.signer) @@ -1899,6 +1937,7 @@ fn test_valsem111(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Artificially add the proposal. let verifiable_plaintext = insert_proposal_and_resign( provider, + ciphersuite, vec![ProposalOrRef::Reference( ProposalRef::from_raw_proposal(ciphersuite, provider.crypto(), &update_proposal) .expect("error creating hash reference"), @@ -1916,12 +1955,12 @@ fn test_valsem111(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could process message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::InvalidCommit(StageCommitError::ProposalValidationError( ProposalValidationError::CommitterIncludedOwnUpdate )) - ); + )); let original_update_plaintext = MlsMessageIn::tls_deserialize(&mut serialized_update.as_slice()) @@ -1941,8 +1980,8 @@ fn test_valsem111(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { /// ValSem112 /// Update Proposal: /// The sender of a standalone update proposal must be of type member -#[apply(ciphersuites_and_providers)] -fn test_valsem112(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem112() { // Before we can test creation or reception of (invalid) proposals, we set // up a new group with Alice and Bob. let ProposalValidationTestSetup { @@ -1987,10 +2026,10 @@ fn test_valsem112(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .process_message(provider, update_message_in) .expect_err("Could parse message despite modified public key in path."); - assert_eq!( + assert!(matches!( err, ProcessMessageError::ValidationError(ValidationError::NotACommit) - ); + )); // We can't test with sender type External, since that currently panics // with `unimplemented`. @@ -2002,10 +2041,160 @@ fn test_valsem112(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("Unexpected error."); } +/// ValSem113 +/// All Proposals: The proposal type must be supported by all members of the +/// group +#[openmls_test::openmls_test] +fn valsem113() { + #[derive(Debug)] + enum TestMode { + Unsupported, + Supported, + } + + let custom_proposal_type = 0xFFFF; + let custom_proposal_payload = vec![0, 1, 2, 3]; + let custom_proposal = + CustomProposal::new(custom_proposal_type, custom_proposal_payload.clone()); + + let capabilities_with_support = Capabilities::new( + None, + None, + None, + Some(&[ProposalType::Custom(custom_proposal_type)]), + None, + ); + + let mls_group_config = MlsGroupJoinConfig::default(); + + // Generate credentials with keys + let alice_credential_with_keys = + generate_credential_with_key(b"alice".into(), ciphersuite.signature_algorithm(), provider); + + let bob_credential_with_keys = + generate_credential_with_key(b"bob".into(), ciphersuite.signature_algorithm(), provider); + + for test_mode in [TestMode::Unsupported, TestMode::Supported] { + // Before we can test creation or reception of a commit with an + // (unsupported) custom proposal, we set up a new group with Alice and Bob. + + // Generate Bob's KeyPackage depending on the test mode + let bob_key_package = if matches!(test_mode, TestMode::Unsupported) { + KeyPackageBuilder::new() + } else { + KeyPackageBuilder::new().leaf_node_capabilities(capabilities_with_support.clone()) + } + .build( + ciphersuite, + provider, + &bob_credential_with_keys.signer, + bob_credential_with_keys.credential_with_key.clone(), + ) + .unwrap(); + + // Create a group with the defined capabilities + let mut alice_group = if matches!(test_mode, TestMode::Unsupported) { + MlsGroup::builder() + } else { + MlsGroup::builder().with_capabilities(capabilities_with_support.clone()) + } + .ciphersuite(ciphersuite) + .build( + provider, + &alice_credential_with_keys.signer, + alice_credential_with_keys.credential_with_key.clone(), + ) + .unwrap(); + + // Add Bob + let (_mls_message, welcome, _group_info) = alice_group + .add_members( + provider, + &alice_credential_with_keys.signer, + &[bob_key_package.key_package().clone()], + ) + .unwrap(); + + alice_group.merge_pending_commit(provider).unwrap(); + + let staged_welcome = StagedWelcome::new_from_welcome( + provider, + &mls_group_config, + welcome.into_welcome().unwrap(), + Some(alice_group.export_ratchet_tree().into()), + ) + .unwrap(); + + let mut bob_group = staged_welcome.into_group(provider).unwrap(); + + // Have alice create a commit with a custom proposal. + let (custom_proposal_message, _proposal_ref) = alice_group + .propose_custom_proposal_by_reference( + provider, + &alice_credential_with_keys.signer, + custom_proposal.clone(), + ) + .unwrap(); + + let result = + alice_group.commit_to_pending_proposals(provider, &alice_credential_with_keys.signer); + + // If the proposal is unsupported, we expect an error here. + let (commit, _, _) = if matches!(test_mode, TestMode::Unsupported) { + assert!(matches!( + result, + Err(CommitToPendingProposalsError::CreateCommitError( + CreateCommitError::ProposalValidationError( + ProposalValidationError::UnsupportedProposalType + ) + )) + )); + continue; + } else { + result.expect("Error creating commit") + }; + + // Have bob process the custom proposal first. + let processed_message = bob_group + .process_message( + provider, + custom_proposal_message.into_protocol_message().unwrap(), + ) + .unwrap(); + + if let ProcessedMessageContent::ProposalMessage(proposal) = processed_message.into_content() + { + bob_group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap(); + } else { + panic!("Unexpected message type"); + } + + let result = bob_group.process_message(provider, commit.into_protocol_message().unwrap()); + + // If the proposal is unsupported, we expect an error here. + let _processed_message = if matches!(test_mode, TestMode::Unsupported) { + assert!(matches!( + result, + Err(ProcessMessageError::InvalidCommit( + StageCommitError::ProposalValidationError( + ProposalValidationError::UnsupportedProposalType + ) + )) + )); + continue; + } else { + // If the proposal is supported, we expect no error. + result.expect("Error processing commit") + }; + } +} + // --- PreSharedKey Proposals --- -#[apply(ciphersuites_and_providers)] -fn test_valsem401_valsem402(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_valsem401_valsem402() { let ProposalValidationTestSetup { mut alice_group, alice_credential_with_key_and_signer, @@ -2013,11 +2202,14 @@ fn test_valsem401_valsem402(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro .. } = validation_test_setup(PURE_PLAINTEXT_WIRE_FORMAT_POLICY, ciphersuite, provider); - let alice_provider = OpenMlsRustCrypto::default(); - let bob_provider = OpenMlsRustCrypto::default(); + let alice_provider = Provider::default(); + let bob_provider = Provider::default(); // TODO(#1354): This is currently not tested because we can't easily create invalid commits. - let bad_psks: [(Vec, ProcessMessageError); 0] = [ + let bad_psks: [( + Vec, + ProcessMessageError<::StorageError>, + ); 0] = [ // // ValSem401 // ( // vec![PreSharedKeyId::external( @@ -2096,12 +2288,8 @@ fn test_valsem401_valsem402(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro let mut proposals = Vec::new(); for psk_id in psk_ids { - psk_id - .write_to_key_store(&alice_provider, ciphersuite, b"irrelevant") - .unwrap(); - psk_id - .write_to_key_store(&bob_provider, ciphersuite, b"irrelevant") - .unwrap(); + psk_id.store(&alice_provider, b"irrelevant").unwrap(); + psk_id.store(&bob_provider, b"irrelevant").unwrap(); let (psk_proposal, _) = alice_group .propose_external_psk( @@ -2121,8 +2309,12 @@ fn test_valsem401_valsem402(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro ) .unwrap(); - alice_group.clear_pending_proposals(); - alice_group.clear_pending_commit(); + alice_group + .clear_pending_proposals(provider.storage()) + .unwrap(); + alice_group + .clear_pending_commit(provider.storage()) + .unwrap(); for psk_proposal in proposals.into_iter() { let processed_message = bob_group @@ -2131,7 +2323,9 @@ fn test_valsem401_valsem402(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro match processed_message.into_content() { ProcessedMessageContent::ProposalMessage(queued_proposal) => { - bob_group.store_pending_proposal(*queued_proposal); + bob_group + .store_pending_proposal(provider.storage(), *queued_proposal) + .unwrap(); } _ => unreachable!(), } @@ -2144,7 +2338,9 @@ fn test_valsem401_valsem402(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro .unwrap_err(), ); - bob_group.clear_pending_proposals(); - bob_group.clear_pending_commit(); + bob_group + .clear_pending_proposals(provider.storage()) + .unwrap(); + bob_group.clear_pending_commit(provider.storage()).unwrap(); } } diff --git a/openmls/src/group/tests/test_remove_operation.rs b/openmls/src/group/tests/test_remove_operation.rs index b0263e271b..f287a5e13e 100644 --- a/openmls/src/group/tests/test_remove_operation.rs +++ b/openmls/src/group/tests/test_remove_operation.rs @@ -1,22 +1,12 @@ //! This module tests the classification of remove operations with RemoveOperation use super::utils::{generate_credential_with_key, generate_key_package}; -use crate::{ - framing::*, - group::{config::CryptoConfig, *}, - test_utils::*, - *, -}; -use openmls_rust_crypto::OpenMlsRustCrypto; +use crate::{framing::*, group::*}; +use openmls_traits::prelude::*; // Tests the different variants of the RemoveOperation enum. -#[apply(ciphersuites_and_providers)] -fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - let _ = provider; - let alice_provider = OpenMlsRustCrypto::default(); - let bob_provider = OpenMlsRustCrypto::default(); - let charlie_provider = OpenMlsRustCrypto::default(); - +#[openmls_test::openmls_test] +fn test_remove_operation_variants() { // We define two test cases, one where the member is removed by another member // and one where the member leaves the group on its own #[derive(Debug, Clone, Copy)] @@ -32,43 +22,40 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open let alice_credential_with_key_and_signer = generate_credential_with_key( "Alice".into(), ciphersuite.signature_algorithm(), - &alice_provider, + provider, ); - let bob_credential_with_key_and_signer = generate_credential_with_key( - "Bob".into(), - ciphersuite.signature_algorithm(), - &bob_provider, - ); + let bob_credential_with_key_and_signer = + generate_credential_with_key("Bob".into(), ciphersuite.signature_algorithm(), provider); let charlie_credential_with_key_and_signer = generate_credential_with_key( "Charlie".into(), ciphersuite.signature_algorithm(), - &charlie_provider, + provider, ); // Generate KeyPackages let bob_key_package = generate_key_package( ciphersuite, Extensions::empty(), - &bob_provider, + provider, bob_credential_with_key_and_signer.clone(), ); let charlie_key_package = generate_key_package( ciphersuite, Extensions::empty(), - &charlie_provider, + provider, charlie_credential_with_key_and_signer, ); // Define the MlsGroup configuration let mls_group_create_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === let mut alice_group = MlsGroup::new_with_group_id( - &alice_provider, + provider, &alice_credential_with_key_and_signer.signer, &mls_group_create_config, group_id, @@ -82,13 +69,16 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open let (_message, welcome, _group_info) = alice_group .add_members( - &alice_provider, + provider, &alice_credential_with_key_and_signer.signer, - &[bob_key_package, charlie_key_package], + &[ + bob_key_package.key_package().clone(), + charlie_key_package.key_package().clone(), + ], ) .expect("An unexpected error occurred."); alice_group - .merge_pending_commit(&alice_provider) + .merge_pending_commit(provider) .expect("error merging pending commit"); let welcome: MlsMessageIn = welcome.into(); @@ -97,23 +87,23 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open .expect("expected message to be a welcome"); let mut bob_group = StagedWelcome::new_from_welcome( - &bob_provider, + provider, mls_group_create_config.join_config(), welcome.clone(), Some(alice_group.export_ratchet_tree().into()), ) .expect("Error creating staged join from Welcome") - .into_group(&bob_provider) + .into_group(provider) .expect("Error creating group from staged join"); let mut charlie_group = StagedWelcome::new_from_welcome( - &charlie_provider, + provider, mls_group_create_config.join_config(), welcome, Some(alice_group.export_ratchet_tree().into()), ) .expect("Error creating staged join from Welcome") - .into_group(&charlie_provider) + .into_group(provider) .expect("Error creating group from staged join"); // === Remove operation === @@ -126,7 +116,7 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open // Alice removes Bob TestCase::Remove => alice_group .remove_members( - &alice_provider, + provider, &alice_credential_with_key_and_signer.signer, &[bob_index], ) @@ -135,21 +125,20 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open TestCase::Leave => { // Bob leaves the group let message = bob_group - .leave_group(&bob_provider, &bob_credential_with_key_and_signer.signer) + .leave_group(provider, &bob_credential_with_key_and_signer.signer) .expect("Could not leave group."); // Alice & Charlie store the pending proposal for group in [&mut alice_group, &mut charlie_group] { let processed_message = group - .process_message( - &charlie_provider, - message.clone().into_protocol_message().unwrap(), - ) + .process_message(provider, message.clone().into_protocol_message().unwrap()) .expect("Could not process message."); match processed_message.into_content() { ProcessedMessageContent::ProposalMessage(proposal) => { - group.store_pending_proposal(*proposal); + group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap(); } _ => unreachable!(), } @@ -158,7 +147,7 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open // Alice commits to Bob's proposal alice_group .commit_to_pending_proposals( - &alice_provider, + provider, &alice_credential_with_key_and_signer.signer, ) .expect("An unexpected error occurred.") @@ -203,10 +192,7 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open // === Remove operation from Bob's perspective === let bob_processed_message = bob_group - .process_message( - &bob_provider, - message.clone().into_protocol_message().unwrap(), - ) + .process_message(provider, message.clone().into_protocol_message().unwrap()) .expect("Could not process message."); match bob_processed_message.into_content() { @@ -259,7 +245,7 @@ fn test_remove_operation_variants(ciphersuite: Ciphersuite, provider: &impl Open let protocol_message = message.into_protocol_message().unwrap(); let charlie_processed_message = charlie_group - .process_message(&charlie_provider, protocol_message) + .process_message(provider, protocol_message) .expect("Could not process message."); match charlie_processed_message.into_content() { diff --git a/openmls/src/group/tests/test_wire_format_policy.rs b/openmls/src/group/tests/test_wire_format_policy.rs index 2b00c3307c..2748785d37 100644 --- a/openmls/src/group/tests/test_wire_format_policy.rs +++ b/openmls/src/group/tests/test_wire_format_policy.rs @@ -1,15 +1,8 @@ //! This module tests the different values for `WireFormatPolicy` -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{signatures::Signer, types::Ciphersuite, OpenMlsProvider}; +use openmls_traits::{signatures::Signer, types::Ciphersuite}; -use rstest::*; -use rstest_reuse::{self, *}; - -use crate::{ - framing::*, - group::{config::CryptoConfig, *}, -}; +use crate::{framing::*, group::*}; use super::utils::{ generate_credential_with_key, generate_key_package, CredentialWithKeyAndSigner, @@ -18,7 +11,7 @@ use super::utils::{ // Creates a group with one member fn create_group( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, wire_format_policy: WireFormatPolicy, ) -> (MlsGroup, CredentialWithKeyAndSigner) { let group_id = GroupId::from_slice(b"Test Group"); @@ -31,7 +24,7 @@ fn create_group( let mls_group_config = MlsGroupCreateConfig::builder() .wire_format_policy(wire_format_policy) .use_ratchet_tree_extension(true) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); ( @@ -50,7 +43,7 @@ fn create_group( // Takes an existing group, adds a new member and sends a message from the second member to the first one, returns that message fn receive_message( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, alice_group: &mut MlsGroup, alice_signer: &impl Signer, ) -> MlsMessageIn { @@ -67,7 +60,11 @@ fn receive_message( ); let (_message, welcome, _group_info) = alice_group - .add_members(provider, alice_signer, &[bob_key_package]) + .add_members( + provider, + alice_signer, + &[bob_key_package.key_package().clone()], + ) .expect("Could not add member."); alice_group @@ -95,8 +92,11 @@ fn receive_message( } // Test positive cases with all valid (pure & mixed) policies -#[apply(ciphersuites_and_providers)] -fn test_wire_policy_positive(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_wire_policy_positive( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { for wire_format_policy in WIRE_FORMAT_POLICIES.iter() { let (mut alice_group, alice_credential_with_key_and_signer) = create_group(ciphersuite, provider, *wire_format_policy); @@ -113,8 +113,11 @@ fn test_wire_policy_positive(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr } // Test negative cases with only icompatible policies -#[apply(ciphersuites_and_providers)] -fn test_wire_policy_negative(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_wire_policy_negative( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // All combinations that are not part of WIRE_FORMAT_POLICIES let incompatible_policies = vec![ WireFormatPolicy::new( @@ -138,6 +141,6 @@ fn test_wire_policy_negative(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr let err = alice_group .process_message(provider, message.try_into_protocol_message().unwrap()) .expect_err("An unexpected error occurred."); - assert_eq!(err, ProcessMessageError::IncompatibleWireFormat); + assert!(matches!(err, ProcessMessageError::IncompatibleWireFormat)); } } diff --git a/openmls/src/group/tests/utils.rs b/openmls/src/group/tests/utils.rs index bdf5e9b433..3c88fba601 100644 --- a/openmls/src/group/tests/utils.rs +++ b/openmls/src/group/tests/utils.rs @@ -6,21 +6,19 @@ use std::{cell::RefCell, collections::HashMap}; -use config::CryptoConfig; use openmls_basic_credential::SignatureKeyPair; use openmls_traits::crypto::OpenMlsCrypto; -use openmls_traits::{ - key_store::OpenMlsKeyStore, signatures::Signer, types::SignatureScheme, OpenMlsProvider, -}; +use openmls_traits::{signatures::Signer, types::SignatureScheme}; use rand::{rngs::OsRng, RngCore}; use tls_codec::Serialize; use crate::{ ciphersuite::signable::Signable, credentials::*, framing::*, group::*, key_packages::*, - messages::ConfirmationTag, schedule::psk::store::ResumptionPskStore, test_utils::*, - versions::ProtocolVersion, *, + messages::ConfirmationTag, schedule::psk::store::ResumptionPskStore, test_utils::*, *, }; +use self::storage::OpenMlsProvider; + /// Configuration of a client meant to be used in a test setup. #[derive(Clone)] pub(crate) struct TestClientConfig { @@ -79,7 +77,10 @@ pub(crate) struct TestSetup { const KEY_PACKAGE_COUNT: usize = 10; /// The setup function creates a set of groups and clients. -pub(crate) fn setup(config: TestSetupConfig, provider: &impl OpenMlsProvider) -> TestSetup { +pub(crate) fn setup( + config: TestSetupConfig, + provider: &impl crate::storage::OpenMlsProvider, +) -> TestSetup { let mut test_clients: HashMap<&'static str, RefCell> = HashMap::new(); let mut key_store: HashMap<(&'static str, Ciphersuite), Vec> = HashMap::new(); // Initialize the clients for which we have configurations. @@ -99,7 +100,7 @@ pub(crate) fn setup(config: TestSetupConfig, provider: &impl OpenMlsProvider) -> // Create a number of key packages. let mut key_packages = Vec::new(); for _ in 0..KEY_PACKAGE_COUNT { - let key_package_bundle: KeyPackageBundle = KeyPackageBundle::new( + let key_package_bundle: KeyPackageBundle = KeyPackageBundle::generate( provider, &credentia_with_key_and_signer.signer, ciphersuite, @@ -140,7 +141,7 @@ pub(crate) fn setup(config: TestSetupConfig, provider: &impl OpenMlsProvider) -> // Initialize the group state for the initial member. let core_group = CoreGroup::builder( GroupId::from_slice(&group_id.to_be_bytes()), - CryptoConfig::with_default_version(group_config.ciphersuite), + group_config.ciphersuite, credential_with_key_and_signer.credential_with_key.clone(), ) .with_config(group_config.config) @@ -299,19 +300,19 @@ fn test_random() { randombytes(0); } -#[apply(providers)] -fn test_setup(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_setup(provider: &impl crate::storage::OpenMlsProvider) { let test_client_config_a = TestClientConfig { name: "TestClientConfigA", - ciphersuites: vec![Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519], + ciphersuites: vec![Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519], }; let test_client_config_b = TestClientConfig { name: "TestClientConfigB", - ciphersuites: vec![Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519], + ciphersuites: vec![Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519], }; let group_config = CoreGroupConfig::default(); let test_group_config = TestGroupConfig { - ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, + ciphersuite: Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, config: group_config, members: vec![test_client_config_a.clone(), test_client_config_b.clone()], }; @@ -329,15 +330,15 @@ pub(crate) struct CredentialWithKeyAndSigner { } // Helper function to generate a CredentialWithKeyAndSigner -pub(crate) fn generate_credential_with_key( +pub(crate) fn generate_credential_with_key( identity: Vec, signature_scheme: SignatureScheme, - provider: &impl OpenMlsProvider, + provider: &Provider, ) -> CredentialWithKeyAndSigner { let (credential, signer) = { - let credential = BasicCredential::new(identity).unwrap(); + let credential = BasicCredential::new(identity); let signature_keys = SignatureKeyPair::new(signature_scheme).unwrap(); - signature_keys.store(provider.key_store()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); (credential, signature_keys) }; @@ -354,19 +355,16 @@ pub(crate) fn generate_credential_with_key( } // Helper function to generate a KeyPackageBundle -pub(crate) fn generate_key_package( +pub(crate) fn generate_key_package( ciphersuite: Ciphersuite, extensions: Extensions, - provider: &impl OpenMlsProvider, + provider: &Provider, credential_with_keys: CredentialWithKeyAndSigner, -) -> KeyPackage { +) -> KeyPackageBundle { KeyPackage::builder() .key_package_extensions(extensions) .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, provider, &credential_with_keys.signer, credential_with_keys.credential_with_key, @@ -379,8 +377,9 @@ pub(crate) fn resign_message( alice_group: &MlsGroup, plaintext: PublicMessage, original_plaintext: &PublicMessage, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, signer: &impl Signer, + ciphersuite: Ciphersuite, ) -> PublicMessage { let serialized_context = alice_group .export_group_context() @@ -409,6 +408,7 @@ pub(crate) fn resign_message( signed_plaintext .set_membership_tag( provider.crypto(), + ciphersuite, membership_key, alice_group.group().message_secrets().serialized_context(), ) diff --git a/openmls/src/key_packages/errors.rs b/openmls/src/key_packages/errors.rs index babe1c1549..d437fdf6ca 100644 --- a/openmls/src/key_packages/errors.rs +++ b/openmls/src/key_packages/errors.rs @@ -48,16 +48,16 @@ pub enum KeyPackageExtensionSupportError { /// KeyPackage new error #[derive(Error, Debug, PartialEq, Clone)] -pub enum KeyPackageNewError { +pub enum KeyPackageNewError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), /// The ciphersuite does not match the signature scheme. #[error("The ciphersuite does not match the signature scheme.")] CiphersuiteSignatureSchemeMismatch, - /// Accessing the key store failed. - #[error("Accessing the key store failed.")] - KeyStoreError(KeyStoreError), + /// Accessing storage failed. + #[error("Accessing storage failed.")] + StorageError, /// See [`SignatureError`] for more details. #[error(transparent)] SignatureError(#[from] SignatureError), diff --git a/openmls/src/key_packages/key_package_in.rs b/openmls/src/key_packages/key_package_in.rs index 527b3d7c32..5eb6d674e7 100644 --- a/openmls/src/key_packages/key_package_in.rs +++ b/openmls/src/key_packages/key_package_in.rs @@ -18,6 +18,9 @@ use super::{ errors::KeyPackageVerifyError, InitKey, KeyPackage, KeyPackageTbs, SIGNATURE_KEY_PACKAGE_LABEL, }; +#[cfg(any(feature = "test-utils", test))] +use super::KeyPackageBundle; + /// Intermediary struct for deserialization of a [`KeyPackageIn`]. struct VerifiableKeyPackage { payload: KeyPackageTbs, @@ -239,6 +242,16 @@ impl From for KeyPackageIn { } } +#[cfg(any(feature = "test-utils", test))] +impl From for KeyPackageIn { + fn from(value: KeyPackageBundle) -> Self { + Self { + payload: value.key_package.payload.into(), + signature: value.key_package.signature, + } + } +} + #[cfg(any(feature = "test-utils", test))] impl From for KeyPackage { fn from(value: KeyPackageIn) -> Self { diff --git a/openmls/src/key_packages/mod.rs b/openmls/src/key_packages/mod.rs index eab82c048b..cf53695000 100644 --- a/openmls/src/key_packages/mod.rs +++ b/openmls/src/key_packages/mod.rs @@ -29,15 +29,14 @@ //! package bundle can be created as follows: //! //! ``` -//! use openmls::prelude::{*, tls_codec::*}; +//! use openmls::{prelude::{*, tls_codec::*}}; //! use openmls_rust_crypto::OpenMlsRustCrypto; //! use openmls_basic_credential::SignatureKeyPair; //! //! let ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519; //! let provider = OpenMlsRustCrypto::default(); //! -//! let credential = BasicCredential::new("identity".into()) -//! .expect("Error creating a credential."); +//! let credential = BasicCredential::new("identity".into()); //! let signer = //! SignatureKeyPair::new(ciphersuite.signature_algorithm()) //! .expect("Error generating a signature key pair."); @@ -47,10 +46,7 @@ //! }; //! let key_package = KeyPackage::builder() //! .build( -//! CryptoConfig { -//! ciphersuite, -//! version: ProtocolVersion::default(), -//! }, +//! ciphersuite, //! &provider, //! &signer, //! credential_with_key, @@ -94,8 +90,6 @@ //! //! See [`KeyPackage`] for more details on how to use key packages. -#[cfg(test)] -use crate::treesync::node::encryption_keys::EncryptionKey; use crate::{ ciphersuite::{ hash_ref::{make_key_package_ref, KeyPackageRef}, @@ -105,10 +99,10 @@ use crate::{ credentials::*, error::LibraryError, extensions::{Extension, ExtensionType, Extensions, LastResortExtension}, - group::config::CryptoConfig, + storage::OpenMlsProvider, treesync::{ node::{ - encryption_keys::EncryptionKeyPair, + encryption_keys::{EncryptionKeyPair, EncryptionPrivateKey}, leaf_node::{Capabilities, LeafNodeSource, NewLeafNodeParams, TreeInfoTbs}, }, LeafNode, @@ -116,11 +110,7 @@ use crate::{ versions::ProtocolVersion, }; use openmls_traits::{ - crypto::OpenMlsCrypto, - key_store::{MlsEntity, MlsEntityId, OpenMlsKeyStore}, - signatures::Signer, - types::Ciphersuite, - OpenMlsProvider, + crypto::OpenMlsCrypto, signatures::Signer, storage::StorageProvider, types::Ciphersuite, }; use serde::{Deserialize, Serialize}; use tls_codec::{ @@ -207,10 +197,6 @@ impl SignedStruct for KeyPackage { const SIGNATURE_KEY_PACKAGE_LABEL: &str = "KeyPackageTBS"; -impl MlsEntity for KeyPackage { - const ID: MlsEntityId = MlsEntityId::KeyPackage; -} - /// Helper struct containing the results of building a new [`KeyPackage`]. pub(crate) struct KeyPackageCreationResult { pub key_package: KeyPackage, @@ -254,6 +240,12 @@ impl From> for InitKey { } } +impl From for InitKey { + fn from(key: HpkePublicKey) -> Self { + Self { key } + } +} + // Public `KeyPackage` functions. impl KeyPackage { /// Create a key package builder. @@ -265,31 +257,31 @@ impl KeyPackage { #[allow(clippy::too_many_arguments)] /// Create a new key package for the given `ciphersuite` and `identity`. - pub(crate) fn create( - config: CryptoConfig, - provider: &impl OpenMlsProvider, + pub(crate) fn create( + ciphersuite: Ciphersuite, + provider: &impl OpenMlsProvider, signer: &impl Signer, credential_with_key: CredentialWithKey, lifetime: Lifetime, extensions: Extensions, leaf_node_capabilities: Capabilities, leaf_node_extensions: Extensions, - ) -> Result> { - if config.ciphersuite.signature_algorithm() != signer.signature_scheme() { + ) -> Result { + if ciphersuite.signature_algorithm() != signer.signature_scheme() { return Err(KeyPackageNewError::CiphersuiteSignatureSchemeMismatch); } // Create a new HPKE key pair - let ikm = Secret::random(config.ciphersuite, provider.rand(), config.version) + let ikm = Secret::random(ciphersuite, provider.rand()) .map_err(LibraryError::unexpected_crypto_error)?; let init_key = provider .crypto() - .derive_hpke_keypair(config.ciphersuite.hpke_config(), ikm.as_slice()) + .derive_hpke_keypair(ciphersuite.hpke_config(), ikm.as_slice()) .map_err(|e| { KeyPackageNewError::LibraryError(LibraryError::unexpected_crypto_error(e)) })?; let (key_package, encryption_keypair) = Self::new_from_keys( - config, + ciphersuite, provider, signer, credential_with_key, @@ -317,9 +309,9 @@ impl KeyPackage { /// /// The caller is responsible for storing the new values. #[allow(clippy::too_many_arguments)] - fn new_from_keys( - config: CryptoConfig, - provider: &impl OpenMlsProvider, + fn new_from_keys( + ciphersuite: Ciphersuite, + provider: &impl OpenMlsProvider, signer: &impl Signer, credential_with_key: CredentialWithKey, lifetime: Lifetime, @@ -327,12 +319,12 @@ impl KeyPackage { capabilities: Capabilities, leaf_node_extensions: Extensions, init_key: InitKey, - ) -> Result<(Self, EncryptionKeyPair), KeyPackageNewError> { + ) -> Result<(Self, EncryptionKeyPair), KeyPackageNewError> { // We don't need the private key here. It's stored in the key store for // use later when creating a group with this key package. let new_leaf_node_params = NewLeafNodeParams { - config, + ciphersuite, leaf_node_source: LeafNodeSource::KeyPackage(lifetime), credential_with_key, capabilities, @@ -344,8 +336,8 @@ impl KeyPackage { LeafNode::new(provider, signer, new_leaf_node_params)?; let key_package_tbs = KeyPackageTbs { - protocol_version: config.version, - ciphersuite: config.ciphersuite, + protocol_version: ProtocolVersion::default(), + ciphersuite, init_key, leaf_node, extensions, @@ -356,19 +348,6 @@ impl KeyPackage { Ok((key_package, encryption_key_pair)) } - /// Delete this key package and its private key from the key store. - pub fn delete( - &self, - provider: &impl OpenMlsProvider, - ) -> Result<(), KeyStore::Error> { - provider - .key_store() - .delete::(self.hash_ref(provider.crypto()).unwrap().as_slice())?; - provider - .key_store() - .delete::(self.hpke_init_key().as_slice()) - } - /// Get a reference to the extensions of this key package. pub fn extensions(&self) -> &Extensions { &self.payload.extensions @@ -437,171 +416,6 @@ impl KeyPackage { } } -/// Helpers for testing. -#[cfg(any(feature = "test-utils", test))] -#[allow(clippy::too_many_arguments)] -impl KeyPackage { - /// Generate a new key package with a given init key - pub fn new_from_init_key( - config: CryptoConfig, - provider: &impl OpenMlsProvider, - signer: &impl Signer, - credential_with_key: CredentialWithKey, - extensions: Extensions, - leaf_node_capabilities: Capabilities, - leaf_node_extensions: Extensions, - init_key: InitKey, - ) -> Result> { - let (key_package, encryption_key_pair) = Self::new_from_keys( - config, - provider, - signer, - credential_with_key, - Lifetime::default(), - extensions, - leaf_node_capabilities, - leaf_node_extensions, - init_key, - )?; - - // Store the key package in the key store with the hash reference as id - // for retrieval when parsing welcome messages. - provider - .key_store() - .store( - key_package.hash_ref(provider.crypto())?.as_slice(), - &key_package, - ) - .map_err(KeyPackageNewError::KeyStoreError)?; - - // Store the encryption key pair in the key store. - encryption_key_pair - .write_to_key_store(provider.key_store()) - .map_err(KeyPackageNewError::KeyStoreError)?; - - Ok(key_package) - } - - /// Create new key package with a leaf node encryption key set to the - /// provided `encryption_key`. - #[cfg(test)] - #[allow(clippy::too_many_arguments)] - pub(crate) fn new_from_encryption_key( - config: CryptoConfig, - provider: &impl OpenMlsProvider, - signer: &impl Signer, - credential_with_key: CredentialWithKey, - extensions: Extensions, - leaf_node_capabilities: Capabilities, - leaf_node_extensions: Extensions, - encryption_key: EncryptionKey, - ) -> Result> { - // Create a new HPKE init key pair - let ikm = Secret::random(config.ciphersuite, provider.rand(), config.version).unwrap(); - let init_key = provider - .crypto() - .derive_hpke_keypair(config.ciphersuite.hpke_config(), ikm.as_slice()) - .map_err(|e| { - KeyPackageNewError::LibraryError(LibraryError::unexpected_crypto_error(e)) - })?; - - // Store the private part of the init_key into the key store. - // The key is the public key. - provider - .key_store() - .store::(&init_key.public, &init_key.private) - .map_err(KeyPackageNewError::KeyStoreError)?; - - // We don't need the private key here. It's stored in the key store for - // use later when creating a group with this key package. - let leaf_node = LeafNode::create_new_with_key( - encryption_key, - credential_with_key, - LeafNodeSource::KeyPackage(Lifetime::default()), - leaf_node_capabilities, - leaf_node_extensions, - TreeInfoTbs::KeyPackage, - signer, - ) - .unwrap(); - - let key_package = KeyPackageTbs { - protocol_version: config.version, - ciphersuite: config.ciphersuite, - init_key: init_key.public.into(), - leaf_node, - extensions, - }; - - let key_package = key_package.sign(signer)?; - - // Store the key package in the key store with the hash reference as id - // for retrieval when parsing welcome messages. - provider - .key_store() - .store( - key_package.hash_ref(provider.crypto())?.as_slice(), - &key_package, - ) - .map_err(KeyPackageNewError::KeyStoreError)?; - - Ok(key_package) - } - - pub fn into_with_init_key( - self, - config: CryptoConfig, - signer: &impl Signer, - init_key: InitKey, - ) -> Result { - let key_package_tbs = KeyPackageTbs { - protocol_version: config.version, - ciphersuite: config.ciphersuite, - init_key, - leaf_node: self.leaf_node().clone(), - extensions: self.extensions().clone(), - }; - - key_package_tbs.sign(signer) - } - - /// Resign this key package with another credential. - pub fn resign(mut self, signer: &impl Signer, credential_with_key: CredentialWithKey) -> Self { - self.payload - .leaf_node - .set_credential(credential_with_key.credential.clone()); - self.payload - .leaf_node - .set_signature_key(credential_with_key.signature_key.clone()); - - self.payload - .leaf_node - .resign(signer, credential_with_key, TreeInfoTbs::KeyPackage); - - self.payload.sign(signer).unwrap() - } - - /// Replace the public key in the KeyPackage. - pub fn set_init_key(&mut self, init_key: InitKey) { - self.payload.init_key = init_key - } - - /// Replace the version in the KeyPackage. - pub fn set_version(&mut self, version: ProtocolVersion) { - self.payload.protocol_version = version - } - - /// Replace the ciphersuite in the KeyPackage. - pub fn set_ciphersuite(&mut self, ciphersuite: Ciphersuite) { - self.payload.ciphersuite = ciphersuite - } - - /// Set the [`LeafNode`]. - pub fn set_leaf_node(&mut self, leaf_node: LeafNode) { - self.payload.leaf_node = leaf_node; - } -} - /// Builder that helps creating (and configuring) a [`KeyPackage`]. #[derive(Default, Debug, Clone, Serialize, Deserialize)] pub struct KeyPackageBuilder { @@ -667,16 +481,16 @@ impl KeyPackageBuilder { } } - pub(crate) fn build_without_key_storage( + pub(crate) fn build_without_storage( mut self, - config: CryptoConfig, - provider: &impl OpenMlsProvider, + ciphersuite: Ciphersuite, + provider: &impl OpenMlsProvider, signer: &impl Signer, credential_with_key: CredentialWithKey, - ) -> Result> { + ) -> Result { self.ensure_last_resort(); KeyPackage::create( - config, + ciphersuite, provider, signer, credential_with_key, @@ -688,20 +502,20 @@ impl KeyPackageBuilder { } /// Finalize and build the key package. - pub fn build( + pub fn build( mut self, - config: CryptoConfig, - provider: &impl OpenMlsProvider, + ciphersuite: Ciphersuite, + provider: &impl OpenMlsProvider, signer: &impl Signer, credential_with_key: CredentialWithKey, - ) -> Result> { + ) -> Result { self.ensure_last_resort(); let KeyPackageCreationResult { key_package, encryption_keypair, init_private_key, } = KeyPackage::create( - config, + ciphersuite, provider, signer, credential_with_key, @@ -713,78 +527,89 @@ impl KeyPackageBuilder { // Store the key package in the key store with the hash reference as id // for retrieval when parsing welcome messages. + let full_kp = KeyPackageBundle { + key_package, + private_init_key: init_private_key, + private_encryption_key: encryption_keypair.private_key().clone(), + }; provider - .key_store() - .store( - key_package.hash_ref(provider.crypto())?.as_slice(), - &key_package, - ) - .map_err(KeyPackageNewError::KeyStoreError)?; + .storage() + .write_key_package(&full_kp.key_package.hash_ref(provider.crypto())?, &full_kp) + .map_err(|_| KeyPackageNewError::StorageError)?; // Store the encryption key pair in the key store. encryption_keypair - .write_to_key_store(provider.key_store()) - .map_err(KeyPackageNewError::KeyStoreError)?; + .write(provider.storage()) + .map_err(|_| KeyPackageNewError::StorageError)?; - // Store the private part of the init_key into the key store. - // The key is the public key. - provider - .key_store() - .store::(key_package.hpke_init_key().as_slice(), &init_private_key) - .map_err(KeyPackageNewError::KeyStoreError)?; - - Ok(key_package) + Ok(full_kp) } } -/// A [`KeyPackageBundle`] contains a [`KeyPackage`] and the corresponding private -/// key. +/// A [`KeyPackageBundle`] contains a [`KeyPackage`] and the init and encryption +/// private key. +/// +/// This is stored to ensure the private key is handled together with the key +/// package. #[derive(Debug, Clone, Serialize, Deserialize)] -#[cfg_attr(test, derive(PartialEq))] -pub(crate) struct KeyPackageBundle { +pub struct KeyPackageBundle { pub(crate) key_package: KeyPackage, - pub(crate) private_key: HpkePrivateKey, + pub(crate) private_init_key: HpkePrivateKey, + pub(crate) private_encryption_key: EncryptionPrivateKey, } // Public `KeyPackageBundle` functions. impl KeyPackageBundle { /// Get a reference to the public part of this bundle, i.e. the [`KeyPackage`]. - pub(crate) fn key_package(&self) -> &KeyPackage { + pub fn key_package(&self) -> &KeyPackage { &self.key_package } - /// Get a reference to the private key. - pub fn private_key(&self) -> &HpkePrivateKey { - &self.private_key + /// Get a reference to the private init key. + pub fn init_private_key(&self) -> &HpkePrivateKey { + &self.private_init_key + } + + /// Get the encryption key pair. + pub(crate) fn encryption_key_pair(&self) -> EncryptionKeyPair { + EncryptionKeyPair::from(( + self.key_package.leaf_node().encryption_key().clone(), + self.private_encryption_key.clone(), + )) + } +} + +#[cfg(any(test, feature = "test-utils"))] +impl KeyPackageBundle { + /// Generate a new key package bundle with the private key. + pub fn new( + key_package: KeyPackage, + private_init_key: HpkePrivateKey, + private_encryption_key: EncryptionPrivateKey, + ) -> Self { + Self { + key_package, + private_init_key, + private_encryption_key, + } + } + + /// Get a reference to the private encryption key. + pub fn encryption_private_key(&self) -> &HpkePrivateKey { + self.private_encryption_key.key() } } #[cfg(test)] impl KeyPackageBundle { - pub(crate) fn new( + pub(crate) fn generate( provider: &impl OpenMlsProvider, signer: &impl Signer, ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, ) -> Self { - let key_package = KeyPackage::builder() - .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - provider, - signer, - credential_with_key, - ) - .unwrap(); - let private_key = provider - .key_store() - .read::(key_package.hpke_init_key().as_slice()) - .unwrap(); - Self { - key_package, - private_key, - } + KeyPackage::builder() + .build(ciphersuite, provider, signer, credential_with_key) + .unwrap() } } diff --git a/openmls/src/key_packages/test_key_packages.rs b/openmls/src/key_packages/test_key_packages.rs index 673a8f646c..a3b5bb07d5 100644 --- a/openmls/src/key_packages/test_key_packages.rs +++ b/openmls/src/key_packages/test_key_packages.rs @@ -1,25 +1,23 @@ use crate::test_utils::*; use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; +use openmls_traits::prelude::*; + use tls_codec::Deserialize; -use crate::{extensions::*, key_packages::*}; +use crate::{extensions::*, key_packages::*, storage::OpenMlsProvider}; /// Helper function to generate key packages pub(crate) fn key_package( ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider, -) -> (KeyPackage, Credential, SignatureKeyPair) { - let credential = BasicCredential::new(b"Sasha".to_vec()).unwrap(); +) -> (KeyPackageBundle, Credential, SignatureKeyPair) { + let credential = BasicCredential::new(b"Sasha".to_vec()); let signer = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); // Generate a valid KeyPackage. let key_package = KeyPackage::builder() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, provider, &signer, CredentialWithKey { @@ -32,21 +30,22 @@ pub(crate) fn key_package( (key_package, credential.into(), signer) } -#[apply(ciphersuites_and_providers)] -fn generate_key_package(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn generate_key_package() { let (key_package, _credential, _signature_keys) = key_package(ciphersuite, provider); - let kpi = KeyPackageIn::from(key_package); + let kpi = KeyPackageIn::from(key_package.key_package().clone()); assert!(kpi .validate(provider.crypto(), ProtocolVersion::Mls10) .is_ok()); } -#[apply(ciphersuites_and_providers)] -fn serialization(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn serialization() { let (key_package, _, _) = key_package(ciphersuite, provider); let encoded = key_package + .key_package() .tls_serialize_detached() .expect("An unexpected error occurred."); @@ -54,12 +53,12 @@ fn serialization(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { KeyPackageIn::tls_deserialize(&mut encoded.as_slice()) .expect("An unexpected error occurred."), ); - assert_eq!(key_package, decoded_key_package); + assert_eq!(key_package.key_package(), &decoded_key_package); } -#[apply(ciphersuites_and_providers)] -fn application_id_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - let credential = BasicCredential::new(b"Sasha".to_vec()).unwrap(); +#[openmls_test::openmls_test] +fn application_id_extension() { + let credential = BasicCredential::new(b"Sasha".to_vec()); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); // Generate a valid KeyPackage. @@ -69,10 +68,7 @@ fn application_id_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro ApplicationIdExtension::new(id), ))) .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, provider, &signature_keys, CredentialWithKey { @@ -82,7 +78,7 @@ fn application_id_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro ) .expect("An unexpected error occurred."); - let kpi = KeyPackageIn::from(key_package.clone()); + let kpi = KeyPackageIn::from(key_package.key_package().clone()); assert!(kpi .validate(provider.crypto(), ProtocolVersion::Mls10) .is_ok()); @@ -91,6 +87,7 @@ fn application_id_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro assert_eq!( Some(id), key_package + .key_package() .leaf_node() .extensions() .application_id() @@ -101,22 +98,19 @@ fn application_id_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsPro /// Test that the key package is correctly validated: /// - The protocol version is correct /// - The init key is not equal to the encryption key -#[apply(ciphersuites_and_providers)] -fn key_package_validation(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn key_package_validation() { let (key_package_orig, _, _) = key_package(ciphersuite, provider); // === Protocol version === - let mut key_package = key_package_orig.clone(); - + let mut franken_key_package = + frankenstein::FrankenKeyPackage::from(key_package_orig.key_package().clone()); // Set an invalid protocol version - key_package.set_version(ProtocolVersion::Mls10Draft11); + franken_key_package.protocol_version = 999; - let encoded = key_package - .tls_serialize_detached() - .expect("An unexpected error occurred."); + let key_package_in = KeyPackageIn::from(franken_key_package); - let key_package_in = KeyPackageIn::tls_deserialize(&mut encoded.as_slice()).unwrap(); let err = key_package_in .validate(provider.crypto(), ProtocolVersion::Mls10) .unwrap_err(); @@ -126,23 +120,13 @@ fn key_package_validation(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi // === Init/encryption key === - let mut key_package = key_package_orig; - + let mut franken_key_package = + frankenstein::FrankenKeyPackage::from(key_package_orig.key_package().clone()); // Set an invalid init key - key_package.set_init_key(InitKey::from( - key_package - .leaf_node() - .encryption_key() - .key() - .as_slice() - .to_vec(), - )); + franken_key_package.init_key = franken_key_package.leaf_node.encryption_key.clone(); - let encoded = key_package - .tls_serialize_detached() - .expect("An unexpected error occurred."); + let key_package_in = KeyPackageIn::from(franken_key_package); - let key_package_in = KeyPackageIn::tls_deserialize(&mut encoded.as_slice()).unwrap(); let err = key_package_in .validate(provider.crypto(), ProtocolVersion::Mls10) .unwrap_err(); @@ -153,19 +137,16 @@ fn key_package_validation(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvi /// Test that a key package is correctly built with a last resort extension when /// the last resort flag is set during the build process. -#[apply(ciphersuites_and_providers)] -fn last_resort_key_package(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - let credential = Credential::from(BasicCredential::new(b"Sasha".to_vec()).unwrap()); +#[openmls_test::openmls_test] +fn last_resort_key_package() { + let credential = Credential::from(BasicCredential::new(b"Sasha".to_vec())); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); // build without any other extensions let key_package = KeyPackage::builder() .mark_as_last_resort() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, provider, &signature_keys, CredentialWithKey { @@ -174,17 +155,14 @@ fn last_resort_key_package(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv }, ) .expect("An unexpected error occurred."); - assert!(key_package.last_resort()); + assert!(key_package.key_package().last_resort()); // build with empty extensions let key_package = KeyPackage::builder() .key_package_extensions(Extensions::empty()) .mark_as_last_resort() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, provider, &signature_keys, CredentialWithKey { @@ -193,7 +171,7 @@ fn last_resort_key_package(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv }, ) .expect("An unexpected error occurred."); - assert!(key_package.last_resort()); + assert!(key_package.key_package().last_resort()); // build with extension let key_package = KeyPackage::builder() @@ -203,10 +181,7 @@ fn last_resort_key_package(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv ))) .mark_as_last_resort() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, + ciphersuite, provider, &signature_keys, CredentialWithKey { @@ -215,5 +190,5 @@ fn last_resort_key_package(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv }, ) .expect("An unexpected error occurred."); - assert!(key_package.last_resort()); + assert!(key_package.key_package().last_resort()); } diff --git a/openmls/src/lib.rs b/openmls/src/lib.rs index 0f4ba83996..e12ebb36de 100644 --- a/openmls/src/lib.rs +++ b/openmls/src/lib.rs @@ -4,7 +4,7 @@ //! up to parties and have them create a group. //! //! ``` -//! use openmls::prelude::{*, config::CryptoConfig, tls_codec::*}; +//! use openmls::{prelude::{*, tls_codec::*}}; //! use openmls_rust_crypto::{OpenMlsRustCrypto}; //! use openmls_basic_credential::SignatureKeyPair; //! @@ -22,8 +22,7 @@ //! signature_algorithm: SignatureScheme, //! provider: &impl OpenMlsProvider, //! ) -> (CredentialWithKey, SignatureKeyPair) { -//! let credential = BasicCredential::new(identity) -//! .expect("Error creating a credential."); +//! let credential = BasicCredential::new(identity); //! let signature_keys = //! SignatureKeyPair::new(signature_algorithm) //! .expect("Error generating a signature key pair."); @@ -31,7 +30,7 @@ //! // Store the signature key into the key store so OpenMLS has access //! // to it. //! signature_keys -//! .store(provider.key_store()) +//! .store(provider.storage()) //! .expect("Error storing signature keys in key store."); //! //! ( @@ -49,14 +48,11 @@ //! provider: &impl OpenMlsProvider, //! signer: &SignatureKeyPair, //! credential_with_key: CredentialWithKey, -//! ) -> KeyPackage { +//! ) -> KeyPackageBundle { //! // Create the key package //! KeyPackage::builder() //! .build( -//! CryptoConfig { -//! ciphersuite, -//! version: ProtocolVersion::default(), -//! }, +//! ciphersuite, //! provider, //! signer, //! credential_with_key, @@ -98,7 +94,7 @@ //! // The key package has to be retrieved from Maxim in some way. Most likely //! // via a server storing key packages for users. //! let (mls_message_out, welcome_out, group_info) = sasha_group -//! .add_members(provider, &sasha_signer, &[maxim_key_package]) +//! .add_members(provider, &sasha_signer, &[maxim_key_package.key_package().clone()]) //! .expect("Could not add members."); //! //! // Sasha merges the pending commit that adds Maxim. @@ -161,9 +157,6 @@ compile_error!("In order for OpenMLS to build for WebAssembly, JavaScript APIs m #[cfg(any(feature = "test-utils", test))] pub mod prelude_test; -#[cfg(any(feature = "test-utils", test))] -pub use rstest_reuse; - #[cfg(any(feature = "test-utils", test))] #[macro_use] pub mod test_utils; @@ -190,6 +183,10 @@ pub mod schedule; pub mod treesync; pub mod versions; +// implement storage traits +// public +pub mod storage; + // Private mod binary_tree; mod tree; diff --git a/openmls/src/messages/codec.rs b/openmls/src/messages/codec.rs new file mode 100644 index 0000000000..7a05f3f7d8 --- /dev/null +++ b/openmls/src/messages/codec.rs @@ -0,0 +1,139 @@ +//! # Codec +//! +//! This module contains the encoding and decoding logic for Proposals. + +use tls_codec::{Deserialize, DeserializeBytes, Serialize, Size}; + +use super::{ + proposals::{ + AppAckProposal, ExternalInitProposal, GroupContextExtensionProposal, PreSharedKeyProposal, + Proposal, ProposalType, ReInitProposal, RemoveProposal, + }, + proposals_in::{AddProposalIn, ProposalIn, UpdateProposalIn}, + CustomProposal, +}; + +impl Size for Proposal { + fn tls_serialized_len(&self) -> usize { + self.proposal_type().tls_serialized_len() + + match self { + Proposal::Add(p) => p.tls_serialized_len(), + Proposal::Update(p) => p.tls_serialized_len(), + Proposal::Remove(p) => p.tls_serialized_len(), + Proposal::PreSharedKey(p) => p.tls_serialized_len(), + Proposal::ReInit(p) => p.tls_serialized_len(), + Proposal::ExternalInit(p) => p.tls_serialized_len(), + Proposal::GroupContextExtensions(p) => p.tls_serialized_len(), + Proposal::AppAck(p) => p.tls_serialized_len(), + Proposal::Custom(p) => p.payload().tls_serialized_len(), + } + } +} + +impl Serialize for Proposal { + fn tls_serialize(&self, writer: &mut W) -> Result { + let written = self.proposal_type().tls_serialize(writer)?; + match self { + Proposal::Add(p) => p.tls_serialize(writer), + Proposal::Update(p) => p.tls_serialize(writer), + Proposal::Remove(p) => p.tls_serialize(writer), + Proposal::PreSharedKey(p) => p.tls_serialize(writer), + Proposal::ReInit(p) => p.tls_serialize(writer), + Proposal::ExternalInit(p) => p.tls_serialize(writer), + Proposal::GroupContextExtensions(p) => p.tls_serialize(writer), + Proposal::AppAck(p) => p.tls_serialize(writer), + Proposal::Custom(p) => p.payload().tls_serialize(writer), + } + .map(|l| written + l) + } +} + +impl Size for &ProposalIn { + fn tls_serialized_len(&self) -> usize { + self.proposal_type().tls_serialized_len() + + match self { + ProposalIn::Add(p) => p.tls_serialized_len(), + ProposalIn::Update(p) => p.tls_serialized_len(), + ProposalIn::Remove(p) => p.tls_serialized_len(), + ProposalIn::PreSharedKey(p) => p.tls_serialized_len(), + ProposalIn::ReInit(p) => p.tls_serialized_len(), + ProposalIn::ExternalInit(p) => p.tls_serialized_len(), + ProposalIn::GroupContextExtensions(p) => p.tls_serialized_len(), + ProposalIn::AppAck(p) => p.tls_serialized_len(), + ProposalIn::Custom(p) => p.payload().tls_serialized_len(), + } + } +} + +impl Size for ProposalIn { + fn tls_serialized_len(&self) -> usize { + (&self).tls_serialized_len() + } +} + +impl Serialize for &ProposalIn { + fn tls_serialize(&self, writer: &mut W) -> Result { + let written = self.proposal_type().tls_serialize(writer)?; + match self { + ProposalIn::Add(p) => p.tls_serialize(writer), + ProposalIn::Update(p) => p.tls_serialize(writer), + ProposalIn::Remove(p) => p.tls_serialize(writer), + ProposalIn::PreSharedKey(p) => p.tls_serialize(writer), + ProposalIn::ReInit(p) => p.tls_serialize(writer), + ProposalIn::ExternalInit(p) => p.tls_serialize(writer), + ProposalIn::GroupContextExtensions(p) => p.tls_serialize(writer), + ProposalIn::AppAck(p) => p.tls_serialize(writer), + ProposalIn::Custom(p) => p.payload().tls_serialize(writer), + } + .map(|l| written + l) + } +} + +impl Serialize for ProposalIn { + fn tls_serialize(&self, writer: &mut W) -> Result { + (&self).tls_serialize(writer) + } +} + +impl Deserialize for ProposalIn { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + let proposal_type = ProposalType::tls_deserialize(bytes)?; + let proposal = match proposal_type { + ProposalType::Add => ProposalIn::Add(AddProposalIn::tls_deserialize(bytes)?), + ProposalType::Update => ProposalIn::Update(UpdateProposalIn::tls_deserialize(bytes)?), + ProposalType::Remove => ProposalIn::Remove(RemoveProposal::tls_deserialize(bytes)?), + ProposalType::PreSharedKey => { + ProposalIn::PreSharedKey(PreSharedKeyProposal::tls_deserialize(bytes)?) + } + ProposalType::Reinit => ProposalIn::ReInit(ReInitProposal::tls_deserialize(bytes)?), + ProposalType::ExternalInit => { + ProposalIn::ExternalInit(ExternalInitProposal::tls_deserialize(bytes)?) + } + ProposalType::GroupContextExtensions => ProposalIn::GroupContextExtensions( + GroupContextExtensionProposal::tls_deserialize(bytes)?, + ), + ProposalType::AppAck => ProposalIn::AppAck(AppAckProposal::tls_deserialize(bytes)?), + ProposalType::Custom(_) => { + let payload = Vec::::tls_deserialize(bytes)?; + let custom_proposal = CustomProposal::new(proposal_type.into(), payload); + ProposalIn::Custom(custom_proposal) + } + }; + Ok(proposal) + } +} + +impl DeserializeBytes for ProposalIn { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> + where + Self: Sized, + { + let mut bytes_ref = bytes; + let proposal = ProposalIn::tls_deserialize(&mut bytes_ref)?; + let remainder = &bytes[proposal.tls_serialized_len()..]; + Ok((proposal, remainder)) + } +} diff --git a/openmls/src/messages/external_proposals.rs b/openmls/src/messages/external_proposals.rs index 36a67e745c..e88b1fbead 100644 --- a/openmls/src/messages/external_proposals.rs +++ b/openmls/src/messages/external_proposals.rs @@ -14,6 +14,7 @@ use crate::{ }, key_packages::KeyPackage, messages::{AddProposal, Proposal}, + storage::{OpenMlsProvider, StorageProvider}, }; use openmls_traits::signatures::Signer; @@ -40,12 +41,12 @@ impl JoinProposal { /// * `epoch` - group's epoch /// * `signer` - of the sender to sign the message #[allow(clippy::new_ret_no_self)] - pub fn new( + pub fn new( key_package: KeyPackage, group_id: GroupId, epoch: GroupEpoch, signer: &impl Signer, - ) -> Result { + ) -> Result> { AuthenticatedContent::new_join_proposal( Proposal::Add(AddProposal { key_package }), group_id, @@ -69,13 +70,13 @@ impl ExternalProposal { /// * `signer` - of the sender to sign the message /// * `sender` - index of the sender of the proposal (in the [crate::extensions::ExternalSendersExtension] array /// from the Group Context) - pub fn new_remove( + pub fn new_remove( removed: LeafNodeIndex, group_id: GroupId, epoch: GroupEpoch, signer: &impl Signer, sender_index: SenderExtensionIndex, - ) -> Result { + ) -> Result> { AuthenticatedContent::new_external_proposal( Proposal::Remove(RemoveProposal { removed }), group_id, diff --git a/openmls/src/messages/mod.rs b/openmls/src/messages/mod.rs index daf70af0a6..395b566a80 100644 --- a/openmls/src/messages/mod.rs +++ b/openmls/src/messages/mod.rs @@ -32,6 +32,7 @@ use crate::{ #[cfg(test)] use openmls_traits::random::OpenMlsRand; +pub(crate) mod codec; pub mod external_proposals; pub mod group_info; pub mod proposals; @@ -322,7 +323,7 @@ impl PathSecret { ) -> Result { let node_secret = self .path_secret - .kdf_expand_label(crypto, "node", &[], ciphersuite.hash_length()) + .kdf_expand_label(crypto, ciphersuite, "node", &[], ciphersuite.hash_length()) .map_err(LibraryError::unexpected_crypto_error)?; let HpkeKeyPair { public, private } = crypto .derive_hpke_keypair(ciphersuite.hpke_config(), node_secret.as_slice()) @@ -339,7 +340,7 @@ impl PathSecret { ) -> Result { let path_secret = self .path_secret - .kdf_expand_label(crypto, "path", &[], ciphersuite.hash_length()) + .kdf_expand_label(crypto, ciphersuite, "path", &[], ciphersuite.hash_length()) .map_err(LibraryError::unexpected_crypto_error)?; Ok(Self { path_secret }) } @@ -375,14 +376,13 @@ impl PathSecret { pub(crate) fn decrypt( crypto: &impl OpenMlsCrypto, ciphersuite: Ciphersuite, - version: ProtocolVersion, ciphertext: &HpkeCiphertext, private_key: &EncryptionPrivateKey, group_context: &[u8], ) -> Result { // ValSem203: Path secrets must decrypt correctly private_key - .decrypt(crypto, ciphersuite, version, ciphertext, group_context) + .decrypt(crypto, ciphersuite, ciphertext, group_context) .map(|path_secret| Self { path_secret }) .map_err(|e| e.into()) } @@ -446,9 +446,7 @@ impl GroupSecrets { // Note: This also checks that no extraneous data was encrypted. let group_secrets = GroupSecrets::tls_deserialize_exact(group_secrets_plaintext) - .map_err(|_| GroupSecretsError::Malformed)? - // TODO(#1065) - .config(ciphersuite, ProtocolVersion::Mls10); + .map_err(|_| GroupSecretsError::Malformed)?; Ok(group_secrets) } @@ -466,19 +464,6 @@ impl GroupSecrets { } .tls_serialize_detached() } - - /// Set the config for the secrets, i.e. ciphersuite and MLS version. - pub(crate) fn config( - mut self, - ciphersuite: Ciphersuite, - mls_version: ProtocolVersion, - ) -> GroupSecrets { - self.joiner_secret.config(ciphersuite, mls_version); - if let Some(s) = &mut self.path_secret { - s.path_secret.config(ciphersuite, mls_version); - } - self - } } #[cfg(test)] @@ -486,7 +471,6 @@ impl GroupSecrets { pub fn random_encoded( ciphersuite: Ciphersuite, rng: &impl OpenMlsRand, - version: ProtocolVersion, ) -> Result, tls_codec::Error> { let psk_id = PreSharedKeyId::new( ciphersuite, @@ -500,10 +484,9 @@ impl GroupSecrets { let psks = vec![psk_id]; GroupSecrets::new_encoded( - &JoinerSecret::random(ciphersuite, rng, version), + &JoinerSecret::random(ciphersuite, rng), Some(&PathSecret { - path_secret: Secret::random(ciphersuite, rng, version) - .expect("Not enough randomness."), + path_secret: Secret::random(ciphersuite, rng).expect("Not enough randomness."), }), &psks, ) diff --git a/openmls/src/messages/proposals.rs b/openmls/src/messages/proposals.rs index ac2442f0f4..c13a236c0c 100644 --- a/openmls/src/messages/proposals.rs +++ b/openmls/src/messages/proposals.rs @@ -1,9 +1,6 @@ //! # Proposals //! //! This module defines all the different types of Proposals. -//! -//! To find out if a specific proposal type is supported, -//! [`ProposalType::is_supported()`] can be used. use std::io::{Read, Write}; @@ -71,7 +68,7 @@ use crate::{ /// | Value | Name | Recommended | Path Required | Reference | Notes | /// |:=======|:========|:============|:==============|:==========|:=============================| /// | 0x0008 | app_ack | Y | Y | RFC XXXX | draft-ietf-mls-extensions-00 | -#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize)] +#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug, Serialize, Deserialize, Hash)] #[allow(missing_docs)] pub enum ProposalType { Add, @@ -82,7 +79,7 @@ pub enum ProposalType { ExternalInit, GroupContextExtensions, AppAck, - Unknown(u16), + Custom(u16), } impl Size for ProposalType { @@ -124,21 +121,6 @@ impl DeserializeBytes for ProposalType { } impl ProposalType { - /// Check whether a proposal type is supported or not. Returns `true` - /// if a proposal is supported and `false` otherwise. - pub fn is_supported(&self) -> bool { - matches!( - self, - ProposalType::Add - | ProposalType::Update - | ProposalType::Remove - | ProposalType::PreSharedKey - | ProposalType::Reinit - | ProposalType::ExternalInit - | ProposalType::GroupContextExtensions - ) - } - /// Returns `true` if the proposal type requires a path and `false` pub fn is_path_required(&self) -> bool { matches!( @@ -159,7 +141,7 @@ impl From for ProposalType { 6 => ProposalType::ExternalInit, 7 => ProposalType::GroupContextExtensions, 8 => ProposalType::AppAck, - unknown => ProposalType::Unknown(unknown), + other => ProposalType::Custom(other), } } } @@ -175,7 +157,7 @@ impl From for u16 { ProposalType::ExternalInit => 6, ProposalType::GroupContextExtensions => 7, ProposalType::AppAck => 8, - ProposalType::Unknown(unknown) => unknown, + ProposalType::Custom(id) => id, } } } @@ -200,29 +182,22 @@ impl From for u16 { /// } Proposal; /// ``` #[allow(clippy::large_enum_variant)] -#[derive(Debug, PartialEq, Clone, Serialize, Deserialize, TlsSize, TlsSerialize)] +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[allow(missing_docs)] #[repr(u16)] pub enum Proposal { - #[tls_codec(discriminant = 1)] Add(AddProposal), - #[tls_codec(discriminant = 2)] Update(UpdateProposal), - #[tls_codec(discriminant = 3)] Remove(RemoveProposal), - #[tls_codec(discriminant = 4)] PreSharedKey(PreSharedKeyProposal), - #[tls_codec(discriminant = 5)] ReInit(ReInitProposal), - #[tls_codec(discriminant = 6)] ExternalInit(ExternalInitProposal), - #[tls_codec(discriminant = 7)] GroupContextExtensions(GroupContextExtensionProposal), // # Extensions // TODO(#916): `AppAck` is not in draft-ietf-mls-protocol-17 but // was moved to `draft-ietf-mls-extensions-00`. - #[tls_codec(discriminant = 8)] AppAck(AppAckProposal), + Custom(CustomProposal), } impl Proposal { @@ -237,6 +212,10 @@ impl Proposal { Proposal::ExternalInit(_) => ProposalType::ExternalInit, Proposal::GroupContextExtensions(_) => ProposalType::GroupContextExtensions, Proposal::AppAck(_) => ProposalType::AppAck, + Proposal::Custom(CustomProposal { + proposal_type, + payload: _, + }) => ProposalType::Custom(proposal_type.to_owned()), } } @@ -634,6 +613,43 @@ pub(crate) struct MessageRange { last_generation: u32, } +/// A custom proposal with semantics to be implemented by the application. +#[derive( + Debug, + PartialEq, + Clone, + Serialize, + Deserialize, + TlsSize, + TlsSerialize, + TlsDeserialize, + TlsDeserializeBytes, +)] +pub struct CustomProposal { + proposal_type: u16, + payload: Vec, +} + +impl CustomProposal { + /// Generate a new custom proposal. + pub fn new(proposal_type: u16, payload: Vec) -> Self { + Self { + proposal_type, + payload, + } + } + + /// Returns the proposal type of this [`CustomProposal`]. + pub fn proposal_type(&self) -> u16 { + self.proposal_type + } + + /// Returns the payload of this [`CustomProposal`]. + pub fn payload(&self) -> &[u8] { + &self.payload + } +} + #[cfg(test)] mod tests { use tls_codec::{Deserialize, Serialize}; @@ -652,7 +668,7 @@ mod tests { let got = ProposalType::tls_deserialize_exact(&test).unwrap(); match got { - ProposalType::Unknown(got_proposal_type) => { + ProposalType::Custom(got_proposal_type) => { assert_eq!(proposal_type, got_proposal_type); } other => panic!("Expected `ProposalType::Unknown`, got `{:?}`.", other), diff --git a/openmls/src/messages/proposals_in.rs b/openmls/src/messages/proposals_in.rs index 2c5b23579e..9a20874e07 100644 --- a/openmls/src/messages/proposals_in.rs +++ b/openmls/src/messages/proposals_in.rs @@ -1,9 +1,6 @@ //! # Proposals //! //! This module defines all the different types of Proposals. -//! -//! To find out if a specific proposal type is supported, -//! [`ProposalType::is_supported()`] can be used. use crate::{ ciphersuite::{hash_ref::ProposalRef, signable::Verifiable}, @@ -19,10 +16,13 @@ use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite}; use serde::{Deserialize, Serialize}; use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize}; -use super::proposals::{ - AddProposal, AppAckProposal, ExternalInitProposal, GroupContextExtensionProposal, - PreSharedKeyProposal, Proposal, ProposalOrRef, ProposalType, ReInitProposal, RemoveProposal, - UpdateProposal, +use super::{ + proposals::{ + AddProposal, AppAckProposal, ExternalInitProposal, GroupContextExtensionProposal, + PreSharedKeyProposal, Proposal, ProposalOrRef, ProposalType, ReInitProposal, + RemoveProposal, UpdateProposal, + }, + CustomProposal, }; /// Proposal. @@ -45,39 +45,22 @@ use super::proposals::{ /// } Proposal; /// ``` #[allow(clippy::large_enum_variant)] -#[derive( - Debug, - PartialEq, - Clone, - Serialize, - Deserialize, - TlsSize, - TlsSerialize, - TlsDeserialize, - TlsDeserializeBytes, -)] +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[allow(missing_docs)] #[repr(u16)] pub enum ProposalIn { - #[tls_codec(discriminant = 1)] Add(AddProposalIn), - #[tls_codec(discriminant = 2)] Update(UpdateProposalIn), - #[tls_codec(discriminant = 3)] Remove(RemoveProposal), - #[tls_codec(discriminant = 4)] PreSharedKey(PreSharedKeyProposal), - #[tls_codec(discriminant = 5)] ReInit(ReInitProposal), - #[tls_codec(discriminant = 6)] ExternalInit(ExternalInitProposal), - #[tls_codec(discriminant = 7)] GroupContextExtensions(GroupContextExtensionProposal), // # Extensions // TODO(#916): `AppAck` is not in draft-ietf-mls-protocol-17 but // was moved to `draft-ietf-mls-extensions-00`. - #[tls_codec(discriminant = 8)] AppAck(AppAckProposal), + Custom(CustomProposal), } impl ProposalIn { @@ -92,6 +75,9 @@ impl ProposalIn { ProposalIn::ExternalInit(_) => ProposalType::ExternalInit, ProposalIn::GroupContextExtensions(_) => ProposalType::GroupContextExtensions, ProposalIn::AppAck(_) => ProposalType::AppAck, + ProposalIn::Custom(custom_proposal) => { + ProposalType::Custom(custom_proposal.proposal_type()) + } } } @@ -125,6 +111,7 @@ impl ProposalIn { Proposal::GroupContextExtensions(group_context_extension) } ProposalIn::AppAck(app_ack) => Proposal::AppAck(app_ack), + ProposalIn::Custom(custom) => Proposal::Custom(custom), }) } } @@ -327,6 +314,7 @@ impl From for crate::messages::proposals::Proposal { Self::GroupContextExtensions(group_context_extension) } ProposalIn::AppAck(app_ack) => Self::AppAck(app_ack), + ProposalIn::Custom(other) => Self::Custom(other), } } } @@ -344,6 +332,7 @@ impl From for ProposalIn { Self::GroupContextExtensions(group_context_extension) } Proposal::AppAck(app_ack) => Self::AppAck(app_ack), + Proposal::Custom(other) => Self::Custom(other), } } } diff --git a/openmls/src/messages/tests/test_codec.rs b/openmls/src/messages/tests/test_codec.rs index a3f0c66d34..ba1f42a016 100644 --- a/openmls/src/messages/tests/test_codec.rs +++ b/openmls/src/messages/tests/test_codec.rs @@ -1,4 +1,3 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; use tls_codec::{Deserialize, Serialize}; use crate::{ @@ -11,8 +10,8 @@ use crate::{ /// Test the encoding for PreSharedKeyProposal, that also covers some of the /// other PSK-related structs -#[apply(providers)] -fn test_pre_shared_key_proposal_codec(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_pre_shared_key_proposal_codec() { // External let psk = PreSharedKeyId { psk: Psk::External(ExternalPsk::new(vec![1, 2, 3])), @@ -79,8 +78,8 @@ fn test_pre_shared_key_proposal_codec(provider: &impl OpenMlsProvider) { } /// Test the encoding for ReInitProposal, that also covers some of the /// other PSK-related structs -#[apply(ciphersuites_and_providers)] -fn test_reinit_proposal_codec(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_reinit_proposal_codec() { let orig = ReInitProposal { group_id: GroupId::random(provider.rand()), version: ProtocolVersion::default(), diff --git a/openmls/src/messages/tests/test_export_group_info.rs b/openmls/src/messages/tests/test_export_group_info.rs index 75ca13ffb2..54dd4f1c17 100644 --- a/openmls/src/messages/tests/test_export_group_info.rs +++ b/openmls/src/messages/tests/test_export_group_info.rs @@ -3,16 +3,13 @@ use tls_codec::{Deserialize, Serialize}; use crate::{ ciphersuite::signable::Verifiable, group::test_core_group::setup_alice_group, - messages::{ - group_info::{GroupInfo, VerifiableGroupInfo}, - *, - }, + messages::group_info::{GroupInfo, VerifiableGroupInfo}, test_utils::*, }; /// Tests the creation of an [UnverifiedGroupInfo] and verifies it was correctly signed. -#[apply(ciphersuites_and_providers)] -fn export_group_info(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn export_group_info() { // Alice creates a group let (group_alice, _, signer, pk) = setup_alice_group(ciphersuite, provider); diff --git a/openmls/src/messages/tests/test_proposals.rs b/openmls/src/messages/tests/test_proposals.rs index e35c5e4b8e..b77ba4358d 100644 --- a/openmls/src/messages/tests/test_proposals.rs +++ b/openmls/src/messages/tests/test_proposals.rs @@ -1,4 +1,3 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; use tls_codec::{Deserialize, Serialize}; use crate::{ @@ -13,8 +12,8 @@ use crate::{ /// This test encodes and decodes the `ProposalOrRef` struct and makes sure the /// decoded values are the same as the original -#[apply(ciphersuites_and_providers)] -fn proposals_codec(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn proposals_codec() { // Proposal let remove_proposal = RemoveProposal { diff --git a/openmls/src/messages/tests/test_welcome.rs b/openmls/src/messages/tests/test_welcome.rs index 09c76f65bf..f817bbb073 100644 --- a/openmls/src/messages/tests/test_welcome.rs +++ b/openmls/src/messages/tests/test_welcome.rs @@ -1,10 +1,5 @@ use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{ - crypto::OpenMlsCrypto, key_store::OpenMlsKeyStore, types::Ciphersuite, OpenMlsProvider, -}; -use rstest::*; -use rstest_reuse::{self, *}; +use openmls_traits::prelude::{openmls_types::*, *}; use tls_codec::{Deserialize, Serialize}; use crate::{ @@ -14,28 +9,28 @@ use crate::{ }, extensions::Extensions, group::{ - config::CryptoConfig, errors::WelcomeError, GroupContext, GroupId, MlsGroup, - MlsGroupCreateConfig, StagedWelcome, + errors::WelcomeError, GroupContext, GroupId, MlsGroup, MlsGroupCreateConfig, StagedWelcome, }, messages::{ group_info::{GroupInfoTBS, VerifiableGroupInfo}, ConfirmationTag, EncryptedGroupSecrets, GroupSecrets, GroupSecretsError, Welcome, }, - prelude::HpkePrivateKey, schedule::{ psk::{load_psks, store::ResumptionPskStore, PskSecret}, KeySchedule, }, treesync::node::encryption_keys::EncryptionKeyPair, - versions::ProtocolVersion, }; /// This test detects if the decryption of the encrypted group secrets fails due to a change in /// the encrypted group info. As the group info is part of the decryption context of the encrypted /// group info, it is not possible to generate a matching encrypted group context with different /// parameters. -#[apply(ciphersuites_and_providers)] -fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_welcome_context_mismatch( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { let _ = pretty_env_logger::try_init(); // We need a ciphersuite that is different from the current one to create @@ -49,7 +44,7 @@ fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenM let group_id = GroupId::random(provider.rand()); let mls_group_create_config = MlsGroupCreateConfig::builder() - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); let (alice_credential_with_key, _alice_kpb, alice_signer, _alice_signature_key) = @@ -58,7 +53,7 @@ fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenM crate::group::test_core_group::setup_client("Bob", ciphersuite, provider); let bob_kp = bob_kpb.key_package(); - let bob_private_key = bob_kpb.private_key(); + let bob_private_key = bob_kpb.init_private_key(); // === Alice creates a group and adds Bob === let mut alice_group = MlsGroup::new_with_group_id( @@ -96,15 +91,14 @@ fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenM ) .expect("Could not decrypt group secrets."); let group_secrets = GroupSecrets::tls_deserialize(&mut group_secrets_bytes.as_slice()) - .expect("Could not deserialize group secrets.") - .config(ciphersuite, ProtocolVersion::Mls10); + .expect("Could not deserialize group secrets."); let joiner_secret = group_secrets.joiner_secret; // Prepare the PskSecret let psk_secret = { let resumption_psk_store = ResumptionPskStore::new(1024); - let psks = load_psks(provider.key_store(), &resumption_psk_store, &[]).unwrap(); + let psks = load_psks(provider.storage(), &resumption_psk_store, &[]).unwrap(); PskSecret::new(provider.crypto(), ciphersuite, psks).unwrap() }; @@ -116,9 +110,9 @@ fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenM // Derive welcome key & nonce from the key schedule let (welcome_key, welcome_nonce) = key_schedule - .welcome(provider.crypto()) + .welcome(provider.crypto(), ciphersuite) .expect("Using the key schedule in the wrong state") - .derive_welcome_key_nonce(provider.crypto()) + .derive_welcome_key_nonce(provider.crypto(), ciphersuite) .expect("Could not derive welcome key and nonce."); let group_info_bytes = welcome_key @@ -154,11 +148,9 @@ fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenM welcome.encrypted_group_info = encrypted_verifiable_group_info.into(); // Create backup of encryption keypair, s.t. we can process the welcome a second time after failing. - let encryption_keypair = EncryptionKeyPair::read_from_key_store( - provider, - bob_kpb.key_package().leaf_node().encryption_key(), - ) - .unwrap(); + let encryption_keypair = + EncryptionKeyPair::read(provider, bob_kpb.key_package().leaf_node().encryption_key()) + .unwrap(); // Bob tries to join the group let err = StagedWelcome::new_from_welcome( @@ -169,30 +161,21 @@ fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenM ) .expect_err("Created a staged join from an invalid Welcome."); - assert_eq!( + assert!(matches!( err, WelcomeError::GroupSecrets(GroupSecretsError::DecryptionFailed) - ); + )); // === Process the original Welcome === // We need to store the key package and its encryption key again because it // has been consumed already. provider - .key_store() - .store( - bob_kp.hash_ref(provider.crypto()).unwrap().as_slice(), - bob_kp, - ) - .unwrap(); - provider - .key_store() - .store::(bob_kp.hpke_init_key().as_slice(), bob_private_key) + .storage() + .write_key_package(&bob_kp.hash_ref(provider.crypto()).unwrap(), &bob_kpb) .unwrap(); - encryption_keypair - .write_to_key_store(provider.key_store()) - .unwrap(); + encryption_keypair.write(provider.storage()).unwrap(); let _group = StagedWelcome::new_from_welcome( provider, @@ -205,12 +188,12 @@ fn test_welcome_context_mismatch(ciphersuite: Ciphersuite, provider: &impl OpenM .expect("Error creating group from a valid staged join."); } -#[apply(ciphersuites_and_providers)] -fn test_welcome_msg(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_welcome_msg() { test_welcome_message(ciphersuite, provider); } -fn test_welcome_message(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +fn test_welcome_message(ciphersuite: Ciphersuite, provider: &impl crate::storage::OpenMlsProvider) { // We use this dummy group info in all test cases. let group_info_tbs = { let group_context = GroupContext::new( @@ -248,7 +231,7 @@ fn test_welcome_message(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide .crypto() .derive_hpke_keypair( ciphersuite.hpke_config(), - Secret::random(ciphersuite, provider.rand(), None) + Secret::random(ciphersuite, provider.rand()) .expect("Not enough randomness.") .as_slice(), ) diff --git a/openmls/src/prelude.rs b/openmls/src/prelude.rs index 3a0e984021..2957e3bcdb 100644 --- a/openmls/src/prelude.rs +++ b/openmls/src/prelude.rs @@ -2,7 +2,7 @@ //! Include this to get access to all the public functions of OpenMLS. // MlsGroup -pub use crate::group::{config::CryptoConfig, core_group::Member, ser::*, *}; +pub use crate::group::{core_group::Member, ser::*, *}; pub use crate::group::public_group::{errors::*, PublicGroup}; @@ -61,7 +61,4 @@ pub use tls_codec::{self, *}; pub use crate::error::*; // OpenMLS traits -pub use openmls_traits::{ - crypto::OpenMlsCrypto, key_store::OpenMlsKeyStore, random::OpenMlsRand, types::*, - OpenMlsProvider, -}; +pub use openmls_traits::{crypto::OpenMlsCrypto, random::OpenMlsRand, types::*, OpenMlsProvider}; diff --git a/openmls/src/schedule/errors.rs b/openmls/src/schedule/errors.rs index 7519a4c398..e4c8f8836c 100644 --- a/openmls/src/schedule/errors.rs +++ b/openmls/src/schedule/errors.rs @@ -17,12 +17,12 @@ pub enum PskError { /// More than 2^16 PSKs were provided. #[error("More than 2^16 PSKs were provided.")] TooManyKeys, - /// The PSK could not be found in the key store. - #[error("The PSK could not be found in the key store.")] + /// The PSK could not be found in the store. + #[error("The PSK could not be found in the store.")] KeyNotFound, - /// Failed to write PSK into keystore. - #[error("Failed to write PSK into keystore.")] - KeyStore, + /// Failed to write PSK into storage. + #[error("Failed to write PSK storage.")] + Storage, /// Type mismatch. #[error("Type mismatch. Expected {allowed:?}, got {got:?}.")] TypeMismatch { diff --git a/openmls/src/schedule/kat_key_schedule.rs b/openmls/src/schedule/kat_key_schedule.rs index fa177ddc70..5e8e7e9df5 100644 --- a/openmls/src/schedule/kat_key_schedule.rs +++ b/openmls/src/schedule/kat_key_schedule.rs @@ -6,7 +6,6 @@ //! If values are not present, they are encoded as empty strings. use log::info; -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{random::OpenMlsRand, types::HpkeKeyPair, OpenMlsProvider}; use serde::{self, Deserialize, Serialize}; use tls_codec::Serialize as TlsSerializeTrait; @@ -91,9 +90,9 @@ fn generate( // PSK secret can sometimes be the all zero vector let a: [u8; 1] = crypto.rand().random_array().unwrap(); let psk_secret = if a[0] > 127 { - PskSecret::from(Secret::random(ciphersuite, crypto.rand(), ProtocolVersion::Mls10).unwrap()) + PskSecret::from(Secret::random(ciphersuite, crypto.rand()).unwrap()) } else { - PskSecret::from(Secret::zero(ciphersuite, ProtocolVersion::Mls10)) + PskSecret::from(Secret::zero(ciphersuite)) }; let group_context = GroupContext::new( @@ -107,6 +106,7 @@ fn generate( let joiner_secret = JoinerSecret::new( crypto.crypto(), + ciphersuite, commit_secret.clone(), init_secret, &group_context.tls_serialize_detached().unwrap(), @@ -120,7 +120,7 @@ fn generate( ) .expect("Could not create KeySchedule."); let welcome_secret = key_schedule - .welcome(crypto.crypto()) + .welcome(crypto.crypto(), ciphersuite) .expect("An unexpected error occurred."); let serialized_group_context = group_context @@ -131,7 +131,7 @@ fn generate( .add_context(crypto.crypto(), &serialized_group_context) .expect("An unexpected error occurred."); let epoch_secrets = key_schedule - .epoch_secrets(crypto.crypto()) + .epoch_secrets(crypto.crypto(), ciphersuite) .expect("An unexpected error occurred."); // Calculate external HPKE key pair @@ -163,8 +163,7 @@ pub fn generate_test_vector( // Set up setting. let mut init_secret = - InitSecret::random(ciphersuite, provider.rand(), ProtocolVersion::default()) - .expect("Not enough randomness."); + InitSecret::random(ciphersuite, provider.rand()).expect("Not enough randomness."); let initial_init_secret = init_secret.clone(); let group_id = provider .rand() @@ -255,8 +254,8 @@ fn write_test_vectors() { write("test_vectors/key-schedule-new.json", &tests); } -#[apply(providers)] -fn read_test_vectors_key_schedule(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn read_test_vectors_key_schedule() { let _ = pretty_env_logger::try_init(); let tests: Vec = read_json!("../../test_vectors/key-schedule.json"); @@ -292,20 +291,12 @@ pub fn run_test_vector( " InitSecret from tve: {:?}", test_vector.initial_init_secret ); - let mut init_secret = InitSecret::from(Secret::from_slice( - &init_secret, - ProtocolVersion::default(), - ciphersuite, - )); + let mut init_secret = InitSecret::from(Secret::from_slice(&init_secret)); for (epoch_ctr, epoch) in test_vector.epochs.iter().enumerate() { let tree_hash = hex_to_bytes(&epoch.tree_hash); let secret = hex_to_bytes(&epoch.commit_secret); - let commit_secret = CommitSecret::from(PathSecret::from(Secret::from_slice( - &secret, - ProtocolVersion::default(), - ciphersuite, - ))); + let commit_secret = CommitSecret::from(PathSecret::from(Secret::from_slice(&secret))); log::trace!(" CommitSecret from tve {:?}", epoch.commit_secret); let confirmed_transcript_hash = hex_to_bytes(&epoch.confirmed_transcript_hash); @@ -321,6 +312,7 @@ pub fn run_test_vector( let joiner_secret = JoinerSecret::new( provider.crypto(), + ciphersuite, commit_secret, &init_secret, &group_context.tls_serialize_detached().unwrap(), @@ -333,18 +325,14 @@ pub fn run_test_vector( return Err(KsTestVectorError::JoinerSecretMismatch); } - let psk_secret_inner = Secret::from_slice( - &hex_to_bytes(&epoch.psk_secret), - ProtocolVersion::Mls10, - ciphersuite, - ); + let psk_secret_inner = Secret::from_slice(&hex_to_bytes(&epoch.psk_secret)); let psk_secret = PskSecret::from(psk_secret_inner); let mut key_schedule = KeySchedule::init(ciphersuite, provider.crypto(), &joiner_secret, psk_secret) .expect("Could not create KeySchedule."); let welcome_secret = key_schedule - .welcome(provider.crypto()) + .welcome(provider.crypto(), ciphersuite) .expect("An unexpected error occurred."); if hex_to_bytes(&epoch.welcome_secret) != welcome_secret.as_slice() { @@ -373,7 +361,7 @@ pub fn run_test_vector( .expect("An unexpected error occurred."); let epoch_secrets = key_schedule - .epoch_secrets(provider.crypto()) + .epoch_secrets(provider.crypto(), ciphersuite) .expect("An unexpected error occurred."); init_secret = epoch_secrets.init_secret().clone(); diff --git a/openmls/src/schedule/kat_psk_secret.rs b/openmls/src/schedule/kat_psk_secret.rs index 191a8ffc86..f13d6c592f 100644 --- a/openmls/src/schedule/kat_psk_secret.rs +++ b/openmls/src/schedule/kat_psk_secret.rs @@ -80,9 +80,7 @@ fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result let psk_id = PreSharedKeyId::new_with_nonce(psk_type, psk.psk_nonce.clone()); - psk_id - .write_to_key_store(provider, ciphersuite, &psk.psk) - .unwrap(); + psk_id.store(provider, &psk.psk).unwrap(); psk_id }) .collect::>(); @@ -91,7 +89,7 @@ fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result let psk_secret = { let resumption_psk_store = ResumptionPskStore::new(1024); - let psks = load_psks(provider.key_store(), &resumption_psk_store, &psk_ids).unwrap(); + let psks = load_psks(provider.storage(), &resumption_psk_store, &psk_ids).unwrap(); PskSecret::new(provider.crypto(), ciphersuite, psks).unwrap() }; @@ -103,8 +101,8 @@ fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result } } -#[apply(providers)] -fn read_test_vectors_ps(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn read_test_vectors_ps() { let _ = pretty_env_logger::try_init(); log::debug!("Reading test vectors ..."); diff --git a/openmls/src/schedule/mod.rs b/openmls/src/schedule/mod.rs index 12b1eed708..b82a0610ba 100644 --- a/openmls/src/schedule/mod.rs +++ b/openmls/src/schedule/mod.rs @@ -170,8 +170,14 @@ pub struct ResumptionPskSecret { impl ResumptionPskSecret { /// Derive an `ResumptionPsk` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { - let secret = epoch_secret.secret.derive_secret(crypto, "resumption")?; + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { + let secret = epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "resumption")?; Ok(Self { secret }) } @@ -191,10 +197,14 @@ pub struct EpochAuthenticator { impl EpochAuthenticator { /// Derive an `EpochAuthenticator` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { let secret = epoch_secret .secret - .derive_secret(crypto, "authentication")?; + .derive_secret(crypto, ciphersuite, "authentication")?; Ok(Self { secret }) } @@ -224,17 +234,16 @@ impl From for CommitSecret { impl CommitSecret { /// Create a CommitSecret consisting of an all-zero string of length /// `hash_length`. - pub(crate) fn zero_secret(ciphersuite: Ciphersuite, version: ProtocolVersion) -> Self { + pub(crate) fn zero_secret(ciphersuite: Ciphersuite) -> Self { CommitSecret { - secret: Secret::zero(ciphersuite, version), + secret: Secret::zero(ciphersuite), } } #[cfg(any(feature = "test-utils", test))] pub(crate) fn random(ciphersuite: Ciphersuite, rng: &impl OpenMlsRand) -> Self { Self { - secret: Secret::random(ciphersuite, rng, None /* MLS version */) - .expect("Not enough randomness."), + secret: Secret::random(ciphersuite, rng).expect("Not enough randomness."), } } @@ -263,14 +272,20 @@ impl From for InitSecret { fn hpke_info_from_version(version: ProtocolVersion) -> &'static str { match version { ProtocolVersion::Mls10 => "MLS 1.0 external init secret", - ProtocolVersion::Mls10Draft11 => "", + _ => "", } } impl InitSecret { /// Derive an `InitSecret` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: EpochSecret) -> Result { - let secret = epoch_secret.secret.derive_secret(crypto, "init")?; + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: EpochSecret, + ) -> Result { + let secret = epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "init")?; log_crypto!(trace, "Init secret: {:x?}", secret); Ok(InitSecret { secret }) } @@ -279,10 +294,9 @@ impl InitSecret { pub(crate) fn random( ciphersuite: Ciphersuite, rand: &impl OpenMlsRand, - version: ProtocolVersion, ) -> Result { Ok(InitSecret { - secret: Secret::random(ciphersuite, rand, version)?, + secret: Secret::random(ciphersuite, rand)?, }) } @@ -303,7 +317,7 @@ impl InitSecret { )?; Ok(( InitSecret { - secret: Secret::from_slice(&raw_init_secret, version, ciphersuite), + secret: Secret::from_slice(&raw_init_secret), }, kem_output, )) @@ -328,7 +342,7 @@ impl InitSecret { ) .map_err(LibraryError::unexpected_crypto_error)?; Ok(InitSecret { - secret: Secret::from_slice(&raw_init_secret, version, ciphersuite), + secret: Secret::from_slice(&raw_init_secret), }) } @@ -357,42 +371,36 @@ impl JoinerSecret { /// partial commit. pub(crate) fn new( crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, commit_secret_option: impl Into>, init_secret: &InitSecret, serialized_group_context: &[u8], ) -> Result { let intermediate_secret = init_secret.secret.hkdf_extract( crypto, + ciphersuite, commit_secret_option.into().as_ref().map(|cs| &cs.secret), )?; let secret = intermediate_secret.kdf_expand_label( crypto, + ciphersuite, "joiner", serialized_group_context, - intermediate_secret.ciphersuite().hash_length(), + ciphersuite.hash_length(), )?; log_crypto!(trace, "Joiner secret: {:x?}", secret); Ok(JoinerSecret { secret }) } - /// Set the config for the secret, i.e. ciphersuite and MLS version. - pub(crate) fn config(&mut self, ciphersuite: Ciphersuite, mls_version: ProtocolVersion) { - self.secret.config(ciphersuite, mls_version); - } - #[cfg(any(feature = "test-utils", test))] pub(crate) fn as_slice(&self) -> &[u8] { self.secret.as_slice() } #[cfg(test)] - pub(crate) fn random( - ciphersuite: Ciphersuite, - rand: &impl OpenMlsRand, - version: ProtocolVersion, - ) -> Self { + pub(crate) fn random(ciphersuite: Ciphersuite, rand: &impl OpenMlsRand) -> Self { Self { - secret: Secret::random(ciphersuite, rand, version).expect("Not enough randomness."), + secret: Secret::random(ciphersuite, rand).expect("Not enough randomness."), } } } @@ -426,7 +434,7 @@ impl KeySchedule { " joiner_secret: {:x?}", joiner_secret.secret.as_slice() ); - let intermediate_secret = IntermediateSecret::new(crypto, joiner_secret, psk) + let intermediate_secret = IntermediateSecret::new(crypto, ciphersuite, joiner_secret, psk) .map_err(LibraryError::unexpected_crypto_error)?; Ok(Self { ciphersuite, @@ -441,6 +449,7 @@ impl KeySchedule { pub(crate) fn welcome( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, ) -> Result { if self.state != State::Initial || self.intermediate_secret.is_none() { log::error!("Trying to derive a welcome secret while not in the initial state."); @@ -453,7 +462,11 @@ impl KeySchedule { .as_ref() .ok_or_else(|| LibraryError::custom("state machine error"))?; - Ok(WelcomeSecret::new(crypto, intermediate_secret)?) + Ok(WelcomeSecret::new( + crypto, + ciphersuite, + intermediate_secret, + )?) } /// Add the group context to the key schedule. @@ -502,6 +515,7 @@ impl KeySchedule { pub(crate) fn epoch_secrets( &mut self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, ) -> Result { if self.state != State::Context || self.epoch_secret.is_none() { log::error!("Trying to derive the epoch secrets while not in the right state."); @@ -515,7 +529,7 @@ impl KeySchedule { None => return Err(LibraryError::custom("state machine error").into()), }; - Ok(EpochSecrets::new(crypto, epoch_secret)?) + Ok(EpochSecrets::new(crypto, ciphersuite, epoch_secret)?) } } @@ -530,11 +544,14 @@ impl IntermediateSecret { /// PSK. fn new( crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, joiner_secret: &JoinerSecret, psk: PskSecret, ) -> Result { log_crypto!(trace, "PSK input: {:x?}", psk.as_slice()); - let secret = joiner_secret.secret.hkdf_extract(crypto, psk.secret())?; + let secret = joiner_secret + .secret + .hkdf_extract(crypto, ciphersuite, psk.secret())?; log_crypto!(trace, "Intermediate secret: {:x?}", secret); Ok(Self { secret }) } @@ -548,11 +565,12 @@ impl WelcomeSecret { /// Derive a `WelcomeSecret` from to decrypt a `Welcome` message. fn new( crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, intermediate_secret: &IntermediateSecret, ) -> Result { let secret = intermediate_secret .secret - .derive_secret(crypto, "welcome")?; + .derive_secret(crypto, ciphersuite, "welcome")?; log_crypto!(trace, "Welcome secret: {:x?}", secret); Ok(WelcomeSecret { secret }) } @@ -562,34 +580,42 @@ impl WelcomeSecret { pub(crate) fn derive_welcome_key_nonce( self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, ) -> Result<(AeadKey, AeadNonce), CryptoError> { - let welcome_nonce = self.derive_aead_nonce(crypto)?; - let welcome_key = self.derive_aead_key(crypto)?; + let welcome_nonce = self.derive_aead_nonce(crypto, ciphersuite)?; + let welcome_key = self.derive_aead_key(crypto, ciphersuite)?; Ok((welcome_key, welcome_nonce)) } /// Derive a new AEAD key from a `WelcomeSecret`. - fn derive_aead_key(&self, crypto: &impl OpenMlsCrypto) -> Result { - log::trace!( - "WelcomeSecret.derive_aead_key with {}", - self.secret.ciphersuite() - ); + fn derive_aead_key( + &self, + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + ) -> Result { + log::trace!("WelcomeSecret.derive_aead_key with {}", ciphersuite); let aead_secret = self.secret.kdf_expand_label( crypto, + ciphersuite, "key", b"", - self.secret.ciphersuite().aead_key_length(), + ciphersuite.aead_key_length(), )?; - Ok(AeadKey::from_secret(aead_secret)) + Ok(AeadKey::from_secret(aead_secret, ciphersuite)) } /// Derive a new AEAD nonce from a `WelcomeSecret`. - fn derive_aead_nonce(&self, crypto: &impl OpenMlsCrypto) -> Result { + fn derive_aead_nonce( + &self, + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + ) -> Result { let nonce_secret = self.secret.kdf_expand_label( crypto, + ciphersuite, "nonce", b"", - self.secret.ciphersuite().aead_nonce_length(), + ciphersuite.aead_nonce_length(), )?; Ok(AeadNonce::from_secret(nonce_secret)) } @@ -617,6 +643,7 @@ impl EpochSecret { ) -> Result { let secret = intermediate_secret.secret.kdf_expand_label( crypto, + ciphersuite, "epoch", serialized_group_context, ciphersuite.hash_length(), @@ -634,9 +661,15 @@ pub(crate) struct EncryptionSecret { impl EncryptionSecret { /// Derive an encryption secret from a reference to an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { Ok(EncryptionSecret { - secret: epoch_secret.secret.derive_secret(crypto, "encryption")?, + secret: epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "encryption")?, }) } @@ -658,8 +691,7 @@ impl EncryptionSecret { #[cfg(test)] pub(crate) fn random(ciphersuite: Ciphersuite, rng: &impl OpenMlsRand) -> Self { EncryptionSecret { - secret: Secret::random(ciphersuite, rng, None /* MLS version */) - .expect("Not enough randomness."), + secret: Secret::random(ciphersuite, rng).expect("Not enough randomness."), } } @@ -670,13 +702,9 @@ impl EncryptionSecret { #[cfg(any(feature = "test-utils", test))] /// Create a new secret from a byte vector. - pub(crate) fn from_slice( - bytes: &[u8], - mls_version: ProtocolVersion, - ciphersuite: Ciphersuite, - ) -> Self { + pub(crate) fn from_slice(bytes: &[u8]) -> Self { Self { - secret: Secret::from_slice(bytes, mls_version, ciphersuite), + secret: Secret::from_slice(bytes), } } } @@ -690,8 +718,14 @@ pub(crate) struct ExporterSecret { impl ExporterSecret { /// Derive an `ExporterSecret` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { - let secret = epoch_secret.secret.derive_secret(crypto, "exporter")?; + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { + let secret = epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "exporter")?; Ok(ExporterSecret { secret }) } @@ -714,8 +748,8 @@ impl ExporterSecret { let context_hash = &crypto.hash(ciphersuite.hash_algorithm(), context)?; Ok(self .secret - .derive_secret(crypto, label)? - .kdf_expand_label(crypto, "exported", context_hash, key_length)? + .derive_secret(crypto, ciphersuite, label)? + .kdf_expand_label(crypto, ciphersuite, "exported", context_hash, key_length)? .as_slice() .to_vec()) } @@ -730,8 +764,14 @@ pub(crate) struct ExternalSecret { impl ExternalSecret { /// Derive an `ExternalSecret` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { - let secret = epoch_secret.secret.derive_secret(crypto, "external")?; + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { + let secret = epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "external")?; Ok(Self { secret }) } @@ -759,14 +799,20 @@ pub(crate) struct ConfirmationKey { impl ConfirmationKey { /// Derive an `ConfirmationKey` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { log::debug!("Computing confirmation key."); log_crypto!( trace, " epoch_secret {:x?}", epoch_secret.secret.as_slice() ); - let secret = epoch_secret.secret.derive_secret(crypto, "confirm")?; + let secret = epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "confirm")?; Ok(Self { secret }) } @@ -781,6 +827,7 @@ impl ConfirmationKey { pub(crate) fn tag( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, confirmed_transcript_hash: &[u8], ) -> Result { log::debug!("Computing confirmation tag."); @@ -788,6 +835,7 @@ impl ConfirmationKey { log_crypto!(trace, " transcript hash {:x?}", confirmed_transcript_hash); Ok(ConfirmationTag(Mac::new( crypto, + ciphersuite, &self.secret, confirmed_transcript_hash, )?)) @@ -805,8 +853,7 @@ impl ConfirmationKey { impl ConfirmationKey { pub(crate) fn random(ciphersuite: Ciphersuite, rng: &impl OpenMlsRand) -> Self { Self { - secret: Secret::random(ciphersuite, rng, None /* MLS version */) - .expect("Not enough randomness."), + secret: Secret::random(ciphersuite, rng).expect("Not enough randomness."), } } @@ -824,8 +871,14 @@ pub(crate) struct MembershipKey { impl MembershipKey { /// Derive an `MembershipKey` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { - let secret = epoch_secret.secret.derive_secret(crypto, "membership")?; + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { + let secret = epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "membership")?; Ok(Self { secret }) } @@ -839,11 +892,13 @@ impl MembershipKey { pub(crate) fn tag_message( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, tbm_payload: AuthenticatedContentTbm, ) -> Result { Ok(MembershipTag( Mac::new( crypto, + ciphersuite, &self.secret, &tbm_payload .into_bytes() @@ -866,8 +921,7 @@ impl MembershipKey { #[cfg(any(feature = "test-utils", test))] pub(crate) fn random(ciphersuite: Ciphersuite, rng: &impl OpenMlsRand) -> Self { Self { - secret: Secret::random(ciphersuite, rng, None /* MLS version */) - .expect("Not enough randomness."), + secret: Secret::random(ciphersuite, rng).expect("Not enough randomness."), } } } @@ -886,15 +940,24 @@ fn ciphertext_sample(ciphersuite: Ciphersuite, ciphertext: &[u8]) -> &[u8] { /// A key that can be used to derive an `AeadKey` and an `AeadNonce`. #[derive(Serialize, Deserialize)] #[cfg_attr(test, derive(PartialEq))] -#[cfg_attr(any(feature = "test-utils", test), derive(Debug, Clone))] +#[cfg_attr( + any(feature = "test-utils", feature = "crypto-debug", test), + derive(Debug, Clone) +)] pub(crate) struct SenderDataSecret { secret: Secret, } impl SenderDataSecret { /// Derive an `ExporterSecret` from an `EpochSecret`. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: &EpochSecret) -> Result { - let secret = epoch_secret.secret.derive_secret(crypto, "sender data")?; + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: &EpochSecret, + ) -> Result { + let secret = epoch_secret + .secret + .derive_secret(crypto, ciphersuite, "sender data")?; Ok(SenderDataSecret { secret }) } @@ -902,20 +965,22 @@ impl SenderDataSecret { pub(crate) fn derive_aead_key( &self, crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, ciphertext: &[u8], ) -> Result { - let ciphertext_sample = ciphertext_sample(self.secret.ciphersuite(), ciphertext); + let ciphertext_sample = ciphertext_sample(ciphersuite, ciphertext); log::debug!( "SenderDataSecret::derive_aead_key ciphertext sample: {:x?}", ciphertext_sample ); let secret = self.secret.kdf_expand_label( crypto, + ciphersuite, "key", ciphertext_sample, - self.secret.ciphersuite().aead_key_length(), + ciphersuite.aead_key_length(), )?; - Ok(AeadKey::from_secret(secret)) + Ok(AeadKey::from_secret(secret, ciphersuite)) } /// Derive a new AEAD nonce from a `SenderDataSecret`. @@ -932,6 +997,7 @@ impl SenderDataSecret { ); let nonce_secret = self.secret.kdf_expand_label( crypto, + ciphersuite, "nonce", ciphertext_sample, ciphersuite.aead_nonce_length(), @@ -942,8 +1008,7 @@ impl SenderDataSecret { #[cfg(any(feature = "test-utils", test))] pub(crate) fn random(ciphersuite: Ciphersuite, rng: &impl OpenMlsRand) -> Self { Self { - secret: Secret::random(ciphersuite, rng, None /* MLS version */) - .expect("Not enough randomness."), + secret: Secret::random(ciphersuite, rng).expect("Not enough randomness."), } } @@ -954,13 +1019,9 @@ impl SenderDataSecret { #[cfg(any(feature = "test-utils", test))] /// Create a new secret from a byte vector. - pub(crate) fn from_slice( - bytes: &[u8], - mls_version: ProtocolVersion, - ciphersuite: Ciphersuite, - ) -> Self { + pub(crate) fn from_slice(bytes: &[u8]) -> Self { Self { - secret: Secret::from_slice(bytes, mls_version, ciphersuite), + secret: Secret::from_slice(bytes), } } } @@ -1075,27 +1136,31 @@ impl EpochSecrets { /// Derive `EpochSecrets` from an `EpochSecret`. /// If the `with_init_secret` argument is `true`, the init secret is derived and /// part of the `EpochSecrets`. Otherwise not. - fn new(crypto: &impl OpenMlsCrypto, epoch_secret: EpochSecret) -> Result { + fn new( + crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, + epoch_secret: EpochSecret, + ) -> Result { log::debug!( "Computing EpochSecrets from epoch secret with {}", - epoch_secret.secret.ciphersuite() + ciphersuite ); log_crypto!( trace, " epoch_secret: {:x?}", epoch_secret.secret.as_slice() ); - let sender_data_secret = SenderDataSecret::new(crypto, &epoch_secret)?; - let encryption_secret = EncryptionSecret::new(crypto, &epoch_secret)?; - let exporter_secret = ExporterSecret::new(crypto, &epoch_secret)?; - let epoch_authenticator = EpochAuthenticator::new(crypto, &epoch_secret)?; - let external_secret = ExternalSecret::new(crypto, &epoch_secret)?; - let confirmation_key = ConfirmationKey::new(crypto, &epoch_secret)?; - let membership_key = MembershipKey::new(crypto, &epoch_secret)?; - let resumption_psk = ResumptionPskSecret::new(crypto, &epoch_secret)?; + let sender_data_secret = SenderDataSecret::new(crypto, ciphersuite, &epoch_secret)?; + let encryption_secret = EncryptionSecret::new(crypto, ciphersuite, &epoch_secret)?; + let exporter_secret = ExporterSecret::new(crypto, ciphersuite, &epoch_secret)?; + let epoch_authenticator = EpochAuthenticator::new(crypto, ciphersuite, &epoch_secret)?; + let external_secret = ExternalSecret::new(crypto, ciphersuite, &epoch_secret)?; + let confirmation_key = ConfirmationKey::new(crypto, ciphersuite, &epoch_secret)?; + let membership_key = MembershipKey::new(crypto, ciphersuite, &epoch_secret)?; + let resumption_psk = ResumptionPskSecret::new(crypto, ciphersuite, &epoch_secret)?; log::trace!(" Computing init secret."); - let init_secret = InitSecret::new(crypto, epoch_secret)?; + let init_secret = InitSecret::new(crypto, ciphersuite, epoch_secret)?; Ok(EpochSecrets { init_secret, @@ -1116,15 +1181,13 @@ impl EpochSecrets { /// external init. pub(crate) fn with_init_secret( crypto: &impl OpenMlsCrypto, + ciphersuite: Ciphersuite, init_secret: InitSecret, ) -> Result { let epoch_secret = EpochSecret { - secret: Secret::zero( - init_secret.secret.ciphersuite(), - init_secret.secret.version(), - ), + secret: Secret::zero(ciphersuite), }; - let mut epoch_secrets = Self::new(crypto, epoch_secret)?; + let mut epoch_secrets = Self::new(crypto, ciphersuite, epoch_secret)?; epoch_secrets.init_secret = init_secret; Ok(epoch_secrets) } diff --git a/openmls/src/schedule/psk.rs b/openmls/src/schedule/psk.rs index 16671eb202..a725bb8285 100644 --- a/openmls/src/schedule/psk.rs +++ b/openmls/src/schedule/psk.rs @@ -1,10 +1,6 @@ //! # Preshared keys. -use openmls_traits::{ - key_store::{MlsEntity, MlsEntityId, OpenMlsKeyStore}, - random::OpenMlsRand, - OpenMlsProvider, -}; +use openmls_traits::{random::OpenMlsRand, storage::StorageProvider as StorageProviderTrait}; use serde::{Deserialize, Serialize}; use tls_codec::{Serialize as TlsSerializeTrait, VLBytes}; @@ -12,6 +8,7 @@ use super::*; use crate::{ group::{GroupEpoch, GroupId}, schedule::psk::store::ResumptionPskStore, + storage::{OpenMlsProvider, StorageProvider}, }; /// Resumption PSK usage. @@ -97,10 +94,6 @@ pub(crate) struct PskBundle { secret: Secret, } -impl MlsEntity for PskBundle { - const ID: MlsEntityId = MlsEntityId::PskBundle; -} - /// Resumption PSK. #[derive( Clone, @@ -284,40 +277,21 @@ impl PreSharedKeyId { /// Save this `PreSharedKeyId` in the keystore. /// /// Note: The nonce is not saved as it must be unique for each time it's being applied. - pub fn write_to_key_store( + pub fn store( &self, - provider: &impl OpenMlsProvider, - ciphersuite: Ciphersuite, + provider: &Provider, psk: &[u8], ) -> Result<(), PskError> { - let keystore_id = self.keystore_id()?; - let psk_bundle = { - let secret = Secret::from_slice(psk, ProtocolVersion::default(), ciphersuite); + let secret = Secret::from_slice(psk); PskBundle { secret } }; provider - .key_store() - .store(&keystore_id, &psk_bundle) - .map_err(|_| PskError::KeyStore) - } - - pub(crate) fn keystore_id(&self) -> Result, LibraryError> { - let psk_id_with_empty_nonce = PreSharedKeyId { - psk: self.psk.clone(), - psk_nonce: VLBytes::new(vec![]), - }; - - log::trace!( - "keystore id: {:x?}", - psk_id_with_empty_nonce.tls_serialize_detached() - ); - - psk_id_with_empty_nonce - .tls_serialize_detached() - .map_err(LibraryError::missing_bound_check) + .storage() + .write_psk(&self.psk, &psk_bundle) + .map_err(|_| PskError::Storage) } // ----- Validation ---------------------------------------------------------------------------- @@ -414,19 +388,17 @@ impl PskSecret { // Check that we don't have too many PSKs let num_psks = u16::try_from(psks.len()).map_err(|_| PskError::TooManyKeys)?; - let mls_version = ProtocolVersion::default(); - // Following comments are from `draft-ietf-mls-protocol-19`. // // psk_secret_[0] = 0 - let mut psk_secret = Secret::zero(ciphersuite, mls_version); + let mut psk_secret = Secret::zero(ciphersuite); for (index, (psk_id, psk)) in psks.into_iter().enumerate() { // psk_extracted_[i] = KDF.Extract(0, psk_[i]) let psk_extracted = { - let zero_secret = Secret::zero(ciphersuite, mls_version); + let zero_secret = Secret::zero(ciphersuite); zero_secret - .hkdf_extract(crypto, &psk) + .hkdf_extract(crypto, ciphersuite, &psk) .map_err(LibraryError::unexpected_crypto_error)? }; @@ -437,13 +409,19 @@ impl PskSecret { .map_err(LibraryError::missing_bound_check)?; psk_extracted - .kdf_expand_label(crypto, "derived psk", &psk_label, ciphersuite.hash_length()) + .kdf_expand_label( + crypto, + ciphersuite, + "derived psk", + &psk_label, + ciphersuite.hash_length(), + ) .map_err(LibraryError::unexpected_crypto_error)? }; // psk_secret_[i] = KDF.Extract(psk_input_[i-1], psk_secret_[i-1]) psk_secret = psk_input - .hkdf_extract(crypto, &psk_secret) + .hkdf_extract(crypto, ciphersuite, &psk_secret) .map_err(LibraryError::unexpected_crypto_error)?; } @@ -455,7 +433,7 @@ impl PskSecret { &self.secret } - #[cfg(any(feature = "test-utils", test))] + #[cfg(any(feature = "test-utils", feature = "crypto-debug", test))] pub(crate) fn as_slice(&self) -> &[u8] { self.secret.as_slice() } @@ -468,8 +446,8 @@ impl From for PskSecret { } } -pub(crate) fn load_psks<'p>( - key_store: &impl OpenMlsKeyStore, +pub(crate) fn load_psks<'p, Storage: StorageProvider>( + storage: &Storage, resumption_psk_store: &ResumptionPskStore, psk_ids: &'p [PreSharedKeyId], ) -> Result, PskError> { @@ -487,7 +465,10 @@ pub(crate) fn load_psks<'p>( } } Psk::External(_) => { - if let Some(psk_bundle) = key_store.read::(&psk_id.keystore_id()?) { + let psk_bundle: Option = storage + .psk(psk_id.psk()) + .map_err(|_| PskError::KeyNotFound)?; + if let Some(psk_bundle) = psk_bundle { psk_bundles.push((psk_id, psk_bundle.secret)); } else { return Err(PskError::KeyNotFound); diff --git a/openmls/src/schedule/unit_tests.rs b/openmls/src/schedule/unit_tests.rs index 8931eee112..64305a899d 100644 --- a/openmls/src/schedule/unit_tests.rs +++ b/openmls/src/schedule/unit_tests.rs @@ -1,18 +1,15 @@ //! Key Schedule Unit Tests -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{random::OpenMlsRand, OpenMlsProvider}; use super::PskSecret; use crate::{ ciphersuite::Secret, schedule::psk::{store::ResumptionPskStore, *}, - test_utils::*, - versions::ProtocolVersion, }; -#[apply(ciphersuites_and_providers)] -fn test_psks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_psks() { // Create a new PSK secret from multiple PSKs. let prng = provider.rand(); @@ -29,24 +26,16 @@ fn test_psks(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .collect::>(); for (secret, psk_id) in (0..33) - .map(|_| { - Secret::from_slice( - &prng.random_vec(55).expect("An unexpected error occurred."), - ProtocolVersion::Mls10, - ciphersuite, - ) - }) + .map(|_| Secret::from_slice(&prng.random_vec(55).expect("An unexpected error occurred."))) .zip(psk_ids.clone()) { - psk_id - .write_to_key_store(provider, ciphersuite, secret.as_slice()) - .unwrap(); + psk_id.store(provider, secret.as_slice()).unwrap(); } let _psk_secret = { let resumption_psk_store = ResumptionPskStore::new(1024); - let psks = load_psks(provider.key_store(), &resumption_psk_store, &psk_ids).unwrap(); + let psks = load_psks(provider.storage(), &resumption_psk_store, &psk_ids).unwrap(); PskSecret::new(provider.crypto(), ciphersuite, psks).unwrap() }; diff --git a/openmls/src/storage.rs b/openmls/src/storage.rs new file mode 100644 index 0000000000..854a7a666e --- /dev/null +++ b/openmls/src/storage.rs @@ -0,0 +1,242 @@ +//! OpenMLS Storage +//! +//! This module serves two purposes: +//! +//! - It implements the Key, Entity and type traits from `openmls_traits::storage::traits`. +//! - It defines traits that specialize the Storage and Provider traits from `openmls_traits`. +//! This way, the Rust compiler knows that the concrete types match when we use the Provider in +//! the code. + +use openmls_traits::storage::{traits, Entity, Key, CURRENT_VERSION}; + +use crate::binary_tree::LeafNodeIndex; +use crate::group::{MlsGroupJoinConfig, MlsGroupState}; +use crate::{ + ciphersuite::hash_ref::ProposalRef, + group::{GroupContext, GroupId, InterimTranscriptHash, QueuedProposal}, + messages::ConfirmationTag, + treesync::{LeafNode, TreeSync}, +}; +use crate::{ + group::{past_secrets::MessageSecretsStore, GroupEpoch}, + prelude::KeyPackageBundle, + schedule::{ + psk::{store::ResumptionPskStore, PskBundle}, + GroupEpochSecrets, Psk, + }, + treesync::{node::encryption_keys::EncryptionKeyPair, EncryptionKey}, +}; + +/// A convenience trait for the current version of the storage. +/// Throughout the code, this one should be used instead of `openmls_traits::storage::StorageProvider`. +pub trait StorageProvider: openmls_traits::storage::StorageProvider {} + +impl> StorageProvider for P {} + +/// A convenience trait for the OpenMLS provider that defines the storage provider +/// for the current version of storage. +/// Throughout the code, this one should be used instead of `openmls_traits::OpenMlsProvider`. +pub trait OpenMlsProvider: + openmls_traits::OpenMlsProvider +{ + /// The storage to use + type Storage: StorageProvider; + /// The storage error type + type StorageError: std::error::Error; +} + +impl< + Error: std::error::Error + PartialEq, + SP: StorageProvider, + OP: openmls_traits::OpenMlsProvider, + > OpenMlsProvider for OP +{ + type Storage = SP; + type StorageError = Error; +} + +// Implementations for the Entity and Key traits + +impl Entity for QueuedProposal {} +impl traits::QueuedProposal for QueuedProposal {} + +impl Entity for TreeSync {} +impl traits::TreeSync for TreeSync {} + +impl Key for GroupId {} +impl traits::GroupId for GroupId {} + +impl Key for ProposalRef {} +impl Entity for ProposalRef {} +impl traits::ProposalRef for ProposalRef {} +impl traits::HashReference for ProposalRef {} + +impl Entity for GroupContext {} +impl traits::GroupContext for GroupContext {} + +impl Entity for InterimTranscriptHash {} +impl traits::InterimTranscriptHash for InterimTranscriptHash {} + +impl Entity for ConfirmationTag {} +impl traits::ConfirmationTag for ConfirmationTag {} + +impl Entity for KeyPackageBundle {} +impl traits::KeyPackage for KeyPackageBundle {} + +impl Key for EncryptionKey {} +impl traits::EncryptionKey for EncryptionKey {} + +impl Entity for EncryptionKeyPair {} +impl traits::HpkeKeyPair for EncryptionKeyPair {} + +impl Entity for LeafNodeIndex {} +impl traits::LeafNodeIndex for LeafNodeIndex {} + +impl Entity for GroupEpochSecrets {} +impl traits::GroupEpochSecrets for GroupEpochSecrets {} + +impl Entity for MessageSecretsStore {} +impl traits::MessageSecrets for MessageSecretsStore {} + +impl Entity for ResumptionPskStore {} +impl traits::ResumptionPskStore for ResumptionPskStore {} + +impl Entity for MlsGroupJoinConfig {} +impl traits::MlsGroupJoinConfig for MlsGroupJoinConfig {} + +impl Entity for MlsGroupState {} +impl traits::GroupState for MlsGroupState {} + +impl Entity for LeafNode {} +impl traits::LeafNode for LeafNode {} + +// Crypto + +impl Key for GroupEpoch {} +impl traits::EpochKey for GroupEpoch {} + +impl Key for Psk {} +impl traits::PskId for Psk {} + +impl Entity for PskBundle {} +impl traits::PskBundle for PskBundle {} + +#[cfg(test)] +mod test { + use crate::{group::test_core_group::setup_client, prelude::KeyPackageBuilder}; + + use super::*; + + use openmls_rust_crypto::{MemoryStorage, OpenMlsRustCrypto}; + use openmls_traits::{ + storage::{traits as type_traits, StorageProvider, V_TEST}, + types::{Ciphersuite, HpkePrivateKey}, + OpenMlsProvider, + }; + use serde::{Deserialize, Serialize}; + + // Test upgrade path + // Assume we have a new key package bundle representation. + #[derive(Serialize, Deserialize)] + struct NewKeyPackageBundle { + ciphersuite: Ciphersuite, + key_package: crate::key_packages::KeyPackage, + private_init_key: HpkePrivateKey, + private_encryption_key: crate::treesync::node::encryption_keys::EncryptionPrivateKey, + } + + impl Entity for NewKeyPackageBundle {} + impl type_traits::KeyPackage for NewKeyPackageBundle {} + + impl Key for EncryptionKey {} + impl type_traits::EncryptionKey for EncryptionKey {} + + impl Entity for EncryptionKeyPair {} + impl type_traits::HpkeKeyPair for EncryptionKeyPair {} + + impl Key for ProposalRef {} + impl type_traits::HashReference for ProposalRef {} + + #[test] + fn key_packages_key_upgrade() { + // Store an old version + let provider = OpenMlsRustCrypto::default(); + + let (credential_with_key, _kpb, signer, _pk) = setup_client( + "Alice", + Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, + &provider, + ); + + // build and store key package bundle + let key_package_bundle = KeyPackageBuilder::new() + .build( + Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, + &provider, + &signer, + credential_with_key, + ) + .unwrap(); + + let key_package = key_package_bundle.key_package(); + let key_package_ref = key_package.hash_ref(provider.crypto()).unwrap(); + + // TODO #1566: Serialize the old storage. This should become a kat test file + + // ---- migration starts here ---- + let new_storage_provider = MemoryStorage::default(); + + // first, read the old data + let read_key_package_bundle: crate::prelude::KeyPackageBundle = + >::key_package( + provider.storage(), + &key_package_ref, + ) + .unwrap() + .unwrap(); + + // then, build the new data from the old data + let new_key_package_bundle = NewKeyPackageBundle { + ciphersuite: read_key_package_bundle.key_package().ciphersuite(), + key_package: read_key_package_bundle.key_package().clone(), + private_init_key: read_key_package_bundle.init_private_key().clone(), + private_encryption_key: read_key_package_bundle.private_encryption_key.clone(), + }; + + // insert the data in the new format + >::write_key_package( + &new_storage_provider, + &key_package_ref, + &new_key_package_bundle, + ) + .unwrap(); + + // read the new value from storage + let read_new_key_package_bundle: NewKeyPackageBundle = + >::key_package( + &new_storage_provider, + &key_package_ref, + ) + .unwrap() + .unwrap(); + + // compare it to the old_storage + + assert_eq!( + &read_new_key_package_bundle.key_package, + key_package_bundle.key_package() + ); + assert_eq!( + read_new_key_package_bundle.ciphersuite, + key_package_bundle.key_package().ciphersuite() + ); + assert_eq!( + &read_new_key_package_bundle.private_encryption_key, + &key_package_bundle.private_encryption_key + ); + assert_eq!( + &read_new_key_package_bundle.private_init_key, + &key_package_bundle.private_init_key + ); + } +} diff --git a/openmls/src/test_utils/frankenstein/codec.rs b/openmls/src/test_utils/frankenstein/codec.rs new file mode 100644 index 0000000000..428c4e63a3 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/codec.rs @@ -0,0 +1,279 @@ +use std::io::{Read, Write}; + +use tls_codec::*; + +use super::{ + extensions::{ + FrankenApplicationIdExtension, FrankenExtension, FrankenExtensionType, + FrankenExternalPubExtension, FrankenExternalSendersExtension, FrankenRatchetTreeExtension, + FrankenRequiredCapabilitiesExtension, + }, + FrankenAddProposal, FrankenAppAckProposal, FrankenCustomProposal, FrankenExternalInitProposal, + FrankenPreSharedKeyProposal, FrankenProposal, FrankenProposalType, FrankenReInitProposal, + FrankenRemoveProposal, FrankenUpdateProposal, +}; + +fn vlbytes_len_len(length: usize) -> usize { + if length < 0x40 { + 1 + } else if length < 0x3fff { + 2 + } else if length < 0x3fff_ffff { + 4 + } else { + 8 + } +} + +impl Size for FrankenProposalType { + fn tls_serialized_len(&self) -> usize { + 2 + } +} + +impl Deserialize for FrankenProposalType { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + let mut proposal_type = [0u8; 2]; + bytes.read_exact(&mut proposal_type)?; + + Ok(FrankenProposalType::from(u16::from_be_bytes(proposal_type))) + } +} + +impl Serialize for FrankenProposalType { + fn tls_serialize(&self, writer: &mut W) -> Result { + writer.write_all(&u16::from(*self).to_be_bytes())?; + + Ok(2) + } +} + +impl DeserializeBytes for FrankenProposalType { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> + where + Self: Sized, + { + let mut bytes_ref = bytes; + let proposal_type = FrankenProposalType::tls_deserialize(&mut bytes_ref)?; + let remainder = &bytes[proposal_type.tls_serialized_len()..]; + Ok((proposal_type, remainder)) + } +} + +impl Size for FrankenProposal { + fn tls_serialized_len(&self) -> usize { + self.proposal_type().tls_serialized_len() + + match self { + FrankenProposal::Add(p) => p.tls_serialized_len(), + FrankenProposal::Update(p) => p.tls_serialized_len(), + FrankenProposal::Remove(p) => p.tls_serialized_len(), + FrankenProposal::PreSharedKey(p) => p.tls_serialized_len(), + FrankenProposal::ReInit(p) => p.tls_serialized_len(), + FrankenProposal::ExternalInit(p) => p.tls_serialized_len(), + FrankenProposal::GroupContextExtensions(p) => p.tls_serialized_len(), + FrankenProposal::AppAck(p) => p.tls_serialized_len(), + FrankenProposal::Custom(p) => p.tls_serialized_len(), + } + } +} + +impl Serialize for FrankenProposal { + fn tls_serialize(&self, writer: &mut W) -> Result { + let written = self.proposal_type().tls_serialize(writer)?; + match self { + FrankenProposal::Add(p) => p.tls_serialize(writer), + FrankenProposal::Update(p) => p.tls_serialize(writer), + FrankenProposal::Remove(p) => p.tls_serialize(writer), + FrankenProposal::PreSharedKey(p) => p.tls_serialize(writer), + FrankenProposal::ReInit(p) => p.tls_serialize(writer), + FrankenProposal::ExternalInit(p) => p.tls_serialize(writer), + FrankenProposal::GroupContextExtensions(p) => p.tls_serialize(writer), + FrankenProposal::AppAck(p) => p.tls_serialize(writer), + FrankenProposal::Custom(p) => p.payload.tls_serialize(writer), + } + .map(|l| written + l) + } +} + +impl Deserialize for FrankenProposal { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + let proposal_type = FrankenProposalType::tls_deserialize(bytes)?; + let proposal = match proposal_type { + FrankenProposalType::Add => { + FrankenProposal::Add(FrankenAddProposal::tls_deserialize(bytes)?) + } + FrankenProposalType::Update => { + FrankenProposal::Update(FrankenUpdateProposal::tls_deserialize(bytes)?) + } + FrankenProposalType::Remove => { + FrankenProposal::Remove(FrankenRemoveProposal::tls_deserialize(bytes)?) + } + FrankenProposalType::PreSharedKey => { + FrankenProposal::PreSharedKey(FrankenPreSharedKeyProposal::tls_deserialize(bytes)?) + } + FrankenProposalType::Reinit => { + FrankenProposal::ReInit(FrankenReInitProposal::tls_deserialize(bytes)?) + } + FrankenProposalType::ExternalInit => { + FrankenProposal::ExternalInit(FrankenExternalInitProposal::tls_deserialize(bytes)?) + } + FrankenProposalType::GroupContextExtensions => FrankenProposal::GroupContextExtensions( + Vec::::tls_deserialize(bytes)?, + ), + FrankenProposalType::AppAck => { + FrankenProposal::AppAck(FrankenAppAckProposal::tls_deserialize(bytes)?) + } + FrankenProposalType::Custom(_) => { + let payload = VLBytes::tls_deserialize(bytes)?; + let custom_proposal = FrankenCustomProposal { + proposal_type: proposal_type.into(), + payload, + }; + FrankenProposal::Custom(custom_proposal) + } + }; + Ok(proposal) + } +} + +impl DeserializeBytes for FrankenProposal { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> + where + Self: Sized, + { + let mut bytes_ref = bytes; + let proposal = FrankenProposal::tls_deserialize(&mut bytes_ref)?; + let remainder = &bytes[proposal.tls_serialized_len()..]; + Ok((proposal, remainder)) + } +} + +impl Size for FrankenExtensionType { + fn tls_serialized_len(&self) -> usize { + 2 + } +} + +impl Deserialize for FrankenExtensionType { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + let mut extension_type = [0u8; 2]; + bytes.read_exact(&mut extension_type)?; + + Ok(FrankenExtensionType::from(u16::from_be_bytes( + extension_type, + ))) + } +} + +impl DeserializeBytes for FrankenExtensionType { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> + where + Self: Sized, + { + let mut bytes_ref = bytes; + let extension_type = FrankenExtensionType::tls_deserialize(&mut bytes_ref)?; + let remainder = &bytes[extension_type.tls_serialized_len()..]; + Ok((extension_type, remainder)) + } +} + +impl Serialize for FrankenExtensionType { + fn tls_serialize(&self, writer: &mut W) -> Result { + writer.write_all(&u16::from(*self).to_be_bytes())?; + + Ok(2) + } +} + +impl Size for FrankenExtension { + fn tls_serialized_len(&self) -> usize { + let extension_type_length = 2; + let extension_data_len = match self { + FrankenExtension::ApplicationId(e) => e.tls_serialized_len(), + FrankenExtension::RatchetTree(e) => e.tls_serialized_len(), + FrankenExtension::RequiredCapabilities(e) => e.tls_serialized_len(), + FrankenExtension::ExternalPub(e) => e.tls_serialized_len(), + FrankenExtension::ExternalSenders(e) => e.tls_serialized_len(), + FrankenExtension::LastResort => 0, + FrankenExtension::Unknown(_, e) => e.tls_serialized_len(), + }; + let vlbytes_len_len = vlbytes_len_len(extension_data_len); + extension_type_length + vlbytes_len_len + extension_data_len + } +} + +impl Serialize for FrankenExtension { + fn tls_serialize(&self, writer: &mut W) -> Result { + let written = self.extension_type().tls_serialize(writer)?; + let extension_data_len = self.tls_serialized_len(); + let mut extension_data = Vec::with_capacity(extension_data_len); + + let _ = match self { + FrankenExtension::ApplicationId(e) => e.tls_serialize(&mut extension_data), + FrankenExtension::RatchetTree(e) => e.tls_serialize(&mut extension_data), + FrankenExtension::RequiredCapabilities(e) => e.tls_serialize(&mut extension_data), + FrankenExtension::ExternalPub(e) => e.tls_serialize(&mut extension_data), + FrankenExtension::ExternalSenders(e) => e.tls_serialize(&mut extension_data), + FrankenExtension::LastResort => Ok(0), + FrankenExtension::Unknown(_, e) => extension_data + .write_all(e.as_slice()) + .map(|_| e.tls_serialized_len()) + .map_err(|_| tls_codec::Error::EndOfStream), + }?; + + Serialize::tls_serialize(&extension_data, writer).map(|l| l + written) + } +} + +impl Deserialize for FrankenExtension { + fn tls_deserialize(bytes: &mut R) -> Result { + // Read the extension type and extension data. + let extension_type = FrankenExtensionType::tls_deserialize(bytes)?; + let extension_data = VLBytes::tls_deserialize(bytes)?; + + // Now deserialize the extension itself from the extension data. + let mut extension_data = extension_data.as_slice(); + Ok(match extension_type { + FrankenExtensionType::ApplicationId => FrankenExtension::ApplicationId( + FrankenApplicationIdExtension::tls_deserialize(&mut extension_data)?, + ), + FrankenExtensionType::RatchetTree => FrankenExtension::RatchetTree( + FrankenRatchetTreeExtension::tls_deserialize(&mut extension_data)?, + ), + FrankenExtensionType::RequiredCapabilities => FrankenExtension::RequiredCapabilities( + FrankenRequiredCapabilitiesExtension::tls_deserialize(&mut extension_data)?, + ), + FrankenExtensionType::ExternalPub => FrankenExtension::ExternalPub( + FrankenExternalPubExtension::tls_deserialize(&mut extension_data)?, + ), + FrankenExtensionType::ExternalSenders => FrankenExtension::ExternalSenders( + FrankenExternalSendersExtension::tls_deserialize(&mut extension_data)?, + ), + FrankenExtensionType::LastResort => FrankenExtension::LastResort, + FrankenExtensionType::Unknown(unknown) => { + FrankenExtension::Unknown(unknown, extension_data.to_vec().into()) + } + }) + } +} + +impl DeserializeBytes for FrankenExtension { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> + where + Self: Sized, + { + let mut bytes_ref = bytes; + let extension = FrankenExtension::tls_deserialize(&mut bytes_ref)?; + let remainder = &bytes[extension.tls_serialized_len()..]; + Ok((extension, remainder)) + } +} diff --git a/openmls/src/test_utils/frankenstein/commit.rs b/openmls/src/test_utils/frankenstein/commit.rs new file mode 100644 index 0000000000..c635a0dfd9 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/commit.rs @@ -0,0 +1,45 @@ +use tls_codec::*; + +use super::{FrankenLeafNode, FrankenProposal}; + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenCommit { + pub proposals: Vec, + pub path: Option, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenProposalOrRef { + #[tls_codec(discriminant = 1)] + Proposal(FrankenProposal), + Reference(VLBytes), +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenUpdatePathIn { + pub leaf_node: FrankenLeafNode, + pub nodes: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenUpdatePathNode { + pub(super) public_key: VLBytes, + pub(super) encrypted_path_secrets: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenHpkeCiphertext { + pub kem_output: VLBytes, + pub ciphertext: VLBytes, +} diff --git a/openmls/src/test_utils/frankenstein/credentials.rs b/openmls/src/test_utils/frankenstein/credentials.rs new file mode 100644 index 0000000000..6c672209e3 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/credentials.rs @@ -0,0 +1,9 @@ +use tls_codec::*; + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenCredential { + credential_type: u16, + serialized_credential_content: VLBytes, +} diff --git a/openmls/src/test_utils/frankenstein/extensions.rs b/openmls/src/test_utils/frankenstein/extensions.rs new file mode 100644 index 0000000000..146cf6e5a1 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/extensions.rs @@ -0,0 +1,133 @@ +use tls_codec::*; + +use super::{FrankenCredential, FrankenLeafNode}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FrankenExtensionType { + ApplicationId, + RatchetTree, + RequiredCapabilities, + ExternalPub, + ExternalSenders, + LastResort, + Unknown(u16), +} + +impl From for FrankenExtensionType { + fn from(a: u16) -> Self { + match a { + 1 => FrankenExtensionType::ApplicationId, + 2 => FrankenExtensionType::RatchetTree, + 3 => FrankenExtensionType::RequiredCapabilities, + 4 => FrankenExtensionType::ExternalPub, + 5 => FrankenExtensionType::ExternalSenders, + 10 => FrankenExtensionType::LastResort, + unknown => FrankenExtensionType::Unknown(unknown), + } + } +} + +impl From for u16 { + fn from(value: FrankenExtensionType) -> Self { + match value { + FrankenExtensionType::ApplicationId => 1, + FrankenExtensionType::RatchetTree => 2, + FrankenExtensionType::RequiredCapabilities => 3, + FrankenExtensionType::ExternalPub => 4, + FrankenExtensionType::ExternalSenders => 5, + FrankenExtensionType::LastResort => 10, + FrankenExtensionType::Unknown(unknown) => unknown, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[repr(u16)] +pub enum FrankenExtension { + ApplicationId(FrankenApplicationIdExtension), + RatchetTree(FrankenRatchetTreeExtension), + RequiredCapabilities(FrankenRequiredCapabilitiesExtension), + ExternalPub(FrankenExternalPubExtension), + ExternalSenders(FrankenExternalSendersExtension), + LastResort, + Unknown(u16, VLBytes), +} + +impl FrankenExtension { + pub const fn extension_type(&self) -> FrankenExtensionType { + match self { + FrankenExtension::ApplicationId(_) => FrankenExtensionType::ApplicationId, + FrankenExtension::RatchetTree(_) => FrankenExtensionType::RatchetTree, + FrankenExtension::RequiredCapabilities(_) => FrankenExtensionType::RequiredCapabilities, + FrankenExtension::ExternalPub(_) => FrankenExtensionType::ExternalPub, + FrankenExtension::ExternalSenders(_) => FrankenExtensionType::ExternalSenders, + FrankenExtension::LastResort => FrankenExtensionType::LastResort, + FrankenExtension::Unknown(kind, _) => FrankenExtensionType::Unknown(*kind), + } + } +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenApplicationIdExtension { + pub key_id: VLBytes, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenRatchetTreeExtension { + pub ratchet_tree: Vec>, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenNode { + #[tls_codec(discriminant = 1)] + LeafNode(FrankenLeafNode), + #[tls_codec(discriminant = 2)] + ParentNode(FrankenParentNode), +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenParentNode { + pub encryption_key: VLBytes, + pub parent_hash: VLBytes, + pub unmerged_leaves: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenRequiredCapabilitiesExtension { + pub extension_types: Vec, + pub proposal_types: Vec, + pub credential_types: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenExternalPubExtension { + external_pub: VLBytes, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenExternalSendersExtension { + external_senders: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenExternalSender { + pub signature_key: VLBytes, + pub credential: FrankenCredential, +} diff --git a/openmls/src/test_utils/frankenstein/framing.rs b/openmls/src/test_utils/frankenstein/framing.rs new file mode 100644 index 0000000000..8bc98bd74b --- /dev/null +++ b/openmls/src/test_utils/frankenstein/framing.rs @@ -0,0 +1,189 @@ +use tls_codec::*; + +use crate::{ + framing::{ + MlsMessageIn, MlsMessageOut, PrivateMessage, PrivateMessageIn, PublicMessage, + PublicMessageIn, + }, + messages::Welcome, +}; + +use super::{ + commit::FrankenCommit, group_info::FrankenGroupInfo, FrankenKeyPackage, FrankenProposal, +}; + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenMlsMessage { + pub version: u16, + pub body: FrankenMlsMessageBody, +} + +#[allow(clippy::large_enum_variant)] +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u16)] +pub enum FrankenMlsMessageBody { + #[tls_codec(discriminant = 1)] + PublicMessage(FrankenPublicMessage), + #[tls_codec(discriminant = 2)] + PrivateMessage(FrankenPrivateMessage), + #[tls_codec(discriminant = 3)] + Welcome(FrankenWelcome), + #[tls_codec(discriminant = 4)] + GroupInfo(FrankenGroupInfo), + #[tls_codec(discriminant = 5)] + KeyPackage(FrankenKeyPackage), +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenPublicMessage { + pub content: FrankenFramedContent, + pub auth: FrankenFramedContentAuthData, + pub membership_tag: Option, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenFramedContent { + pub group_id: VLBytes, + pub epoch: u64, + pub sender: FrankenSender, + pub authenticated_data: VLBytes, + pub body: FrankenFramedContentBody, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenSender { + #[tls_codec(discriminant = 1)] + Member(u32), + External(u32), + NewMemberProposal, + NewMemberCommit, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenFramedContentBody { + #[tls_codec(discriminant = 1)] + Application(VLBytes), + #[tls_codec(discriminant = 2)] + Proposal(FrankenProposal), + #[tls_codec(discriminant = 3)] + Commit(FrankenCommit), +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenPrivateMessage { + pub group_id: VLBytes, + pub epoch: VLBytes, + pub content_type: FrankenContentType, + pub authenticated_data: VLBytes, + pub encrypted_sender_data: VLBytes, + pub ciphertext: VLBytes, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenWelcome { + pub cipher_suite: u16, + pub secrets: Vec, + pub encrypted_group_info: VLBytes, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenFramedContentAuthData { + pub signature: VLBytes, + pub confirmation_tag: Option, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenContentType { + Application = 1, + Proposal = 2, + Commit = 3, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenEncryptedGroupSecrets { + pub new_member: VLBytes, + pub encrypted_group_secrets: VLBytes, +} + +impl From for FrankenMlsMessage { + fn from(ln: MlsMessageOut) -> Self { + FrankenMlsMessage::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +impl From for MlsMessageOut { + fn from(fln: FrankenMlsMessage) -> Self { + MlsMessageIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + .into() + } +} + +impl From for FrankenPublicMessage { + fn from(ln: PublicMessage) -> Self { + FrankenPublicMessage::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +impl From for PublicMessage { + fn from(fln: FrankenPublicMessage) -> Self { + PublicMessageIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + .into() + } +} + +impl From for FrankenPrivateMessage { + fn from(ln: PrivateMessage) -> Self { + FrankenPrivateMessage::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +impl From for PrivateMessage { + fn from(fln: FrankenPrivateMessage) -> Self { + PrivateMessageIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + .into() + } +} + +impl From for FrankenWelcome { + fn from(ln: Welcome) -> Self { + FrankenWelcome::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +impl From for Welcome { + fn from(fln: FrankenWelcome) -> Self { + Welcome::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()).unwrap() + } +} diff --git a/openmls/src/test_utils/frankenstein/group_info.rs b/openmls/src/test_utils/frankenstein/group_info.rs new file mode 100644 index 0000000000..1f116e1ea2 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/group_info.rs @@ -0,0 +1,105 @@ +use std::ops::{Deref, DerefMut}; + +use tls_codec::*; + +use openmls_basic_credential::SignatureKeyPair; +use openmls_traits::{signatures::Signer, types::Ciphersuite, OpenMlsProvider}; + +use crate::{ + ciphersuite::{ + signable::{Signable, SignedStruct}, + signature::{OpenMlsSignaturePublicKey, Signature}, + }, + messages::group_info::GroupInfo, +}; + +use super::extensions::FrankenExtension; + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenGroupInfo { + pub payload: FrankenGroupInfoTbs, + pub signature: VLBytes, +} + +impl FrankenGroupInfo { + // Re-sign both the KeyPackage and the enclosed LeafNode + pub fn resign(&mut self, signer: &impl Signer) { + let new_self = self.payload.clone().sign(signer).unwrap(); + let _ = std::mem::replace(self, new_self); + } +} + +impl Deref for FrankenGroupInfo { + type Target = FrankenGroupInfoTbs; + + fn deref(&self) -> &Self::Target { + &self.payload + } +} + +impl DerefMut for FrankenGroupInfo { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.payload + } +} + +impl SignedStruct for FrankenGroupInfo { + fn from_payload(payload: FrankenGroupInfoTbs, signature: Signature) -> Self { + Self { + payload, + signature: signature.as_slice().to_owned().into(), + } + } +} + +const SIGNATURE_GROUP_INFO_LABEL: &str = "GroupInfoTBS"; + +impl Signable for FrankenGroupInfoTbs { + type SignedOutput = FrankenGroupInfo; + + fn unsigned_payload(&self) -> Result, tls_codec::Error> { + self.tls_serialize_detached() + } + + fn label(&self) -> &str { + SIGNATURE_GROUP_INFO_LABEL + } +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenGroupInfoTbs { + pub group_context: FrankenGroupContext, + pub extensions: Vec, + pub confirmation_tag: VLBytes, + pub signer: u32, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenGroupContext { + protocol_version: u16, + ciphersuite: u16, + group_id: VLBytes, + epoch: u64, + tree_hash: VLBytes, + confirmed_transcript_hash: VLBytes, + extensions: Vec, +} + +impl From for FrankenGroupInfo { + fn from(ln: GroupInfo) -> Self { + FrankenGroupInfo::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +impl From for GroupInfo { + fn from(fln: FrankenGroupInfo) -> Self { + GroupInfo::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()).unwrap() + } +} diff --git a/openmls/src/test_utils/frankenstein/key_package.rs b/openmls/src/test_utils/frankenstein/key_package.rs new file mode 100644 index 0000000000..d6a41429f3 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/key_package.rs @@ -0,0 +1,169 @@ +use std::ops::{Deref, DerefMut}; + +use openmls_basic_credential::SignatureKeyPair; +use openmls_test::openmls_test; +use openmls_traits::{signatures::Signer, types::Ciphersuite, OpenMlsProvider}; +use tls_codec::*; + +use super::{extensions::FrankenExtension, leaf_node::FrankenLeafNode}; +use crate::{ + ciphersuite::{ + signable::{Signable, SignedStruct}, + signature::{OpenMlsSignaturePublicKey, Signature}, + }, + credentials::{BasicCredential, CredentialWithKey}, + key_packages::{KeyPackage, KeyPackageIn}, + prelude::KeyPackageBundle, + test_utils::OpenMlsRustCrypto, + versions::ProtocolVersion, +}; + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenKeyPackage { + pub payload: FrankenKeyPackageTbs, + pub signature: VLBytes, +} + +impl FrankenKeyPackage { + // Re-sign both the KeyPackage and the enclosed LeafNode + pub fn resign(&mut self, signer: &impl Signer) { + self.payload.leaf_node.resign(signer); + let new_self = self.payload.clone().sign(signer).unwrap(); + let _ = std::mem::replace(self, new_self); + } + + // Only re-sign the KeyPackage + pub fn resign_only_key_package(&mut self, signer: &impl Signer) { + let new_self = self.payload.clone().sign(signer).unwrap(); + let _ = std::mem::replace(self, new_self); + } +} + +impl Deref for FrankenKeyPackage { + type Target = FrankenKeyPackageTbs; + + fn deref(&self) -> &Self::Target { + &self.payload + } +} + +impl DerefMut for FrankenKeyPackage { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.payload + } +} + +impl SignedStruct for FrankenKeyPackage { + fn from_payload(payload: FrankenKeyPackageTbs, signature: Signature) -> Self { + Self { + payload, + signature: signature.as_slice().to_owned().into(), + } + } +} + +const SIGNATURE_KEY_PACKAGE_LABEL: &str = "KeyPackageTBS"; + +impl Signable for FrankenKeyPackageTbs { + type SignedOutput = FrankenKeyPackage; + + fn unsigned_payload(&self) -> Result, tls_codec::Error> { + self.tls_serialize_detached() + } + + fn label(&self) -> &str { + SIGNATURE_KEY_PACKAGE_LABEL + } +} + +impl From for FrankenKeyPackage { + fn from(kp: KeyPackage) -> Self { + FrankenKeyPackage::tls_deserialize(&mut kp.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +impl From for FrankenKeyPackage { + fn from(kp: KeyPackageBundle) -> Self { + FrankenKeyPackage::tls_deserialize( + &mut kp + .key_package() + .tls_serialize_detached() + .unwrap() + .as_slice(), + ) + .unwrap() + } +} + +impl From for KeyPackage { + fn from(fkp: FrankenKeyPackage) -> Self { + KeyPackageIn::tls_deserialize(&mut fkp.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + .into() + } +} + +impl From for KeyPackageIn { + fn from(fkp: FrankenKeyPackage) -> Self { + KeyPackageIn::tls_deserialize(&mut fkp.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenKeyPackageTbs { + pub protocol_version: u16, + pub ciphersuite: u16, + pub init_key: VLBytes, + pub leaf_node: FrankenLeafNode, + pub extensions: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenLifetime { + pub not_before: u64, + pub not_after: u64, +} + +#[openmls_test] +fn test_franken_key_package() { + let config = ciphersuite; + + let (credential, signer) = { + let credential = BasicCredential::new(b"test identity".to_vec()); + let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); + + (credential, signature_keys) + }; + let signature_key = OpenMlsSignaturePublicKey::new( + signer.to_public_vec().into(), + ciphersuite.signature_algorithm(), + ) + .unwrap(); + + let credential_with_key = CredentialWithKey { + credential: credential.into(), + signature_key: signature_key.into(), + }; + + let kp = KeyPackage::builder() + .build(config, provider, &signer, credential_with_key) + .unwrap(); + + let ser = kp.key_package().tls_serialize_detached().unwrap(); + let fkp = FrankenKeyPackage::tls_deserialize(&mut ser.as_slice()).unwrap(); + + let ser2 = fkp.tls_serialize_detached().unwrap(); + assert_eq!(ser, ser2); + + let kp2 = KeyPackage::from(fkp.clone()); + assert_eq!(kp.key_package(), &kp2); +} diff --git a/openmls/src/test_utils/frankenstein/leaf_node.rs b/openmls/src/test_utils/frankenstein/leaf_node.rs new file mode 100644 index 0000000000..1862073335 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/leaf_node.rs @@ -0,0 +1,122 @@ +use std::ops::{Deref, DerefMut}; + +use openmls_basic_credential::SignatureKeyPair; +use openmls_traits::{signatures::Signer, types::Ciphersuite, OpenMlsProvider}; +use tls_codec::*; + +use super::{extensions::FrankenExtension, key_package::FrankenLifetime, FrankenCredential}; +use crate::{ + ciphersuite::{ + signable::{Signable, SignedStruct}, + signature::Signature, + }, + treesync::{node::leaf_node::LeafNodeIn, LeafNode}, +}; + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenLeafNode { + pub payload: FrankenLeafNodeTbs, + pub signature: VLBytes, +} + +impl FrankenLeafNode { + // Re-sign the LeafNode + pub fn resign(&mut self, signer: &impl Signer) { + let new_self = self.payload.clone().sign(signer).unwrap(); + let _ = std::mem::replace(self, new_self); + } +} + +impl Deref for FrankenLeafNode { + type Target = FrankenLeafNodeTbs; + + fn deref(&self) -> &Self::Target { + &self.payload + } +} + +impl DerefMut for FrankenLeafNode { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.payload + } +} + +impl SignedStruct for FrankenLeafNode { + fn from_payload(payload: FrankenLeafNodeTbs, signature: Signature) -> Self { + Self { + payload, + signature: signature.as_slice().to_owned().into(), + } + } +} + +const LEAF_NODE_SIGNATURE_LABEL: &str = "LeafNodeTBS"; + +impl Signable for FrankenLeafNodeTbs { + type SignedOutput = FrankenLeafNode; + + fn unsigned_payload(&self) -> Result, tls_codec::Error> { + self.tls_serialize_detached() + } + + fn label(&self) -> &str { + LEAF_NODE_SIGNATURE_LABEL + } +} + +impl From for FrankenLeafNode { + fn from(ln: LeafNode) -> Self { + FrankenLeafNode::tls_deserialize(&mut ln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + } +} + +impl From for LeafNode { + fn from(fln: FrankenLeafNode) -> Self { + LeafNodeIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()) + .unwrap() + .into() + } +} + +impl From for LeafNodeIn { + fn from(fln: FrankenLeafNode) -> Self { + LeafNodeIn::tls_deserialize(&mut fln.tls_serialize_detached().unwrap().as_slice()).unwrap() + } +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenLeafNodeTbs { + pub encryption_key: VLBytes, + pub signature_key: VLBytes, + pub credential: FrankenCredential, + pub capabilities: FrankenCapabilities, + pub leaf_node_source: FrankenLeafNodeSource, + pub extensions: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenCapabilities { + pub versions: Vec, + pub ciphersuites: Vec, + pub extensions: Vec, + pub proposals: Vec, + pub credentials: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenLeafNodeSource { + #[tls_codec(discriminant = 1)] + KeyPackage(FrankenLifetime), + Update, + Commit(VLBytes), +} diff --git a/openmls/src/test_utils/frankenstein/mod.rs b/openmls/src/test_utils/frankenstein/mod.rs new file mode 100644 index 0000000000..f26393ea48 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/mod.rs @@ -0,0 +1,21 @@ +//! This module contains the Frankenstein test utilities. +//! +//! The Frankenstein test utilities are used to build and manipulate test +//! structures in a way that is not possible with the public API. This is +//! useful for testing and fuzzing. + +mod codec; +mod commit; +mod credentials; +mod extensions; +mod framing; +mod group_info; +mod key_package; +mod leaf_node; +mod proposals; + +pub use self::credentials::*; +pub use self::framing::*; +pub use self::key_package::*; +pub use self::leaf_node::*; +pub use self::proposals::*; diff --git a/openmls/src/test_utils/frankenstein/proposals.rs b/openmls/src/test_utils/frankenstein/proposals.rs new file mode 100644 index 0000000000..6359c11ec5 --- /dev/null +++ b/openmls/src/test_utils/frankenstein/proposals.rs @@ -0,0 +1,197 @@ +use tls_codec::*; + +use super::{extensions::FrankenExtension, FrankenKeyPackage, FrankenLeafNode}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FrankenProposalType { + Add, + Update, + Remove, + PreSharedKey, + Reinit, + ExternalInit, + GroupContextExtensions, + AppAck, + Custom(u16), +} + +impl From for FrankenProposalType { + fn from(value: u16) -> Self { + match value { + 1 => FrankenProposalType::Add, + 2 => FrankenProposalType::Update, + 3 => FrankenProposalType::Remove, + 4 => FrankenProposalType::PreSharedKey, + 5 => FrankenProposalType::Reinit, + 6 => FrankenProposalType::ExternalInit, + 7 => FrankenProposalType::GroupContextExtensions, + 8 => FrankenProposalType::AppAck, + other => FrankenProposalType::Custom(other), + } + } +} + +impl From for u16 { + fn from(value: FrankenProposalType) -> Self { + match value { + FrankenProposalType::Add => 1, + FrankenProposalType::Update => 2, + FrankenProposalType::Remove => 3, + FrankenProposalType::PreSharedKey => 4, + FrankenProposalType::Reinit => 5, + FrankenProposalType::ExternalInit => 6, + FrankenProposalType::GroupContextExtensions => 7, + FrankenProposalType::AppAck => 8, + FrankenProposalType::Custom(id) => id, + } + } +} + +impl FrankenProposal { + pub fn proposal_type(&self) -> FrankenProposalType { + match self { + FrankenProposal::Add(_) => FrankenProposalType::Add, + FrankenProposal::Update(_) => FrankenProposalType::Update, + FrankenProposal::Remove(_) => FrankenProposalType::Remove, + FrankenProposal::PreSharedKey(_) => FrankenProposalType::PreSharedKey, + FrankenProposal::ReInit(_) => FrankenProposalType::Reinit, + FrankenProposal::ExternalInit(_) => FrankenProposalType::ExternalInit, + FrankenProposal::GroupContextExtensions(_) => { + FrankenProposalType::GroupContextExtensions + } + FrankenProposal::AppAck(_) => FrankenProposalType::AppAck, + FrankenProposal::Custom(FrankenCustomProposal { + proposal_type, + payload: _, + }) => FrankenProposalType::Custom(proposal_type.to_owned()), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +#[repr(u16)] +pub enum FrankenProposal { + Add(FrankenAddProposal), + Update(FrankenUpdateProposal), + Remove(FrankenRemoveProposal), + PreSharedKey(FrankenPreSharedKeyProposal), + ReInit(FrankenReInitProposal), + ExternalInit(FrankenExternalInitProposal), + GroupContextExtensions(Vec), + AppAck(FrankenAppAckProposal), + Custom(FrankenCustomProposal), +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenAddProposal { + pub key_package: FrankenKeyPackage, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenUpdateProposal { + pub leaf_node: FrankenLeafNode, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenRemoveProposal { + pub removed: u32, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenPreSharedKeyProposal { + pub psk: FrankenPreSharedKeyId, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenPreSharedKeyId { + pub psk: FrankenPsk, + pub psk_nonce: VLBytes, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenPsk { + #[tls_codec(discriminant = 1)] + External(FrankenExternalPsk), + #[tls_codec(discriminant = 2)] + Resumption(FrankenResumptionPsk), +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenExternalPsk { + pub psk_id: VLBytes, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenResumptionPsk { + pub usage: FrankenResumptionPskUsage, + pub psk_group_id: VLBytes, + pub psk_epoch: u64, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +#[repr(u8)] +pub enum FrankenResumptionPskUsage { + Application = 1, + Reinit = 2, + Branch = 3, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenReInitProposal { + pub group_id: VLBytes, + pub version: u16, + pub ciphersuite: u16, + pub extensions: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenExternalInitProposal { + pub kem_output: VLBytes, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenAppAckProposal { + pub received_ranges: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenMessageRange { + pub sender: VLBytes, + pub first_generation: u32, + pub last_generation: u32, +} + +#[derive( + Debug, Clone, PartialEq, Eq, TlsSerialize, TlsDeserialize, TlsDeserializeBytes, TlsSize, +)] +pub struct FrankenCustomProposal { + pub proposal_type: u16, + pub payload: VLBytes, +} diff --git a/openmls/src/test_utils/mod.rs b/openmls/src/test_utils/mod.rs index 2ee2ba605b..85fa9dd757 100644 --- a/openmls/src/test_utils/mod.rs +++ b/openmls/src/test_utils/mod.rs @@ -9,10 +9,11 @@ use std::{ }; use openmls_basic_credential::SignatureKeyPair; -use openmls_traits::{key_store::OpenMlsKeyStore, types::HpkeKeyPair}; -pub use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; -pub use rstest::*; -pub use rstest_reuse::{self, *}; +pub use openmls_traits::{ + storage::StorageProvider as StorageProviderTrait, + types::{Ciphersuite, HpkeKeyPair}, + OpenMlsProvider, +}; use serde::{self, de::DeserializeOwned, Serialize}; #[cfg(test)] @@ -21,11 +22,12 @@ pub use crate::utils::*; use crate::{ ciphersuite::{HpkePrivateKey, OpenMlsSignaturePublicKey}, credentials::{Credential, CredentialType, CredentialWithKey}, - key_packages::KeyPackage, - prelude::{CryptoConfig, KeyPackageBuilder}, + key_packages::{KeyPackage, KeyPackageBuilder}, + prelude::KeyPackageBundle, treesync::node::encryption_keys::{EncryptionKeyPair, EncryptionPrivateKey}, }; +pub mod frankenstein; pub mod test_framework; pub(crate) fn write(file_name: &str, obj: impl Serialize) { @@ -97,9 +99,7 @@ pub fn hex_to_bytes_option(hex: Option) -> Vec { #[cfg(test)] pub(crate) struct GroupCandidate { pub identity: Vec, - pub key_package: KeyPackage, - pub encryption_keypair: EncryptionKeyPair, - pub init_keypair: HpkeKeyPair, + pub key_package: KeyPackageBundle, pub signature_keypair: SignatureKeyPair, pub credential_with_key_and_signer: CredentialWithKeyAndSigner, } @@ -111,16 +111,16 @@ pub(crate) fn generate_group_candidate( provider: &impl OpenMlsProvider, use_store: bool, ) -> GroupCandidate { - use crate::credentials::BasicCredential; + use crate::{credentials::BasicCredential, prelude::KeyPackageBundle}; let credential_with_key_and_signer = { - let credential = BasicCredential::new(identity.to_vec()).unwrap(); + let credential = BasicCredential::new(identity.to_vec()); let signature_keypair = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); // Store if there is a key store. if use_store { - signature_keypair.store(provider.key_store()).unwrap(); + signature_keypair.store(provider.storage()).unwrap(); } let signature_pkey = OpenMlsSignaturePublicKey::new( @@ -138,63 +138,38 @@ pub(crate) fn generate_group_candidate( } }; - let (key_package, encryption_keypair, init_keypair) = { + let key_package = { let builder = KeyPackageBuilder::new(); if use_store { - let key_package = builder + builder .build( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, provider, &credential_with_key_and_signer.signer, credential_with_key_and_signer.credential_with_key.clone(), ) - .unwrap(); - - let encryption_keypair = EncryptionKeyPair::read_from_key_store( - provider, - key_package.leaf_node().encryption_key(), - ) - .unwrap(); - let init_keypair = { - let private = provider - .key_store() - .read::(key_package.hpke_init_key().as_slice()) - .unwrap(); - - HpkeKeyPair { - private, - public: key_package.hpke_init_key().as_slice().to_vec(), - } - }; - - (key_package, encryption_keypair, init_keypair) + .unwrap() } else { // We don't want to store anything. So... let provider = OpenMlsRustCrypto::default(); let key_package_creation_result = builder - .build_without_key_storage( - CryptoConfig::with_default_version(ciphersuite), + .build_without_storage( + ciphersuite, &provider, &credential_with_key_and_signer.signer, credential_with_key_and_signer.credential_with_key.clone(), ) .unwrap(); - let init_keypair = HpkeKeyPair { - private: key_package_creation_result.init_private_key, - public: key_package_creation_result - .key_package - .hpke_init_key() - .as_slice() - .to_vec(), - }; - - ( + KeyPackageBundle::new( key_package_creation_result.key_package, - key_package_creation_result.encryption_keypair, - init_keypair, + key_package_creation_result.init_private_key, + key_package_creation_result + .encryption_keypair + .private_key() + .clone(), ) } }; @@ -202,78 +177,17 @@ pub(crate) fn generate_group_candidate( GroupCandidate { identity: identity.as_ref().to_vec(), key_package, - encryption_keypair, - init_keypair, signature_keypair: credential_with_key_and_signer.signer.clone(), credential_with_key_and_signer, } } -// === Define provider per platform === - -// This provider is currently used on all platforms -#[cfg(feature = "libcrux-provider")] -pub use openmls_libcrux_crypto::Provider as OpenMlsLibcrux; -pub use openmls_rust_crypto::OpenMlsRustCrypto; - -// === providers === - -#[template] -#[export] -#[cfg_attr(feature = "libcrux-provider", rstest(provider, - case::rust_crypto(&OpenMlsRustCrypto::default()), - case::libcrux(&OpenMlsLibcrux::default()), - ) -)] -#[cfg_attr(not(feature = "libcrux-provider"),rstest(provider, - case::rust_crypto(&OpenMlsRustCrypto::default()), - ) -)] -#[allow(non_snake_case)] -#[cfg_attr(target_arch = "wasm32", openmls::wasm::test)] -pub fn providers(provider: &impl OpenMlsProvider) {} - -// === Ciphersuites === - -// For now we support all ciphersuites, regardless of the provider - -#[template] -#[export] -#[rstest( - ciphersuite, - case::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519( - Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 - ), - case::MLS_128_DHKEMP256_AES128GCM_SHA256_P256( - Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 - ), - case::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519( - Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 - ) -)] -#[allow(non_snake_case)] -#[cfg_attr(target_arch = "wasm32", openmls::wasm::test)] -pub fn ciphersuites(ciphersuite: Ciphersuite) {} - -// === Ciphersuites & providers === - -#[template] -#[export] -#[cfg_attr(feature = "libcrux-provider", rstest(ciphersuite, provider, - case::rust_crypto_MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519(Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, &OpenMlsRustCrypto::default()), - case::rust_crypto_MLS_128_DHKEMP256_AES128GCM_SHA256_P256(Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, &OpenMlsRustCrypto::default()), - case::rust_crypto_MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519(Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, &OpenMlsRustCrypto::default()), - case::libcrux_MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519(Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, &$crate::test_utils::OpenMlsLibcrux::default()), - case::libcrux_MLS_128_DHKEMP256_AES128GCM_SHA256_P256(Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, &$crate::test_utils::OpenMlsLibcrux::default()), - case::libcrux_MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519(Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, &$crate::test_utils::OpenMlsLibcrux::default()), - ) -)] -#[cfg_attr(not(feature = "libcrux-provider"),rstest(ciphersuite, provider, - case::rust_crypto_MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519(Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, &OpenMlsRustCrypto::default()), - case::rust_crypto_MLS_128_DHKEMP256_AES128GCM_SHA256_P256(Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, &OpenMlsRustCrypto::default()), - case::rust_crypto_MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519(Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, &OpenMlsRustCrypto::default()), - ) -)] -#[allow(non_snake_case)] -#[cfg_attr(target_arch = "wasm32", openmls::wasm::test)] -pub fn ciphersuites_and_providers(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) {} +#[cfg(all( + feature = "libcrux-provider", + not(any( + target_arch = "wasm32", + all(target_arch = "x86", target_os = "windows") + )) +))] +pub type OpenMlsLibcrux = openmls_libcrux_crypto::Provider; +pub type OpenMlsRustCrypto = openmls_rust_crypto::OpenMlsRustCrypto; diff --git a/openmls/src/test_utils/test_framework/client.rs b/openmls/src/test_utils/test_framework/client.rs index 6091408c61..ebdb3bc47c 100644 --- a/openmls/src/test_utils/test_framework/client.rs +++ b/openmls/src/test_utils/test_framework/client.rs @@ -4,23 +4,24 @@ use std::{collections::HashMap, sync::RwLock}; use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{ - key_store::OpenMlsKeyStore, types::{Ciphersuite, HpkeKeyPair, SignatureScheme}, - OpenMlsProvider, + OpenMlsProvider as _, }; use tls_codec::{Deserialize, Serialize}; +use super::OpenMlsRustCrypto; + use crate::{ binary_tree::array_representation::LeafNodeIndex, ciphersuite::hash_ref::KeyPackageRef, credentials::*, extensions::*, framing::*, - group::{config::CryptoConfig, *}, + group::*, key_packages::*, messages::{group_info::GroupInfo, *}, + storage::OpenMlsProvider, treesync::{ node::{leaf_node::Capabilities, Node}, LeafNode, RatchetTree, RatchetTreeIn, @@ -36,29 +37,29 @@ use super::{errors::ClientError, ActionType}; /// containing its `CredentialWithKey`s. The `key_package_bundles` field /// contains generated `KeyPackageBundle`s that are waiting to be used for new /// groups. -pub struct Client { +pub struct Client { /// Name of the client. pub identity: Vec, /// Ciphersuites supported by the client. pub credentials: HashMap, - pub crypto: OpenMlsRustCrypto, + pub provider: Provider, pub groups: RwLock>, } -impl Client { +impl Client { /// Generate a fresh key package and return it. /// The first ciphersuite determines the /// credential used to generate the `KeyPackage`. pub fn get_fresh_key_package( &self, ciphersuite: Ciphersuite, - ) -> Result { + ) -> Result> { let credential_with_key = self .credentials .get(&ciphersuite) .ok_or(ClientError::CiphersuiteNotSupported)?; let keys = SignatureKeyPair::read( - self.crypto.key_store(), + self.provider.storage(), credential_with_key.signature_key.as_slice(), ciphersuite.signature_algorithm(), ) @@ -66,17 +67,14 @@ impl Client { let key_package = KeyPackage::builder() .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - &self.crypto, + ciphersuite, + &self.provider, &keys, credential_with_key.clone(), ) .unwrap(); - Ok(key_package) + Ok(key_package.key_package) } /// Create a group with the given [MlsGroupCreateConfig] and [Ciphersuite], and return the created [GroupId]. @@ -86,20 +84,21 @@ impl Client { &self, mls_group_create_config: MlsGroupCreateConfig, ciphersuite: Ciphersuite, - ) -> Result { + ) -> Result> { let credential_with_key = self .credentials .get(&ciphersuite) - .ok_or(ClientError::CiphersuiteNotSupported)?; + .ok_or(ClientError::CiphersuiteNotSupported); + let credential_with_key = credential_with_key?; let signer = SignatureKeyPair::read( - self.crypto.key_store(), + self.provider.storage(), credential_with_key.signature_key.as_slice(), ciphersuite.signature_algorithm(), ) .unwrap(); let group_state = MlsGroup::new( - &self.crypto, + &self.provider, &signer, &mls_group_create_config, credential_with_key.clone(), @@ -121,14 +120,14 @@ impl Client { mls_group_config: MlsGroupJoinConfig, welcome: Welcome, ratchet_tree: Option, - ) -> Result<(), ClientError> { + ) -> Result<(), ClientError> { let staged_join = StagedWelcome::new_from_welcome( - &self.crypto, + &self.provider, &mls_group_config, welcome, ratchet_tree, )?; - let new_group = staged_join.into_group(&self.crypto)?; + let new_group = staged_join.into_group(&self.provider)?; self.groups .write() .expect("An unexpected error occurred.") @@ -144,29 +143,31 @@ impl Client { message: &ProtocolMessage, sender_id: &[u8], authentication_service: &AS, - ) -> Result<(), ClientError> { + ) -> Result<(), ClientError> { let mut group_states = self.groups.write().expect("An unexpected error occurred."); let group_id = message.group_id(); let group_state = group_states .get_mut(group_id) .ok_or(ClientError::NoMatchingGroup)?; if sender_id == self.identity && message.content_type() == ContentType::Commit { - group_state.merge_pending_commit(&self.crypto)? + group_state.merge_pending_commit(&self.provider)? } else { if message.content_type() == ContentType::Commit { // Clear any potential pending commits. - group_state.clear_pending_commit(); + group_state.clear_pending_commit(self.provider.storage())?; } // Process the message. - let processed_message = group_state.process_message(&self.crypto, message.clone())?; + let processed_message = group_state.process_message(&self.provider, message.clone())?; match processed_message.into_content() { ProcessedMessageContent::ApplicationMessage(_) => {} ProcessedMessageContent::ProposalMessage(staged_proposal) => { - group_state.store_pending_proposal(*staged_proposal); + group_state + .store_pending_proposal(self.provider.storage(), *staged_proposal)?; } ProcessedMessageContent::ExternalJoinProposalMessage(staged_proposal) => { - group_state.store_pending_proposal(*staged_proposal); + group_state + .store_pending_proposal(self.provider.storage(), *staged_proposal)?; } ProcessedMessageContent::StagedCommitMessage(staged_commit) => { for credential in staged_commit.credentials_to_verify() { @@ -177,7 +178,7 @@ impl Client { return Err(ClientError::NoMatchingCredential); } } - group_state.merge_staged_commit(&self.crypto, *staged_commit)?; + group_state.merge_staged_commit(&self.provider, *staged_commit)?; } } } @@ -188,7 +189,10 @@ impl Client { /// Get the credential and the index of each group member of the group with /// the given id. Returns an error if no group exists with the given group /// id. - pub fn get_members_of_group(&self, group_id: &GroupId) -> Result, ClientError> { + pub fn get_members_of_group( + &self, + group_id: &GroupId, + ) -> Result, ClientError> { let groups = self.groups.read().expect("An unexpected error occurred."); let group = groups.get(group_id).ok_or(ClientError::NoMatchingGroup)?; let members = group.members().collect(); @@ -206,7 +210,10 @@ impl Client { action_type: ActionType, group_id: &GroupId, leaf_node: Option, - ) -> Result<(MlsMessageOut, Option, Option), ClientError> { + ) -> Result< + (MlsMessageOut, Option, Option), + ClientError, + > { let mut groups = self.groups.write().expect("An unexpected error occurred."); let group = groups .get_mut(group_id) @@ -215,16 +222,16 @@ impl Client { // key store. let signature_pk = group.own_leaf().unwrap().signature_key(); let signer = SignatureKeyPair::read( - self.crypto.key_store(), + self.provider.storage(), signature_pk.as_slice(), group.ciphersuite().signature_algorithm(), ) .unwrap(); let (msg, welcome_option, group_info) = match action_type { - ActionType::Commit => group.self_update(&self.crypto, &signer)?, + ActionType::Commit => group.self_update(&self.provider, &signer)?, ActionType::Proposal => ( group - .propose_self_update(&self.crypto, &signer, leaf_node) + .propose_self_update(&self.provider, &signer, leaf_node) .map(|(out, _)| out)?, None, None, @@ -248,7 +255,10 @@ impl Client { action_type: ActionType, group_id: &GroupId, key_packages: &[KeyPackage], - ) -> Result<(Vec, Option, Option), ClientError> { + ) -> Result< + (Vec, Option, Option), + ClientError, + > { let mut groups = self.groups.write().expect("An unexpected error occurred."); let group = groups .get_mut(group_id) @@ -257,7 +267,7 @@ impl Client { // key store. let signature_pk = group.own_leaf().unwrap().signature_key(); let signer = SignatureKeyPair::read( - self.crypto.key_store(), + self.provider.storage(), signature_pk.as_slice(), group.ciphersuite().signature_algorithm(), ) @@ -265,7 +275,7 @@ impl Client { let action_results = match action_type { ActionType::Commit => { let (messages, welcome_message, group_info) = - group.add_members(&self.crypto, &signer, key_packages)?; + group.add_members(&self.provider, &signer, key_packages)?; ( vec![messages], Some( @@ -280,7 +290,7 @@ impl Client { let mut messages = Vec::new(); for key_package in key_packages { let message = group - .propose_add_member(&self.crypto, &signer, key_package) + .propose_add_member(&self.provider, &signer, key_package) .map(|(out, _)| out)?; messages.push(message); } @@ -301,7 +311,10 @@ impl Client { action_type: ActionType, group_id: &GroupId, targets: &[LeafNodeIndex], - ) -> Result<(Vec, Option, Option), ClientError> { + ) -> Result< + (Vec, Option, Option), + ClientError, + > { let mut groups = self.groups.write().expect("An unexpected error occurred."); let group = groups .get_mut(group_id) @@ -310,7 +323,7 @@ impl Client { // key store. let signature_pk = group.own_leaf().unwrap().signature_key(); let signer = SignatureKeyPair::read( - self.crypto.key_store(), + self.provider.storage(), signature_pk.as_slice(), group.ciphersuite().signature_algorithm(), ) @@ -318,7 +331,7 @@ impl Client { let action_results = match action_type { ActionType::Commit => { let (message, welcome_option, group_info) = - group.remove_members(&self.crypto, &signer, targets)?; + group.remove_members(&self.provider, &signer, targets)?; ( vec![message], welcome_option.map(|w| w.into_welcome().expect("Unexpected message type.")), @@ -329,7 +342,7 @@ impl Client { let mut messages = Vec::new(); for target in targets { let message = group - .propose_remove_member(&self.crypto, &signer, *target) + .propose_remove_member(&self.provider, &signer, *target) .map(|(out, _)| out)?; messages.push(message); } @@ -345,7 +358,7 @@ impl Client { let group = groups.get(group_id).unwrap(); let leaf = group.own_leaf(); leaf.map(|l| { - let credential = BasicCredential::try_from(l.credential()).unwrap(); + let credential = BasicCredential::try_from(l.credential().clone()).unwrap(); credential.identity().to_vec() }) } diff --git a/openmls/src/test_utils/test_framework/errors.rs b/openmls/src/test_utils/test_framework/errors.rs index b95640c9e8..370fe9ac49 100644 --- a/openmls/src/test_utils/test_framework/errors.rs +++ b/openmls/src/test_utils/test_framework/errors.rs @@ -1,12 +1,11 @@ -use openmls_traits::key_store::OpenMlsKeyStore; use thiserror::Error; use crate::{error::LibraryError, group::errors::*}; -use openmls_rust_crypto::{MemoryKeyStore, MemoryKeyStoreError}; +use openmls_rust_crypto::MemoryStorage; /// Setup error -#[derive(Error, Debug, PartialEq, Clone)] -pub enum SetupError { +#[derive(Error, Debug, PartialEq)] +pub enum SetupError { #[error("")] UnknownGroupId, #[error("")] @@ -21,10 +20,10 @@ pub enum SetupError { NoFreshKeyPackage, /// See [`ClientError`] for more details. #[error(transparent)] - ClientError(#[from] ClientError), + ClientError(#[from] ClientError), /// See [`ExportSecretError`] for more details. #[error(transparent)] - ExportSecretError(#[from] ExportSecretError), + ExportSecretError(#[from] ExportSecretError), /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), @@ -38,8 +37,8 @@ pub enum SetupGroupError { } /// Errors that can occur when processing messages with the client. -#[derive(Error, Debug, PartialEq, Clone)] -pub enum ClientError { +#[derive(Error, Debug, PartialEq)] +pub enum ClientError { #[error("")] NoMatchingKeyPackage, #[error("")] @@ -52,51 +51,52 @@ pub enum ClientError { NoCiphersuite, /// See [`WelcomeError`] for more details. #[error(transparent)] - FailedToJoinGroup(#[from] WelcomeError), + FailedToJoinGroup(#[from] WelcomeError), + /// See [`tls_codec::Error`] for more details. #[error(transparent)] - TlsCodecError(#[from] tls_codec::Error), + TlsCodecError(tls_codec::Error), /// See [`ProcessMessageError`] for more details. #[error(transparent)] - ProcessMessageError(#[from] ProcessMessageError), + ProcessMessageError(#[from] ProcessMessageError), /// See [`MlsGroupStateError`] for more details. #[error(transparent)] - MlsGroupStateError(#[from] MlsGroupStateError), + MlsGroupStateError(#[from] MlsGroupStateError), /// See [`AddMembersError`] for more details. #[error(transparent)] - AddMembersError(#[from] AddMembersError), + AddMembersError(#[from] AddMembersError), /// See [`RemoveMembersError`] for more details. #[error(transparent)] - RemoveMembersError(#[from] RemoveMembersError), + RemoveMembersError(#[from] RemoveMembersError), /// See [`ProposeAddMemberError`] for more details. #[error(transparent)] - ProposeAddMemberError(#[from] ProposeAddMemberError), + ProposeAddMemberError(#[from] ProposeAddMemberError), /// See [`ProposeRemoveMemberError`] for more details. #[error(transparent)] - ProposeRemoveMemberError(#[from] ProposeRemoveMemberError), + ProposeRemoveMemberError(#[from] ProposeRemoveMemberError), /// See [`ExportSecretError`] for more details. #[error(transparent)] - ExportSecretError(#[from] ExportSecretError), + ExportSecretError(#[from] ExportSecretError), /// See [`NewGroupError`] for more details. #[error(transparent)] - NewGroupError(#[from] NewGroupError), + NewGroupError(#[from] NewGroupError), /// See [`SelfUpdateError`] for more details. #[error(transparent)] - SelfUpdateError(#[from] SelfUpdateError), + SelfUpdateError(#[from] SelfUpdateError), /// See [`ProposeSelfUpdateError`] for more details. #[error(transparent)] - ProposeSelfUpdateError(#[from] ProposeSelfUpdateError), + ProposeSelfUpdateError(#[from] ProposeSelfUpdateError), /// See [`MergePendingCommitError`] for more details. #[error(transparent)] - MergePendingCommitError(#[from] MergePendingCommitError), + MergePendingCommitError(#[from] MergePendingCommitError), /// See [`MergeCommitError`] for more details. #[error(transparent)] - MergeCommitError(#[from] MergeCommitError), - /// See [`MemoryKeyStoreError`] for more details. + MergeCommitError(#[from] MergeCommitError), + /// See [`StorageError>`] for more details. #[error(transparent)] - KeyStoreError(#[from] MemoryKeyStoreError), + KeyStoreError(#[from] StorageError), /// See [`LibraryError`] for more details. #[error(transparent)] - LibraryError(#[from] LibraryError), + LibraryError(LibraryError), #[error("")] Unknown, } diff --git a/openmls/src/test_utils/test_framework/mod.rs b/openmls/src/test_utils/test_framework/mod.rs index e150606583..383283f529 100644 --- a/openmls/src/test_utils/test_framework/mod.rs +++ b/openmls/src/test_utils/test_framework/mod.rs @@ -21,6 +21,8 @@ //! can be manipulated manually via the `Client` struct, which contains their //! group states. +use crate::storage::OpenMlsProvider; +use crate::test_utils::OpenMlsRustCrypto; use crate::{ binary_tree::array_representation::LeafNodeIndex, ciphersuite::{hash_ref::KeyPackageRef, *}, @@ -33,20 +35,15 @@ use crate::{ }; use ::rand::{rngs::OsRng, RngCore}; use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{ crypto::OpenMlsCrypto, - key_store::OpenMlsKeyStore, types::{Ciphersuite, HpkeKeyPair, SignatureScheme}, - OpenMlsProvider, + OpenMlsProvider as _, }; use std::{collections::HashMap, sync::RwLock}; use tls_codec::*; -#[cfg(not(target_arch = "wasm32"))] -use rayon::prelude::*; - pub mod client; pub mod errors; @@ -105,9 +102,9 @@ pub enum CodecUse { /// that the `MlsGroupTestSetup` can only be initialized with a fixed number of /// clients and that `create_clients` has to be called before it can be /// otherwise used. -pub struct MlsGroupTestSetup { +pub struct MlsGroupTestSetup { // The clients identity is its position in the vector in be_bytes. - pub clients: RwLock, RwLock>>, + pub clients: RwLock, RwLock>>>, pub groups: RwLock>, // This maps key package hashes to client ids. pub waiting_for_welcome: RwLock, Vec>>, @@ -133,7 +130,7 @@ pub struct MlsGroupTestSetup { // context that the `MlsGroupTestSetup` lives in, because otherwise the // references don't live long enough. -impl MlsGroupTestSetup { +impl MlsGroupTestSetup { /// Create a new `MlsGroupTestSetup` with the given default /// `MlsGroupCreateConfig` and the given number of clients. For lifetime /// reasons, `create_clients` has to be called in addition with the same @@ -146,14 +143,13 @@ impl MlsGroupTestSetup { let mut clients = HashMap::new(); for i in 0..number_of_clients { let identity = i.to_be_bytes().to_vec(); - // For now, everyone supports all ciphersuites. - let crypto = OpenMlsRustCrypto::default(); + let provider = Provider::default(); let mut credentials = HashMap::new(); - for ciphersuite in crypto.crypto().supported_ciphersuites().iter() { - let credential = BasicCredential::new(identity.clone()).unwrap(); + for ciphersuite in provider.crypto().supported_ciphersuites().iter() { + let credential = BasicCredential::new(identity.clone()); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); - signature_keys.store(crypto.key_store()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); let signature_key = OpenMlsSignaturePublicKey::new( signature_keys.public().into(), signature_keys.signature_scheme(), @@ -171,7 +167,7 @@ impl MlsGroupTestSetup { let client = Client { identity: identity.clone(), credentials, - crypto, + provider, groups: RwLock::new(HashMap::new()), }; clients.insert(identity, RwLock::new(client)); @@ -193,16 +189,16 @@ impl MlsGroupTestSetup { /// error if the client does not support the given ciphersuite. pub fn get_fresh_key_package( &self, - client: &Client, + client: &Client, ciphersuite: Ciphersuite, - ) -> Result { + ) -> Result> { let key_package = client.get_fresh_key_package(ciphersuite)?; self.waiting_for_welcome .write() .expect("An unexpected error occurred.") .insert( key_package - .hash_ref(client.crypto.crypto())? + .hash_ref(client.provider.crypto())? .as_slice() .to_vec(), client.identity.clone(), @@ -248,7 +244,11 @@ impl MlsGroupTestSetup { /// distribute the commit adding the members to the group. This function /// will throw an error if no key package was previously created for the /// client by `get_fresh_key_package`. - pub fn deliver_welcome(&self, welcome: Welcome, group: &Group) -> Result<(), SetupError> { + pub fn deliver_welcome( + &self, + welcome: Welcome, + group: &Group, + ) -> Result<(), SetupError> { // Serialize and de-serialize the Welcome if the bit was set. let welcome = match self.use_codec { CodecUse::SerializedMessages => { @@ -293,14 +293,17 @@ impl MlsGroupTestSetup { group: &mut Group, message: &MlsMessageIn, authentication_service: &AS, - ) -> Result<(), ClientError> { + ) -> Result<(), ClientError> { // Test serialization if mandated by config let message: ProtocolMessage = match self.use_codec { CodecUse::SerializedMessages => { let mls_message_out: MlsMessageOut = message.clone().into(); - let serialized_message = mls_message_out.tls_serialize_detached()?; + let serialized_message = mls_message_out + .tls_serialize_detached() + .map_err(ClientError::TlsCodecError)?; - MlsMessageIn::tls_deserialize(&mut serialized_message.as_slice())? + MlsMessageIn::tls_deserialize(&mut serialized_message.as_slice()) + .map_err(ClientError::TlsCodecError)? } CodecUse::StructMessages => message.clone(), } @@ -346,16 +349,11 @@ impl MlsGroupTestSetup { .map( |Member { index, credential, .. - }| { - let identity = - VLBytes::tls_deserialize_exact(credential.serialized_content()).unwrap(); - (index.usize(), identity.as_slice().to_vec()) - }, + }| { (index.usize(), credential.serialized_content().to_vec()) }, ) .collect(); group.public_tree = sender_group.export_ratchet_tree(); - group.exporter_secret = - sender_group.export_secret(sender.crypto.crypto(), "test", &[], 32)?; + group.exporter_secret = sender_group.export_secret(&sender.provider, "test", &[], 32)?; Ok(()) } @@ -371,9 +369,6 @@ impl MlsGroupTestSetup { ) { let clients = self.clients.read().expect("An unexpected error occurred."); - #[cfg(not(target_arch = "wasm32"))] - let group_members = group.members.par_iter(); - #[cfg(target_arch = "wasm32")] let group_members = group.members.iter(); let messages = group_members @@ -389,7 +384,7 @@ impl MlsGroupTestSetup { assert_eq!(group_state.export_ratchet_tree(), group.public_tree); assert_eq!( group_state - .export_secret(m.crypto.crypto(), "test", &[], 32) + .export_secret(&m.provider, "test", &[], 32) .expect("An unexpected error occurred."), group.exporter_secret ); @@ -397,13 +392,13 @@ impl MlsGroupTestSetup { // key store. let signature_pk = group_state.own_leaf().unwrap().signature_key(); let signer = SignatureKeyPair::read( - m.crypto.key_store(), + m.provider.storage(), signature_pk.as_slice(), group_state.ciphersuite().signature_algorithm(), ) .unwrap(); let message = group_state - .create_message(&m.crypto, &signer, "Hello World!".as_bytes()) + .create_message(&m.provider, &signer, "Hello World!".as_bytes()) .expect("Error composing message while checking group states."); Some((m_id.to_vec(), message)) } else { @@ -427,7 +422,7 @@ impl MlsGroupTestSetup { &self, group: &Group, number_of_members: usize, - ) -> Result>, SetupError> { + ) -> Result>, SetupError> { let clients = self.clients.read().expect("An unexpected error occurred."); if number_of_members + group.members.len() > clients.len() { return Err(SetupError::NotEnoughClients); @@ -461,7 +456,10 @@ impl MlsGroupTestSetup { /// does not support the given ciphersuite. TODO #310: Fix to always work /// reliably, probably by introducing a mapping from ciphersuite to the set /// of client ids supporting it. - pub fn create_group(&self, ciphersuite: Ciphersuite) -> Result { + pub fn create_group( + &self, + ciphersuite: Ciphersuite, + ) -> Result> { // Pick a random group creator. let clients = self.clients.read().expect("An unexpected error occurred."); let group_creator_id = ((OsRng.next_u32() as usize) % clients.len()) @@ -482,8 +480,7 @@ impl MlsGroupTestSetup { .get(&group_id) .expect("An unexpected error occurred."); let public_tree = group.export_ratchet_tree(); - let exporter_secret = - group.export_secret(group_creator.crypto.crypto(), "test", &[], 32)?; + let exporter_secret = group.export_secret(&group_creator.provider, "test", &[], 32)?; let member_ids = vec![(0, group_creator_id)]; let group = Group { group_id: group_id.clone(), @@ -503,7 +500,7 @@ impl MlsGroupTestSetup { target_group_size: usize, ciphersuite: Ciphersuite, authentication_service: AS, - ) -> Result { + ) -> Result> { // Create the initial group. let group_id = self.create_group(ciphersuite)?; @@ -543,7 +540,7 @@ impl MlsGroupTestSetup { client_id: &[u8], leaf_node: Option, authentication_service: &AS, - ) -> Result<(), SetupError> { + ) -> Result<(), SetupError> { let clients = self.clients.read().expect("An unexpected error occurred."); let client = clients .get(client_id) @@ -577,7 +574,7 @@ impl MlsGroupTestSetup { adder_id: &[u8], addees: Vec>, authentication_service: &AS, - ) -> Result<(), SetupError> { + ) -> Result<(), SetupError> { let clients = self.clients.read().expect("An unexpected error occurred."); let adder = clients .get(adder_id) @@ -623,7 +620,7 @@ impl MlsGroupTestSetup { remover_id: &[u8], target_members: &[LeafNodeIndex], authentication_service: AS, - ) -> Result<(), SetupError> { + ) -> Result<(), SetupError> { let clients = self.clients.read().expect("An unexpected error occurred."); let remover = clients .get(remover_id) @@ -653,7 +650,7 @@ impl MlsGroupTestSetup { &self, group: &mut Group, authentication_service: &AS, - ) -> Result<(), SetupError> { + ) -> Result<(), SetupError> { // Who's going to do it? let member_id = group.random_group_member(); println!("Member performing the operation: {member_id:?}"); diff --git a/openmls/src/tree/secret_tree.rs b/openmls/src/tree/secret_tree.rs index 9ce1f81ea9..700266d675 100644 --- a/openmls/src/tree/secret_tree.rs +++ b/openmls/src/tree/secret_tree.rs @@ -74,6 +74,7 @@ impl From<&PublicMessage> for SecretType { /// to the `DeriveTreeSecret` defined in Section 10.1 of the MLS specification. #[inline] pub(crate) fn derive_tree_secret( + ciphersuite: Ciphersuite, secret: &Secret, label: &str, generation: u32, @@ -88,7 +89,13 @@ pub(crate) fn derive_tree_secret( ); log_crypto!(trace, "Input secret {:x?}", secret.as_slice()); - let secret = secret.kdf_expand_label(crypto, label, &generation.to_be_bytes(), length)?; + let secret = secret.kdf_expand_label( + crypto, + ciphersuite, + label, + &generation.to_be_bytes(), + length, + )?; log_crypto!(trace, "Derived secret {:x?}", secret.as_slice()); Ok(secret) } @@ -229,10 +236,20 @@ impl SecretTree { log::trace!("Deriving leaf node secrets for leaf {index:?}"); - let handshake_ratchet_secret = - node_secret.kdf_expand_label(crypto, "handshake", b"", ciphersuite.hash_length())?; - let application_ratchet_secret = - node_secret.kdf_expand_label(crypto, "application", b"", ciphersuite.hash_length())?; + let handshake_ratchet_secret = node_secret.kdf_expand_label( + crypto, + ciphersuite, + "handshake", + b"", + ciphersuite.hash_length(), + )?; + let application_ratchet_secret = node_secret.kdf_expand_label( + crypto, + ciphersuite, + "application", + b"", + ciphersuite.hash_length(), + )?; log_crypto!( trace, @@ -397,8 +414,10 @@ impl SecretTree { log_crypto!(trace, "Node secret: {:x?}", node_secret.as_slice()); let left_index = left(index_in_tree); let right_index = right(index_in_tree); - let left_secret = node_secret.kdf_expand_label(crypto, "tree", b"left", hash_len)?; - let right_secret = node_secret.kdf_expand_label(crypto, "tree", b"right", hash_len)?; + let left_secret = + node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"left", hash_len)?; + let right_secret = + node_secret.kdf_expand_label(crypto, ciphersuite, "tree", b"right", hash_len)?; log_crypto!( trace, "Left node ({}) secret: {:x?}", diff --git a/openmls/src/tree/sender_ratchet.rs b/openmls/src/tree/sender_ratchet.rs index 1d4d27861d..240a87465f 100644 --- a/openmls/src/tree/sender_ratchet.rs +++ b/openmls/src/tree/sender_ratchet.rs @@ -129,6 +129,7 @@ impl RatchetSecret { return Err(SecretTreeError::RatchetTooLong); } let nonce = derive_tree_secret( + ciphersuite, &self.secret, "nonce", self.generation, @@ -136,6 +137,7 @@ impl RatchetSecret { crypto, )?; let key = derive_tree_secret( + ciphersuite, &self.secret, "key", self.generation, @@ -143,6 +145,7 @@ impl RatchetSecret { crypto, )?; self.secret = derive_tree_secret( + ciphersuite, &self.secret, "secret", self.generation, @@ -153,7 +156,10 @@ impl RatchetSecret { self.generation += 1; Ok(( generation, - (AeadKey::from_secret(key), AeadNonce::from_secret(nonce)), + ( + AeadKey::from_secret(key, ciphersuite), + AeadNonce::from_secret(nonce), + ), )) } diff --git a/openmls/src/tree/tests_and_kats/kats/kat_encryption.rs b/openmls/src/tree/tests_and_kats/kats/kat_encryption.rs index 56615c3ec0..0b430b6c55 100644 --- a/openmls/src/tree/tests_and_kats/kats/kat_encryption.rs +++ b/openmls/src/tree/tests_and_kats/kats/kat_encryption.rs @@ -81,8 +81,10 @@ use itertools::izip; use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{signatures::Signer, types::SignatureScheme, OpenMlsProvider}; +use openmls_traits::{ + signatures::Signer, + types::{Ciphersuite, SignatureScheme}, +}; use serde::{self, Deserialize, Serialize}; use thiserror::Error; @@ -96,13 +98,13 @@ use crate::{ group::*, messages::proposals::{Proposal, RemoveProposal}, schedule::{EncryptionSecret, SenderDataSecret}, - test_utils::*, + storage::OpenMlsProvider, + test_utils::bytes_to_hex, tree::{ secret_tree::{SecretTree, SecretType}, sender_ratchet::SenderRatchetConfiguration, }, utils::random_u64, - versions::ProtocolVersion, }; #[derive(Serialize, Deserialize, Debug, Clone, Default)] @@ -142,9 +144,9 @@ fn generate_credential( signature_algorithm: SignatureScheme, provider: &impl OpenMlsProvider, ) -> (CredentialWithKey, SignatureKeyPair) { - let credential = BasicCredential::new(identity).unwrap(); + let credential = BasicCredential::new(identity); let signature_keys = SignatureKeyPair::new(signature_algorithm).unwrap(); - signature_keys.store(provider.key_store()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); ( CredentialWithKey { @@ -160,8 +162,6 @@ fn group( ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider, ) -> (CoreGroup, CredentialWithKey, SignatureKeyPair) { - use crate::group::config::CryptoConfig; - let (credential_with_key, signer) = generate_credential( "Kreator".into(), ciphersuite.signature_algorithm(), @@ -170,7 +170,7 @@ fn group( let group = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, credential_with_key.clone(), ) .build(provider, &signer) @@ -185,21 +185,15 @@ fn receiver_group( provider: &impl OpenMlsProvider, group_id: GroupId, ) -> (CoreGroup, CredentialWithKey, SignatureKeyPair) { - use crate::group::config::CryptoConfig; - let (credential_with_key, signer) = generate_credential( "Receiver".into(), ciphersuite.signature_algorithm(), provider, ); - let group = CoreGroup::builder( - group_id, - CryptoConfig::with_default_version(ciphersuite), - credential_with_key.clone(), - ) - .build(provider, &signer) - .unwrap(); + let group = CoreGroup::builder(group_id, ciphersuite, credential_with_key.clone()) + .build(provider, &signer) + .unwrap(); (group, credential_with_key, signer) } @@ -220,12 +214,7 @@ fn build_handshake_messages( group.context_mut().set_epoch(epoch.into()); let framing_parameters = FramingParameters::new(&[1, 2, 3, 4], WireFormat::PrivateMessage); let membership_key = MembershipKey::from_secret( - Secret::random( - group.ciphersuite(), - provider.rand(), - None, /* MLS version */ - ) - .expect("Not enough randomness."), + Secret::random(group.ciphersuite(), provider.rand()).expect("Not enough randomness."), ); let content = AuthenticatedContentIn::from( AuthenticatedContent::member_proposal( @@ -244,6 +233,7 @@ fn build_handshake_messages( plaintext .set_membership_tag( provider.crypto(), + group.ciphersuite(), &membership_key, &group.context().tls_serialize_detached().unwrap(), ) @@ -280,12 +270,7 @@ fn build_application_messages( let epoch = random_u64(); group.context_mut().set_epoch(epoch.into()); let membership_key = MembershipKey::from_secret( - Secret::random( - group.ciphersuite(), - provider.rand(), - None, /* MLS version */ - ) - .expect("Not enough randomness."), + Secret::random(group.ciphersuite(), provider.rand()).expect("Not enough randomness."), ); let content = AuthenticatedContent::new_application( sender_index, @@ -299,6 +284,7 @@ fn build_application_messages( plaintext .set_membership_tag( provider.crypto(), + group.ciphersuite(), &membership_key, &group.context().tls_serialize_detached().unwrap(), ) @@ -329,7 +315,8 @@ pub fn generate_test_vector( n_leaves: u32, ciphersuite: Ciphersuite, ) -> EncryptionTestVector { - use openmls_traits::random::OpenMlsRand; + use openmls_rust_crypto::OpenMlsRustCrypto; + use openmls_traits::prelude::*; use crate::binary_tree::array_representation::TreeSize; @@ -348,7 +335,7 @@ pub fn generate_test_vector( .random_vec(77) .expect("An unexpected error occurred."); let sender_data_key = sender_data_secret - .derive_aead_key(crypto.crypto(), &ciphertext) + .derive_aead_key(crypto.crypto(), ciphersuite, &ciphertext) .expect("Could not derive AEAD key."); // Derive initial nonce from the key schedule using the ciphertext. let sender_data_nonce = sender_data_secret @@ -361,11 +348,8 @@ pub fn generate_test_vector( }; let (mut group, _, signer) = group(ciphersuite, &crypto); - *group.message_secrets_test_mut().sender_data_secret_mut() = SenderDataSecret::from_slice( - sender_data_secret_bytes, - ProtocolVersion::default(), - ciphersuite, - ); + *group.message_secrets_test_mut().sender_data_secret_mut() = + SenderDataSecret::from_slice(sender_data_secret_bytes); let mut leaves = Vec::new(); for leaf in 0..n_leaves { @@ -373,18 +357,10 @@ pub fn generate_test_vector( // It doesn't matter who the receiver is, as long as it's not the same // as the sender, so we don't get into trouble with the secret tree. let receiver_leaf = LeafNodeIndex::new(u32::from(leaf == 0)); - let encryption_secret = EncryptionSecret::from_slice( - &encryption_secret_bytes[..], - ProtocolVersion::default(), - ciphersuite, - ); + let encryption_secret = EncryptionSecret::from_slice(&encryption_secret_bytes[..]); let size = TreeSize::from_leaf_count(n_leaves); let encryption_secret_tree = SecretTree::new(encryption_secret, size, sender_leaf); - let decryption_secret = EncryptionSecret::from_slice( - &encryption_secret_bytes[..], - ProtocolVersion::default(), - ciphersuite, - ); + let decryption_secret = EncryptionSecret::from_slice(&encryption_secret_bytes[..]); let mut decryption_secret_tree = SecretTree::new(decryption_secret, size, receiver_leaf); *group.message_secrets_test_mut().secret_tree_mut() = encryption_secret_tree; @@ -458,15 +434,16 @@ pub fn generate_test_vector( #[test] fn write_test_vectors() { + use openmls_traits::prelude::*; + let _ = pretty_env_logger::try_init(); - use openmls_traits::crypto::OpenMlsCrypto; let mut tests = Vec::new(); const NUM_LEAVES: u32 = 10; const NUM_GENERATIONS: u32 = 15; log::debug!("Generating new test vectors ..."); - for &ciphersuite in OpenMlsRustCrypto::default() + for &ciphersuite in openmls_rust_crypto::OpenMlsRustCrypto::default() .crypto() .supported_ciphersuites() .iter() @@ -477,7 +454,7 @@ fn write_test_vectors() { } } - write("test_vectors/kat_encryption_openmls-new.json", &tests); + crate::test_utils::write("test_vectors/kat_encryption_openmls-new.json", &tests); } #[cfg(any(feature = "test-utils", test))] @@ -490,6 +467,7 @@ pub fn run_test_vector( use crate::{ binary_tree::array_representation::TreeSize, schedule::{message_secrets::MessageSecrets, ConfirmationKey, MembershipKey}, + test_utils::hex_to_bytes, }; let n_leaves = test_vector.n_leaves; @@ -500,15 +478,13 @@ pub fn run_test_vector( let ciphersuite = Ciphersuite::try_from(test_vector.cipher_suite).expect("Invalid ciphersuite"); log::debug!("Running test vector with {:?}", ciphersuite); - let sender_data_secret = SenderDataSecret::from_slice( - hex_to_bytes(&test_vector.sender_data_secret).as_slice(), - ProtocolVersion::default(), - ciphersuite, - ); + let sender_data_secret = + SenderDataSecret::from_slice(hex_to_bytes(&test_vector.sender_data_secret).as_slice()); let sender_data_key = sender_data_secret .derive_aead_key( provider.crypto(), + ciphersuite, &hex_to_bytes(&test_vector.sender_data_info.ciphertext), ) .expect("Could not derive AEAD key."); @@ -539,11 +515,7 @@ pub fn run_test_vector( let receiver_leaf = LeafNodeIndex::new(u32::from(leaf_index == 0)); let mut secret_tree = SecretTree::new( - EncryptionSecret::from_slice( - hex_to_bytes(&test_vector.encryption_secret).as_slice(), - ProtocolVersion::default(), - ciphersuite, - ), + EncryptionSecret::from_slice(hex_to_bytes(&test_vector.encryption_secret).as_slice()), size, receiver_leaf, ); @@ -619,8 +591,6 @@ pub fn run_test_vector( *group.message_secrets_test_mut().sender_data_secret_mut() = SenderDataSecret::from_slice( hex_to_bytes(&test_vector.sender_data_secret).as_slice(), - ProtocolVersion::default(), - ciphersuite, ); // We have to take the fresh_secret_tree here because the secret_for_decryption @@ -702,8 +672,6 @@ pub fn run_test_vector( *group.message_secrets_test_mut().sender_data_secret_mut() = SenderDataSecret::from_slice( hex_to_bytes(&test_vector.sender_data_secret).as_slice(), - ProtocolVersion::default(), - ciphersuite, ); // Swap secret tree @@ -779,11 +747,7 @@ pub fn run_test_vector( mls_ciphertext_handshake.group_id().clone(), ); *group.message_secrets_test_mut().sender_data_secret_mut() = - SenderDataSecret::from_slice( - &hex_to_bytes(&test_vector.sender_data_secret), - ProtocolVersion::default(), - ciphersuite, - ); + SenderDataSecret::from_slice(&hex_to_bytes(&test_vector.sender_data_secret)); // Swap secret tree let _ = group @@ -833,36 +797,23 @@ pub fn run_test_vector( Ok(()) } -#[apply(providers)] -fn read_test_vectors_encryption(provider: &impl OpenMlsProvider) { +#[test] +fn read_test_vectors_encryption() { let _ = pretty_env_logger::try_init(); log::debug!("Reading test vectors ..."); + // The ciphersuite is defined in here and libcrux can't do all of them yet. + let provider = openmls_rust_crypto::OpenMlsRustCrypto::default(); let tests: Vec = read_json!("../../../../test_vectors/kat_encryption_openmls.json"); for test_vector in tests { - match run_test_vector(test_vector, provider) { + match run_test_vector(test_vector, &provider) { Ok(_) => {} Err(e) => panic!("Error while checking encryption test vector.\n{e:?}"), } } - // mlspp test vectors - let tv_files = [ - /* - mlspp test vectors are not compatible for now because they don't implement - the new wire_format field in framing yet. This is tracked in #495. - "test_vectors/mlspp/mlspp_encryption_1_10.json", - "test_vectors/mlspp/mlspp_encryption_2_10.json", - "test_vectors/mlspp/mlspp_encryption_3_10.json", - */ - ]; - for &tv_file in tv_files.iter() { - let tv: EncryptionTestVector = read(tv_file); - run_test_vector(tv, provider).expect("Error while checking key schedule test vector."); - } - log::trace!("Finished test vector verification"); } diff --git a/openmls/src/tree/tests_and_kats/kats/kat_message_protection.rs b/openmls/src/tree/tests_and_kats/kats/kat_message_protection.rs index 428b064109..f972186ff0 100644 --- a/openmls/src/tree/tests_and_kats/kats/kat_message_protection.rs +++ b/openmls/src/tree/tests_and_kats/kats/kat_message_protection.rs @@ -62,8 +62,7 @@ //! * When protecting the Commit message, add the supplied confirmation tag use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{types::SignatureScheme, OpenMlsProvider}; +use openmls_traits::types::SignatureScheme; use serde::{self, Deserialize, Serialize}; use crate::{ @@ -108,11 +107,11 @@ pub struct MessageProtectionTest { fn generate_credential( identity: Vec, signature_algorithm: SignatureScheme, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> (CredentialWithKey, SignatureKeyPair) { - let credential = BasicCredential::new(identity).unwrap(); + let credential = BasicCredential::new(identity); let signature_keys = SignatureKeyPair::new(signature_algorithm).unwrap(); - signature_keys.store(provider.key_store()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); ( CredentialWithKey { @@ -126,10 +125,8 @@ fn generate_credential( #[cfg(any(feature = "test-utils", test))] fn group( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> (CoreGroup, CredentialWithKey, SignatureKeyPair) { - use crate::group::config::CryptoConfig; - let (credential_with_key, signer) = generate_credential( "Kreator".into(), ciphersuite.signature_algorithm(), @@ -138,7 +135,7 @@ fn group( let group = CoreGroup::builder( GroupId::random(provider.rand()), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, credential_with_key.clone(), ) .build(provider, &signer) @@ -150,24 +147,18 @@ fn group( #[cfg(any(feature = "test-utils", test))] fn receiver_group( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, group_id: GroupId, ) -> (CoreGroup, CredentialWithKey, SignatureKeyPair) { - use crate::group::config::CryptoConfig; - let (credential_with_key, signer) = generate_credential( "Receiver".into(), ciphersuite.signature_algorithm(), provider, ); - let group = CoreGroup::builder( - group_id, - CryptoConfig::with_default_version(ciphersuite), - credential_with_key.clone(), - ) - .build(provider, &signer) - .unwrap(); + let group = CoreGroup::builder(group_id, ciphersuite, credential_with_key.clone()) + .build(provider, &signer) + .unwrap(); (group, credential_with_key, signer) } @@ -175,7 +166,7 @@ fn receiver_group( #[cfg(test)] pub fn run_test_vector( test: MessageProtectionTest, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> Result<(), String> { use openmls_traits::crypto::OpenMlsCrypto; use tls_codec::{Deserialize, Serialize}; @@ -183,7 +174,6 @@ pub fn run_test_vector( use crate::{ binary_tree::array_representation::TreeSize, extensions::Extensions, - group::config::CryptoConfig, messages::{proposals_in::ProposalIn, CommitIn, ConfirmationTag}, prelude::KeyPackageBundle, prelude_test::{Mac, Secret}, @@ -224,7 +214,7 @@ pub fn run_test_vector( // Make the group think it has two members. fn setup_group( - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ciphersuite: Ciphersuite, test: &MessageProtectionTest, sender: bool, @@ -239,7 +229,7 @@ pub fn run_test_vector( ); // Set up the group, unfortunately we can't do without. - let credential = BasicCredential::new(b"This is not needed".to_vec()).unwrap(); + let credential = BasicCredential::new(b"This is not needed".to_vec()); let signature_private_key = hex_to_bytes(&test.signature_priv); let random_own_signature_key = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); @@ -252,7 +242,7 @@ pub fn run_test_vector( let mut group = CoreGroup::builder( group_context.group_id().clone(), - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, CredentialWithKey { credential: credential.into(), signature_key: random_own_signature_key.into(), @@ -261,9 +251,9 @@ pub fn run_test_vector( .build(provider, &signer) .unwrap(); - let credential = BasicCredential::new("Fake user".into()).unwrap(); + let credential = BasicCredential::new("Fake user".into()); let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm()).unwrap(); - let bob_key_package_bundle = KeyPackageBundle::new( + let bob_key_package_bundle = KeyPackageBundle::generate( provider, &signature_keys, ciphersuite, @@ -303,11 +293,8 @@ pub fn run_test_vector( // Inject the test values into the group - let encryption_secret = EncryptionSecret::from_slice( - &hex_to_bytes(&test.encryption_secret), - group_context.protocol_version(), - ciphersuite, - ); + let encryption_secret = + EncryptionSecret::from_slice(&hex_to_bytes(&test.encryption_secret)); let own_index = LeafNodeIndex::new(0); let sender_index = LeafNodeIndex::new(1); let secret_tree = SecretTree::new(encryption_secret.clone(), TreeSize::new(2), own_index); @@ -328,16 +315,9 @@ pub fn run_test_vector( message_secrets.replace_secret_tree(secret_tree); } message_secrets.set_serialized_context(serialized_group_context); - *message_secrets.sender_data_secret_mut() = SenderDataSecret::from_slice( - &hex_to_bytes(&test.sender_data_secret), - ProtocolVersion::Mls10, - ciphersuite, - ); - message_secrets.set_membership_key(Secret::from_slice( - &hex_to_bytes(&test.membership_key), - ProtocolVersion::Mls10, - ciphersuite, - )); + *message_secrets.sender_data_secret_mut() = + SenderDataSecret::from_slice(&hex_to_bytes(&test.sender_data_secret)); + message_secrets.set_membership_key(Secret::from_slice(&hex_to_bytes(&test.membership_key))); group } @@ -352,7 +332,7 @@ pub fn run_test_vector( fn test_proposal_pub( mut group: CoreGroup, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ciphersuite: Ciphersuite, proposal: ProposalIn, proposal_pub: MlsMessageIn, @@ -374,7 +354,7 @@ pub fn run_test_vector( .parse_message(decrypted_message, group.message_secrets_store()) .unwrap(); let processed_message: AuthenticatedContent = processed_unverified_message - .verify(ciphersuite, provider.crypto(), ProtocolVersion::Mls10) + .verify(ciphersuite, provider, ProtocolVersion::Mls10) .unwrap() .0; match processed_message.content().to_owned() { @@ -393,7 +373,7 @@ pub fn run_test_vector( fn test_proposal_priv( mut group: CoreGroup, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, proposal: ProposalIn, proposal_priv: MlsMessageIn, ) { @@ -463,6 +443,7 @@ pub fn run_test_vector( my_proposal_pub .set_membership_tag( provider.crypto(), + ciphersuite, sender_group.message_secrets().membership_key(), sender_group.message_secrets().serialized_context(), ) @@ -488,7 +469,7 @@ pub fn run_test_vector( fn test_commit_pub( mut group: CoreGroup, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ciphersuite: Ciphersuite, commit: CommitIn, commit_pub: MlsMessageIn, @@ -510,7 +491,7 @@ pub fn run_test_vector( .parse_message(decrypted_message, group.message_secrets_store()) .unwrap(); let processed_message: AuthenticatedContent = processed_unverified_message - .verify(ciphersuite, provider.crypto(), ProtocolVersion::Mls10) + .verify(ciphersuite, provider, ProtocolVersion::Mls10) .unwrap() .0; match processed_message.content().to_owned() { @@ -531,7 +512,7 @@ pub fn run_test_vector( fn test_commit_priv( mut group: CoreGroup, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ciphersuite: Ciphersuite, commit: CommitIn, commit_priv: MlsMessageIn, @@ -553,7 +534,7 @@ pub fn run_test_vector( .parse_message(decrypted_message, group.message_secrets_store()) .unwrap(); let processed_message: AuthenticatedContent = processed_unverified_message - .verify(ciphersuite, provider.crypto(), ProtocolVersion::Mls10) + .verify(ciphersuite, provider, ProtocolVersion::Mls10) .unwrap() .0; match processed_message.content().to_owned() { @@ -618,6 +599,7 @@ pub fn run_test_vector( my_commit_pub_msg .set_membership_tag( provider.crypto(), + ciphersuite, sender_group.message_secrets().membership_key(), sender_group.message_secrets().serialized_context(), ) @@ -641,7 +623,7 @@ pub fn run_test_vector( fn test_application_priv( mut group: CoreGroup, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, application: Vec, application_priv: MlsMessageIn, ) { @@ -695,8 +677,8 @@ pub fn run_test_vector( Ok(()) } -#[apply(providers)] -fn read_test_vectors_mp(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn read_test_vectors_mp(provider: &impl crate::storage::OpenMlsProvider) { let _ = pretty_env_logger::try_init(); log::debug!("Reading test vectors ..."); diff --git a/openmls/src/tree/tests_and_kats/kats/secret_tree.rs b/openmls/src/tree/tests_and_kats/kats/secret_tree.rs index 585b12eee8..1e33bf78bf 100644 --- a/openmls/src/tree/tests_and_kats/kats/secret_tree.rs +++ b/openmls/src/tree/tests_and_kats/kats/secret_tree.rs @@ -51,8 +51,6 @@ use serde::{Deserialize, Serialize}; -use crate::test_utils::*; - #[derive(Serialize, Deserialize, Debug, Clone)] pub struct SenderData { sender_data_secret: String, @@ -80,14 +78,17 @@ pub struct SecretTree { } #[cfg(test)] -pub fn run_test_vector(test: SecretTree, provider: &impl OpenMlsProvider) -> Result<(), String> { - use openmls_traits::crypto::OpenMlsCrypto; +pub fn run_test_vector( + test: SecretTree, + provider: &Provider, +) -> Result<(), String> { + use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite}; use crate::{ binary_tree::{array_representation::TreeSize, LeafNodeIndex}, schedule::{EncryptionSecret, SenderDataSecret}, + test_utils::hex_to_bytes, tree::secret_tree::{SecretTree, SecretType}, - versions::ProtocolVersion, }; let ciphersuite = Ciphersuite::try_from(test.cipher_suite).unwrap(); @@ -104,14 +105,13 @@ pub fn run_test_vector(test: SecretTree, provider: &impl OpenMlsProvider) -> Res // Check sender data let sender_data_secret = hex_to_bytes(&test.sender_data.sender_data_secret); - let sender_data_secret = - SenderDataSecret::from_slice(&sender_data_secret, ProtocolVersion::Mls10, ciphersuite); + let sender_data_secret = SenderDataSecret::from_slice(&sender_data_secret); let sender_data_ciphertext = hex_to_bytes(&test.sender_data.ciphertext); let sender_data_key = hex_to_bytes(&test.sender_data.key); let sender_data_nonce = hex_to_bytes(&test.sender_data.nonce); let my_sender_data_key = sender_data_secret - .derive_aead_key(provider.crypto(), &sender_data_ciphertext) + .derive_aead_key(provider.crypto(), ciphersuite, &sender_data_ciphertext) .unwrap(); assert_eq!(&sender_data_key, my_sender_data_key.as_slice()); let my_sender_data_nonce = sender_data_secret @@ -131,11 +131,7 @@ pub fn run_test_vector(test: SecretTree, provider: &impl OpenMlsProvider) -> Res log::trace!(" Testing generation {generation}"); let mut secret_tree = SecretTree::new( - EncryptionSecret::from_slice( - &encryption_secret, - ProtocolVersion::Mls10, - ciphersuite, - ), + EncryptionSecret::from_slice(&encryption_secret), TreeSize::new(num_leaves as u32), LeafNodeIndex::new(leaf_index as u32), ); @@ -187,8 +183,8 @@ pub fn run_test_vector(test: SecretTree, provider: &impl OpenMlsProvider) -> Res Ok(()) } -#[apply(providers)] -fn read_test_vectors_st(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn read_test_vectors_st() { let _ = pretty_env_logger::try_init(); log::debug!("Reading test vectors ..."); diff --git a/openmls/src/tree/tests_and_kats/unit_tests/test_secret_tree.rs b/openmls/src/tree/tests_and_kats/unit_tests/test_secret_tree.rs index 33d3f06381..f4e944c0e3 100644 --- a/openmls/src/tree/tests_and_kats/unit_tests/test_secret_tree.rs +++ b/openmls/src/tree/tests_and_kats/unit_tests/test_secret_tree.rs @@ -1,4 +1,3 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::random::OpenMlsRand; use crate::{ @@ -6,13 +5,12 @@ use crate::{ schedule::EncryptionSecret, test_utils::*, tree::{secret_tree::*, sender_ratchet::SenderRatchetConfiguration}, - versions::ProtocolVersion, }; use std::collections::HashMap; // This tests the boundaries of the generations from a SecretTree -#[apply(ciphersuites_and_providers)] -fn test_boundaries(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_boundaries() { let configuration = &SenderRatchetConfiguration::default(); let encryption_secret = EncryptionSecret::random(ciphersuite, provider.rand()); let mut secret_tree = SecretTree::new( @@ -156,8 +154,8 @@ fn test_boundaries(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // This tests if the generation gets incremented correctly and that the returned // values are unique. -#[apply(ciphersuites_and_providers)] -fn increment_generation(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn increment_generation() { const SIZE: usize = 100; const MAX_GENERATIONS: usize = 10; @@ -223,8 +221,8 @@ fn increment_generation(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide } } -#[apply(ciphersuites_and_providers)] -fn secret_tree(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn secret_tree() { let leaf_index = 0u32; let generation = 0; let n_leaves = 10u32; @@ -235,8 +233,6 @@ fn secret_tree(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .rand() .random_vec(ciphersuite.hash_length()) .expect("An unexpected error occurred.")[..], - ProtocolVersion::default(), - ciphersuite, ), TreeSize::new(n_leaves), LeafNodeIndex::new(1u32), diff --git a/openmls/src/tree/tests_and_kats/unit_tests/test_sender_ratchet.rs b/openmls/src/tree/tests_and_kats/unit_tests/test_sender_ratchet.rs index 61ba6d6079..7bf8de9144 100644 --- a/openmls/src/tree/tests_and_kats/unit_tests/test_sender_ratchet.rs +++ b/openmls/src/tree/tests_and_kats/unit_tests/test_sender_ratchet.rs @@ -1,16 +1,12 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; - use crate::{ - ciphersuite::Secret, test_utils::*, tree::secret_tree::SecretTreeError, - tree::sender_ratchet::*, versions::ProtocolVersion, + ciphersuite::Secret, test_utils::*, tree::secret_tree::SecretTreeError, tree::sender_ratchet::*, }; // Test the maximum forward ratcheting -#[apply(ciphersuites_and_providers)] -fn test_max_forward_distance(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_max_forward_distance() { let configuration = &SenderRatchetConfiguration::default(); - let secret = Secret::random(ciphersuite, provider.rand(), ProtocolVersion::Mls10) - .expect("Not enough randomness."); + let secret = Secret::random(ciphersuite, provider.rand()).expect("Not enough randomness."); let mut ratchet1 = DecryptionRatchet::new(secret.clone()); let mut ratchet2 = DecryptionRatchet::new(secret); @@ -44,11 +40,10 @@ fn test_max_forward_distance(ciphersuite: Ciphersuite, provider: &impl OpenMlsPr } // Test out-of-order generations -#[apply(ciphersuites_and_providers)] -fn test_out_of_order_generations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_out_of_order_generations() { let configuration = &SenderRatchetConfiguration::default(); - let secret = Secret::random(ciphersuite, provider.rand(), ProtocolVersion::Mls10) - .expect("Not enough randomness."); + let secret = Secret::random(ciphersuite, provider.rand()).expect("Not enough randomness."); let mut ratchet1 = DecryptionRatchet::new(secret); // Ratchet forward twice the size of the window @@ -82,13 +77,12 @@ fn test_out_of_order_generations(ciphersuite: Ciphersuite, provider: &impl OpenM } // Test forward secrecy -#[apply(ciphersuites_and_providers)] -fn test_forward_secrecy(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_forward_secrecy() { // Encryption Ratchets are forward-secret by default, since they don't store // any keys. Thus, we can only test FS on Decryption Ratchets. let configuration = &SenderRatchetConfiguration::default(); - let secret = Secret::random(ciphersuite, provider.rand(), ProtocolVersion::Mls10) - .expect("Not enough randomness."); + let secret = Secret::random(ciphersuite, provider.rand()).expect("Not enough randomness."); let mut ratchet = DecryptionRatchet::new(secret); // Let's ratchet once and see if the ratchet keeps any keys around. @@ -138,8 +132,7 @@ fn test_forward_secrecy(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide fn sender_ratchet_generation_overflow() { let provider = OpenMlsRustCrypto::default(); let ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519; - let secret = Secret::random(ciphersuite, provider.rand(), ProtocolVersion::Mls10) - .expect("Not enough randomness."); + let secret = Secret::random(ciphersuite, provider.rand()).expect("Not enough randomness."); let mut ratchet = RatchetSecret::initial_ratchet_secret(secret); ratchet.set_generation(u32::MAX - 1); let _ = ratchet diff --git a/openmls/src/treesync/diff.rs b/openmls/src/treesync/diff.rs index 4b7a83901e..71a1624413 100644 --- a/openmls/src/treesync/diff.rs +++ b/openmls/src/treesync/diff.rs @@ -271,7 +271,7 @@ impl<'a> TreeSyncDiff<'a> { leaf_index: LeafNodeIndex, ) -> Result { let path_secret = PathSecret::from( - Secret::random(ciphersuite, provider.rand(), None) + Secret::random(ciphersuite, provider.rand()) .map_err(LibraryError::unexpected_crypto_error)?, ); diff --git a/openmls/src/treesync/mod.rs b/openmls/src/treesync/mod.rs index 462dee3c3f..d1813b59ac 100644 --- a/openmls/src/treesync/mod.rs +++ b/openmls/src/treesync/mod.rs @@ -18,12 +18,6 @@ // Finally, this module contains the [`treekem`] module, which allows the // encryption and decryption of updates to the tree. -#[cfg(test)] -use openmls_rust_crypto::OpenMlsRustCrypto; -#[cfg(test)] -use rstest::*; -#[cfg(test)] -use rstest_reuse::apply; #[cfg(any(feature = "test-utils", test))] use std::fmt; @@ -64,7 +58,7 @@ use crate::{ credentials::CredentialWithKey, error::LibraryError, extensions::Extensions, - group::{config::CryptoConfig, GroupId, Member}, + group::{GroupId, Member}, key_packages::Lifetime, messages::{PathSecret, PathSecretError}, schedule::CommitSecret, @@ -347,14 +341,14 @@ impl fmt::Display for RatchetTree { } } -/// The [`TreeSync`] struct holds an [`MlsBinaryTree`] instance, which contains +/// The [`TreeSync`] struct holds an `MlsBinaryTree` instance, which contains /// the state that is synced across the group, as well as the [`LeafNodeIndex`] /// pointing to the leaf of this group member and the current hash of the tree. /// /// It follows the same pattern of tree and diff as the underlying -/// [`MlsBinaryTree`], where the [`TreeSync`] instance is immutable safe for -/// merging a [`TreeSyncDiff`], which can be created, staged and merged (see -/// [`TreeSyncDiff`]). +/// `MlsBinaryTree`, where the [`TreeSync`] instance is immutable safe for +/// merging a `TreeSyncDiff`, which can be created, staged and merged (see +/// `TreeSyncDiff`). /// /// [`TreeSync`] instance guarantee a few invariants that are checked upon /// creating a new instance from an imported set of nodes, as well as when @@ -374,14 +368,14 @@ impl TreeSync { pub(crate) fn new( provider: &impl OpenMlsProvider, signer: &impl Signer, - config: CryptoConfig, + ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, life_time: Lifetime, capabilities: Capabilities, extensions: Extensions, ) -> Result<(Self, CommitSecret, EncryptionKeyPair), LibraryError> { let new_leaf_node_params = NewLeafNodeParams { - config, + ciphersuite, credential_with_key, // Creation of a group is considered to be from a key package. leaf_node_source: LeafNodeSource::KeyPackage(life_time), @@ -392,11 +386,11 @@ impl TreeSync { let (leaf, encryption_key_pair) = LeafNode::new(provider, signer, new_leaf_node_params)?; let node = Node::LeafNode(leaf); - let path_secret: PathSecret = Secret::random(config.ciphersuite, provider.rand(), None) + let path_secret: PathSecret = Secret::random(ciphersuite, provider.rand()) .map_err(LibraryError::unexpected_crypto_error)? .into(); let commit_secret: CommitSecret = path_secret - .derive_path_secret(provider.crypto(), config.ciphersuite)? + .derive_path_secret(provider.crypto(), ciphersuite)? .into(); let nodes = vec![TreeSyncNode::from(node).into()]; let tree = MlsBinaryTree::new(nodes) @@ -406,7 +400,7 @@ impl TreeSync { tree_hash: vec![], }; // Populate tree hash caches. - tree_sync.populate_parent_hashes(provider.crypto(), config.ciphersuite)?; + tree_sync.populate_parent_hashes(provider.crypto(), ciphersuite)?; Ok((tree_sync, commit_secret, encryption_key_pair)) } @@ -728,7 +722,7 @@ mod test { RatchetTree::trimmed(vec![None]); } - #[apply(ciphersuites_and_providers)] + #[openmls_test::openmls_test] fn test_ratchet_tree_trailing_blank_nodes( ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider, diff --git a/openmls/src/treesync/node/encryption_keys.rs b/openmls/src/treesync/node/encryption_keys.rs index a3edf9510f..0aab483a4d 100644 --- a/openmls/src/treesync/node/encryption_keys.rs +++ b/openmls/src/treesync/node/encryption_keys.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use openmls_traits::{ crypto::OpenMlsCrypto, - key_store::{MlsEntity, MlsEntityId, OpenMlsKeyStore}, + storage::{StorageProvider as StorageProviderTrait, CURRENT_VERSION}, types::{Ciphersuite, HpkeCiphertext, HpkeKeyPair}, OpenMlsProvider, }; @@ -12,8 +12,7 @@ use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBy use crate::{ ciphersuite::{hpke, HpkePrivateKey, HpkePublicKey, Secret}, error::LibraryError, - group::config::CryptoConfig, - versions::ProtocolVersion, + storage::StorageProvider, }; /// [`EncryptionKey`] contains an HPKE public key that allows the encryption of @@ -46,16 +45,6 @@ impl EncryptionKey { self.key.as_slice() } - /// Helper function to prefix the given (serialized) [`EncryptionKey`] with - /// the `ENCRYPTION_KEY_LABEL`. - /// - /// Returns the resulting bytes. - fn to_bytes_with_prefix(&self) -> Vec { - let mut key_store_index = ENCRYPTION_KEY_LABEL.to_vec(); - key_store_index.extend_from_slice(self.as_slice()); - key_store_index - } - /// Encrypt to this HPKE public key. pub(crate) fn encrypt( &self, @@ -80,7 +69,7 @@ impl EncryptionKey { Clone, Serialize, Deserialize, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, )] #[cfg_attr(test, derive(PartialEq, Eq))] -pub(crate) struct EncryptionPrivateKey { +pub struct EncryptionPrivateKey { key: HpkePrivateKey, } @@ -119,7 +108,6 @@ impl EncryptionPrivateKey { &self, crypto: &impl OpenMlsCrypto, ciphersuite: Ciphersuite, - version: ProtocolVersion, ciphertext: &HpkeCiphertext, group_context: &[u8], ) -> Result { @@ -132,11 +120,11 @@ impl EncryptionPrivateKey { ciphersuite, crypto, ) - .map(|secret_bytes| Secret::from_slice(&secret_bytes, version, ciphersuite)) + .map(|secret_bytes| Secret::from_slice(&secret_bytes)) } } -#[cfg(test)] +#[cfg(any(test, feature = "test-utils"))] impl EncryptionPrivateKey { pub(crate) fn key(&self) -> &HpkePrivateKey { &self.key @@ -158,45 +146,49 @@ pub(crate) struct EncryptionKeyPair { private_key: EncryptionPrivateKey, } -const ENCRYPTION_KEY_LABEL: &[u8; 19] = b"leaf_encryption_key"; - impl EncryptionKeyPair { - /// Write the [`EncryptionKeyPair`] to the key store of the `provider`. This - /// function is meant to store standalone keypairs, not ones that are - /// already in use with an MLS group. + /// Write the [`EncryptionKeyPair`] to the store of the `provider`. /// - /// Returns a key store error if access to the key store fails. - pub(crate) fn write_to_key_store( + /// This must only be used for encryption key pairs that are generated for + /// update leaf nodes. All other encryption key pairs are stored as part + /// of the key package or the epoch encryption key pairs. + pub(crate) fn write( &self, - store: &KeyStore, - ) -> Result<(), KeyStore::Error> { - store.store(&self.public_key().to_bytes_with_prefix(), self) + store: &Storage, + ) -> Result<(), Storage::Error> { + store.write_encryption_key_pair(self.public_key(), self) } /// Read the [`EncryptionKeyPair`] from the key store of the `provider`. This /// function is meant to read standalone keypairs, not ones that are /// already in use with an MLS group. /// + /// This must only be used for encryption key pairs that are generated for + /// update leaf nodes. All other encryption key pairs are stored as part + /// of the key package or the epoch encryption key pairs. + /// /// Returns `None` if the keypair cannot be read from the store. - pub(crate) fn read_from_key_store( + pub(crate) fn read( provider: &impl OpenMlsProvider, encryption_key: &EncryptionKey, ) -> Option { provider - .key_store() - .read(&encryption_key.to_bytes_with_prefix()) + .storage() + .encryption_key_pair(encryption_key) + .ok() + .flatten() } - /// Delete the [`EncryptionKeyPair`] from the key store of the `provider`. - /// This function is meant to delete standalone keypairs, not ones that are - /// already in use with an MLS group. + /// Delete the [`EncryptionKeyPair`] from the store of the `provider`. /// - /// Returns a key store error if access to the key store fails. - pub(crate) fn delete_from_key_store( + /// This must only be used for encryption key pairs that are generated for + /// update leaf nodes. All other encryption key pairs are stored as part + /// of the key package or the epoch encryption key pairs. + pub(crate) fn delete>( &self, - store: &KeyStore, - ) -> Result<(), KeyStore::Error> { - store.delete::(&self.public_key().to_bytes_with_prefix()) + store: &Storage, + ) -> Result<(), Storage::Error> { + store.delete_encryption_key_pair(self.public_key()) } pub(crate) fn public_key(&self) -> &EncryptionKey { @@ -209,13 +201,13 @@ impl EncryptionKeyPair { pub(crate) fn random( provider: &impl OpenMlsProvider, - config: CryptoConfig, + ciphersuite: Ciphersuite, ) -> Result { - let ikm = Secret::random(config.ciphersuite, provider.rand(), config.version) + let ikm = Secret::random(ciphersuite, provider.rand()) .map_err(LibraryError::unexpected_crypto_error)?; Ok(provider .crypto() - .derive_hpke_keypair(config.ciphersuite.hpke_config(), ikm.as_slice()) + .derive_hpke_keypair(ciphersuite.hpke_config(), ikm.as_slice()) .map_err(LibraryError::unexpected_crypto_error)? .into()) } @@ -229,7 +221,7 @@ pub mod test_utils { provider: &impl OpenMlsProvider, encryption_key: &EncryptionKey, ) -> HpkeKeyPair { - let keys = EncryptionKeyPair::read_from_key_store(provider, encryption_key).unwrap(); + let keys = EncryptionKeyPair::read(provider, encryption_key).unwrap(); HpkeKeyPair { private: keys.private_key.key, @@ -240,7 +232,7 @@ pub mod test_utils { pub fn write_keys_from_key_store(provider: &impl OpenMlsProvider, encryption_key: HpkeKeyPair) { let keypair = EncryptionKeyPair::from(encryption_key); - keypair.write_to_key_store(provider.key_store()).unwrap(); + keypair.write(provider.storage()).unwrap(); } } @@ -287,7 +279,3 @@ impl From<(EncryptionKey, EncryptionPrivateKey)> for EncryptionKeyPair { } } } - -impl MlsEntity for EncryptionKeyPair { - const ID: MlsEntityId = MlsEntityId::EncryptionKeyPair; -} diff --git a/openmls/src/treesync/node/leaf_node.rs b/openmls/src/treesync/node/leaf_node.rs index e7bbbb6dac..c90a7b69aa 100644 --- a/openmls/src/treesync/node/leaf_node.rs +++ b/openmls/src/treesync/node/leaf_node.rs @@ -1,13 +1,11 @@ //! This module contains the [`LeafNode`] struct and its implementation. -use openmls_traits::{signatures::Signer, types::Ciphersuite, OpenMlsProvider}; +use openmls_traits::{signatures::Signer, types::Ciphersuite}; use serde::{Deserialize, Serialize}; use tls_codec::{ Serialize as TlsSerializeTrait, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes, }; -#[cfg(test)] -use openmls_traits::key_store::OpenMlsKeyStore; #[cfg(test)] use thiserror::Error; @@ -21,10 +19,11 @@ use crate::{ credentials::{Credential, CredentialWithKey}, error::LibraryError, extensions::{ExtensionType, Extensions}, - group::{config::CryptoConfig, GroupId}, + group::GroupId, key_packages::{KeyPackage, Lifetime}, + prelude::KeyPackageBundle, + storage::OpenMlsProvider, treesync::errors::PublicTreeError, - versions::ProtocolVersion, }; use crate::treesync::errors::LeafNodeValidationError; @@ -35,7 +34,7 @@ mod codec; pub use capabilities::*; pub(crate) struct NewLeafNodeParams { - pub(crate) config: CryptoConfig, + pub(crate) ciphersuite: Ciphersuite, pub(crate) credential_with_key: CredentialWithKey, pub(crate) leaf_node_source: LeafNodeSource, pub(crate) capabilities: Capabilities, @@ -91,7 +90,7 @@ impl LeafNode { new_leaf_node_params: NewLeafNodeParams, ) -> Result<(Self, EncryptionKeyPair), LibraryError> { let NewLeafNodeParams { - config, + ciphersuite, credential_with_key, leaf_node_source, capabilities, @@ -100,7 +99,7 @@ impl LeafNode { } = new_leaf_node_params; // Create a new encryption key pair. - let encryption_key_pair = EncryptionKeyPair::random(provider, config)?; + let encryption_key_pair = EncryptionKeyPair::random(provider, ciphersuite)?; let leaf_node = Self::new_with_key( encryption_key_pair.public_key().clone(), @@ -175,15 +174,15 @@ impl LeafNode { /// This function can be used when generating an update. In most other cases /// a leaf node should be generated as part of a new [`KeyPackage`]. #[cfg(test)] - pub(crate) fn updated( + pub(crate) fn updated( &self, - config: CryptoConfig, + ciphersuite: Ciphersuite, tree_info_tbs: TreeInfoTbs, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, - ) -> Result> { + ) -> Result> { Self::generate_update( - config, + ciphersuite, CredentialWithKey { credential: self.payload.credential.clone(), signature_key: self.payload.signature_key.clone(), @@ -204,20 +203,20 @@ impl LeafNode { /// This function can be used when generating an update. In most other cases /// a leaf node should be generated as part of a new [`KeyPackage`]. #[cfg(test)] - pub(crate) fn generate_update( - config: CryptoConfig, + pub(crate) fn generate_update( + ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, capabilities: Capabilities, extensions: Extensions, tree_info_tbs: TreeInfoTbs, - provider: &impl OpenMlsProvider, + provider: &Provider, signer: &impl Signer, - ) -> Result> { + ) -> Result> { // Note that this function is supposed to be used in the public API only // because it is interacting with the key store. let new_leaf_node_params = NewLeafNodeParams { - config, + ciphersuite, credential_with_key, leaf_node_source: LeafNodeSource::Update, capabilities, @@ -229,8 +228,8 @@ impl LeafNode { // Store the encryption key pair in the key store. encryption_key_pair - .write_to_key_store(provider.key_store()) - .map_err(LeafNodeGenerationError::KeyStoreError)?; + .write(provider.storage()) + .map_err(LeafNodeGenerationError::StorageError)?; Ok(leaf_node) } @@ -284,7 +283,6 @@ impl LeafNode { group_id: &GroupId, leaf_index: LeafNodeIndex, ciphersuite: Ciphersuite, - protocol_version: ProtocolVersion, provider: &impl OpenMlsProvider, signer: &impl Signer, ) -> Result { @@ -293,26 +291,19 @@ impl LeafNode { .capabilities .ciphersuites .contains(&ciphersuite.into()) - || !self.capabilities().versions.contains(&protocol_version) { debug_assert!( false, "Ciphersuite or protocol version is not supported by this leaf node.\ - \ncapabilities: {:?}\nprotocol version: {:?}\nciphersuite: {:?}", - self.payload.capabilities, protocol_version, ciphersuite + \ncapabilities: {:?}\nciphersuite: {:?}", + self.payload.capabilities, ciphersuite ); return Err(LibraryError::custom( "Ciphersuite or protocol version is not supported by this leaf node.", ) .into()); } - let key_pair = EncryptionKeyPair::random( - provider, - CryptoConfig { - ciphersuite, - version: protocol_version, - }, - )?; + let key_pair = EncryptionKeyPair::random(provider, ciphersuite)?; self.update_and_re_sign( key_pair.public_key().clone(), @@ -397,73 +388,6 @@ impl LeafNode { } } -#[cfg(test)] -impl LeafNode { - /// Expose [`new_with_key`] for tests. - pub(crate) fn create_new_with_key( - encryption_key: EncryptionKey, - credential_with_key: CredentialWithKey, - leaf_node_source: LeafNodeSource, - capabilities: Capabilities, - extensions: Extensions, - tree_info_tbs: TreeInfoTbs, - signer: &impl Signer, - ) -> Result { - Self::new_with_key( - encryption_key, - credential_with_key, - leaf_node_source, - capabilities, - extensions, - tree_info_tbs, - signer, - ) - } - - /// Return a mutable reference to [`Capabilities`]. - pub fn capabilities_mut(&mut self) -> &mut Capabilities { - &mut self.payload.capabilities - } -} - -#[cfg(any(feature = "test-utils", test))] -impl LeafNode { - /// Replace the credential in the KeyPackage. - pub(crate) fn set_credential(&mut self, credential: Credential) { - self.payload.credential = credential; - } - - /// Replace the signature key in the KeyPackage. - pub(crate) fn set_signature_key(&mut self, signature_key: SignaturePublicKey) { - self.payload.signature_key = signature_key; - } - - /// Resign the node - pub(crate) fn resign( - &mut self, - signer: &impl Signer, - credential_with_key: CredentialWithKey, - tree_info_tbs: TreeInfoTbs, - ) { - let leaf_node_tbs = LeafNodeTbs::new( - self.payload.encryption_key.clone(), - credential_with_key, - self.payload.capabilities.clone(), - self.payload.leaf_node_source.clone(), - self.payload.extensions.clone(), - tree_info_tbs, - ) - .unwrap(); - - let leaf_node = leaf_node_tbs - .sign(signer) - .map_err(|_| LibraryError::custom("Signing failed")) - .unwrap(); - self.payload = leaf_node.payload; - self.signature = leaf_node.signature; - } -} - /// The payload of a [`LeafNode`] /// /// ```text @@ -728,6 +652,12 @@ impl From for LeafNode { } } +impl From for LeafNode { + fn from(key_package: KeyPackageBundle) -> Self { + key_package.key_package().leaf_node().clone() + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum VerifiableLeafNode { KeyPackage(VerifiableKeyPackageLeafNode), @@ -917,11 +847,12 @@ impl SignedStruct for LeafNode { #[cfg(test)] #[derive(Error, Debug, PartialEq, Clone)] -pub enum LeafNodeGenerationError { +pub enum LeafNodeGenerationError { /// See [`LibraryError`] for more details. #[error(transparent)] LibraryError(#[from] LibraryError), - /// Error storing leaf private key in key store. - #[error("Error storing leaf private key in key store.")] - KeyStoreError(KeyStoreError), + + /// Error storing leaf private key in storage. + #[error("Error storing leaf private key.")] + StorageError(StorageError), } diff --git a/openmls/src/treesync/node/leaf_node/capabilities.rs b/openmls/src/treesync/node/leaf_node/capabilities.rs index ffdfe55f59..5dfc510828 100644 --- a/openmls/src/treesync/node/leaf_node/capabilities.rs +++ b/openmls/src/treesync/node/leaf_node/capabilities.rs @@ -211,6 +211,7 @@ pub(super) fn default_ciphersuites() -> Vec { Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519, Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256, Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519, + Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519, ] } @@ -255,7 +256,7 @@ mod tests { #[test] fn that_unknown_capabilities_are_de_serialized_correctly() { - let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Mls10Draft11]; + let versions = vec![ProtocolVersion::Mls10, ProtocolVersion::Other(999)]; let ciphersuites = vec![ Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519.into(), Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256.into(), @@ -276,7 +277,7 @@ mod tests { ExtensionType::Unknown(0xFAFA), ]; - let proposals = vec![ProposalType::Unknown(0x7A7A)]; + let proposals = vec![ProposalType::Custom(0x7A7A)]; let credentials = vec![ CredentialType::Basic, diff --git a/openmls/src/treesync/tests_and_kats/kats/kat_tree_operations.rs b/openmls/src/treesync/tests_and_kats/kats/kat_tree_operations.rs index 727fd5dd7b..4aaf4dc955 100644 --- a/openmls/src/treesync/tests_and_kats/kats/kat_tree_operations.rs +++ b/openmls/src/treesync/tests_and_kats/kats/kat_tree_operations.rs @@ -34,7 +34,6 @@ use crate::{ messages::{proposals::Proposal, proposals_in::ProposalIn, ConfirmationTag}, test_utils::*, treesync::{node::NodeIn, RatchetTree, RatchetTreeIn, TreeSync}, - versions::ProtocolVersion, }; #[derive(Deserialize)] @@ -79,7 +78,8 @@ fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result let initial_confirmation_tag = ConfirmationTag( Mac::new( provider.crypto(), - &Secret::random(ciphersuite, provider.rand(), ProtocolVersion::Mls10).unwrap(), + ciphersuite, + &Secret::random(ciphersuite, provider.rand()).unwrap(), &[], ) .unwrap(), @@ -134,8 +134,8 @@ fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result Ok(()) } -#[apply(providers)] -fn read_test_vectors_tree_operations(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn read_test_vectors_tree_operations() { let _ = pretty_env_logger::try_init(); log::debug!("Reading test vectors ..."); diff --git a/openmls/src/treesync/tests_and_kats/kats/kat_tree_validation.rs b/openmls/src/treesync/tests_and_kats/kats/kat_tree_validation.rs index b0537466ec..475b0072a7 100644 --- a/openmls/src/treesync/tests_and_kats/kats/kat_tree_validation.rs +++ b/openmls/src/treesync/tests_and_kats/kats/kat_tree_validation.rs @@ -136,8 +136,8 @@ fn run_test_vector(test: TestElement, provider: &impl OpenMlsProvider) -> Result Ok(()) } -#[apply(providers)] -fn read_test_vectors_tree_validation(provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn read_test_vectors_tree_validation() { let _ = pretty_env_logger::try_init(); log::debug!("Reading test vectors ..."); diff --git a/openmls/src/treesync/tests_and_kats/kats/kat_treekem.rs b/openmls/src/treesync/tests_and_kats/kats/kat_treekem.rs index 9df787d8cc..f33aaf4b43 100644 --- a/openmls/src/treesync/tests_and_kats/kats/kat_treekem.rs +++ b/openmls/src/treesync/tests_and_kats/kats/kat_treekem.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; use log::{debug, trace}; use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::{crypto::OpenMlsCrypto, types::Ciphersuite, OpenMlsProvider}; use serde::{Deserialize, Serialize}; use tls_codec::{Deserialize as TlsDeserializeTrait, Serialize as TlsSerializeTrait}; @@ -14,13 +13,12 @@ use crate::{ messages::PathSecret, prelude_test::Secret, schedule::CommitSecret, - test_utils::hex_to_bytes, + test_utils::{hex_to_bytes, OpenMlsRustCrypto}, treesync::{ node::encryption_keys::EncryptionKeyPair, treekem::{DecryptPathParams, UpdatePath, UpdatePathIn}, TreeSync, }, - versions::ProtocolVersion, }; #[derive(Serialize, Deserialize, Debug, Clone)] @@ -125,11 +123,7 @@ pub fn run_test_vector(test: TreeKemTest, provider: &impl OpenMlsProvider) { )]; for path_secret in path_secrets_test { - let my_path_secret = PathSecret::from(Secret::from_slice( - &path_secret.path_secret, - ProtocolVersion::Mls10, - ciphersuite, - )); + let my_path_secret = PathSecret::from(Secret::from_slice(&path_secret.path_secret)); let keypair = my_path_secret .derive_key_pair(provider.crypto(), ciphersuite) .unwrap(); @@ -299,7 +293,6 @@ pub fn run_test_vector(test: TreeKemTest, provider: &impl OpenMlsProvider) { } let params = DecryptPathParams { - version: ProtocolVersion::Mls10, update_path: update_path.nodes(), sender_leaf_index: LeafNodeIndex::new(path_test.sender), exclusion_list: &HashSet::default(), @@ -336,7 +329,6 @@ fn apply_update_path( leaf_node_info_test: &LeafNodeInfoTest, ) -> CommitSecret { let params = DecryptPathParams { - version: ProtocolVersion::Mls10, update_path: update_path.nodes(), sender_leaf_index: LeafNodeIndex::new(sender), exclusion_list: &HashSet::default(), @@ -366,11 +358,7 @@ fn apply_update_path( .as_ref() .unwrap(); - let path_secret = PathSecret::from(Secret::from_slice( - &hex_to_bytes(expected_path_secret), - ProtocolVersion::Mls10, - ciphersuite, - )); + let path_secret = PathSecret::from(Secret::from_slice(&hex_to_bytes(expected_path_secret))); path_secret .derive_key_pair(provider.crypto(), ciphersuite) diff --git a/openmls/src/treesync/tests_and_kats/tests.rs b/openmls/src/treesync/tests_and_kats/tests.rs index 6b16ffadeb..210a2986ad 100644 --- a/openmls/src/treesync/tests_and_kats/tests.rs +++ b/openmls/src/treesync/tests_and_kats/tests.rs @@ -1,6 +1,3 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; -use tls_codec::*; - use crate::{ group::{ tests::utils::{generate_credential_with_key, CredentialWithKeyAndSigner}, @@ -8,7 +5,7 @@ use crate::{ }, key_packages::KeyPackage, prelude::*, - test_utils::*, + storage::OpenMlsProvider, }; mod test_diff; @@ -16,32 +13,27 @@ mod test_unmerged_leaves; /// Pathological example taken from ... /// https://github.com/mlswg/mls-protocol/issues/690#issue-1244086547. -#[apply(ciphersuites_and_providers)] -fn that_commit_secret_is_derived_from_end_of_update_path_not_root( - ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, -) { - let _ = provider; // get rid of warning - let crypto_config = CryptoConfig::with_default_version(ciphersuite); +#[openmls_test::openmls_test] +fn that_commit_secret_is_derived_from_end_of_update_path_not_root() { let mls_group_create_config = MlsGroupCreateConfig::builder() - .crypto_config(crypto_config) + .ciphersuite(ciphersuite) .use_ratchet_tree_extension(true) .build(); - struct Member { + struct Member { id: Vec, credential_with_key_and_signer: CredentialWithKeyAndSigner, key_package: KeyPackage, - // FIXME: the own_leaf_index from the group is beeing computed incorrectly, so we can't use + // FIXME: the own_leaf_index from the group is being computed incorrectly, so we can't use // the provider from the function parameter. #1221 - provider: OpenMlsRustCrypto, + provider: Provider, } - fn create_member( + fn create_member( ciphersuite: Ciphersuite, - provider: OpenMlsRustCrypto, + provider: Provider, name: Vec, - ) -> Member { + ) -> Member { let credential_with_key_and_signer = generate_credential_with_key( name.clone(), ciphersuite.signature_algorithm(), @@ -49,7 +41,7 @@ fn that_commit_secret_is_derived_from_end_of_update_path_not_root( ); let key_package = KeyPackage::builder() .build( - CryptoConfig::with_default_version(ciphersuite), + ciphersuite, &provider, &credential_with_key_and_signer.signer, credential_with_key_and_signer.credential_with_key.clone(), @@ -59,7 +51,7 @@ fn that_commit_secret_is_derived_from_end_of_update_path_not_root( Member { id: name, credential_with_key_and_signer, - key_package, + key_package: key_package.key_package().clone(), provider, } } @@ -68,9 +60,7 @@ fn that_commit_secret_is_derived_from_end_of_update_path_not_root( group .members() .find_map(|member| { - let identity = - VLBytes::tls_deserialize_exact(member.credential.serialized_content()).unwrap(); - if identity.as_slice() == target_id { + if member.credential.serialized_content() == target_id { Some(member.index) } else { None @@ -79,10 +69,10 @@ fn that_commit_secret_is_derived_from_end_of_update_path_not_root( .unwrap() } - let alice = create_member(ciphersuite, OpenMlsRustCrypto::default(), "alice".into()); - let bob = create_member(ciphersuite, OpenMlsRustCrypto::default(), "bob".into()); - let charlie = create_member(ciphersuite, OpenMlsRustCrypto::default(), "charlie".into()); - let dave = create_member(ciphersuite, OpenMlsRustCrypto::default(), "dave".into()); + let alice = create_member(ciphersuite, Provider::default(), "alice".into()); + let bob = create_member(ciphersuite, Provider::default(), "bob".into()); + let charlie = create_member(ciphersuite, Provider::default(), "charlie".into()); + let dave = create_member(ciphersuite, Provider::default(), "dave".into()); // `A` creates a group with `B`, `C`, and `D` ... let mut alice_group = MlsGroup::new( diff --git a/openmls/src/treesync/tests_and_kats/tests/test_diff.rs b/openmls/src/treesync/tests_and_kats/tests/test_diff.rs index 0a74405e4b..5e6f486236 100644 --- a/openmls/src/treesync/tests_and_kats/tests/test_diff.rs +++ b/openmls/src/treesync/tests_and_kats/tests/test_diff.rs @@ -1,7 +1,4 @@ -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{types::Ciphersuite, OpenMlsProvider}; -use rstest::*; -use rstest_reuse::apply; +use openmls_traits::prelude::*; use crate::{ credentials::test_utils::new_credential, @@ -10,14 +7,14 @@ use crate::{ }; // Verifies that when we add a leaf to a tree with blank leaf nodes, the leaf will be added at the leftmost free leaf index -#[apply(ciphersuites_and_providers)] -fn test_free_leaf_computation(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test::openmls_test] +fn test_free_leaf_computation() { let (c_0, sk_0) = new_credential(provider, b"leaf0", ciphersuite.signature_algorithm()); - let kpb_0 = KeyPackageBundle::new(provider, &sk_0, ciphersuite, c_0); + let kpb_0 = KeyPackageBundle::generate(provider, &sk_0, ciphersuite, c_0); let (c_3, sk_3) = new_credential(provider, b"leaf3", ciphersuite.signature_algorithm()); - let kpb_3 = KeyPackageBundle::new(provider, &sk_3, ciphersuite, c_3); + let kpb_3 = KeyPackageBundle::generate(provider, &sk_3, ciphersuite, c_3); // Build a rudimentary tree with two populated and two empty leaf nodes. let ratchet_tree = RatchetTree::trimmed(vec![ @@ -37,7 +34,7 @@ fn test_free_leaf_computation(ciphersuite: Ciphersuite, provider: &impl OpenMlsP // Create and add a new leaf. It should go to leaf index 1 let (c_2, signer_2) = new_credential(provider, b"leaf2", ciphersuite.signature_algorithm()); - let kpb_2 = KeyPackageBundle::new(provider, &signer_2, ciphersuite, c_2); + let kpb_2 = KeyPackageBundle::generate(provider, &signer_2, ciphersuite, c_2); let mut diff = tree.empty_diff(); let free_leaf_index = diff.free_leaf_index(); diff --git a/openmls/src/treesync/treekem.rs b/openmls/src/treesync/treekem.rs index 25dc3c5daa..3c6660a69b 100644 --- a/openmls/src/treesync/treekem.rs +++ b/openmls/src/treesync/treekem.rs @@ -32,7 +32,6 @@ use crate::{ messages::{proposals::AddProposal, EncryptedGroupSecrets, GroupSecrets, PathSecret}, schedule::{psk::PreSharedKeyId, CommitSecret, JoinerSecret}, treesync::node::NodeReference, - versions::ProtocolVersion, }; impl<'a> TreeSyncDiff<'a> { @@ -133,7 +132,6 @@ impl<'a> TreeSyncDiff<'a> { let path_secret = PathSecret::decrypt( crypto, ciphersuite, - params.version, ciphertext, decryption_key, params.group_context, @@ -231,7 +229,6 @@ impl<'a> TreeSyncDiff<'a> { } pub(crate) struct DecryptPathParams<'a> { - pub(crate) version: ProtocolVersion, pub(crate) update_path: &'a [UpdatePathNode], pub(crate) sender_leaf_index: LeafNodeIndex, pub(crate) exclusion_list: &'a HashSet<&'a LeafNodeIndex>, diff --git a/openmls/src/versions.rs b/openmls/src/versions.rs index 0f33a72190..e3a7d51cfc 100644 --- a/openmls/src/versions.rs +++ b/openmls/src/versions.rs @@ -3,9 +3,12 @@ //! Only MLS 1.0 is currently supported. use serde::{Deserialize, Serialize}; -use std::fmt; +use std::{fmt, io::Read}; use thiserror::Error; -use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize}; +use tls_codec::{ + Deserialize as TlsDeserializeTrait, DeserializeBytes, Error, Serialize as TlsSerializeTrait, + Size, +}; // Public types @@ -18,26 +21,12 @@ use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize}; /// (65535) /// } ProtocolVersion; /// ``` -#[derive( - Debug, - Copy, - Clone, - PartialEq, - Eq, - PartialOrd, - Ord, - Serialize, - Deserialize, - TlsDeserialize, - TlsDeserializeBytes, - TlsSerialize, - TlsSize, -)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] #[repr(u16)] #[allow(missing_docs)] pub enum ProtocolVersion { Mls10 = 1, - Mls10Draft11 = 200, // pre RFC version + Other(u16), } /// There's only one version right now, which is the default. @@ -47,26 +36,58 @@ impl Default for ProtocolVersion { } } -impl TryFrom for ProtocolVersion { - type Error = VersionError; - +impl From for ProtocolVersion { /// Convert an integer to the corresponding protocol version. - /// - /// Returns an error if the protocol version is not supported. - fn try_from(v: u16) -> Result { + fn from(v: u16) -> Self { match v { - 1 => Ok(ProtocolVersion::Mls10), - 200 => Ok(ProtocolVersion::Mls10Draft11), - _ => Err(VersionError::UnsupportedMlsVersion), + 1 => ProtocolVersion::Mls10, + _ => ProtocolVersion::Other(v), } } } +impl TlsSerializeTrait for ProtocolVersion { + fn tls_serialize(&self, writer: &mut W) -> Result { + match self { + ProtocolVersion::Mls10 => { + let v = 1u16; + v.tls_serialize(writer) + } + ProtocolVersion::Other(v) => v.tls_serialize(writer), + } + } +} + +impl TlsDeserializeTrait for ProtocolVersion { + fn tls_deserialize(bytes: &mut R) -> Result + where + Self: Sized, + { + u16::tls_deserialize(bytes).map(ProtocolVersion::from) + } +} + +impl DeserializeBytes for ProtocolVersion { + fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), Error> + where + Self: Sized, + { + let (v, bytes) = u16::tls_deserialize_bytes(bytes)?; + Ok((ProtocolVersion::from(v), bytes)) + } +} + +impl Size for ProtocolVersion { + fn tls_serialized_len(&self) -> usize { + 2 + } +} + impl fmt::Display for ProtocolVersion { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match &self { ProtocolVersion::Mls10 => write!(f, "MLS 1.0"), - ProtocolVersion::Mls10Draft11 => write!(f, "MLS 1.0 (Draft 11)"), + ProtocolVersion::Other(v) => write!(f, "Other version: {}", v), } } } diff --git a/openmls/tests/book_code.rs b/openmls/tests/book_code.rs index e6ba03ac14..65ede1c720 100644 --- a/openmls/tests/book_code.rs +++ b/openmls/tests/book_code.rs @@ -1,19 +1,16 @@ use openmls::{ - prelude::{config::CryptoConfig, tls_codec::*, *}, + prelude::{tls_codec::*, CustomProposal, *}, test_utils::*, *, }; use openmls_basic_credential::SignatureKeyPair; -use openmls_rust_crypto::OpenMlsRustCrypto; -use openmls_traits::{signatures::Signer, types::SignatureScheme, OpenMlsProvider}; -use tls_codec::VLBytes; +use openmls_test::openmls_test; +use openmls_traits::{signatures::Signer, types::SignatureScheme}; #[test] fn create_provider_rust_crypto() { // ANCHOR: create_provider_rust_crypto - use openmls_rust_crypto::OpenMlsRustCrypto; - - let provider = OpenMlsRustCrypto::default(); + let provider: OpenMlsRustCrypto = OpenMlsRustCrypto::default(); // ANCHOR_END: create_provider_rust_crypto // Suppress warning. @@ -23,14 +20,14 @@ fn create_provider_rust_crypto() { fn generate_credential( identity: Vec, signature_algorithm: SignatureScheme, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, ) -> (CredentialWithKey, SignatureKeyPair) { // ANCHOR: create_basic_credential - let credential = BasicCredential::new(identity).unwrap(); + let credential = BasicCredential::new(identity); // ANCHOR_END: create_basic_credential // ANCHOR: create_credential_keys let signature_keys = SignatureKeyPair::new(signature_algorithm).unwrap(); - signature_keys.store(provider.key_store()).unwrap(); + signature_keys.store(provider.storage()).unwrap(); // ANCHOR_END: create_credential_keys ( @@ -46,19 +43,14 @@ fn generate_key_package( ciphersuite: Ciphersuite, credential_with_key: CredentialWithKey, extensions: Extensions, - provider: &impl OpenMlsProvider, + provider: &impl crate::storage::OpenMlsProvider, signer: &impl Signer, -) -> KeyPackage { +) -> KeyPackageBundle { // ANCHOR: create_key_package // Create the key package KeyPackage::builder() .key_package_extensions(extensions) - .build( - CryptoConfig::with_default_version(ciphersuite), - provider, - signer, - credential_with_key, - ) + .build(ciphersuite, provider, signer, credential_with_key) .unwrap() // ANCHOR_END: create_key_package } @@ -77,8 +69,8 @@ fn generate_key_package( /// - Alice removes Charlie and adds Bob /// - Bob leaves /// - Test saving the group state -#[apply(ciphersuites_and_providers)] -fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn book_operations() { // Generate credentials with keys let (alice_credential, alice_signature_keys) = generate_credential("Alice".into(), ciphersuite.signature_algorithm(), provider); @@ -126,7 +118,7 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ), ]))) .expect("error adding external senders extension to group context extensions") - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .capabilities(Capabilities::new( None, // Defaults to the group's protocol version None, // Defaults to the group's ciphersuite @@ -203,7 +195,7 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { 10, // out_of_order_tolerance 2000, // maximum_forward_distance )) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .use_ratchet_tree_extension(true) .build(provider, &alice_signature_keys, alice_credential.clone()) .expect("An unexpected error occurred."); @@ -218,7 +210,11 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // === Alice adds Bob === // ANCHOR: alice_adds_bob let (mls_message_out, welcome, group_info) = alice_group - .add_members(provider, &alice_signature_keys, &[bob_key_package]) + .add_members( + provider, + &alice_signature_keys, + &[bob_key_package.key_package().clone()], + ) .expect("Could not add members."); // ANCHOR_END: alice_adds_bob @@ -255,10 +251,10 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Alice & Bob are the members of the group let members = alice_group.members().collect::>(); - let id0 = VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let id1 = VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(id0.as_slice(), b"Alice"); - assert_eq!(id1.as_slice(), b"Bob"); + let id0 = members[0].credential.serialized_content(); + let id1 = members[1].credential.serialized_content(); + assert_eq!(id0, b"Alice"); + assert_eq!(id1, b"Bob"); // ANCHOR: mls_group_config_example let mls_group_config = MlsGroupJoinConfig::builder() @@ -286,7 +282,7 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // ANCHOR: alice_exports_group_info let verifiable_group_info = alice_group - .export_group_info(provider.crypto(), &alice_signature_keys, true) + .export_group_info(provider, &alice_signature_keys, true) .expect("Cannot export group info") .into_verifiable_group_info() .expect("Could not get group info"); @@ -395,8 +391,8 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that both groups have the same state assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - bob_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + bob_group.export_secret(provider, "", &[], 32).unwrap() ); // Make sure that both groups have the same public tree @@ -436,7 +432,9 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { &alice_credential.credential ); // Store proposal - alice_group.store_pending_proposal(*staged_proposal.clone()); + alice_group + .store_pending_proposal(provider.storage(), *staged_proposal.clone()) + .unwrap(); } else { unreachable!("Expected a Proposal."); } @@ -446,7 +444,9 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { staged_proposal.sender(), Sender::Member(member) if *member == alice_group.own_leaf_index() )); - bob_group.store_pending_proposal(*staged_proposal); + bob_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a QueuedProposal."); } @@ -491,8 +491,8 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that both groups have the same state assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - bob_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + bob_group.export_secret(provider, "", &[], 32).unwrap() ); // Make sure that both groups have the same public tree @@ -511,7 +511,11 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); let (queued_message, welcome, _group_info) = bob_group - .add_members(provider, &bob_signature_keys, &[charlie_key_package]) + .add_members( + provider, + &bob_signature_keys, + &[charlie_key_package.key_package().clone()], + ) .unwrap(); let alice_processed_message = alice_group @@ -564,15 +568,12 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Alice, Bob & Charlie are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - let credential2 = - VLBytes::tls_deserialize_exact(members[2].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); - assert_eq!(credential2.as_slice(), b"Charlie"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + let credential2 = members[2].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); + assert_eq!(credential2, b"Charlie"); assert_eq!(members.len(), 3); // === Charlie sends a message to the group === @@ -652,12 +653,12 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that all groups have the same state assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - bob_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + bob_group.export_secret(provider, "", &[], 32).unwrap() ); assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - charlie_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + charlie_group.export_secret(provider, "", &[], 32).unwrap() ); // Make sure that all groups have the same public tree @@ -681,26 +682,18 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { index: _, credential, .. - }| { - let credential = - VLBytes::tls_deserialize_exact(credential.serialized_content()).unwrap(); - credential.as_slice() == b"Bob" - }, + }| { credential.serialized_content() == b"Bob" }, ) .expect("Couldn't find Bob in the list of group members."); // Make sure that this is Bob's actual KP reference. - let bob_cred = - VLBytes::tls_deserialize_exact(bob_member.credential.serialized_content()).unwrap(); - let bob_group_cred = VLBytes::tls_deserialize_exact( - bob_group - .own_leaf() - .unwrap() - .credential() - .serialized_content(), - ) - .unwrap(); - assert_eq!(bob_cred.as_slice(), bob_group_cred.as_slice()); + let bob_cred = bob_member.credential.serialized_content(); + let bob_group_cred = bob_group + .own_leaf() + .unwrap() + .credential() + .serialized_content(); + assert_eq!(bob_cred, bob_group_cred); // === Charlie removes Bob === // ANCHOR: charlie_removes_bob @@ -821,12 +814,10 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { assert!(!bob_group.is_active()); let members = bob_group.members().collect::>(); assert_eq!(members.len(), 2); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Charlie"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Charlie"); // ANCHOR_END: getting_removed // Make sure that all groups have the same public tree @@ -840,12 +831,10 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Alice & Charlie are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Charlie"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Charlie"); // Check that Bob can no longer send messages assert!(bob_group @@ -891,7 +880,9 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Charlie was removed assert_eq!(remove_proposal.removed(), charlie_group.own_leaf_index()); // Store proposal - charlie_group.store_pending_proposal(*staged_proposal.clone()); + charlie_group + .store_pending_proposal(provider.storage(), *staged_proposal.clone()) + .unwrap(); } else { unreachable!("Expected a Proposal."); } @@ -908,17 +899,25 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Create AddProposal and remove it // ANCHOR: rollback_proposal_by_ref let (_mls_message_out, proposal_ref) = alice_group - .propose_add_member(provider, &alice_signature_keys, &bob_key_package) + .propose_add_member( + provider, + &alice_signature_keys, + bob_key_package.key_package(), + ) .expect("Could not create proposal to add Bob"); alice_group - .remove_pending_proposal(proposal_ref) + .remove_pending_proposal(provider.storage(), proposal_ref) .expect("The proposal was not found"); // ANCHOR_END: rollback_proposal_by_ref // Create AddProposal and process it // ANCHOR: propose_add let (mls_message_out, _proposal_ref) = alice_group - .propose_add_member(provider, &alice_signature_keys, &bob_key_package) + .propose_add_member( + provider, + &alice_signature_keys, + bob_key_package.key_package(), + ) .expect("Could not create proposal to add Bob"); // ANCHOR_END: propose_add @@ -953,7 +952,9 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { Sender::Member(member) if *member == alice_group.own_leaf_index() )); // Store proposal - charlie_group.store_pending_proposal(*staged_proposal); + charlie_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } // ANCHOR_END: inspect_add_proposal else { @@ -996,12 +997,10 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Alice & Bob are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); let welcome: MlsMessageIn = welcome_option.expect("Welcome was not returned").into(); let welcome = welcome @@ -1024,24 +1023,20 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Alice & Bob are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); // Make sure the group contains two members assert_eq!(bob_group.members().count(), 2); // Check that Alice & Bob are the members of the group let members = bob_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); // === Alice sends a message to the group === let message_alice = b"Hi, I'm Alice!"; @@ -1085,7 +1080,9 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { assert_eq!(sender_cred_from_msg, sender_cred_from_group); assert_eq!( &sender_cred_from_msg, - alice_group.credential().expect("Expected a credential.") + alice_group + .credential::() + .expect("Expected a credential.") ); } else { unreachable!("Expected an ApplicationMessage."); @@ -1113,18 +1110,20 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { alice_processed_message.into_content() { // Store proposal - alice_group.store_pending_proposal(*staged_proposal); + alice_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a QueuedProposal."); } // Should fail because you cannot remove yourself from a group - assert_eq!( + assert!(matches!( bob_group.commit_to_pending_proposals(provider, &bob_signature_keys), Err(CommitToPendingProposalsError::CreateCommitError( CreateCommitError::CannotRemoveSelf )) - ); + )); let (queued_message, _welcome_option, _group_info) = alice_group .commit_to_pending_proposals(provider, &alice_signature_keys) @@ -1202,9 +1201,8 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Check that Alice is the only member of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); + let credential0 = members[0].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); // === Re-Add Bob with external Add proposal === @@ -1218,13 +1216,14 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { ); // ANCHOR: external_join_proposal - let proposal = JoinProposal::new( - bob_key_package, - alice_group.group_id().clone(), - alice_group.epoch(), - &bob_signature_keys, - ) - .expect("Could not create external Add proposal"); + let proposal = + JoinProposal::new::<::StorageProvider>( + bob_key_package.key_package().clone(), + alice_group.group_id().clone(), + alice_group.epoch(), + &bob_signature_keys, + ) + .expect("Could not create external Add proposal"); // ANCHOR_END: external_join_proposal // ANCHOR: decrypt_external_join_proposal @@ -1238,7 +1237,9 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("Could not process message."); match alice_processed_message.into_content() { ProcessedMessageContent::ExternalJoinProposalMessage(proposal) => { - alice_group.store_pending_proposal(*proposal); + alice_group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap(); let (_commit, welcome, _group_info) = alice_group .commit_to_pending_proposals(provider, &alice_signature_keys) .expect("Could not commit"); @@ -1272,9 +1273,8 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let bob_index = alice_group .members() .find_map(|member| { - let credential = - VLBytes::tls_deserialize_exact(member.credential.serialized_content()).unwrap(); - if credential.as_slice() == b"Bob" { + let credential = member.credential.serialized_content(); + if credential == b"Bob" { Some(member.index) } else { None @@ -1283,7 +1283,7 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .unwrap(); // ANCHOR: external_remove_proposal - let proposal = ExternalProposal::new_remove( + let proposal = ExternalProposal::new_remove::( bob_index, alice_group.group_id().clone(), alice_group.epoch(), @@ -1304,7 +1304,9 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("Could not process message."); match alice_processed_message.into_content() { ProcessedMessageContent::ProposalMessage(proposal) => { - alice_group.store_pending_proposal(*proposal); + alice_group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap(); assert_eq!(alice_group.members().count(), 2); alice_group .commit_to_pending_proposals(provider, &alice_signature_keys) @@ -1331,7 +1333,11 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // Add Bob to the group let (_queued_message, welcome, _group_info) = alice_group - .add_members(provider, &alice_signature_keys, &[bob_key_package]) + .add_members( + provider, + &alice_signature_keys, + &[bob_key_package.key_package().clone()], + ) .expect("Could not add Bob"); // Merge Commit @@ -1359,33 +1365,31 @@ fn book_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { .expect("Could not create group from StagedWelcome"); assert_eq!( - alice_group.export_secret(provider.crypto(), "before load", &[], 32), - bob_group.export_secret(provider.crypto(), "before load", &[], 32) + alice_group + .export_secret(provider, "before load", &[], 32) + .unwrap(), + bob_group + .export_secret(provider, "before load", &[], 32) + .unwrap() ); - // Check that the state flag gets reset when saving - assert_eq!(bob_group.state_changed(), InnerState::Changed); - //save(&mut bob_group); - - bob_group - .save(provider.key_store()) - .expect("Could not write group state to file"); - - // Check that the state flag gets reset when saving - assert_eq!(bob_group.state_changed(), InnerState::Persisted); - - let bob_group = - MlsGroup::load(&group_id, provider.key_store()).expect("Could not load group from file"); + bob_group = MlsGroup::load(provider.storage(), &group_id) + .expect("An error occurred while loading the group") + .expect("No group with provided group id exists"); // Make sure the state is still the same assert_eq!( - alice_group.export_secret(provider.crypto(), "after load", &[], 32), - bob_group.export_secret(provider.crypto(), "after load", &[], 32) + alice_group + .export_secret(provider, "after load", &[], 32) + .unwrap(), + bob_group + .export_secret(provider, "after load", &[], 32) + .unwrap() ); } -#[apply(ciphersuites_and_providers)] -fn test_empty_input_errors(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_empty_input_errors() { let group_id = GroupId::from_slice(b"Test Group"); // Generate credentials with keys @@ -1405,18 +1409,132 @@ fn test_empty_input_errors(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv ) .expect("An unexpected error occurred."); - assert_eq!( + assert!(matches!( alice_group .add_members(provider, &alice_signature_keys, &[]) .expect_err("No EmptyInputError when trying to pass an empty slice to `add_members`."), AddMembersError::EmptyInput(EmptyInputError::AddMembers) - ); - assert_eq!( + )); + assert!(matches!( alice_group .remove_members(provider, &alice_signature_keys, &[]) .expect_err( "No EmptyInputError when trying to pass an empty slice to `remove_members`." ), RemoveMembersError::EmptyInput(EmptyInputError::RemoveMembers) + )); +} + +#[openmls_test] +fn custom_proposal_usage( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { + // Generate credentials with keys + let (alice_credential_with_key, alice_signer) = + generate_credential(b"alice".into(), ciphersuite.signature_algorithm(), provider); + + let (bob_credential_with_key, bob_signer) = + generate_credential(b"bob".into(), ciphersuite.signature_algorithm(), provider); + + // ANCHOR: custom_proposal_type + // Define a custom proposal type + let custom_proposal_type = 0xFFFF; + + // Define capabilities supporting the custom proposal type + let capabilities = Capabilities::new( + None, + None, + None, + Some(&[ProposalType::Custom(custom_proposal_type)]), + None, ); + + // Generate KeyPackage that signals support for the custom proposal type + let bob_key_package = KeyPackageBuilder::new() + .leaf_node_capabilities(capabilities.clone()) + .build(ciphersuite, provider, &bob_signer, bob_credential_with_key) + .unwrap(); + + // Create a group that supports the custom proposal type + let mut alice_group = MlsGroup::builder() + .with_capabilities(capabilities.clone()) + .ciphersuite(ciphersuite) + .build(provider, &alice_signer, alice_credential_with_key) + .unwrap(); + // ANCHOR_END: custom_proposal_type + + // Add Bob + let (_mls_message, welcome, _group_info) = alice_group + .add_members( + provider, + &alice_signer, + &[bob_key_package.key_package().clone()], + ) + .unwrap(); + + alice_group.merge_pending_commit(provider).unwrap(); + + let staged_welcome = StagedWelcome::new_from_welcome( + provider, + &MlsGroupJoinConfig::default(), + welcome.into_welcome().unwrap(), + Some(alice_group.export_ratchet_tree().into()), + ) + .unwrap(); + + let mut bob_group = staged_welcome.into_group(provider).unwrap(); + + // ANCHOR: custom_proposal_usage + // Create a custom proposal based on an example payload and the custom + // proposal type defined above + let custom_proposal_payload = vec![0, 1, 2, 3]; + let custom_proposal = + CustomProposal::new(custom_proposal_type, custom_proposal_payload.clone()); + + let (custom_proposal_message, _proposal_ref) = alice_group + .propose_custom_proposal_by_reference(provider, &alice_signer, custom_proposal.clone()) + .unwrap(); + + // Have bob process the custom proposal. + let processed_message = bob_group + .process_message( + provider, + custom_proposal_message.into_protocol_message().unwrap(), + ) + .unwrap(); + + let ProcessedMessageContent::ProposalMessage(proposal) = processed_message.into_content() + else { + panic!("Unexpected message type"); + }; + + bob_group + .store_pending_proposal(provider.storage(), *proposal) + .unwrap(); + + // Commit to the proposal + let (commit, _, _) = alice_group + .commit_to_pending_proposals(provider, &alice_signer) + .unwrap(); + + let processed_message = bob_group + .process_message(provider, commit.into_protocol_message().unwrap()) + .unwrap(); + + let staged_commit = match processed_message.into_content() { + ProcessedMessageContent::StagedCommitMessage(staged_commit) => staged_commit, + _ => panic!("Unexpected message type"), + }; + + // Check that the proposal is present in the staged commit + assert!(staged_commit.queued_proposals().any(|qp| { + let Proposal::Custom(custom_proposal) = qp.proposal() else { + return false; + }; + custom_proposal.proposal_type() == custom_proposal_type + && custom_proposal.payload() == custom_proposal_payload + })); + + // ANCHOR_END: custom_proposal_usage } diff --git a/openmls/tests/key_store.rs b/openmls/tests/key_store.rs deleted file mode 100644 index 391bf7e03d..0000000000 --- a/openmls/tests/key_store.rs +++ /dev/null @@ -1,31 +0,0 @@ -//! A couple of simple tests on how to interact with the key store. -use openmls::{prelude::*, test_utils::*, *}; -use openmls_basic_credential::SignatureKeyPair; - -#[apply(ciphersuites_and_providers)] -fn test_store_key_package(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { - // ANCHOR: key_store_store - // First we generate a credential and key package for our user. - let credential = BasicCredential::new(b"User ID".to_vec()).unwrap(); - let signature_keys = SignatureKeyPair::new(ciphersuite.into()).unwrap(); - - let key_package = KeyPackage::builder() - .build( - CryptoConfig::with_default_version(ciphersuite), - provider, - &signature_keys, - CredentialWithKey { - credential: credential.into(), - signature_key: signature_keys.to_public_vec().into(), - }, - ) - .unwrap(); - // ANCHOR_END: key_store_store - - // ANCHOR: key_store_delete - // Delete the key package - key_package - .delete(provider) - .expect("Error deleting key package"); - // ANCHOR_END: key_store_delete -} diff --git a/openmls/tests/store.rs b/openmls/tests/store.rs new file mode 100644 index 0000000000..251d21670f --- /dev/null +++ b/openmls/tests/store.rs @@ -0,0 +1,61 @@ +//! A couple of simple tests on how to interact with the key store. +use openmls::prelude::*; +use openmls_basic_credential::SignatureKeyPair; +use openmls_test::openmls_test; +use openmls_traits::storage::StorageProvider as _; + +#[openmls_test] +fn test_store_key_package() { + // ANCHOR: store_store + // First we generate a credential and key package for our user. + let credential = BasicCredential::new(b"User ID".to_vec()); + let signature_keys = SignatureKeyPair::new(ciphersuite.into()).unwrap(); + + // This key package includes the private init and encryption key as well. + // See [`KeyPackageBundle`]. + let key_package = KeyPackage::builder() + .build( + ciphersuite, + provider, + &signature_keys, + CredentialWithKey { + credential: credential.into(), + signature_key: signature_keys.to_public_vec().into(), + }, + ) + .unwrap(); + // ANCHOR_END: store_store + + // ANCHOR: hash_ref + // Build the hash reference. + // This is the key for key packages. + let hash_ref = key_package + .key_package() + .hash_ref(provider.crypto()) + .unwrap(); + // ANCHOR_END: hash_ref + + // ANCHOR: store_read + // Read the key package + let read_key_package: Option = provider + .storage() + .key_package(&hash_ref) + .expect("Error reading key package"); + assert_eq!( + read_key_package.unwrap().key_package(), + key_package.key_package() + ); + // ANCHOR_END: store_read + + // ANCHOR: store_delete + // Delete the key package + let hash_ref = key_package + .key_package() + .hash_ref(provider.crypto()) + .unwrap(); + provider + .storage() + .delete_key_package(&hash_ref) + .expect("Error deleting key package"); + // ANCHOR_END: store_delete +} diff --git a/openmls/tests/test_decryption_key_index.rs b/openmls/tests/test_decryption_key_index.rs index c3e4752d1e..538e54ba26 100644 --- a/openmls/tests/test_decryption_key_index.rs +++ b/openmls/tests/test_decryption_key_index.rs @@ -1,21 +1,20 @@ //! Test decryption key index computation in larger trees. use openmls::{ prelude::*, - test_utils::{ - test_framework::{noop_authentication_service, ActionType, CodecUse, MlsGroupTestSetup}, - *, + test_utils::test_framework::{ + noop_authentication_service, ActionType, CodecUse, MlsGroupTestSetup, }, - *, }; +use openmls_test::openmls_test; -#[apply(ciphersuites)] -fn decryption_key_index_computation(ciphersuite: Ciphersuite) { +#[openmls_test] +fn decryption_key_index_computation() { println!("Testing ciphersuite {ciphersuite:?}"); // Some basic setup functions for the MlsGroup. let mls_group_create_config = MlsGroupCreateConfig::test_default(ciphersuite); let number_of_clients = 20; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( mls_group_create_config, number_of_clients, CodecUse::StructMessages, diff --git a/openmls/tests/test_external_commit.rs b/openmls/tests/test_external_commit.rs index 31cd3652d0..a358c168d1 100644 --- a/openmls/tests/test_external_commit.rs +++ b/openmls/tests/test_external_commit.rs @@ -2,19 +2,18 @@ use openmls::{ credentials::test_utils::new_credential, messages::group_info::VerifiableGroupInfo, prelude::{tls_codec::*, *}, - test_utils::*, - *, }; use openmls_basic_credential::SignatureKeyPair; +use openmls_test::openmls_test; fn create_alice_group( ciphersuite: Ciphersuite, - provider: &impl OpenMlsProvider, + provider: &impl openmls::storage::OpenMlsProvider, use_ratchet_tree_extension: bool, ) -> (MlsGroup, CredentialWithKey, SignatureKeyPair) { let group_config = MlsGroupCreateConfig::builder() .use_ratchet_tree_extension(use_ratchet_tree_extension) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); let (credential_with_key, signature_keys) = @@ -31,15 +30,15 @@ fn create_alice_group( (group, credential_with_key, signature_keys) } -#[apply(ciphersuites_and_providers)] -fn test_external_commit(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_external_commit() { // Alice creates a new group ... let (alice_group, _, alice_signer) = create_alice_group(ciphersuite, provider, false); // ... and exports a group info (with ratchet_tree). let verifiable_group_info = { let group_info = alice_group - .export_group_info(provider.crypto(), &alice_signer, true) + .export_group_info(provider, &alice_signer, true) .unwrap(); let serialized_group_info = group_info.tls_serialize_detached().unwrap(); @@ -52,7 +51,7 @@ fn test_external_commit(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide let verifiable_group_info_broken = { let group_info = alice_group - .export_group_info(provider.crypto(), &alice_signer, true) + .export_group_info(provider, &alice_signer, true) .unwrap(); let serialized_group_info = { @@ -106,17 +105,17 @@ fn test_external_commit(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide ) .unwrap_err(); - assert_eq!( + assert!(matches!( got_error, ExternalCommitError::PublicGroupError( CreationFromExternalError::InvalidGroupInfoSignature ) - ); + )); } } -#[apply(ciphersuites_and_providers)] -fn test_group_info(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_group_info() { // Alice creates a new group ... let (mut alice_group, _, alice_signer) = create_alice_group(ciphersuite, provider, true); @@ -195,8 +194,11 @@ fn test_group_info(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { bob_group.merge_pending_commit(provider).unwrap(); } -#[apply(ciphersuites_and_providers)] -fn test_not_present_group_info(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_not_present_group_info( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { // Alice creates a new group ... let (mut alice_group, _, alice_signer) = create_alice_group(ciphersuite, provider, false); diff --git a/openmls/tests/test_interop_scenarios.rs b/openmls/tests/test_interop_scenarios.rs index 6e54699445..baf3e0ab12 100644 --- a/openmls/tests/test_interop_scenarios.rs +++ b/openmls/tests/test_interop_scenarios.rs @@ -3,9 +3,8 @@ use openmls::{ test_utils::test_framework::{ noop_authentication_service, ActionType, CodecUse, MlsGroupTestSetup, }, - test_utils::*, - *, }; +use openmls_test::openmls_test; // The following tests correspond to the interop test scenarios detailed here: // https://github.com/mlswg/mls-implementations/blob/master/test-scenarios.md @@ -17,11 +16,11 @@ use openmls::{ // B->A: KeyPackage // A->B: Welcome // ***: Verify group state -#[apply(ciphersuites)] -fn one_to_one_join(ciphersuite: Ciphersuite) { +#[openmls_test] +fn one_to_one_join() { println!("Testing ciphersuite {ciphersuite:?}"); let number_of_clients = 2; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( MlsGroupCreateConfig::test_default(ciphersuite), number_of_clients, CodecUse::StructMessages, @@ -68,12 +67,12 @@ fn one_to_one_join(ciphersuite: Ciphersuite) { // A->B: Add(C), Commit // A->C: Welcome // ***: Verify group state -#[apply(ciphersuites)] -fn three_party_join(ciphersuite: Ciphersuite) { +#[openmls_test] +fn three_party_join() { println!("Testing ciphersuite {ciphersuite:?}"); let number_of_clients = 3; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( MlsGroupCreateConfig::test_default(ciphersuite), number_of_clients, CodecUse::StructMessages, @@ -135,12 +134,12 @@ fn three_party_join(ciphersuite: Ciphersuite) { // A->B: Welcome // A->C: Welcome // ***: Verify group state -#[apply(ciphersuites)] -fn multiple_joins(ciphersuite: Ciphersuite) { +#[openmls_test] +fn multiple_joins() { println!("Testing ciphersuite {ciphersuite:?}"); let number_of_clients = 3; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( MlsGroupCreateConfig::test_default(ciphersuite), number_of_clients, CodecUse::StructMessages, @@ -188,12 +187,12 @@ fn multiple_joins(ciphersuite: Ciphersuite) { // A->B: Welcome // A->B: Update, Commit // ***: Verify group state -#[apply(ciphersuites)] -fn update(ciphersuite: Ciphersuite) { +#[openmls_test] +fn update() { println!("Testing ciphersuite {ciphersuite:?}"); let number_of_clients = 2; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( MlsGroupCreateConfig::test_default(ciphersuite), number_of_clients, CodecUse::StructMessages, @@ -236,12 +235,12 @@ fn update(ciphersuite: Ciphersuite) { // A->C: Welcome // A->B: Remove(B), Commit // ***: Verify group state -#[apply(ciphersuites)] -fn remove(ciphersuite: Ciphersuite) { +#[openmls_test] +fn remove() { println!("Testing ciphersuite {ciphersuite:?}"); let number_of_clients = 2; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( MlsGroupCreateConfig::test_default(ciphersuite), number_of_clients, CodecUse::StructMessages, @@ -291,13 +290,13 @@ fn remove(ciphersuite: Ciphersuite) { // * All members update // * While the group size is >1, a randomly-chosen group member removes a // randomly-chosen other group member -#[apply(ciphersuites)] -fn large_group_lifecycle(ciphersuite: Ciphersuite) { +#[openmls_test] +fn large_group_lifecycle() { println!("Testing ciphersuite {ciphersuite:?}"); // "Large" is 20 for now. let number_of_clients = 20; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( MlsGroupCreateConfig::test_default(ciphersuite), number_of_clients, CodecUse::StructMessages, diff --git a/openmls/tests/test_managed_api.rs b/openmls/tests/test_managed_api.rs index 089f242b09..32b9d9dd1f 100644 --- a/openmls/tests/test_managed_api.rs +++ b/openmls/tests/test_managed_api.rs @@ -3,16 +3,15 @@ use openmls::{ test_utils::test_framework::{ noop_authentication_service, ActionType, CodecUse, MlsGroupTestSetup, }, - test_utils::*, - *, }; +use openmls_test::openmls_test; -#[apply(ciphersuites)] -fn test_mls_group_api(ciphersuite: Ciphersuite) { +#[openmls_test] +fn test_mls_group_api() { // Some basic setup functions for the MlsGroup. let mls_group_create_config = MlsGroupCreateConfig::test_default(ciphersuite); let number_of_clients = 20; - let setup = MlsGroupTestSetup::new( + let setup = MlsGroupTestSetup::::new( mls_group_create_config, number_of_clients, CodecUse::SerializedMessages, diff --git a/openmls/tests/test_mls_group.rs b/openmls/tests/test_mls_group.rs index ac3ce3cf8e..f86cc90dbe 100644 --- a/openmls/tests/test_mls_group.rs +++ b/openmls/tests/test_mls_group.rs @@ -1,31 +1,25 @@ use openmls::{ - prelude::{config::CryptoConfig, test_utils::new_credential, tls_codec::*, *}, - test_utils::*, - *, + prelude::{test_utils::new_credential, *}, + storage::OpenMlsProvider, }; +use openmls_traits::OpenMlsProvider as _; -use openmls_traits::{key_store::OpenMlsKeyStore, signatures::Signer, OpenMlsProvider}; -use tls_codec::VLBytes; +use openmls_test::openmls_test; +use openmls_traits::signatures::Signer; -fn generate_key_package( +fn generate_key_package( ciphersuite: Ciphersuite, extensions: Extensions, - provider: &impl OpenMlsProvider, + provider: &Provider, credential_with_key: CredentialWithKey, signer: &impl Signer, ) -> KeyPackage { KeyPackage::builder() .key_package_extensions(extensions) - .build( - CryptoConfig { - ciphersuite, - version: ProtocolVersion::default(), - }, - provider, - signer, - credential_with_key, - ) + .build(ciphersuite, provider, signer, credential_with_key) .unwrap() + .key_package() + .clone() } /// This test simulates various group operations like Add, Update, Remove in a @@ -42,8 +36,8 @@ fn generate_key_package( /// - Alice removes Charlie and adds Bob /// - Bob leaves /// - Test saving the group state -#[apply(ciphersuites_and_providers)] -fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn mls_group_operations() { for wire_format_policy in WIRE_FORMAT_POLICIES.iter() { let group_id = GroupId::from_slice(b"Test Group"); @@ -70,7 +64,7 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(*wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -117,12 +111,10 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Alice & Bob are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); let welcome: MlsMessageIn = welcome.into(); let welcome = welcome @@ -175,7 +167,7 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide assert_eq!( &sender, alice_group - .credential() + .credential::() .expect("An unexpected error occurred.") ); } else { @@ -217,8 +209,8 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that both groups have the same state assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - bob_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + bob_group.export_secret(provider, "", &[], 32).unwrap() ); // Make sure that both groups have the same public tree @@ -253,7 +245,9 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide &alice_credential.credential ); // Store proposal - alice_group.store_pending_proposal(*staged_proposal.clone()); + alice_group + .store_pending_proposal(provider.storage(), *staged_proposal.clone()) + .unwrap(); } else { unreachable!("Expected a Proposal."); } @@ -264,7 +258,9 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide Sender::Member(member) if *member == alice_group.own_leaf_index() )); - bob_group.store_pending_proposal(*staged_proposal); + bob_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a QueuedProposal."); } @@ -300,8 +296,8 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that both groups have the same state assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - bob_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + bob_group.export_secret(provider, "", &[], 32).unwrap() ); // Make sure that both groups have the same public tree @@ -374,15 +370,12 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Alice, Bob & Charlie are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - let credential2 = - VLBytes::tls_deserialize_exact(members[2].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); - assert_eq!(credential2.as_slice(), b"Charlie"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + let credential2 = members[2].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); + assert_eq!(credential2, b"Charlie"); // === Charlie sends a message to the group === let message_charlie = b"Hi, I'm Charlie!"; @@ -463,12 +456,12 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that all groups have the same state assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - bob_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + bob_group.export_secret(provider, "", &[], 32).unwrap() ); assert_eq!( - alice_group.export_secret(provider.crypto(), "", &[], 32), - charlie_group.export_secret(provider.crypto(), "", &[], 32) + alice_group.export_secret(provider, "", &[], 32).unwrap(), + charlie_group.export_secret(provider, "", &[], 32).unwrap() ); // Make sure that all groups have the same public tree @@ -575,12 +568,10 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Alice & Charlie are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Charlie"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Charlie"); // Check that Bob can no longer send messages assert!(bob_group @@ -621,7 +612,9 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Charlie was removed assert_eq!(remove_proposal.removed(), members[1].index); // Store proposal - charlie_group.store_pending_proposal(*staged_proposal.clone()); + charlie_group + .store_pending_proposal(provider.storage(), *staged_proposal.clone()) + .unwrap(); } else { unreachable!("Expected a Proposal."); } @@ -670,7 +663,9 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide Sender::Member(member) if *member == members[0].index )); // Store proposal - charlie_group.store_pending_proposal(*staged_proposal); + charlie_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a QueuedProposal."); } @@ -711,12 +706,10 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Alice & Bob are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); let welcome: MlsMessageIn = welcome_option.expect("Welcome was not returned").into(); let welcome = welcome @@ -739,24 +732,20 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Alice & Bob are the members of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); // Make sure the group contains two members assert_eq!(bob_group.members().count(), 2); // Check that Alice & Bob are the members of the group let members = bob_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential0 = members[0].credential.serialized_content(); + let credential1 = members[1].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); + assert_eq!(credential1, b"Bob"); // === Alice sends a message to the group === let message_alice = b"Hi, I'm Alice!"; @@ -784,7 +773,9 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Alice sent the message assert_eq!( &sender, - alice_group.credential().expect("Expected a credential") + alice_group + .credential::() + .expect("Expected a credential") ); } else { unreachable!("Expected an ApplicationMessage."); @@ -811,18 +802,20 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide alice_processed_message.into_content() { // Store proposal - alice_group.store_pending_proposal(*staged_proposal); + alice_group + .store_pending_proposal(provider.storage(), *staged_proposal) + .unwrap(); } else { unreachable!("Expected a QueuedProposal."); } // Should fail because you cannot remove yourself from a group - assert_eq!( + assert!(matches!( bob_group.commit_to_pending_proposals(provider, &bob_signer), Err(CommitToPendingProposalsError::CreateCommitError( CreateCommitError::CannotRemoveSelf )) - ); + )); let (queued_message, _welcome_option, _group_info) = alice_group .commit_to_pending_proposals(provider, &alice_signer) @@ -892,9 +885,8 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide // Check that Alice is the only member of the group let members = alice_group.members().collect::>(); - let credential0 = - VLBytes::tls_deserialize_exact(members[0].credential.serialized_content()).unwrap(); - assert_eq!(credential0.as_slice(), b"Alice"); + let credential0 = members[0].credential.serialized_content(); + assert_eq!(credential0, b"Alice"); // === Save the group state === @@ -912,13 +904,9 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide .add_members(provider, &alice_signer, &[bob_key_package]) .expect("Could not add Bob"); - // Test saving & loading the group state when there is a pending commit - alice_group - .save(provider.key_store()) - .expect("Could not save group state."); - - let _test_group = MlsGroup::load(&group_id, provider.key_store()) - .expect("Could not load the group state."); + let _test_group = MlsGroup::load(provider.storage(), &group_id) + .expect("Could not load the group state due to an error.") + .expect("Could not load the group state because the group does not exist."); // Merge Commit alice_group @@ -941,33 +929,32 @@ fn mls_group_operations(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvide .expect("Could not create group from staged join"); assert_eq!( - alice_group.export_secret(provider.crypto(), "before load", &[], 32), - bob_group.export_secret(provider.crypto(), "before load", &[], 32) + alice_group + .export_secret(provider, "before load", &[], 32) + .unwrap(), + bob_group + .export_secret(provider, "before load", &[], 32) + .unwrap() ); - // Check that the state flag gets reset when saving - assert_eq!(bob_group.state_changed(), InnerState::Changed); - - bob_group - .save(provider.key_store()) - .expect("Could not write group state to file"); - - // Check that the state flag gets reset when saving - assert_eq!(bob_group.state_changed(), InnerState::Persisted); - - let bob_group = MlsGroup::load(&group_id, provider.key_store()) - .expect("Could not load group from file"); + bob_group = MlsGroup::load(provider.storage(), &group_id) + .expect("Could not load group from file because of an error") + .expect("Could not load group from file because there is no group with given id"); // Make sure the state is still the same assert_eq!( - alice_group.export_secret(provider.crypto(), "after load", &[], 32), - bob_group.export_secret(provider.crypto(), "after load", &[], 32) + alice_group + .export_secret(provider, "after load", &[], 32) + .unwrap(), + bob_group + .export_secret(provider, "after load", &[], 32) + .unwrap() ); } } -#[apply(ciphersuites_and_providers)] -fn addition_order(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn addition_order() { for wire_format_policy in WIRE_FORMAT_POLICIES.iter() { let group_id = GroupId::from_slice(b"Test Group"); // Generate credentials with keys @@ -1000,7 +987,7 @@ fn addition_order(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { let mls_group_config = MlsGroupCreateConfig::builder() .wire_format_policy(*wire_format_policy) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -1058,19 +1045,20 @@ fn addition_order(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { // in the original API call. After merging, bob should be at index 1 and // charlie at index 2. let members = alice_group.members().collect::>(); - let credential1 = - VLBytes::tls_deserialize_exact(members[1].credential.serialized_content()).unwrap(); - let credential2 = - VLBytes::tls_deserialize_exact(members[2].credential.serialized_content()).unwrap(); - assert_eq!(credential1.as_slice(), b"Bob"); + let credential1 = members[1].credential.serialized_content(); + let credential2 = members[2].credential.serialized_content(); + assert_eq!(credential1, b"Bob"); assert_eq!(members[1].index, LeafNodeIndex::new(1)); - assert_eq!(credential2.as_slice(), b"Charlie"); + assert_eq!(credential2, b"Charlie"); assert_eq!(members[2].index, LeafNodeIndex::new(2)); } } -#[apply(ciphersuites_and_providers)] -fn test_empty_input_errors(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn test_empty_input_errors( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { let group_id = GroupId::from_slice(b"Test Group"); // Generate credentials with keys @@ -1090,25 +1078,28 @@ fn test_empty_input_errors(ciphersuite: Ciphersuite, provider: &impl OpenMlsProv ) .expect("An unexpected error occurred."); - assert_eq!( + assert!(matches!( alice_group .add_members(provider, &alice_signer, &[]) .expect_err("No EmptyInputError when trying to pass an empty slice to `add_members`."), AddMembersError::EmptyInput(EmptyInputError::AddMembers) - ); - assert_eq!( + )); + assert!(matches!( alice_group .remove_members(provider, &alice_signer, &[]) .expect_err( "No EmptyInputError when trying to pass an empty slice to `remove_members`." ), RemoveMembersError::EmptyInput(EmptyInputError::RemoveMembers) - ); + )); } // This tests the ratchet tree extension usage flag in the configuration -#[apply(ciphersuites_and_providers)] -fn mls_group_ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl OpenMlsProvider) { +#[openmls_test] +fn mls_group_ratchet_tree_extension( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { for wire_format_policy in WIRE_FORMAT_POLICIES.iter() { let group_id = GroupId::from_slice(b"Test Group"); @@ -1133,7 +1124,7 @@ fn mls_group_ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl Op let mls_group_create_config = MlsGroupCreateConfig::builder() .wire_format_policy(*wire_format_policy) .use_ratchet_tree_extension(true) - .crypto_config(CryptoConfig::with_default_version(ciphersuite)) + .ciphersuite(ciphersuite) .build(); // === Alice creates a group === @@ -1216,6 +1207,90 @@ fn mls_group_ratchet_tree_extension(ciphersuite: Ciphersuite, provider: &impl Op ) .expect_err("Could join a group without a ratchet tree"); - assert_eq!(error, WelcomeError::MissingRatchetTree); + assert!(matches!(error, WelcomeError::MissingRatchetTree)); } } + +/// Test that the a group context extensions proposal is correctly applied when valid, and rejected when not. +#[openmls_test] +fn group_context_extensions_proposal( + ciphersuite: Ciphersuite, + provider: &impl crate::storage::OpenMlsProvider, +) { + let (alice_credential_with_key, alice_signer) = + new_credential(provider, b"Alice", ciphersuite.signature_algorithm()); + + // === Alice creates a group === + let mut alice_group = MlsGroup::builder() + .ciphersuite(ciphersuite) + .build(provider, &alice_signer, alice_credential_with_key) + .expect("error creating group using builder"); + + // No required capabilities, so no specifically required extensions. + assert!(alice_group.extensions().required_capabilities().is_none()); + + let new_extensions = Extensions::single(Extension::RequiredCapabilities( + RequiredCapabilitiesExtension::new(&[ExtensionType::RequiredCapabilities], &[], &[]), + )); + + let new_extensions_2 = Extensions::single(Extension::RequiredCapabilities( + RequiredCapabilitiesExtension::new(&[ExtensionType::RatchetTree], &[], &[]), + )); + + alice_group + .propose_group_context_extensions(provider, new_extensions.clone(), &alice_signer) + .expect("failed to build group context extensions proposal"); + + assert_eq!(alice_group.pending_proposals().count(), 1); + + alice_group + .commit_to_pending_proposals(provider, &alice_signer) + .expect("failed to commit to pending proposals"); + + alice_group + .merge_pending_commit(provider) + .expect("error merging pending commit"); + + let required_capabilities = alice_group + .extensions() + .required_capabilities() + .expect("couldn't get required_capabilities"); + + // has required_capabilities as required capability + assert!(required_capabilities.extension_types() == [ExtensionType::RequiredCapabilities]); + + // === committing to two group context extensions should fail + + alice_group + .propose_group_context_extensions(provider, new_extensions, &alice_signer) + .expect("failed to build group context extensions proposal"); + + // the proposals need to be different or they will be deduplicated + alice_group + .propose_group_context_extensions(provider, new_extensions_2, &alice_signer) + .expect("failed to build group context extensions proposal"); + + assert_eq!(alice_group.pending_proposals().count(), 2); + + alice_group + .commit_to_pending_proposals(provider, &alice_signer) + .expect_err( + "expected error when committing to multiple group context extensions proposals", + ); + + // === can't update required required_capabilities to extensions that existing group members + // are not capable of + + // contains unsupported extension + let new_extensions = Extensions::single(Extension::RequiredCapabilities( + RequiredCapabilitiesExtension::new(&[ExtensionType::Unknown(0xf042)], &[], &[]), + )); + + alice_group + .propose_group_context_extensions(provider, new_extensions, &alice_signer) + .expect_err("expected an error building GCE proposal with bad required_capabilities"); + + // TODO: we need to test that processing a commit with multiple group context extensions + // proposal also fails. however, we can't generate this commit, because our functions for + // constructing commits does not permit it. See #1476 +} diff --git a/openmls_rust_crypto/Cargo.toml b/openmls_rust_crypto/Cargo.toml index 6bb3ef2994..b9594cd36c 100644 --- a/openmls_rust_crypto/Cargo.toml +++ b/openmls_rust_crypto/Cargo.toml @@ -11,10 +11,10 @@ readme = "README.md" [dependencies] openmls_traits = { version = "0.2.0", path = "../traits" } -openmls_memory_keystore = { version = "0.2.0", path = "../memory_keystore" } +openmls_memory_storage = { version = "0.2.0", path = "../memory_storage" } hpke = { version = "0.2.0", package = "hpke-rs", default-features = false, features = [ - "hazmat", - "serialization", + "hazmat", + "serialization", ] } # Rust Crypto dependencies sha2 = { version = "0.10" } diff --git a/openmls_rust_crypto/src/lib.rs b/openmls_rust_crypto/src/lib.rs index 358070c38a..d9e437330e 100644 --- a/openmls_rust_crypto/src/lib.rs +++ b/openmls_rust_crypto/src/lib.rs @@ -3,7 +3,7 @@ //! This is an implementation of the [`OpenMlsProvider`] trait to use with //! OpenMLS. -pub use openmls_memory_keystore::{MemoryKeyStore, MemoryKeyStoreError}; +pub use openmls_memory_storage::{MemoryStorage, MemoryStorageError}; use openmls_traits::OpenMlsProvider; mod provider; @@ -12,13 +12,17 @@ pub use provider::*; #[derive(Default, Debug)] pub struct OpenMlsRustCrypto { crypto: RustCrypto, - key_store: MemoryKeyStore, + key_store: MemoryStorage, } impl OpenMlsProvider for OpenMlsRustCrypto { type CryptoProvider = RustCrypto; type RandProvider = RustCrypto; - type KeyStoreProvider = MemoryKeyStore; + type StorageProvider = MemoryStorage; + + fn storage(&self) -> &Self::StorageProvider { + &self.key_store + } fn crypto(&self) -> &Self::CryptoProvider { &self.crypto @@ -27,8 +31,4 @@ impl OpenMlsProvider for OpenMlsRustCrypto { fn rand(&self) -> &Self::RandProvider { &self.crypto } - - fn key_store(&self) -> &Self::KeyStoreProvider { - &self.key_store - } } diff --git a/openmls_rust_crypto/src/provider.rs b/openmls_rust_crypto/src/provider.rs index 1556999f5f..23222c8827 100644 --- a/openmls_rust_crypto/src/provider.rs +++ b/openmls_rust_crypto/src/provider.rs @@ -47,6 +47,9 @@ fn kem_mode(kem: HpkeKemType) -> hpke_types::KemAlgorithm { HpkeKemType::DhKemP521 => hpke_types::KemAlgorithm::DhKemP521, HpkeKemType::DhKem25519 => hpke_types::KemAlgorithm::DhKem25519, HpkeKemType::DhKem448 => hpke_types::KemAlgorithm::DhKem448, + HpkeKemType::XWingKemDraft2 => { + unimplemented!("XWingKemDraft1 is not supported by the RustCrypto provider.") + } } } diff --git a/openmls_test/Cargo.toml b/openmls_test/Cargo.toml new file mode 100644 index 0000000000..be54b59fc4 --- /dev/null +++ b/openmls_test/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "openmls_test" +version = "0.1.0" +edition = "2021" +authors = ["Franziskus Kiefer "] + +[lib] +proc-macro = true + +[features] +# This needs to be enabled explicity to allow disabling on some platforms +libcrux-provider = ["dep:openmls_libcrux_crypto"] + +[dependencies] +syn = { version = "2.0", features = ["full", "visit"] } +proc-macro2 = { version = "1.0.10", features = ["span-locations"] } +ansi_term = "0.12.1" +quote = "1.0" +rstest = { version = "0.17" } +rstest_reuse = { version = "0.5" } +openmls_rust_crypto = { version = "=0.2.0", path = "../openmls_rust_crypto" } +openmls_libcrux_crypto = { version = "=0.1.0", path = "../libcrux_crypto", optional = true } +openmls_traits = { path = "../traits" } diff --git a/openmls_test/Readme.md b/openmls_test/Readme.md new file mode 100644 index 0000000000..5d27abd7d7 --- /dev/null +++ b/openmls_test/Readme.md @@ -0,0 +1,3 @@ +# Test macro + +This crate implements a proc macro for testing in OpenMLS. diff --git a/openmls_test/src/lib.rs b/openmls_test/src/lib.rs new file mode 100644 index 0000000000..8908ec7500 --- /dev/null +++ b/openmls_test/src/lib.rs @@ -0,0 +1,95 @@ +use openmls_rust_crypto::OpenMlsRustCrypto; +use openmls_traits::{crypto::OpenMlsCrypto, OpenMlsProvider}; +use proc_macro::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_macro_input, ItemFn}; + +#[proc_macro_attribute] +pub fn openmls_test(_attr: TokenStream, item: TokenStream) -> TokenStream { + let func = parse_macro_input!(item as ItemFn); + + let attrs = func.attrs; + let sig = func.sig; + let fn_name = sig.ident; + let body = func.block.stmts; + + let rc = OpenMlsRustCrypto::default(); + + let rc_ciphersuites = rc.crypto().supported_ciphersuites(); + + let mut test_funs = Vec::new(); + + for ciphersuite in rc_ciphersuites { + let val = ciphersuite as u16; + let ciphersuite_name = format!("{:?}", ciphersuite); + let name = format_ident!("{}_rustcrypto_{}", fn_name, ciphersuite_name); + let test_fun = quote! { + #(#attrs)* + #[allow(non_snake_case)] + #[test] + fn #name() { + use openmls_rust_crypto::OpenMlsRustCrypto; + use openmls_traits::{types::Ciphersuite, crypto::OpenMlsCrypto}; + + type Provider = OpenMlsRustCrypto; + + let ciphersuite = Ciphersuite::try_from(#val).unwrap(); + let provider = OpenMlsRustCrypto::default(); + let provider = &provider; + #(#body)* + } + }; + + test_funs.push(test_fun); + } + + #[cfg(all( + feature = "libcrux-provider", + not(any( + target_arch = "wasm32", + all(target_arch = "x86", target_os = "windows") + )) + ))] + { + let libcrux = openmls_libcrux_crypto::Provider::default(); + let libcrux_ciphersuites = libcrux.crypto().supported_ciphersuites(); + + for ciphersuite in libcrux_ciphersuites { + let val = ciphersuite as u16; + let ciphersuite_name = format!("{:?}", ciphersuite); + let name = format_ident!("{}_libcrux_{}", fn_name, ciphersuite_name); + let test_fun = quote! { + #(#attrs)* + #[allow(non_snake_case)] + #[test] + fn #name() { + use openmls_libcrux_crypto::Provider as OpenMlsLibcrux; + use openmls_traits::{types::Ciphersuite, prelude::*}; + + type Provider = OpenMlsLibcrux; + + let ciphersuite = Ciphersuite::try_from(#val).unwrap(); + let provider = OpenMlsLibcrux::default(); + + // When cross-compiling the supported ciphersuites may be wrong. + // They are set at compile-time. + if provider.crypto().supports(ciphersuite).is_err() { + eprintln!("Skipping unsupported ciphersuite {ciphersuite:?}."); + return; + } + + let provider = &provider; + #(#body)* + } + }; + + test_funs.push(test_fun); + } + } + + let out = quote! { + #(#test_funs)* + }; + + out.into() +} diff --git a/traits/src/key_store.rs b/traits/src/key_store.rs deleted file mode 100644 index c0c4dadba8..0000000000 --- a/traits/src/key_store.rs +++ /dev/null @@ -1,53 +0,0 @@ -//! # OpenMLS Key Store Trait - -/// Sealed list of struct openmls manages (create/read/delete) through [OpenMlsKeyStore] -pub enum MlsEntityId { - SignatureKeyPair, - HpkePrivateKey, - KeyPackage, - PskBundle, - EncryptionKeyPair, - GroupState, -} - -/// To implement by any struct owned by openmls aiming to be persisted in [OpenMlsKeyStore] -pub trait MlsEntity: serde::Serialize + serde::de::DeserializeOwned { - /// Identifier used to downcast the actual entity within an [OpenMlsKeyStore] method. - /// In case for example you need to select a SQL table depending on the entity type - const ID: MlsEntityId; -} - -/// Blanket impl for when you have to lookup a list of entities from the keystore -impl MlsEntity for Vec -where - T: MlsEntity + std::fmt::Debug, -{ - const ID: MlsEntityId = T::ID; -} - -/// The Key Store trait -pub trait OpenMlsKeyStore { - /// The error type returned by the [`OpenMlsKeyStore`]. - type Error: std::error::Error + std::fmt::Debug + PartialEq; - - /// Store a value `v` that implements the [`MlsEntity`] trait for - /// serialization for ID `k`. - /// - /// Returns an error if storing fails. - fn store(&self, k: &[u8], v: &V) -> Result<(), Self::Error> - where - Self: Sized; - - /// Read and return a value stored for ID `k` that implements the - /// [`MlsEntity`] trait for deserialization. - /// - /// Returns [`None`] if no value is stored for `k` or reading fails. - fn read(&self, k: &[u8]) -> Option - where - Self: Sized; - - /// Delete a value stored for ID `k`. - /// - /// Returns an error if storing fails. - fn delete(&self, k: &[u8]) -> Result<(), Self::Error>; -} diff --git a/traits/src/random.rs b/traits/src/random.rs index 483890a2a1..424961a497 100644 --- a/traits/src/random.rs +++ b/traits/src/random.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; pub trait OpenMlsRand { - type Error: std::error::Error + Debug + Clone + PartialEq; + type Error: std::error::Error + Debug; /// Fill an array with random bytes. fn random_array(&self) -> Result<[u8; N], Self::Error>; diff --git a/traits/src/storage.rs b/traits/src/storage.rs new file mode 100644 index 0000000000..f0b138fe67 --- /dev/null +++ b/traits/src/storage.rs @@ -0,0 +1,652 @@ +//! This module describes the storage provider and type traits. +//! The concept is that the type traits are implemented by OpenMLS, and the storage provider +//! implements the [`StorageProvider`] trait. The trait mostly defines getters and setters, but +//! also a few methods that append to lists (which behave similar to setters). + +use serde::{de::DeserializeOwned, Serialize}; +/// The storage version used by OpenMLS +pub const CURRENT_VERSION: u16 = 1; + +/// For testing there is a test version defined here. +/// +/// THIS VERSION MUST NEVER BE USED OUTSIDE OF TESTS. +#[cfg(any(test, feature = "test-utils"))] +pub const V_TEST: u16 = u16::MAX; + +/// StorageProvider describes the storage backing OpenMLS and persists the state of OpenMLS groups. +/// +/// The getters for individual values usually return a `Result, E>`, where `Err(_)` +/// indicates that some sort of IO or internal error occurred, and `Ok(None)` indicates that no +/// error occurred, but no value exists. +/// Many getters for lists return a `Result, E>`. In this case, if there was no error but +/// the value doesn't exist, an empty vector should be returned. +/// +/// More details can be taken from the comments on the respective method. +pub trait StorageProvider { + /// An opaque error returned by all methods on this trait. + type Error: core::fmt::Debug + std::error::Error + PartialEq; + + /// Get the version of this provider. + fn version() -> u16 { + VERSION + } + + // + // --- setters/writers/enqueuers for group state --- + // + + /// Writes the MlsGroupJoinConfig for the group with given id to storage + fn write_mls_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + group_id: &GroupId, + config: &MlsGroupJoinConfig, + ) -> Result<(), Self::Error>; + + /// Writes the AAD for the group with given id to storage + fn write_aad>( + &self, + group_id: &GroupId, + aad: &[u8], + ) -> Result<(), Self::Error>; + + /// Adds an own leaf node for the group with given id to storage + fn append_own_leaf_node< + GroupId: traits::GroupId, + LeafNode: traits::LeafNode, + >( + &self, + group_id: &GroupId, + leaf_node: &LeafNode, + ) -> Result<(), Self::Error>; + + /// Clears the own leaf node for the group with given id to storage + fn clear_own_leaf_nodes>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Enqueue a proposal. + /// + /// A good way to implement this could be to add a proposal to a proposal store, indexed by the + /// proposal reference, and adding the reference to a per-group proposal queue list. + fn queue_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + group_id: &GroupId, + proposal_ref: &ProposalRef, + proposal: &QueuedProposal, + ) -> Result<(), Self::Error>; + + /// Write the TreeSync tree. + fn write_tree, TreeSync: traits::TreeSync>( + &self, + group_id: &GroupId, + tree: &TreeSync, + ) -> Result<(), Self::Error>; + + /// Write the interim transcript hash. + fn write_interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + group_id: &GroupId, + interim_transcript_hash: &InterimTranscriptHash, + ) -> Result<(), Self::Error>; + + /// Write the group context. + fn write_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + group_id: &GroupId, + group_context: &GroupContext, + ) -> Result<(), Self::Error>; + + /// Write the confirmation tag. + fn write_confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + group_id: &GroupId, + confirmation_tag: &ConfirmationTag, + ) -> Result<(), Self::Error>; + + /// Writes the MlsGroupState for group with given id. + fn write_group_state< + GroupState: traits::GroupState, + GroupId: traits::GroupId, + >( + &self, + group_id: &GroupId, + group_state: &GroupState, + ) -> Result<(), Self::Error>; + + /// Writes the MessageSecretsStore for the group with the given id. + fn write_message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + group_id: &GroupId, + message_secrets: &MessageSecrets, + ) -> Result<(), Self::Error>; + + /// Writes the ResumptionPskStore for the group with the given id. + fn write_resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + group_id: &GroupId, + resumption_psk_store: &ResumptionPskStore, + ) -> Result<(), Self::Error>; + + /// Writes the own leaf index inside the group for the group with the given id. + fn write_own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + group_id: &GroupId, + own_leaf_index: &LeafNodeIndex, + ) -> Result<(), Self::Error>; + + /// Returns the MlsGroupState for group with given id. + /// Sets whether to use the RatchetTreeExtension for the group with the given id. + fn set_use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + value: bool, + ) -> Result<(), Self::Error>; + + /// Writes the GroupEpochSecrets for the group with the given id. + fn write_group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + group_id: &GroupId, + group_epoch_secrets: &GroupEpochSecrets, + ) -> Result<(), Self::Error>; + + // + // --- setters/writers/enqueuers for crypto objects --- + // + + /// Store a signature key. + /// + /// The signature key pair is not known to OpenMLS. This may be used by the + /// application + fn write_signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + public_key: &SignaturePublicKey, + signature_key_pair: &SignatureKeyPair, + ) -> Result<(), Self::Error>; + + /// Store an HPKE encryption key pair. + /// This includes the private and public key + /// + /// This is only be used for encryption key pairs that are generated for + /// update leaf nodes. All other encryption key pairs are stored as part + /// of the key package or the epoch encryption key pairs. + fn write_encryption_key_pair< + EncryptionKey: traits::EncryptionKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + public_key: &EncryptionKey, + key_pair: &HpkeKeyPair, + ) -> Result<(), Self::Error>; + + /// Store a list of HPKE encryption key pairs for a given epoch. + /// This includes the private and public keys. + fn write_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + key_pairs: &[HpkeKeyPair], + ) -> Result<(), Self::Error>; + + /// Store key packages. + /// + /// Store a key package. This includes the private init key. + /// The encryption key is stored separately with `write_encryption_key_pair`. + /// + /// Note that it is recommended to store a list of the hash references as well + /// in order to iterate over key packages. OpenMLS does not have a reference + /// for them. + // ANCHOR: write_key_package + fn write_key_package< + HashReference: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &HashReference, + key_package: &KeyPackage, + ) -> Result<(), Self::Error>; + // ANCHOR_END: write_key_package + + /// Store a PSK. + /// + /// This stores PSKs based on the PSK id. + /// + /// PSKs are only read by OpenMLS. The application is responsible for managing + /// and storing PSKs. + fn write_psk, PskBundle: traits::PskBundle>( + &self, + psk_id: &PskId, + psk: &PskBundle, + ) -> Result<(), Self::Error>; + + // + // --- getters for group state --- + // + + /// Returns the MlsGroupJoinConfig for the group with given id + fn mls_group_join_config< + GroupId: traits::GroupId, + MlsGroupJoinConfig: traits::MlsGroupJoinConfig, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the own leaf nodes for the group with given id + fn own_leaf_nodes, LeafNode: traits::LeafNode>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the AAD for the group with given id + /// If the value has not been set, returns an empty vector. + fn aad>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns references of all queued proposals for the group with group id `group_id`, or an empty vector of none are stored. + fn queued_proposal_refs< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns all queued proposals for the group with group id `group_id`, or an empty vector of none are stored. + fn queued_proposals< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + QueuedProposal: traits::QueuedProposal, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the TreeSync tree for the group with group id `group_id`. + fn treesync, TreeSync: traits::TreeSync>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the group context for the group with group id `group_id`. + fn group_context< + GroupId: traits::GroupId, + GroupContext: traits::GroupContext, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the interim transcript hash for the group with group id `group_id`. + fn interim_transcript_hash< + GroupId: traits::GroupId, + InterimTranscriptHash: traits::InterimTranscriptHash, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the confirmation tag for the group with group id `group_id`. + fn confirmation_tag< + GroupId: traits::GroupId, + ConfirmationTag: traits::ConfirmationTag, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the group state for the group with group id `group_id`. + fn group_state, GroupId: traits::GroupId>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the MessageSecretsStore for the group with the given id. + fn message_secrets< + GroupId: traits::GroupId, + MessageSecrets: traits::MessageSecrets, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the ResumptionPskStore for the group with the given id. + fn resumption_psk_store< + GroupId: traits::GroupId, + ResumptionPskStore: traits::ResumptionPskStore, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the own leaf index inside the group for the group with the given id. + fn own_leaf_index< + GroupId: traits::GroupId, + LeafNodeIndex: traits::LeafNodeIndex, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns whether to use the RatchetTreeExtension for the group with the given id. + fn use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + /// Returns the GroupEpochSecrets for the group with the given id. + fn group_epoch_secrets< + GroupId: traits::GroupId, + GroupEpochSecrets: traits::GroupEpochSecrets, + >( + &self, + group_id: &GroupId, + ) -> Result, Self::Error>; + + // + // --- getter for crypto objects --- + // + + /// Get a signature key based on the public key. + /// + /// The signature key pair is not known to OpenMLS. This may be used by the + /// application + fn signature_key_pair< + SignaturePublicKey: traits::SignaturePublicKey, + SignatureKeyPair: traits::SignatureKeyPair, + >( + &self, + public_key: &SignaturePublicKey, + ) -> Result, Self::Error>; + + /// Get an HPKE encryption key pair based on the public key. + /// + /// This is only be used for encryption key pairs that are generated for + /// update leaf nodes. All other encryption key pairs are stored as part + /// of the key package or the epoch encryption key pairs. + fn encryption_key_pair< + HpkeKeyPair: traits::HpkeKeyPair, + EncryptionKey: traits::EncryptionKey, + >( + &self, + public_key: &EncryptionKey, + ) -> Result, Self::Error>; + + /// Get a list of HPKE encryption key pairs for a given epoch. + /// This includes the private and public keys. + fn encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + HpkeKeyPair: traits::HpkeKeyPair, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + ) -> Result, Self::Error>; + + /// Get a key package based on its hash reference. + fn key_package< + KeyPackageRef: traits::HashReference, + KeyPackage: traits::KeyPackage, + >( + &self, + hash_ref: &KeyPackageRef, + ) -> Result, Self::Error>; + + /// Get a PSK based on the PSK identifier. + fn psk, PskId: traits::PskId>( + &self, + psk_id: &PskId, + ) -> Result, Self::Error>; + + // + // --- deleters for group state --- + // + + /// Removes an individual proposal from the proposal queue of the group with the provided id + fn remove_proposal< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + proposal_ref: &ProposalRef, + ) -> Result<(), Self::Error>; + + /// Deletes the AAD for the given id from storage + fn delete_aad>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes own leaf nodes for the given id from storage + fn delete_own_leaf_nodes>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the MlsGroupJoinConfig for the given id from storage + fn delete_group_config>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the tree from storage + fn delete_tree>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the confirmation tag from storage + fn delete_confirmation_tag>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the MlsGroupState for group with given id. + fn delete_group_state>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the group context for the group with given id + fn delete_context>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the interim transcript hash for the group with given id + fn delete_interim_transcript_hash>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the MessageSecretsStore for the group with the given id. + fn delete_message_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the ResumptionPskStore for the group with the given id. + fn delete_all_resumption_psk_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the own leaf index inside the group for the group with the given id. + fn delete_own_leaf_index>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes any preference about whether to use the RatchetTreeExtension for the group with the given id. + fn delete_use_ratchet_tree_extension>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Deletes the GroupEpochSecrets for the group with the given id. + fn delete_group_epoch_secrets>( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + /// Clear the proposal queue for the grou pwith the given id. + fn clear_proposal_queue< + GroupId: traits::GroupId, + ProposalRef: traits::ProposalRef, + >( + &self, + group_id: &GroupId, + ) -> Result<(), Self::Error>; + + // + // --- deleters for crypto objects --- + // + + /// Delete a signature key pair based on its public key + /// + /// The signature key pair is not known to OpenMLS. This may be used by the + /// application + fn delete_signature_key_pair>( + &self, + public_key: &SignaturePublicKey, + ) -> Result<(), Self::Error>; + + /// Delete an encryption key pair for a public key. + /// + /// This is only be used for encryption key pairs that are generated for + /// update leaf nodes. All other encryption key pairs are stored as part + /// of the key package or the epoch encryption key pairs. + fn delete_encryption_key_pair>( + &self, + public_key: &EncryptionKey, + ) -> Result<(), Self::Error>; + + /// Delete a list of HPKE encryption key pairs for a given epoch. + /// This includes the private and public keys. + fn delete_encryption_epoch_key_pairs< + GroupId: traits::GroupId, + EpochKey: traits::EpochKey, + >( + &self, + group_id: &GroupId, + epoch: &EpochKey, + leaf_index: u32, + ) -> Result<(), Self::Error>; + + /// Delete a key package based on the hash reference. + /// + /// This function only deletes the key package. + /// The corresponding encryption keys must be deleted separately. + fn delete_key_package>( + &self, + hash_ref: &KeyPackageRef, + ) -> Result<(), Self::Error>; + + /// Delete a PSK based on an identifier. + fn delete_psk>( + &self, + psk_id: &PskKey, + ) -> Result<(), Self::Error>; +} + +// base traits for keys and values + +// ANCHOR: key_trait +/// Key is a trait implemented by all types that serve as a key (in the database sense) to in the +/// storage. For example, a GroupId is a key to the stored entities for the group with that id. +/// The point of a key is not to be stored, it's to address something that is stored. +pub trait Key: Serialize {} +// ANCHOR_END: key_trait + +// ANCHOR: entity_trait +/// Entity is a trait implemented by the values being stored. +pub trait Entity: Serialize + DeserializeOwned {} +// ANCHOR_END: entity_trait + +impl Entity for bool {} +impl Entity for u8 {} + +// in the following we define specific traits for Keys and Entities. That way +// we can don't sacrifice type safety in the implementations of the storage provider. +// note that there are types that are used both as keys and as entities. + +// ANCHOR: traits +/// Each trait in this module corresponds to a type. Some are used as keys, some as +/// entities, and some both. Therefore, the Key and/or Entity traits also need to be implemented. +pub mod traits { + use super::{Entity, Key}; + + // traits for keys, one per data type + pub trait GroupId: Key {} + pub trait SignaturePublicKey: Key {} + pub trait HashReference: Key {} + pub trait PskId: Key {} + pub trait EncryptionKey: Key {} + pub trait EpochKey: Key {} + + // traits for entity, one per type + pub trait QueuedProposal: Entity {} + pub trait TreeSync: Entity {} + pub trait GroupContext: Entity {} + pub trait InterimTranscriptHash: Entity {} + pub trait ConfirmationTag: Entity {} + pub trait SignatureKeyPair: Entity {} + pub trait PskBundle: Entity {} + pub trait HpkeKeyPair: Entity {} + pub trait GroupState: Entity {} + pub trait GroupEpochSecrets: Entity {} + pub trait LeafNodeIndex: Entity {} + pub trait GroupUseRatchetTreeExtension: Entity {} + pub trait MessageSecrets: Entity {} + pub trait ResumptionPskStore: Entity {} + pub trait KeyPackage: Entity {} + pub trait MlsGroupJoinConfig: Entity {} + pub trait LeafNode: Entity {} + + // traits for types that implement both + pub trait ProposalRef: Entity + Key {} +} +// ANCHOR_END: traits + +impl Entity for Vec {} diff --git a/traits/src/traits.rs b/traits/src/traits.rs index 6ad592b535..96514da36c 100644 --- a/traits/src/traits.rs +++ b/traits/src/traits.rs @@ -4,11 +4,21 @@ //! API of OpenMLS. pub mod crypto; -pub mod key_store; pub mod random; pub mod signatures; +pub mod storage; pub mod types; +/// A prelude to include to get all traits in scope and expose `openmls_types`. +pub mod prelude { + pub use super::crypto::OpenMlsCrypto as _; + pub use super::random::OpenMlsRand as _; + pub use super::signatures::Signer as _; + pub use super::storage::StorageProvider as _; + pub use super::types as openmls_types; + pub use super::OpenMlsProvider as _; +} + /// The OpenMLS Crypto Provider Trait /// /// An implementation of this trait must be passed in to the public OpenMLS API @@ -16,14 +26,14 @@ pub mod types; pub trait OpenMlsProvider { type CryptoProvider: crypto::OpenMlsCrypto; type RandProvider: random::OpenMlsRand; - type KeyStoreProvider: key_store::OpenMlsKeyStore; + type StorageProvider: storage::StorageProvider<{ storage::CURRENT_VERSION }>; + + // Get the storage provider. + fn storage(&self) -> &Self::StorageProvider; /// Get the crypto provider. fn crypto(&self) -> &Self::CryptoProvider; /// Get the randomness provider. fn rand(&self) -> &Self::RandProvider; - - /// Get the key store provider. - fn key_store(&self) -> &Self::KeyStoreProvider; } diff --git a/traits/src/types.rs b/traits/src/types.rs index 9a7109d133..521dd03a60 100644 --- a/traits/src/types.rs +++ b/traits/src/types.rs @@ -9,8 +9,6 @@ use tls_codec::{ SecretVLBytes, TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes, }; -use crate::key_store::{MlsEntity, MlsEntityId}; - #[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] #[repr(u16)] /// AEAD types @@ -181,6 +179,9 @@ pub enum HpkeKemType { /// DH KEM on x448 DhKem448 = 0x0021, + + /// XWing combiner for ML-KEM and X25519 + XWingKemDraft2 = 0x004D, } /// KDF Types for HPKE @@ -274,10 +275,6 @@ impl std::ops::Deref for HpkePrivateKey { } } -impl MlsEntity for HpkePrivateKey { - const ID: MlsEntityId = MlsEntityId::HpkePrivateKey; -} - /// Helper holding a (private, public) key pair as byte vectors. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HpkeKeyPair { @@ -374,6 +371,9 @@ pub enum Ciphersuite { /// DH KEM P384 | AES-GCM 256 | SHA2-384 | EcDSA P384 MLS_256_DHKEMP384_AES256GCM_SHA384_P384 = 0x0007, + + /// X-WING KEM draft-01 | Chacha20Poly1305 | SHA2-256 | Ed25519 + MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519 = 0x004D, } impl core::fmt::Display for Ciphersuite { @@ -409,6 +409,7 @@ impl TryFrom for Ciphersuite { 0x0005 => Ok(Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521), 0x0006 => Ok(Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448), 0x0007 => Ok(Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384), + 0x004D => Ok(Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519), _ => Err(Self::Error::DecodingError(format!( "{v} is not a valid ciphersuite value" ))), @@ -465,9 +466,8 @@ impl Ciphersuite { match self { Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 | Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 - | Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 => { - HashType::Sha2_256 - } + | Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 + | Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519 => HashType::Sha2_256, Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384 => HashType::Sha2_384, Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448 | Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521 @@ -480,7 +480,8 @@ impl Ciphersuite { pub const fn signature_algorithm(&self) -> SignatureScheme { match self { Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 - | Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 => { + | Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 + | Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519 => { SignatureScheme::ED25519 } Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 => { @@ -506,7 +507,8 @@ impl Ciphersuite { Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 | Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 => AeadType::Aes128Gcm, Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 - | Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448 => { + | Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448 + | Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519 => { AeadType::ChaCha20Poly1305 } Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448 @@ -521,9 +523,8 @@ impl Ciphersuite { match self { Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 | Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 - | Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 => { - HpkeKdfType::HkdfSha256 - } + | Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 + | Self::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519 => HpkeKdfType::HkdfSha256, Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384 => HpkeKdfType::HkdfSha384, Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448 | Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521 @@ -546,6 +547,9 @@ impl Ciphersuite { | Ciphersuite::MLS_256_DHKEMX448_CHACHA20POLY1305_SHA512_Ed448 => HpkeKemType::DhKem448, Ciphersuite::MLS_256_DHKEMP384_AES256GCM_SHA384_P384 => HpkeKemType::DhKemP384, Ciphersuite::MLS_256_DHKEMP521_AES256GCM_SHA512_P521 => HpkeKemType::DhKemP521, + Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519 => { + HpkeKemType::XWingKemDraft2 + } } } @@ -555,7 +559,8 @@ impl Ciphersuite { match self { Ciphersuite::MLS_128_DHKEMX25519_AES128GCM_SHA256_Ed25519 | Ciphersuite::MLS_128_DHKEMP256_AES128GCM_SHA256_P256 => HpkeAeadType::AesGcm128, - Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 => { + Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519 + | Ciphersuite::MLS_256_XWING_CHACHA20POLY1305_SHA256_Ed25519 => { HpkeAeadType::ChaCha20Poly1305 } Ciphersuite::MLS_256_DHKEMX448_AES256GCM_SHA512_Ed448