From bcce4f31a03e601a87a155c7b20f162ef4ccfd54 Mon Sep 17 00:00:00 2001 From: Naomi Plasterer Date: Mon, 2 Sep 2024 20:12:14 -0600 Subject: [PATCH] Fix stream cancelation cancels underlying stream (#395) * update the stream to correctly terminate the underlying stream when canceled * fix up a bunch of test warnings * update group streaming as well * bump the pod * make sure all streams cancel correctly --- Sources/XMTPiOS/Conversations.swift | 130 +++++++++++++++++------ Sources/XMTPiOS/Group.swift | 68 +++++++----- Tests/XMTPTests/ConversationTests.swift | 18 ++-- Tests/XMTPTests/GroupTests.swift | 131 +++++++++++++++++------- XMTP.podspec | 2 +- 5 files changed, 246 insertions(+), 103 deletions(-) diff --git a/Sources/XMTPiOS/Conversations.swift b/Sources/XMTPiOS/Conversations.swift index 4cb0d1ea..f35637c7 100644 --- a/Sources/XMTPiOS/Conversations.swift +++ b/Sources/XMTPiOS/Conversations.swift @@ -130,29 +130,52 @@ public actor Conversations { public func streamGroups() async throws -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task { + let task = Task { let groupCallback = GroupStreamCallback(client: self.client) { group in + guard !Task.isCancelled else { + continuation.finish() + return + } continuation.yield(group) } - guard let stream = try await self.client.v3Client?.conversations().stream(callback: groupCallback) else { + guard let stream = await self.client.v3Client?.conversations().stream(callback: groupCallback) else { continuation.finish(throwing: GroupError.streamingFailure) return } + + self.streamHolder.stream = stream continuation.onTermination = { @Sendable reason in stream.end() } } + + continuation.onTermination = { @Sendable reason in + task.cancel() + self.streamHolder.stream?.end() + } } } private func streamGroupConversations() -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task { - self.streamHolder.stream = try await self.client.v3Client?.conversations().stream( + let task = Task { + self.streamHolder.stream = await self.client.v3Client?.conversations().stream( callback: GroupStreamCallback(client: self.client) { group in + guard !Task.isCancelled else { + continuation.finish() + return + } continuation.yield(Conversation.group(group)) } ) + continuation.onTermination = { @Sendable reason in + self.streamHolder.stream?.end() + } + } + + continuation.onTermination = { @Sendable reason in + task.cancel() + self.streamHolder.stream?.end() } } } @@ -383,29 +406,41 @@ public actor Conversations { public func streamAllGroupMessages() -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task { - let messageCallback = MessageCallback(client: self.client) { message in - if let decodedMessage = MessageV3(client: self.client, ffiMessage: message).decodeOrNull() { - continuation.yield(decodedMessage) + let task = Task { + self.streamHolder.stream = await self.client.v3Client?.conversations().streamAllMessages( + messageCallback: MessageCallback(client: self.client) { message in + guard !Task.isCancelled else { + continuation.finish() + self.streamHolder.stream?.end() // End the stream upon cancellation + return + } + do { + continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decode()) + } catch { + print("Error onMessage \(error)") + } } - } - guard let stream = try await client.v3Client?.conversations().streamAllMessages(messageCallback: messageCallback) else { - continuation.finish(throwing: GroupError.streamingFailure) - return - } - continuation.onTermination = { @Sendable reason in - stream.end() - } + ) + } + + continuation.onTermination = { _ in + task.cancel() + self.streamHolder.stream?.end() } } } - public func streamAllMessages(includeGroups: Bool = false) async throws -> AsyncThrowingStream { + public func streamAllMessages(includeGroups: Bool = false) -> AsyncThrowingStream { AsyncThrowingStream { continuation in @Sendable func forwardStreamToMerged(stream: AsyncThrowingStream) async { do { var iterator = stream.makeAsyncIterator() while let element = try await iterator.next() { + guard !Task.isCancelled else { + continuation.finish() + self.streamHolder.stream?.end() + return + } continuation.yield(element) } continuation.finish() @@ -413,33 +448,46 @@ public actor Conversations { continuation.finish(throwing: error) } } - Task { + + let task = Task { await forwardStreamToMerged(stream: streamAllV2Messages()) } + if includeGroups { Task { await forwardStreamToMerged(stream: streamAllGroupMessages()) } } + + continuation.onTermination = { _ in + task.cancel() + self.streamHolder.stream?.end() + } } } public func streamAllGroupDecryptedMessages() -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task { - do { - self.streamHolder.stream = try await self.client.v3Client?.conversations().streamAllMessages( - messageCallback: MessageCallback(client: self.client) { message in - do { - continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt()) - } catch { - print("Error onMessage \(error)") - } + let task = Task { + self.streamHolder.stream = await self.client.v3Client?.conversations().streamAllMessages( + messageCallback: MessageCallback(client: self.client) { message in + guard !Task.isCancelled else { + continuation.finish() + self.streamHolder.stream?.end() // End the stream upon cancellation + return } - ) - } catch { - print("STREAM ERR: \(error)") - } + do { + continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt()) + } catch { + print("Error onMessage \(error)") + } + } + ) + } + + continuation.onTermination = { _ in + task.cancel() + self.streamHolder.stream?.end() } } } @@ -450,6 +498,11 @@ public actor Conversations { do { var iterator = stream.makeAsyncIterator() while let element = try await iterator.next() { + guard !Task.isCancelled else { + continuation.finish() + self.streamHolder.stream?.end() + return + } continuation.yield(element) } continuation.finish() @@ -457,16 +510,25 @@ public actor Conversations { continuation.finish(throwing: error) } } - Task { - await forwardStreamToMerged(stream: try streamAllV2DecryptedMessages()) + + let task = Task { + await forwardStreamToMerged(stream: streamAllV2DecryptedMessages()) } - if (includeGroups) { + + if includeGroups { Task { await forwardStreamToMerged(stream: streamAllGroupDecryptedMessages()) } } + + continuation.onTermination = { _ in + task.cancel() + self.streamHolder.stream?.end() + } } } + + func streamAllV2DecryptedMessages() -> AsyncThrowingStream { diff --git a/Sources/XMTPiOS/Group.swift b/Sources/XMTPiOS/Group.swift index 028f4bc6..a255bd3a 100644 --- a/Sources/XMTPiOS/Group.swift +++ b/Sources/XMTPiOS/Group.swift @@ -293,41 +293,61 @@ public struct Group: Identifiable, Equatable, Hashable { public func streamMessages() -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task.detached { - do { - self.streamHolder.stream = try await ffiGroup.stream( - messageCallback: MessageCallback(client: self.client) { message in - do { - continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decode()) - } catch { - print("Error onMessage \(error)") - } + let task = Task.detached { + self.streamHolder.stream = await self.ffiGroup.stream( + messageCallback: MessageCallback(client: self.client) { message in + guard !Task.isCancelled else { + continuation.finish() + return } - ) - } catch { - print("STREAM ERR: \(error)") + do { + continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decode()) + } catch { + print("Error onMessage \(error)") + continuation.finish(throwing: error) + } + } + ) + + continuation.onTermination = { @Sendable reason in + self.streamHolder.stream?.end() } } + + continuation.onTermination = { @Sendable reason in + task.cancel() + self.streamHolder.stream?.end() + } } } public func streamDecryptedMessages() -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task.detached { - do { - self.streamHolder.stream = try await ffiGroup.stream( - messageCallback: MessageCallback(client: self.client) { message in - do { - continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt()) - } catch { - print("Error onMessage \(error)") - } + let task = Task.detached { + self.streamHolder.stream = await self.ffiGroup.stream( + messageCallback: MessageCallback(client: self.client) { message in + guard !Task.isCancelled else { + continuation.finish() + return } - ) - } catch { - print("STREAM ERR: \(error)") + do { + continuation.yield(try MessageV3(client: self.client, ffiMessage: message).decrypt()) + } catch { + print("Error onMessage \(error)") + continuation.finish(throwing: error) + } + } + ) + + continuation.onTermination = { @Sendable reason in + self.streamHolder.stream?.end() } } + + continuation.onTermination = { @Sendable reason in + task.cancel() + self.streamHolder.stream?.end() + } } } diff --git a/Tests/XMTPTests/ConversationTests.swift b/Tests/XMTPTests/ConversationTests.swift index 1dd168a2..04721db8 100644 --- a/Tests/XMTPTests/ConversationTests.swift +++ b/Tests/XMTPTests/ConversationTests.swift @@ -84,16 +84,16 @@ class ConversationTests: XCTestCase { } func testDoesNotAllowConversationWithSelf() async throws { - let expectation = expectation(description: "convo with self throws") + let expectation = XCTestExpectation(description: "convo with self throws") let client = aliceClient! do { - try await client.conversations.newConversation(with: alice.walletAddress) + _ = try await client.conversations.newConversation(with: alice.walletAddress) } catch { expectation.fulfill() } - wait(for: [expectation], timeout: 0.1) + await fulfillment(of: [expectation], timeout: 3) } func testCanStreamConversationsV2() async throws { @@ -103,7 +103,7 @@ class ConversationTests: XCTestCase { let wallet2 = try PrivateKey.generate() let client2 = try await Client.create(account: wallet2, options: options) - let expectation1 = expectation(description: "got a conversation") + let expectation1 = XCTestExpectation(description: "got a conversation") expectation1.expectedFulfillmentCount = 2 Task(priority: .userInitiated) { @@ -140,7 +140,7 @@ class ConversationTests: XCTestCase { try await conversation2.send(content: "hi from new wallet") - await waitForExpectations(timeout: 30) + await fulfillment(of: [expectation1], timeout: 30) } func publishLegacyContact(client: Client) async throws { @@ -161,7 +161,7 @@ class ConversationTests: XCTestCase { return } - let expectation = expectation(description: "got a message") + let expectation = XCTestExpectation(description: "got a message") Task(priority: .userInitiated) { for try await message in conversation.streamMessages() { @@ -174,7 +174,7 @@ class ConversationTests: XCTestCase { // Stream a message try await conversation.send(content: "hi alice") - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation], timeout: 3) } func testCanLoadV2Messages() async throws { @@ -458,7 +458,7 @@ class ConversationTests: XCTestCase { XCTAssertTrue(isAllowed) try await bobClient.contacts.deny(addresses: [alice.address]) - try await bobClient.contacts.refreshConsentList() + _ = try await bobClient.contacts.refreshConsentList() let isDenied = (try await bobConversation.consentState()) == .denied @@ -491,7 +491,7 @@ class ConversationTests: XCTestCase { XCTAssertTrue(isUnknown) try await aliceConversation.send(content: "hey bob") - try await aliceClient.contacts.refreshConsentList() + _ = try await aliceClient.contacts.refreshConsentList() let isNowAllowed = (try await aliceConversation.consentState()) == .allowed // Conversations you send a message to get marked as allowed diff --git a/Tests/XMTPTests/GroupTests.swift b/Tests/XMTPTests/GroupTests.swift index 5ea5ec48..7662a614 100644 --- a/Tests/XMTPTests/GroupTests.swift +++ b/Tests/XMTPTests/GroupTests.swift @@ -469,7 +469,7 @@ class GroupTests: XCTestCase { try await aliceGroup.sync() aliceMessagesCount = try await aliceGroup.messages().count - var aliceMessagesUnpublishedCount = try await aliceGroup.messages(deliveryStatus: .unpublished).count + let aliceMessagesUnpublishedCount = try await aliceGroup.messages(deliveryStatus: .unpublished).count aliceMessagesPublishedCount = try await aliceGroup.messages(deliveryStatus: .published).count XCTAssertEqual(3, aliceMessagesCount) XCTAssertEqual(0, aliceMessagesUnpublishedCount) @@ -516,7 +516,7 @@ class GroupTests: XCTestCase { let fixtures = try await localFixtures() let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) let membershipChange = GroupUpdated() - let expectation1 = expectation(description: "got a message") + let expectation1 = XCTestExpectation(description: "got a message") expectation1.expectedFulfillmentCount = 1 Task(priority: .userInitiated) { @@ -528,13 +528,13 @@ class GroupTests: XCTestCase { _ = try await group.send(content: "hi") _ = try await group.send(content: membershipChange, options: SendOptions(contentType: ContentTypeGroupUpdated)) - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1], timeout: 3) } func testCanStreamGroups() async throws { let fixtures = try await localFixtures() - let expectation1 = expectation(description: "got a group") + let expectation1 = XCTestExpectation(description: "got a group") Task(priority: .userInitiated) { for try await _ in try await fixtures.aliceClient.conversations.streamGroups() { @@ -544,17 +544,17 @@ class GroupTests: XCTestCase { _ = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1], timeout: 3) } func testCanStreamGroupsAndConversationsWorksGroups() async throws { let fixtures = try await localFixtures() - let expectation1 = expectation(description: "got a conversation") + let expectation1 = XCTestExpectation(description: "got a conversation") expectation1.expectedFulfillmentCount = 2 Task(priority: .userInitiated) { - for try await _ in try await fixtures.aliceClient.conversations.streamAll() { + for try await _ in await fixtures.aliceClient.conversations.streamAll() { expectation1.fulfill() } } @@ -562,14 +562,14 @@ class GroupTests: XCTestCase { _ = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) _ = try await fixtures.bobClient.conversations.newConversation(with: fixtures.alice.address) - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1], timeout: 3) } func testStreamGroupsAndAllMessages() async throws { let fixtures = try await localFixtures() - let expectation1 = expectation(description: "got a group") - let expectation2 = expectation(description: "got a message") + let expectation1 = XCTestExpectation(description: "got a group") + let expectation2 = XCTestExpectation(description: "got a message") Task(priority: .userInitiated) { @@ -579,32 +579,32 @@ class GroupTests: XCTestCase { } Task(priority: .userInitiated) { - for try await _ in try await fixtures.aliceClient.conversations.streamAllMessages(includeGroups: true) { + for try await _ in await fixtures.aliceClient.conversations.streamAllMessages(includeGroups: true) { expectation2.fulfill() } } let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) - try await group.send(content: "hello") + _ = try await group.send(content: "hello") - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1, expectation2], timeout: 3) } func testCanStreamAndUpdateNameWithoutForkingGroup() async throws { let fixtures = try await localFixtures() - let expectation = expectation(description: "got a message") + let expectation = XCTestExpectation(description: "got a message") expectation.expectedFulfillmentCount = 5 Task(priority: .userInitiated) { - for try await _ in try await fixtures.bobClient.conversations.streamAllGroupMessages(){ + for try await _ in await fixtures.bobClient.conversations.streamAllGroupMessages(){ expectation.fulfill() } } let alixGroup = try await fixtures.aliceClient.conversations.newGroup(with: [fixtures.bob.address]) try await alixGroup.updateGroupName(groupName: "hello") - try await alixGroup.send(content: "hello1") + _ = try await alixGroup.send(content: "hello1") try await fixtures.bobClient.conversations.sync() @@ -616,8 +616,8 @@ class GroupTests: XCTestCase { let boMessages1 = try await boGroup.messages() XCTAssertEqual(boMessages1.count, 2, "should have 2 messages on first load received \(boMessages1.count)") - try await boGroup.send(content: "hello2") - try await boGroup.send(content: "hello3") + _ = try await boGroup.send(content: "hello2") + _ = try await boGroup.send(content: "hello3") try await alixGroup.sync() let alixMessages = try await alixGroup.messages() @@ -626,7 +626,7 @@ class GroupTests: XCTestCase { } XCTAssertEqual(alixMessages.count, 5, "should have 5 messages on first load received \(alixMessages.count)") - try await alixGroup.send(content: "hello4") + _ = try await alixGroup.send(content: "hello4") try await boGroup.sync() let boMessages2 = try await boGroup.messages() @@ -635,13 +635,13 @@ class GroupTests: XCTestCase { } XCTAssertEqual(boMessages2.count, 5, "should have 5 messages on second load received \(boMessages2.count)") - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation], timeout: 3) } func testCanStreamAllMessages() async throws { let fixtures = try await localFixtures() - let expectation1 = expectation(description: "got a conversation") + let expectation1 = XCTestExpectation(description: "got a conversation") expectation1.expectedFulfillmentCount = 2 let convo = try await fixtures.bobClient.conversations.newConversation(with: fixtures.alice.address) let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) @@ -655,20 +655,20 @@ class GroupTests: XCTestCase { _ = try await group.send(content: "hi") _ = try await convo.send(content: "hi") - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1], timeout: 3) } func testCanStreamAllDecryptedMessages() async throws { let fixtures = try await localFixtures() let membershipChange = GroupUpdated() - let expectation1 = expectation(description: "got a conversation") + let expectation1 = XCTestExpectation(description: "got a conversation") expectation1.expectedFulfillmentCount = 2 let convo = try await fixtures.bobClient.conversations.newConversation(with: fixtures.alice.address) let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) try await fixtures.aliceClient.conversations.sync() Task(priority: .userInitiated) { - for try await _ in try await fixtures.aliceClient.conversations.streamAllDecryptedMessages(includeGroups: true) { + for try await _ in await fixtures.aliceClient.conversations.streamAllDecryptedMessages(includeGroups: true) { expectation1.fulfill() } } @@ -677,42 +677,42 @@ class GroupTests: XCTestCase { _ = try await group.send(content: membershipChange, options: SendOptions(contentType: ContentTypeGroupUpdated)) _ = try await convo.send(content: "hi") - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1], timeout: 3) } func testCanStreamAllGroupMessages() async throws { let fixtures = try await localFixtures() - let expectation1 = expectation(description: "got a conversation") + let expectation1 = XCTestExpectation(description: "got a conversation") let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) try await fixtures.aliceClient.conversations.sync() Task(priority: .userInitiated) { - for try await _ in try await fixtures.aliceClient.conversations.streamAllGroupMessages() { + for try await _ in await fixtures.aliceClient.conversations.streamAllGroupMessages() { expectation1.fulfill() } } _ = try await group.send(content: "hi") - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1], timeout: 3) } func testCanStreamAllGroupDecryptedMessages() async throws { let fixtures = try await localFixtures() - let expectation1 = expectation(description: "got a conversation") + let expectation1 = XCTestExpectation(description: "got a conversation") let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) try await fixtures.aliceClient.conversations.sync() Task(priority: .userInitiated) { - for try await _ in try await fixtures.aliceClient.conversations.streamAllGroupDecryptedMessages() { + for try await _ in await fixtures.aliceClient.conversations.streamAllGroupDecryptedMessages() { expectation1.fulfill() } } _ = try await group.send(content: "hi") - await waitForExpectations(timeout: 3) + await fulfillment(of: [expectation1], timeout: 3) } func testCanUpdateGroupMetadata() async throws { @@ -800,7 +800,7 @@ class GroupTests: XCTestCase { try await fixtures.aliceClient.conversations.sync() let alixGroup = try fixtures.aliceClient.findGroup(groupId: boGroup.id) try await alixGroup?.sync() - let alixMessage = try fixtures.aliceClient.findMessage(messageId: boMessageId) + _ = try fixtures.aliceClient.findMessage(messageId: boMessageId) XCTAssertEqual(alixGroup?.id, boGroup.id) } @@ -843,12 +843,12 @@ class GroupTests: XCTestCase { var groups: [Group] = [] for _ in 0..<100 { - var group = try await fixtures.aliceClient.conversations.newGroup(with: [fixtures.bob.address]) + let group = try await fixtures.aliceClient.conversations.newGroup(with: [fixtures.bob.address]) groups.append(group) } try await fixtures.bobClient.conversations.sync() let bobGroup = try fixtures.bobClient.findGroup(groupId: groups[0].id) - try await groups[0].send(content: "hi") + _ = try await groups[0].send(content: "hi") let messageCount = try await bobGroup!.messages().count XCTAssertEqual(messageCount, 0) do { @@ -888,7 +888,7 @@ class GroupTests: XCTestCase { var groups: [Group] = [] for _ in 0..<100 { - var group = try await fixtures.aliceClient.conversations.newGroup(with: [fixtures.bob.address]) + let group = try await fixtures.aliceClient.conversations.newGroup(with: [fixtures.bob.address]) groups.append(group) } do { @@ -912,4 +912,65 @@ class GroupTests: XCTestCase { } } } + + func testCanStreamAllDecryptedMessagesAndCancelStream() async throws { + let fixtures = try await localFixtures() + + var messages = 0 + let messagesQueue = DispatchQueue(label: "messages.queue") // Serial queue to synchronize access to `messages` + + let convo = try await fixtures.bobClient.conversations.newConversation(with: fixtures.alice.address) + let group = try await fixtures.bobClient.conversations.newGroup(with: [fixtures.alice.address]) + try await fixtures.aliceClient.conversations.sync() + + let streamingTask = Task(priority: .userInitiated) { + for try await _ in await fixtures.aliceClient.conversations.streamAllDecryptedMessages(includeGroups: true) { + messagesQueue.sync { + messages += 1 + } + } + } + + _ = try await group.send(content: "hi") + _ = try await convo.send(content: "hi") + + try await Task.sleep(nanoseconds: 1_000_000_000) + + streamingTask.cancel() + + messagesQueue.sync { + XCTAssertEqual(messages, 2) + } + + try await Task.sleep(nanoseconds: 1_000_000_000) + + _ = try await group.send(content: "hi") + _ = try await group.send(content: "hi") + _ = try await group.send(content: "hi") + _ = try await convo.send(content: "hi") + + try await Task.sleep(nanoseconds: 1_000_000_000) + + messagesQueue.sync { + XCTAssertEqual(messages, 2) + } + + let streamingTask2 = Task(priority: .userInitiated) { + for try await _ in await fixtures.aliceClient.conversations.streamAllDecryptedMessages(includeGroups: true) { + // Update the messages count in a thread-safe manner + messagesQueue.sync { + messages += 1 + } + } + } + + _ = try await group.send(content: "hi") + _ = try await convo.send(content: "hi") + + try await Task.sleep(nanoseconds: 1_000_000_000) + + messagesQueue.sync { + XCTAssertEqual(messages, 4) + } + } } diff --git a/XMTP.podspec b/XMTP.podspec index a86e32d8..a26632a3 100644 --- a/XMTP.podspec +++ b/XMTP.podspec @@ -16,7 +16,7 @@ Pod::Spec.new do |spec| # spec.name = "XMTP" - spec.version = "0.14.10" + spec.version = "0.14.11" spec.summary = "XMTP SDK Cocoapod" # This description is used to generate tags and improve search results.