From 80561ccd8819fb0bdbc1fd22062ab21e33cb6d5c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Sat, 25 Nov 2023 00:01:36 +0000 Subject: [PATCH] Expose new methods for importing room keys (#54) This exposes the new, separate methods for importing room keys depending on whether they came from backup or manual export. It also adds a new return type which doesn't rely quite so much on JSON-encoded objects. --- CHANGELOG.md | 5 ++ src/machine.rs | 146 ++++++++++++++++++++++++++++++++++++------ src/types.rs | 58 ++++++++++++++++- tests/machine.test.js | 134 +++++++++++++++++++++++++++----------- 4 files changed, 285 insertions(+), 58 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4b8233c7..07b27efbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ ## Other changes +- `OlmMachine.importRoomKeys` is now deprecated in favour of separate + methods for importing room keys from backup and export, + `OlmMachine.importBackedUpRoomKeys` and + `OlmMachine.importExportedRoomKeys`. + - Devices which have exhausted their one-time-keys will now be correctly handled in `/keys/claim` responses (we will register them as "failed" and stop attempting to send to them for a while.) diff --git a/src/machine.rs b/src/machine.rs index 211537712..3b79bbd95 100644 --- a/src/machine.rs +++ b/src/machine.rs @@ -3,16 +3,17 @@ use std::{collections::BTreeMap, ops::Deref, time::Duration}; use futures_util::{pin_mut, StreamExt}; -use js_sys::{Array, Function, Map, Promise, Set}; +use js_sys::{Array, Function, JsString, Map, Promise, Set}; use matrix_sdk_common::ruma::{ self, events::secret::request::SecretName, serde::Raw, DeviceKeyAlgorithm, OwnedTransactionId, UInt, }; use matrix_sdk_crypto::{ backups::MegolmV1BackupKey, + olm::BackedUpRoomKey, store::{DeviceChanges, IdentityChanges}, types::RoomKeyBackupInfo, - EncryptionSyncChanges, GossippedSecret, + CryptoStoreError, EncryptionSyncChanges, GossippedSecret, }; use serde_json::json; use serde_wasm_bindgen; @@ -33,7 +34,7 @@ use crate::{ store, store::RoomKeyInfo, sync_events, - types::{self, SignatureVerification}, + types::{self, RoomKeyImportResult, SignatureVerification}, verification, vodozemac, }; @@ -819,12 +820,17 @@ impl OlmMachine { /// Import the given room keys into our store. /// - /// `exported_keys` is a list of previously exported keys that should be - /// imported into our store. If we already have a better version of a key, - /// the key will _not_ be imported. + /// Mostly, a deprecated alias for `importExportedRoomKeys`, though the + /// return type is different. /// - /// `progress_listener` is a closure that takes 2 arguments: `progress` and - /// `total`, and returns nothing. + /// Returns a String containing a JSON-encoded object, holding three + /// properties: + /// * `total_count` (the total number of keys found in the export data). + /// * `imported_count` (the number of keys that were imported). + /// * `keys` (the keys that were imported; a map from room id to a map of + /// the sender key to a list of session ids). + /// + /// @deprecated Use `importExportedRoomKeys` or `importBackedUpRoomKeys`. #[wasm_bindgen(js_name = "importRoomKeys")] pub fn import_room_keys( &self, @@ -832,21 +838,12 @@ impl OlmMachine { progress_listener: Function, ) -> Result { let me = self.inner.clone(); - let exported_room_keys: Vec = - serde_json::from_str(exported_room_keys)?; + let exported_room_keys = serde_json::from_str(exported_room_keys)?; Ok(future_to_promise(async move { - let matrix_sdk_crypto::RoomKeyImportResult { imported_count, total_count, keys } = me - .store() - .import_exported_room_keys(exported_room_keys, |progress, total| { - let progress: u64 = progress.try_into().unwrap(); - let total: u64 = total.try_into().unwrap(); - - progress_listener - .call2(&JsValue::NULL, &JsValue::from(progress), &JsValue::from(total)) - .expect("Progress listener passed to `import_room_keys` failed"); - }) - .await?; + let matrix_sdk_crypto::RoomKeyImportResult { imported_count, total_count, keys } = + Self::import_exported_room_keys_helper(&me, exported_room_keys, progress_listener) + .await?; Ok(serde_json::to_string(&json!({ "imported_count": imported_count, @@ -856,6 +853,113 @@ impl OlmMachine { })) } + /// Import the given room keys into our store. + /// + /// `exported_keys` is a JSON-encoded list of previously exported keys that + /// should be imported into our store. If we already have a better + /// version of a key, the key will _not_ be imported. + /// + /// `progress_listener` is a closure that takes 2 `BigInt` arguments: + /// `progress` and `total`, and returns nothing. + /// + /// Returns a {@link RoomKeyImportResult}. + #[wasm_bindgen(js_name = "importExportedRoomKeys")] + pub fn import_exported_room_keys( + &self, + exported_room_keys: &str, + progress_listener: Function, + ) -> Result { + let me = self.inner.clone(); + let exported_room_keys = serde_json::from_str(exported_room_keys)?; + + Ok(future_to_promise(async move { + let result: RoomKeyImportResult = + Self::import_exported_room_keys_helper(&me, exported_room_keys, progress_listener) + .await? + .into(); + Ok(result) + })) + } + + /// Shared helper for `import_exported_room_keys` and `import_room_keys`. + /// + /// Wraps the progress listener in a Rust closure and runs + /// `Store::import_exported_room_keys` + async fn import_exported_room_keys_helper( + inner: &matrix_sdk_crypto::OlmMachine, + exported_room_keys: Vec, + progress_listener: Function, + ) -> Result { + inner + .store() + .import_exported_room_keys(exported_room_keys, |progress, total| { + progress_listener + .call2(&JsValue::NULL, &JsValue::from(progress), &JsValue::from(total)) + .expect("Progress listener passed to `importExportedRoomKeys` failed"); + }) + .await + } + + /// Import the given room keys into our store. + /// + /// # Arguments + /// + /// * `backed_up_room_keys`: keys that were retrieved from backup and that + /// should be added to our store (provided they are better than our + /// current versions of those keys). Specifically, it should be a Map from + /// {@link RoomId}, to a Map from session ID to a (decrypted) session data + /// structure. + /// + /// * `progress_listener`: an optional callback that takes 2 arguments: + /// `progress` and `total`, and returns nothing. + /// + /// # Returns + /// + /// A {@link RoomKeyImportResult}. + #[wasm_bindgen(js_name = "importBackedUpRoomKeys")] + pub fn import_backed_up_room_keys( + &self, + backed_up_room_keys: &Map, + progress_listener: Option, + ) -> Result { + let me = self.inner.clone(); + + // convert the js-side data into rust data + let mut keys: BTreeMap<_, BTreeMap<_, _>> = BTreeMap::new(); + for backed_up_room_keys_entry in backed_up_room_keys.entries() { + let backed_up_room_keys_entry: Array = backed_up_room_keys_entry?.dyn_into()?; + let room_id = + &downcast::(&backed_up_room_keys_entry.get(0), "RoomId")? + .inner; + + let room_room_keys: Map = backed_up_room_keys_entry.get(1).dyn_into()?; + + for room_room_keys_entry in room_room_keys.entries() { + let room_room_keys_entry: Array = room_room_keys_entry?.dyn_into()?; + let session_id: JsString = room_room_keys_entry.get(0).dyn_into()?; + let key: BackedUpRoomKey = + serde_wasm_bindgen::from_value(room_room_keys_entry.get(1))?; + + keys.entry(room_id.clone()).or_default().insert(session_id.into(), key); + } + } + + Ok(future_to_promise(async move { + let result: RoomKeyImportResult = me + .backup_machine() + .import_backed_up_room_keys(keys, |progress, total| { + if let Some(callback) = &progress_listener { + callback + .call2(&JsValue::NULL, &JsValue::from(progress), &JsValue::from(total)) + .expect("Progress listener passed to `importBackedUpRoomKeys` failed"); + } + }) + .await? + .into(); + Ok(result) + })) + } + /// Store the backup decryption key in the crypto store. /// /// This is useful if the client wants to support gossiping of the backup diff --git a/src/types.rs b/src/types.rs index eeee7f666..6ce6b116e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,6 +1,9 @@ //! Extra types, like `Signatures`. -use js_sys::{JsString, Map}; +use std::collections::{BTreeMap, BTreeSet}; + +use js_sys::{Array, JsString, Map, Set}; +use matrix_sdk_common::ruma::OwnedRoomId; use matrix_sdk_crypto::backups::{ SignatureState as InnerSignatureState, SignatureVerification as InnerSignatureVerification, }; @@ -227,3 +230,56 @@ impl SignatureVerification { self.inner.trusted() } } + +/// The result of a call to {@link OlmMachine.importExportedRoomKeys} or +/// {@link OlmMachine.importBackedUpRoomKeys}. +#[derive(Clone, Debug)] +#[wasm_bindgen] +pub struct RoomKeyImportResult { + /// The number of room keys that were imported. + #[wasm_bindgen(readonly, js_name = "importedCount")] + pub imported_count: usize, + + /// The total number of room keys that were found in the export. + #[wasm_bindgen(readonly, js_name = "totalCount")] + pub total_count: usize, + + /// The map of keys that were imported. + /// + /// A map from room id to a map of the sender key to a set of session ids. + keys: BTreeMap>>, +} + +#[wasm_bindgen] +impl RoomKeyImportResult { + /// The keys that were imported. + /// + /// A Map from room id to a Map of the sender key to a Set of session ids. + /// + /// Typescript type: `Map>`. + pub fn keys(&self) -> Map { + let key_map = Map::new(); + + for (room_id, room_result) in self.keys.iter() { + let room_map = Map::new(); + key_map.set(&JsString::from(room_id.to_string()), &room_map); + + for (sender_key, sessions) in room_result.iter() { + let s: Array = sessions.iter().map(|s| JsString::from(s.as_ref())).collect(); + room_map.set(&JsString::from(sender_key.as_ref()), &Set::new(&s)); + } + } + + key_map + } +} + +impl From for RoomKeyImportResult { + fn from(value: matrix_sdk_crypto::RoomKeyImportResult) -> Self { + RoomKeyImportResult { + imported_count: value.imported_count, + total_count: value.total_count, + keys: value.keys, + } + } +} diff --git a/tests/machine.test.js b/tests/machine.test.js index b4302c525..1f49f5a4a 100644 --- a/tests/machine.test.js +++ b/tests/machine.test.js @@ -686,52 +686,55 @@ describe(OlmMachine.name, () => { expect(userId.toString()).toEqual(user.toString()); }); - describe("can export/import room keys", () => { + test("can export room keys", async () => { + let m = await machine(); + await m.shareRoomKey(room, [new UserId("@bob:example.org")], new EncryptionSettings()); + + let exportedRoomKeys = await m.exportRoomKeys((session) => { + expect(session).toBeInstanceOf(InboundGroupSession); + expect(session.senderKey.toBase64()).toEqual(m.identityKeys.curve25519.toBase64()); + expect(session.roomId.toString()).toStrictEqual(room.toString()); + expect(session.sessionId).toBeDefined(); + expect(session.hasBeenImported()).toStrictEqual(false); + + return true; + }); + + const roomKeys = JSON.parse(exportedRoomKeys); + expect(roomKeys).toHaveLength(1); + expect(roomKeys[0]).toMatchObject({ + algorithm: expect.any(String), + room_id: room.toString(), + sender_key: expect.any(String), + session_id: expect.any(String), + session_key: expect.any(String), + sender_claimed_keys: { + ed25519: expect.any(String), + }, + forwarding_curve25519_key_chain: [], + }); + }); + + describe("can process exported room keys", () => { let exportedRoomKeys; - test("can export room keys", async () => { + beforeEach(async () => { let m = await machine(); await m.shareRoomKey(room, [new UserId("@bob:example.org")], new EncryptionSettings()); - exportedRoomKeys = await m.exportRoomKeys((session) => { - expect(session).toBeInstanceOf(InboundGroupSession); - expect(session.senderKey.toBase64()).toEqual(m.identityKeys.curve25519.toBase64()); - expect(session.roomId.toString()).toStrictEqual(room.toString()); - expect(session.sessionId).toBeDefined(); - expect(session.hasBeenImported()).toStrictEqual(false); - - return true; - }); - - const roomKeys = JSON.parse(exportedRoomKeys); - expect(roomKeys).toHaveLength(1); - expect(roomKeys[0]).toMatchObject({ - algorithm: expect.any(String), - room_id: room.toString(), - sender_key: expect.any(String), - session_id: expect.any(String), - session_key: expect.any(String), - sender_claimed_keys: { - ed25519: expect.any(String), - }, - forwarding_curve25519_key_chain: [], - }); + exportedRoomKeys = await m.exportRoomKeys((_session) => true); }); - let encryptedExportedRoomKeys; - let encryptionPassphrase = "Hello, Matrix!"; - - test("can encrypt the exported room keys", () => { - encryptedExportedRoomKeys = OlmMachine.encryptExportedRoomKeys( + test("can encrypt and decrypt the exported room keys", () => { + let encryptionPassphrase = "Hello, Matrix!"; + let encryptedExportedRoomKeys = OlmMachine.encryptExportedRoomKeys( exportedRoomKeys, encryptionPassphrase, 100_000, ); expect(encryptedExportedRoomKeys).toMatch(/^-----BEGIN MEGOLM SESSION DATA-----/); - }); - test("can decrypt the exported room keys", () => { const decryptedExportedRoomKeys = OlmMachine.decryptExportedRoomKeys( encryptedExportedRoomKeys, encryptionPassphrase, @@ -740,13 +743,13 @@ describe(OlmMachine.name, () => { expect(decryptedExportedRoomKeys).toStrictEqual(exportedRoomKeys); }); - test("can import room keys", async () => { + test("can import room keys via importRoomKeys", async () => { const progressListener = (progress, total) => { expect(progress).toBeLessThan(total); // Since it's called only once, let's be crazy. - expect(progress).toStrictEqual(0n); - expect(total).toStrictEqual(1n); + expect(progress).toStrictEqual(0); + expect(total).toStrictEqual(1); }; let m = await machine(); @@ -759,12 +762,28 @@ describe(OlmMachine.name, () => { }); }); + test("can import room keys via importExportedRoomKeys", async () => { + const progressListener = (progress, total) => { + expect(progress).toStrictEqual(0); + expect(total).toStrictEqual(1); + }; + + let m = await machine(); + const result = await m.importExportedRoomKeys(exportedRoomKeys, progressListener); + + expect(result.importedCount).toStrictEqual(1); + expect(result.totalCount).toStrictEqual(1); + expect(result.keys()).toMatchObject( + new Map([[room.toString(), new Map([[expect.any(String), new Set([expect.any(String)])]])]]), + ); + }); + test("importing room keys calls RoomKeyUpdatedCallback", async () => { const callback = jest.fn(); callback.mockImplementation(() => Promise.resolve(undefined)); let m = await machine(); m.registerRoomKeyUpdatedCallback(callback); - await m.importRoomKeys(exportedRoomKeys, (_, _1) => {}); + await m.importRoomKeys(exportedRoomKeys, (_, _1) => undefined); expect(callback).toHaveBeenCalledTimes(1); let keyInfoList = callback.mock.calls[0][0]; expect(keyInfoList.length).toEqual(1); @@ -1071,5 +1090,48 @@ describe(OlmMachine.name, () => { expect(savedKey.decryptionKeyBase64).toStrictEqual(keyBackupKey.toBase64()); expect(savedKey.backupVersion).toStrictEqual("3"); }); + + test("can import keys via importBackedUpRoomKeys", async () => { + // first do a backup from one OlmMachine + const m = await machine(); + await m.shareRoomKey(room, [new UserId("@bob:example.org")], new EncryptionSettings()); + const keyBackupKey = BackupDecryptionKey.createRandomKey(); + await m.enableBackupV1(keyBackupKey.megolmV1PublicKey.publicKeyBase64, "1"); + const outgoing = await m.backupRoomKeys(); + expect(outgoing.type).toStrictEqual(RequestType.KeysBackup); + const exportedKeys = JSON.parse(outgoing.body); + + // decrypt the backup + const decryptedKeyMap = new Map(); + for (const [roomId, roomKeys] of Object.entries(exportedKeys.rooms)) { + const decryptedRoomKeyMap = new Map(); + decryptedKeyMap.set(new RoomId(roomId), decryptedRoomKeyMap); + for (const [sessionId, keyBackupData] of Object.entries(roomKeys.sessions)) { + const decrypted = JSON.parse( + keyBackupKey.decryptV1( + keyBackupData.session_data.ephemeral, + keyBackupData.session_data.mac, + keyBackupData.session_data.ciphertext, + ), + ); + expect(decrypted.algorithm).toStrictEqual("m.megolm.v1.aes-sha2"); + decryptedRoomKeyMap.set(sessionId, decrypted); + } + } + + // now import the backup into a new OlmMachine + const progressListener = jest.fn(); + const m2 = await machine(); + await m2.saveBackupDecryptionKey(keyBackupKey, "1"); + const result = await m2.importBackedUpRoomKeys(decryptedKeyMap, progressListener); + expect(result.importedCount).toStrictEqual(1); + expect(result.totalCount).toStrictEqual(1); + expect(result.keys()).toMatchObject( + new Map([[room.toString(), new Map([[expect.any(String), new Set([expect.any(String)])]])]]), + ); + + expect(progressListener).toHaveBeenCalledTimes(1); + expect(progressListener).toHaveBeenCalledWith(0, 1); + }); }); });