Skip to content

Commit

Permalink
update the stream to correctly terminate the underlying stream when c…
Browse files Browse the repository at this point in the history
…anceled
  • Loading branch information
nplasterer committed Sep 2, 2024
1 parent 3fe600e commit 687e31d
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 19 deletions.
52 changes: 36 additions & 16 deletions Sources/XMTPiOS/Conversations.swift
Original file line number Diff line number Diff line change
Expand Up @@ -426,20 +426,26 @@ public actor Conversations {

public func streamAllGroupDecryptedMessages() -> AsyncThrowingStream<DecryptedMessage, Error> {
AsyncThrowingStream { continuation in
Task {
do {
self.streamHolder.stream = try await self.client.v3Client?.conversations().streamAllMessages(
messageCallback: MessageCallback(client: self.client) { message in
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt())
} catch {
print("Error onMessage \(error)")
}
let task = Task {
self.streamHolder.stream = await self.client.v3Client?.conversations().streamAllMessages(
messageCallback: MessageCallback(client: self.client) { message in
guard !Task.isCancelled else {
continuation.finish()
self.streamHolder.stream?.end() // End the stream upon cancellation
return
}
)
} catch {
print("STREAM ERR: \(error)")
}
do {
continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt())
} catch {
print("Error onMessage \(error)")
}
}
)
}

continuation.onTermination = { _ in
task.cancel() // Ensure the task is cancelled if the stream ends
self.streamHolder.stream?.end() // Ensure the stream is ended if the task is cancelled
}
}
}
Expand All @@ -450,23 +456,37 @@ public actor Conversations {
do {
var iterator = stream.makeAsyncIterator()
while let element = try await iterator.next() {
guard !Task.isCancelled else {
continuation.finish()
self.streamHolder.stream?.end() // End the stream upon cancellation
return
}
continuation.yield(element)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
Task {
await forwardStreamToMerged(stream: try streamAllV2DecryptedMessages())

let task = Task {
await forwardStreamToMerged(stream: streamAllV2DecryptedMessages())
}
if (includeGroups) {

if includeGroups {
Task {
await forwardStreamToMerged(stream: streamAllGroupDecryptedMessages())
}
}

continuation.onTermination = { _ in
task.cancel() // Ensure the task is cancelled if the stream ends
self.streamHolder.stream?.end() // Ensure the stream is ended if the task is cancelled
}
}
}




func streamAllV2DecryptedMessages() -> AsyncThrowingStream<DecryptedMessage, Error> {
Expand Down
55 changes: 52 additions & 3 deletions Tests/XMTPTests/GroupTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -550,19 +550,19 @@ class GroupTests: XCTestCase {
func testCanStreamGroupsAndConversationsWorksGroups() async throws {
let fixtures = try await localFixtures()

let expectation1 = expectation(description: "got a conversation")
let expectation1 = XCTestExpectation(description: "got a conversation")
expectation1.expectedFulfillmentCount = 2

Task(priority: .userInitiated) {
for try await _ in try await fixtures.aliceClient.conversations.streamAll() {
for try await _ in await fixtures.aliceClient.conversations.streamAll() {
expectation1.fulfill()
}
}

_ = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address])
_ = try await fixtures.bobClient.conversations.newConversation(with: fixtures.alice.address)

await waitForExpectations(timeout: 3)
await fulfillment(of: [expectation1], timeout: 3)
}

func testStreamGroupsAndAllMessages() async throws {
Expand Down Expand Up @@ -912,4 +912,53 @@ class GroupTests: XCTestCase {
}
}
}

func testCanStreamAllDecryptedMessagesAndCancelStream() async throws {
let fixtures = try await localFixtures()

var messages = 0
let messagesQueue = DispatchQueue(label: "messages.queue") // Serial queue to synchronize access to `messages`

let convo = try await fixtures.bobClient.conversations.newConversation(with: fixtures.alice.address)
let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address])
try await fixtures.aliceClient.conversations.sync()

// Create a Task to handle streaming
let streamingTask = Task(priority: .userInitiated) {
for try await _ in try await fixtures.aliceClient.conversations.streamAllDecryptedMessages(includeGroups: true) {
// Update the messages count in a thread-safe manner
messagesQueue.sync {
messages += 1
}
}
}

// Send messages to trigger the stream
_ = try await group.send(content: "hi")
_ = try await convo.send(content: "hi")

// Allow some time for messages to be processed (adjust this according to your needs)
try await Task.sleep(nanoseconds: 1_000_000_000)

// Cancel the streaming task
streamingTask.cancel()

// Check the message count
messagesQueue.sync {
XCTAssertEqual(messages, 2)
}

try await Task.sleep(nanoseconds: 1_000_000_000)

_ = try await group.send(content: "hi")
_ = try await group.send(content: "hi")
_ = try await group.send(content: "hi")
_ = try await convo.send(content: "hi")

try await Task.sleep(nanoseconds: 1_000_000_000)

messagesQueue.sync {
XCTAssertEqual(messages, 2)
}
}
}

0 comments on commit 687e31d

Please sign in to comment.