diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index 085b81245..05b905a5d 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -233,20 +233,28 @@ where tokio::spawn(async move { let client = client.clone(); - let mut handle = Self::relay_messages(client.clone(), tx.clone(), group_id_to_info.clone()); + let mut messages_stream = client.clone().stream_messages(group_id_to_info.clone()).await?; let mut convo_stream = Self::stream_conversations(&client).await?; loop { - // TODO:insipx We should more closely investigate whether - // the stream mapping in `stream_conversations` is cancellation safe - // otherwise it could lead to hard-to-find bugs tokio::select! { + // biased enforces an order to select!. If a message and a group are both ready + // at the same time, `biased` mode will process the message before the new + // group. + biased; + + Some(message) = messages_stream.next() => { + // an error can only mean the receiver has been dropped or closed so we're + // safe to end the stream + if tx.send(message).is_err() { + break; + } + } Some(new_group) = convo_stream.next() => { if group_id_to_info.contains_key(&new_group.group_id) { continue; } - //TODO: Should we await the handle to ensure it finishes? - handle.abort(); + // messages_stream.cancel(); for info in group_id_to_info.values_mut() { info.cursor = 0; } @@ -257,21 +265,8 @@ where cursor: 1, // For the new group, stream all messages since the group was created }, ); - // TODO:insipx Can remove the indiretion in `relay_messages` and just use - // `stream_messages` directly? - handle = Self::relay_messages(client.clone(), tx.clone(), group_id_to_info.clone()); + messages_stream = client.clone().stream_messages(group_id_to_info.clone()).await?; }, - maybe_finished = &mut handle => { - match maybe_finished { - // if all is well it means the stream closed (receiver is dropped or ended) - Ok(_) => break, - Err(e) => { - // if we have an error, try to restart the stream. - log::error!("{}", e.to_string()); - handle = Self::relay_messages(client.clone(), tx.clone(), group_id_to_info.clone()); - } - } - } } } Ok::<_, ClientError>(()) @@ -301,23 +296,6 @@ where let _ = tokio::task::block_in_place(|| rx.blocking_recv()); handle } - - fn relay_messages( - client: Arc>, - tx: UnboundedSender, - group_id_to_info: HashMap, MessagesStreamInfo>, - ) -> JoinHandle> { - tokio::spawn(async move { - let mut stream = client.stream_messages(group_id_to_info).await?; - while let Some(message) = stream.next().await { - // an error can only mean the receiver has been dropped or closed - if tx.send(message).is_err() { - break; - } - } - Ok::<_, ClientError>(()) - }) - } } #[cfg(test)]