Skip to content

Commit

Permalink
fix: stream sync groups
Browse files Browse the repository at this point in the history
  • Loading branch information
tuddman committed Aug 2, 2024
1 parent ae95fb3 commit 8efba60
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 7 additions & 8 deletions xmtp_mls/src/groups/message_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ where
Ok(())
}

pub async fn send_history_request(&self) -> Result<String, GroupError> {
// returns (sync_group_id, pin_code)
pub async fn send_history_request(&self) -> Result<(Vec<u8>, String), GroupError> {
// find the sync group
let conn = self.store().conn()?;
let sync_group_id = conn
Expand Down Expand Up @@ -134,9 +135,7 @@ where
log::error!("error publishing sync group intents: {:?}", err);
}

// TODO: set up stream here?? for the history sync requester

Ok(pin_code)
Ok((sync_group.group_id, pin_code))
}

pub(crate) async fn send_history_reply(
Expand Down Expand Up @@ -650,7 +649,7 @@ mod tests {
assert_ok!(client.allow_history_sync().await);

// test that the request is sent, and that the pin code is returned
let pin_code = client
let (_group_id, pin_code) = client
.send_history_request()
.await
.expect("history request");
Expand Down Expand Up @@ -682,7 +681,7 @@ mod tests {

amal_a.sync_welcomes().await.expect("sync_welcomes");

let _sent = amal_b
let (_group_id, _pin_code) = amal_b
.send_history_request()
.await
.expect("history request");
Expand Down Expand Up @@ -715,7 +714,7 @@ mod tests {

amal_a.sync_welcomes().await.expect("sync_welcomes");

let pin_code = amal_b
let (_, pin_code) = amal_b
.send_history_request()
.await
.expect("history request");
Expand Down Expand Up @@ -792,7 +791,7 @@ mod tests {
amal_a.sync_welcomes().await.expect("sync_welcomes");

// amal_b sends a message history request to sync group messages
let _pin_code = amal_b
let (_group_id, _pin_code) = amal_b
.send_history_request()
.await
.expect("history request");
Expand Down
42 changes: 32 additions & 10 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ where
.map_err(|e| ClientError::Generic(e.to_string()))?;

let welcome = self.process_streamed_welcome(envelope).await?;

Ok(welcome)
}

// really, stream *groups*
pub async fn stream_conversations(
&self,
) -> Result<Pin<Box<dyn Stream<Item = MlsGroup> + Send + '_>>, ClientError> {
Expand Down Expand Up @@ -173,6 +175,12 @@ where
Ok(Box::pin(futures::stream::select(stream, event_queue)))
}

pub async fn stream_sync_groups(
&self,
) -> Result<Pin<Box<dyn Stream<Item = MlsGroup> + Send + '_>>, ClientError> {
Self::stream_conversations(self).await
}

#[tracing::instrument(skip(self, group_id_to_info))]
pub(crate) async fn stream_messages(
self: Arc<Self>,
Expand Down Expand Up @@ -278,26 +286,40 @@ where

pub async fn stream_all_messages(
client: Arc<Client<ApiClient>>,
is_for_sync_groups: bool,
) -> Result<impl Stream<Item = StoredGroupMessage>, ClientError> {
let (tx, rx) = mpsc::unbounded_channel();

client.sync_welcomes().await?;

let mut group_id_to_info = client
.store()
.conn()?
.find_groups(None, None, None, None)?
.into_iter()
.map(Into::into)
.collect::<HashMap<Vec<u8>, MessagesStreamInfo>>();
let mut group_id_to_info;

if !is_for_sync_groups {
// Gather all regular conversational groups
group_id_to_info = client
.store()
.conn()?
.find_groups(None, None, None, None)?
.into_iter()
.map(Into::into)
.collect::<HashMap<Vec<u8>, MessagesStreamInfo>>();
} else {
// Gather the sync groups
group_id_to_info = client
.store()
.conn()?
.find_sync_groups()?
.into_iter()
.map(Into::into)
.collect::<HashMap<Vec<u8>, MessagesStreamInfo>>();
}

tokio::spawn(async move {
let client = client.clone();
let mut convo_stream = Self::stream_conversations(&client).await?;
let mut messages_stream = client
.clone()
.stream_messages(group_id_to_info.clone())
.await?;
let mut convo_stream = Self::stream_conversations(&client).await?;
let mut extra_messages = Vec::new();

loop {
Expand Down Expand Up @@ -362,7 +384,7 @@ where
let (tx, rx) = oneshot::channel();

let handle = tokio::spawn(async move {
let mut stream = Self::stream_all_messages(client).await?;
let mut stream = Self::stream_all_messages(client, false).await?;
let _ = tx.send(());
while let Some(message) = stream.next().await {
callback(message)
Expand Down

0 comments on commit 8efba60

Please sign in to comment.