Skip to content

Commit

Permalink
lock it
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Dec 13, 2024
1 parent 85675b5 commit 6fd93fe
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
39 changes: 23 additions & 16 deletions xmtp_mls/src/groups/mls_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
groups::device_sync::DeviceSyncContent,
groups::{
device_sync::preference_sync::UserPreferenceUpdate, intents::UpdateMetadataIntentData,
validated_commit::ValidatedCommit,
serial::OpenMlsLock, validated_commit::ValidatedCommit,
},
hpke::{encrypt_welcome, HpkeError},
identity::{parse_credential, IdentityError},
Expand Down Expand Up @@ -358,7 +358,7 @@ where
async fn process_own_message(
&self,
intent: StoredGroupIntent,
openmls_group: &mut OpenMlsGroup,
locked_openmls_group: &mut OpenMlsGroup,
provider: &XmtpOpenMlsProvider,
message: ProtocolMessage,
envelope: &GroupMessageV1,
Expand All @@ -368,17 +368,18 @@ where
id: ref msg_id,
..
} = *envelope;
let mut locked_openmls_group = locked_openmls_group.lock();

if intent.state == IntentState::Committed {
return Ok(IntentState::Committed);
}
let message_epoch = message.epoch();
let group_epoch = openmls_group.epoch();
let group_epoch = locked_openmls_group.epoch();
debug!(
inbox_id = self.client.inbox_id(),
installation_id = %self.client.installation_id(),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_id,
intent.id,
intent.kind = %intent.kind,
Expand Down Expand Up @@ -407,7 +408,7 @@ where
inbox_id = self.client.inbox_id(),
installation_id = %self.client.installation_id(),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_id,
intent.id,
intent.kind = %intent.kind,
Expand Down Expand Up @@ -436,7 +437,7 @@ where
self.client.as_ref(),
conn,
&pending_commit,
openmls_group,
&locked_openmls_group,
)
.await;

Expand All @@ -458,7 +459,9 @@ where
self.context().inbox_id(),
intent.id
);
if let Err(err) = openmls_group.merge_staged_commit(&provider, pending_commit) {
if let Err(err) =
locked_openmls_group.merge_staged_commit(&provider, pending_commit)
{
tracing::error!("error merging commit: {}", err);
return Ok(IntentState::ToPublish);
} else {
Expand Down Expand Up @@ -498,18 +501,22 @@ where
id: ref msg_id,
..
} = *envelope;
let mut locked_openmls_group = openmls_group.lock();

let decrypted_message = openmls_group.process_message(provider, message)?;
let (sender_inbox_id, sender_installation_id) =
extract_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?;
let decrypted_message = locked_openmls_group.process_message(provider, message)?;
let (sender_inbox_id, sender_installation_id) = extract_message_sender(
&mut *locked_openmls_group,
&decrypted_message,
envelope_timestamp_ns,
)?;

tracing::info!(
inbox_id = self.client.inbox_id(),
installation_id = %self.client.installation_id(),
sender_inbox_id = sender_inbox_id,
sender_installation_id = hex::encode(&sender_installation_id),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch = decrypted_message.epoch().as_u64(),
msg_group_id = hex::encode(decrypted_message.group_id().as_slice()),
msg_id,
Expand All @@ -530,7 +537,7 @@ where
sender_installation_id = hex::encode(&sender_installation_id),
installation_id = %self.client.installation_id(),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch,
msg_group_id,
msg_id,
Expand Down Expand Up @@ -658,7 +665,7 @@ where
installation_id = %self.client.installation_id(),
sender_installation_id = hex::encode(&sender_installation_id),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch,
msg_group_id,
msg_id,
Expand All @@ -673,7 +680,7 @@ where
self.client.as_ref(),
provider.conn_ref(),
&sc,
openmls_group,
&locked_openmls_group,
)
.await?;
tracing::info!(
Expand All @@ -682,14 +689,14 @@ where
installation_id = %self.client.installation_id(),
sender_installation_id = hex::encode(&sender_installation_id),
group_id = hex::encode(&self.group_id),
current_epoch = openmls_group.epoch().as_u64(),
current_epoch = locked_openmls_group.epoch().as_u64(),
msg_epoch,
msg_group_id,
msg_id,
"[{}] staged commit is valid, will attempt to merge",
self.context().inbox_id()
);
openmls_group.merge_staged_commit(provider, sc)?;
locked_openmls_group.merge_staged_commit(provider, sc)?;
self.save_transcript_message(
provider.conn_ref(),
validated_commit,
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/groups/serial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ impl OpenMlsLock for OpenMlsGroup {
.or_default()
.clone();

// this may block
let lock = mutex.lock();
let lock = unsafe {
// let the borrow checker know that this guard's mutex is going to be owned by the struct it's returning
std::mem::transmute::<MutexGuard<'_, ()>, MutexGuard<'a, ()>>(lock)
};

// this may block
SerialOpenMlsGroup {
group: self,
lock,
Expand Down

0 comments on commit 6fd93fe

Please sign in to comment.