Skip to content

Commit

Permalink
refactor(crypto): Use the simplified locks across the crypto crate
Browse files Browse the repository at this point in the history
  • Loading branch information
poljar committed Jan 8, 2025
1 parent 46dc2a9 commit 62567ca
Show file tree
Hide file tree
Showing 16 changed files with 213 additions and 258 deletions.
7 changes: 4 additions & 3 deletions crates/matrix-sdk-crypto/src/backups/keys/backup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::{Arc, Mutex};
use std::sync::Arc;

use matrix_sdk_common::locks::Mutex;
use ruma::{
api::client::backup::{EncryptedSessionDataInit, KeyBackupData, KeyBackupDataInit},
serde::Base64,
Expand Down Expand Up @@ -87,15 +88,15 @@ impl MegolmV1BackupKey {

/// Get the backup version that this key is used with, if any.
pub fn backup_version(&self) -> Option<String> {
self.inner.version.lock().unwrap().clone()
self.inner.version.lock().clone()
}

/// Set the backup version that this `MegolmV1BackupKey` will be used with.
///
/// The key won't be able to encrypt room keys unless a version has been
/// set.
pub fn set_version(&self, version: String) {
*self.inner.version.lock().unwrap() = Some(version);
*self.inner.version.lock() = Some(version);
}

/// Export the given inbound group session, and encrypt the data, ready for
Expand Down
53 changes: 25 additions & 28 deletions crates/matrix-sdk-crypto/src/gossiping/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ use std::{
mem,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock as StdRwLock,
Arc,
},
};

use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{
api::client::keys::claim_keys::v3::Request as KeysClaimRequest,
events::secret::request::{
Expand Down Expand Up @@ -168,14 +169,13 @@ impl GossipMachine {
) -> Result<Vec<OutgoingRequest>, CryptoStoreError> {
let mut key_requests = self.load_outgoing_requests().await?;
let key_forwards: Vec<OutgoingRequest> =
self.inner.outgoing_requests.read().unwrap().values().cloned().collect();
self.inner.outgoing_requests.read().values().cloned().collect();
key_requests.extend(key_forwards);

let users_for_key_claim: BTreeMap<_, _> = self
.inner
.users_for_key_claim
.read()
.unwrap()
.iter()
.map(|(key, value)| {
let device_map = value
Expand Down Expand Up @@ -213,7 +213,7 @@ impl GossipMachine {
trace!("Received a secret request event from ourselves, ignoring")
} else {
let request_info = event.to_request_info();
self.inner.incoming_key_requests.write().unwrap().insert(request_info, event);
self.inner.incoming_key_requests.write().insert(request_info, event);
}
}

Expand All @@ -229,8 +229,7 @@ impl GossipMachine {
) -> OlmResult<Vec<Session>> {
let mut changed_sessions = Vec::new();

let incoming_key_requests =
mem::take(&mut *self.inner.incoming_key_requests.write().unwrap());
let incoming_key_requests = mem::take(&mut *self.inner.incoming_key_requests.write());

for event in incoming_key_requests.values() {
if let Some(s) = match event {
Expand All @@ -254,7 +253,6 @@ impl GossipMachine {
self.inner
.users_for_key_claim
.write()
.unwrap()
.entry(device.user_id().to_owned())
.or_default()
.insert(device.device_id().into());
Expand All @@ -275,7 +273,7 @@ impl GossipMachine {
/// * `device_id` - The device ID of the device that got the Olm session.
pub fn retry_keyshare(&self, user_id: &UserId, device_id: &DeviceId) {
if let Entry::Occupied(mut e) =
self.inner.users_for_key_claim.write().unwrap().entry(user_id.to_owned())
self.inner.users_for_key_claim.write().entry(user_id.to_owned())
{
e.get_mut().remove(device_id);

Expand All @@ -284,7 +282,7 @@ impl GossipMachine {
}
}

let mut incoming_key_requests = self.inner.incoming_key_requests.write().unwrap();
let mut incoming_key_requests = self.inner.incoming_key_requests.write();
for (key, event) in self.inner.wait_queue.remove(user_id, device_id) {
incoming_key_requests.entry(key).or_insert(event);
}
Expand Down Expand Up @@ -555,7 +553,7 @@ impl GossipMachine {
request_id: request.txn_id.clone(),
request: Arc::new(request.into()),
};
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);

Ok(used_session)
}
Expand All @@ -581,7 +579,7 @@ impl GossipMachine {
request_id: request.txn_id.clone(),
request: Arc::new(request.into()),
};
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);

Ok(used_session)
}
Expand Down Expand Up @@ -824,7 +822,7 @@ impl GossipMachine {
self.save_outgoing_key_info(info).await?;
}

self.inner.outgoing_requests.write().unwrap().remove(id);
self.inner.outgoing_requests.write().remove(id);

Ok(())
}
Expand All @@ -840,13 +838,13 @@ impl GossipMachine {
"Successfully received a secret, removing the request"
);

self.inner.outgoing_requests.write().unwrap().remove(&key_info.request_id);
self.inner.outgoing_requests.write().remove(&key_info.request_id);
// TODO return the key info instead of deleting it so the sync handler
// can delete it in one transaction.
self.delete_key_info(key_info).await?;

let request = key_info.to_cancellation(self.device_id());
self.inner.outgoing_requests.write().unwrap().insert(request.request_id.clone(), request);
self.inner.outgoing_requests.write().insert(request.request_id.clone(), request);

Ok(())
}
Expand Down Expand Up @@ -1511,7 +1509,6 @@ mod tests {
.inner
.outgoing_requests
.read()
.unwrap()
.first_key_value()
.map(|(_, r)| r.request_id.clone())
.unwrap();
Expand Down Expand Up @@ -1692,7 +1689,7 @@ mod tests {
alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap();

// Bob doesn't have any outgoing requests.
assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(bob_machine.inner.outgoing_requests.read().is_empty());

// Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event);
Expand All @@ -1702,7 +1699,7 @@ mod tests {
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Now bob does have an outgoing request.
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(!bob_machine.inner.outgoing_requests.read().is_empty());

// Get the request and convert it to a encrypted to-device event.
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
Expand Down Expand Up @@ -1774,7 +1771,7 @@ mod tests {
alice_machine.mark_outgoing_request_as_sent(&request.request_id).await.unwrap();

// Bob doesn't have any outgoing requests.
assert!(bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(bob_machine.inner.outgoing_requests.read().is_empty());

// Receive the room key request from alice.
bob_machine.receive_incoming_key_request(&event);
Expand All @@ -1783,7 +1780,7 @@ mod tests {
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
}
// Now bob does have an outgoing request.
assert!(!bob_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(!bob_machine.inner.outgoing_requests.read().is_empty());

// Get the request and convert it to a encrypted to-device event.
let requests = bob_machine.outgoing_to_device_requests().await.unwrap();
Expand Down Expand Up @@ -1875,13 +1872,13 @@ mod tests {
};

// No secret found
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());
alice_machine.receive_incoming_secret_request(&event);
{
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());

// No device found
alice_machine.inner.store.reset_cross_signing_identity().await;
Expand All @@ -1890,7 +1887,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());

alice_machine.inner.store.save_device_data(&[bob_device]).await.unwrap();

Expand All @@ -1901,7 +1898,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());

let event = RumaToDeviceEvent {
sender: alice_id().to_owned(),
Expand All @@ -1918,7 +1915,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(alice_machine.inner.outgoing_requests.read().is_empty());

// We need a trusted device, otherwise we won't serve secrets
alice_device.set_trust_state(LocalTrust::Verified);
Expand All @@ -1929,7 +1926,7 @@ mod tests {
let alice_cache = alice_machine.inner.store.cache().await.unwrap();
alice_machine.collect_incoming_key_requests(&alice_cache).await.unwrap();
}
assert!(!alice_machine.inner.outgoing_requests.read().unwrap().is_empty());
assert!(!alice_machine.inner.outgoing_requests.read().is_empty());
}

#[async_test]
Expand Down Expand Up @@ -2053,7 +2050,7 @@ mod tests {

// Bob doesn't have any outgoing requests.
assert!(bob_machine.outgoing_to_device_requests().await.unwrap().is_empty());
assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
assert!(bob_machine.inner.users_for_key_claim.read().is_empty());
assert!(bob_machine.inner.wait_queue.is_empty());

// Receive the room key request from alice.
Expand All @@ -2068,7 +2065,7 @@ mod tests {
bob_machine.outgoing_to_device_requests().await.unwrap()[0].request(),
AnyOutgoingRequest::KeysClaim(_)
);
assert!(!bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
assert!(!bob_machine.inner.users_for_key_claim.read().is_empty());
assert!(!bob_machine.inner.wait_queue.is_empty());

let (alice_session, bob_session) = alice_machine
Expand Down Expand Up @@ -2096,7 +2093,7 @@ mod tests {
bob_machine.inner.store.save_sessions(&[bob_session]).await.unwrap();

bob_machine.retry_keyshare(alice_id(), alice_device_id());
assert!(bob_machine.inner.users_for_key_claim.read().unwrap().is_empty());
assert!(bob_machine.inner.users_for_key_claim.read().is_empty());
{
let bob_cache = bob_machine.inner.store.cache().await.unwrap();
bob_machine.collect_incoming_key_requests(&bob_cache).await.unwrap();
Expand Down
10 changes: 6 additions & 4 deletions crates/matrix-sdk-crypto/src/gossiping/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ mod machine;

use std::{
collections::{BTreeMap, BTreeSet},
sync::{Arc, RwLock as StdRwLock},
sync::Arc,
};

pub(crate) use machine::GossipMachine;
use matrix_sdk_common::locks::RwLock as StdRwLock;
use ruma::{
events::{
room_key_request::{Action, ToDeviceRoomKeyRequestEventContent},
Expand Down Expand Up @@ -323,7 +324,7 @@ impl WaitQueue {

#[cfg(all(test, feature = "automatic-room-key-forwarding"))]
fn is_empty(&self) -> bool {
let read_guard = self.inner.read().unwrap();
let read_guard = self.inner.read();
read_guard.requests_ids_waiting.is_empty()
&& read_guard.requests_waiting_for_session.is_empty()
}
Expand All @@ -337,13 +338,14 @@ impl WaitQueue {
);
let ids_waiting_key = (device.user_id().to_owned(), device.device_id().into());

let mut write_guard = self.inner.write().unwrap();
let mut write_guard = self.inner.write();
write_guard.requests_waiting_for_session.insert(requests_waiting_key, event);
write_guard.requests_ids_waiting.entry(ids_waiting_key).or_default().insert(request_id);
}

fn remove(&self, user_id: &UserId, device_id: &DeviceId) -> Vec<(RequestInfo, RequestEvent)> {
let mut write_guard = self.inner.write().unwrap();
let mut write_guard = self.inner.write();

write_guard
.requests_ids_waiting
.remove(&(user_id.to_owned(), device_id.into()))
Expand Down
10 changes: 5 additions & 5 deletions crates/matrix-sdk-crypto/src/identities/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ use std::{
ops::Deref,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
Arc,
},
};

use matrix_sdk_common::deserialized_responses::WithheldCode;
use matrix_sdk_common::{deserialized_responses::WithheldCode, locks::RwLock};
use ruma::{
api::client::keys::upload_signatures::v3::Request as SignatureUploadRequest,
events::{key::verification::VerificationMethod, AnyToDeviceEventContent},
Expand Down Expand Up @@ -470,7 +470,7 @@ impl Device {
) -> OlmResult<Raw<ToDeviceEncryptedEventContent>> {
let (used_session, raw_encrypted) = self.encrypt(event_type, content).await?;

// perist the used session
// Persist the used session
self.verification_machine
.store
.save_changes(Changes { sessions: vec![used_session], ..Default::default() })
Expand Down Expand Up @@ -626,7 +626,7 @@ impl DeviceData {

/// Get the trust state of the device.
pub fn local_trust_state(&self) -> LocalTrust {
*self.trust_state.read().unwrap()
*self.trust_state.read()
}

/// Is the device locally marked as trusted.
Expand All @@ -646,7 +646,7 @@ impl DeviceData {
/// Note: This should only done in the crypto store where the trust state
/// can be stored.
pub(crate) fn set_trust_state(&self, state: LocalTrust) {
*self.trust_state.write().unwrap() = state;
*self.trust_state.write() = state;
}

pub(crate) fn mark_withheld_code_as_sent(&self) {
Expand Down
Loading

0 comments on commit 62567ca

Please sign in to comment.