diff --git a/Examples/SlackClone/AppView.swift b/Examples/SlackClone/AppView.swift index 0fa0da07..4442bb6c 100644 --- a/Examples/SlackClone/AppView.swift +++ b/Examples/SlackClone/AppView.swift @@ -41,26 +41,33 @@ final class AppViewModel { @MainActor struct AppView: View { @Bindable var model: AppViewModel + let log = LogStore.shared + + @State var logPresented = false @ViewBuilder var body: some View { if model.session != nil { NavigationSplitView { ChannelListView(channel: $model.selectedChannel) + .toolbar { + ToolbarItem { + Button("Log") { + logPresented = true + } + } + } } detail: { if let channel = model.selectedChannel { MessagesView(channel: channel).id(channel.id) } } - .overlay(alignment: .bottom) { - LabeledContent( - "Connection Status", - value: model.realtimeConnectionStatus?.description ?? "Unknown" - ) - .padding() - .background(.regularMaterial) - .clipShape(Capsule()) - .padding() + .sheet(isPresented: $logPresented) { + List { + ForEach(0 ..< log.messages.count, id: \.self) { i in + Text(log.messages[i].description) + } + } } } else { AuthView() diff --git a/Examples/SlackClone/Logger.swift b/Examples/SlackClone/Logger.swift index 02b4792f..8905700a 100644 --- a/Examples/SlackClone/Logger.swift +++ b/Examples/SlackClone/Logger.swift @@ -13,11 +13,17 @@ extension Logger { static let main = Self(subsystem: "com.supabase.SlackClone", category: "app") } -final class SupabaseLoggerImpl: SupabaseLogger, @unchecked Sendable { +@Observable +final class LogStore: SupabaseLogger { private let lock = NSLock() private var loggers: [String: Logger] = [:] + static let shared = LogStore() + + var messages: [SupabaseLogMessage] = [] + func log(message: SupabaseLogMessage) { + messages.append(message) lock.withLock { if loggers[message.system] == nil { loggers[message.system] = Logger( diff --git a/Examples/SlackClone/Supabase.swift b/Examples/SlackClone/Supabase.swift index f4b46bc8..dcd3c3e9 100644 --- a/Examples/SlackClone/Supabase.swift +++ b/Examples/SlackClone/Supabase.swift @@ -21,10 +21,10 @@ let decoder: JSONDecoder = { }() let supabase = SupabaseClient( - supabaseURL: URL(string: "http://192.168.0.4:54321")!, + supabaseURL: URL(string: "http://192.168.0.6:54321")!, supabaseKey: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6ImFub24iLCJleHAiOjE5ODM4MTI5OTZ9.CRXP1A7WOeoJeXxjNni43kdQwgnWNReilDMblYTn_I0", options: SupabaseClientOptions( db: .init(encoder: encoder, decoder: decoder), - global: SupabaseClientOptions.GlobalOptions(logger: SupabaseLoggerImpl()) + global: SupabaseClientOptions.GlobalOptions(logger: LogStore.shared) ) ) diff --git a/Sources/Realtime/V2/RealtimeClientV2.swift b/Sources/Realtime/V2/RealtimeClientV2.swift index ef18c0db..c78a9c26 100644 --- a/Sources/Realtime/V2/RealtimeClientV2.swift +++ b/Sources/Realtime/V2/RealtimeClientV2.swift @@ -64,18 +64,19 @@ public actor RealtimeClientV2 { } } + let config: Configuration + let ws: any WebSocketClient + var accessToken: String? var ref = 0 var pendingHeartbeatRef: Int? + var heartbeatTask: Task? var messageTask: Task? - var inFlightConnectionTask: Task? + var connectionTask: Task? public private(set) var subscriptions: [String: RealtimeChannelV2] = [:] - let config: Configuration - let ws: any WebSocketClient - private let statusEventEmitter = EventEmitter(initialEvent: .disconnected) public var statusChange: AsyncStream { @@ -93,22 +94,8 @@ public actor RealtimeClientV2 { statusEventEmitter.attach(listener) } - deinit { - heartbeatTask?.cancel() - messageTask?.cancel() - subscriptions = [:] - } - public init(config: Configuration) { - let sessionConfiguration = URLSessionConfiguration.default - sessionConfiguration.httpAdditionalHeaders = config.headers - let ws = DefaultWebSocketClient( - realtimeURL: config.realtimeWebSocketURL, - configuration: sessionConfiguration, - logger: config.logger - ) - - self.init(config: config, ws: ws) + self.init(config: config, ws: WebSocket(config: config)) } init(config: Configuration, ws: any WebSocketClient) { @@ -122,68 +109,86 @@ public actor RealtimeClientV2 { } } - public func connect() async { - guard status != .connected else { - return - } + deinit { + heartbeatTask?.cancel() + messageTask?.cancel() + subscriptions = [:] + } + public func connect() async { await connect(reconnect: false) } func connect(reconnect: Bool) async { - if let inFlightConnectionTask { - return await inFlightConnectionTask.value - } + if status == .disconnected { + connectionTask = Task { + if reconnect { + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(config.reconnectDelay)) - inFlightConnectionTask = Task { [self] in - defer { inFlightConnectionTask = nil } - - if reconnect { - try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(config.reconnectDelay)) + if Task.isCancelled { + config.logger?.debug("Reconnect cancelled, returning") + return + } + } - if Task.isCancelled { - config.logger?.debug("reconnect cancelled, returning") + if status == .connected { + config.logger?.debug("WebsSocket already connected") return } - } - if status == .connected { - config.logger?.debug("Websocket already connected") - return - } + status = .connecting - status = .connecting - - Task { for await connectionStatus in ws.connect() { + if Task.isCancelled { + break + } + switch connectionStatus { - case .open: - status = .connected - config.logger?.debug("Connected to realtime WebSocket") - listenForMessages() - startHeartbeating() - if reconnect { - await rejoinChannels() - } - - case .close: - config.logger?.debug("WebSocket connection closed. Trying again in \(config.reconnectDelay) seconds.") - disconnect() - await connect(reconnect: true) - - case .complete(let error): - config.logger?.error( - "WebSocket connection error: \(error?.localizedDescription ?? "")" - ) - disconnect() + case .connected: + await onConnected(reconnect: reconnect) + + case .disconnected: + await onDisconnected() + + case let .error(error): + await onError(error) } } } + } - _ = await statusChange.first { @Sendable in $0 == .connected } + _ = await statusChange.first { @Sendable in $0 == .connected } + } + + private func onConnected(reconnect: Bool) async { + status = .connected + config.logger?.debug("Connected to realtime WebSocket") + listenForMessages() + startHeartbeating() + if reconnect { + await rejoinChannels() } + } - await inFlightConnectionTask?.value + private func onDisconnected() async { + config.logger? + .debug( + "WebSocket disconnected. Trying again in \(config.reconnectDelay)" + ) + await reconnect() + } + + private func onError(_ error: (any Error)?) async { + config.logger? + .debug( + "WebSocket error \(error?.localizedDescription ?? ""). Trying again in \(config.reconnectDelay)" + ) + await reconnect() + } + + private func reconnect() async { + disconnect() + await connect(reconnect: true) } public func channel( @@ -222,14 +227,8 @@ public actor RealtimeClientV2 { } private func rejoinChannels() async { - await withTaskGroup(of: Void.self) { group in - for channel in subscriptions.values { - _ = group.addTaskUnlessCancelled { - await channel.subscribe() - } - - await group.waitForAll() - } + for channel in subscriptions.values { + await channel.subscribe() } } @@ -249,22 +248,19 @@ public actor RealtimeClientV2 { config.logger?.debug( "Error while listening for messages. Trying again in \(config.reconnectDelay) \(error)" ) - await disconnect() - await connect(reconnect: true) + await reconnect() } } } private func startHeartbeating() { - heartbeatTask = Task { [weak self] in - guard let self else { return } - + heartbeatTask = Task { [weak self, config] in while !Task.isCancelled { try? await Task.sleep(nanoseconds: NSEC_PER_SEC * UInt64(config.heartbeatInterval)) if Task.isCancelled { break } - await sendHeartbeat() + await self?.sendHeartbeat() } } } @@ -272,9 +268,9 @@ public actor RealtimeClientV2 { private func sendHeartbeat() async { if pendingHeartbeatRef != nil { pendingHeartbeatRef = nil - config.logger?.debug("Heartbeat timeout. Trying to reconnect in \(config.reconnectDelay)") - disconnect() - await connect(reconnect: true) + config.logger?.debug("Heartbeat timeout") + + await reconnect() return } @@ -296,7 +292,8 @@ public actor RealtimeClientV2 { ref = 0 messageTask?.cancel() heartbeatTask?.cancel() - ws.cancel() + connectionTask?.cancel() + ws.disconnect() status = .disconnected } diff --git a/Sources/Realtime/V2/WebSocketClient.swift b/Sources/Realtime/V2/WebSocketClient.swift index 6b77121b..44f2cda0 100644 --- a/Sources/Realtime/V2/WebSocketClient.swift +++ b/Sources/Realtime/V2/WebSocketClient.swift @@ -14,21 +14,25 @@ import Foundation #endif enum ConnectionStatus { - case open - case close - case complete((any Error)?) + case connected + case disconnected(reason: String, code: URLSessionWebSocketTask.CloseCode) + case error((any Error)?) } protocol WebSocketClient: Sendable { - func send(_ message: RealtimeMessageV2) async throws -> Void + func send(_ message: RealtimeMessageV2) async throws func receive() -> AsyncThrowingStream func connect() -> AsyncStream - func cancel() + func disconnect(closeCode: URLSessionWebSocketTask.CloseCode) } -class DefaultWebSocketClient: NSObject, URLSessionWebSocketDelegate, WebSocketClient, - @unchecked Sendable -{ +extension WebSocketClient { + func disconnect() { + disconnect(closeCode: .normalClosure) + } +} + +final class WebSocket: NSObject, URLSessionWebSocketDelegate, WebSocketClient, @unchecked Sendable { private let realtimeURL: URL private let configuration: URLSessionConfiguration private let logger: (any SupabaseLogger)? @@ -40,14 +44,13 @@ class DefaultWebSocketClient: NSObject, URLSessionWebSocketDelegate, WebSocketCl let mutableState = LockIsolated(MutableState()) - init(realtimeURL: URL, configuration: URLSessionConfiguration, logger: (any SupabaseLogger)?) { - self.realtimeURL = realtimeURL - self.configuration = configuration - self.logger = logger - } + init(config: RealtimeClientV2.Configuration) { + realtimeURL = config.realtimeWebSocketURL - deinit { - cancel() + let sessionConfiguration = URLSessionConfiguration.default + sessionConfiguration.httpAdditionalHeaders = config.headers + configuration = sessionConfiguration + logger = config.logger } func connect() -> AsyncStream { @@ -62,9 +65,9 @@ class DefaultWebSocketClient: NSObject, URLSessionWebSocketDelegate, WebSocketCl } } - func cancel() { + func disconnect(closeCode: URLSessionWebSocketTask.CloseCode) { mutableState.withValue { state in - state.task?.cancel() + state.task?.cancel(with: closeCode, reason: nil) } } @@ -116,16 +119,21 @@ class DefaultWebSocketClient: NSObject, URLSessionWebSocketDelegate, WebSocketCl webSocketTask _: URLSessionWebSocketTask, didOpenWithProtocol _: String? ) { - mutableState.continuation?.yield(.open) + mutableState.continuation?.yield(.connected) } func urlSession( _: URLSession, webSocketTask _: URLSessionWebSocketTask, - didCloseWith _: URLSessionWebSocketTask.CloseCode, - reason _: Data? + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, + reason: Data? ) { - mutableState.continuation?.yield(.close) + let status = ConnectionStatus.disconnected( + reason: reason.flatMap { String(data: $0, encoding: .utf8) } ?? "", + code: closeCode + ) + + mutableState.continuation?.yield(status) } func urlSession( @@ -133,6 +141,6 @@ class DefaultWebSocketClient: NSObject, URLSessionWebSocketDelegate, WebSocketCl task _: URLSessionTask, didCompleteWithError error: (any Error)? ) { - mutableState.continuation?.yield(.complete(error)) + mutableState.continuation?.yield(.error(error)) } } diff --git a/Tests/RealtimeTests/MockWebSocketClient.swift b/Tests/RealtimeTests/MockWebSocketClient.swift index 0bbd85fb..84ab2f58 100644 --- a/Tests/RealtimeTests/MockWebSocketClient.swift +++ b/Tests/RealtimeTests/MockWebSocketClient.swift @@ -16,6 +16,10 @@ final class MockWebSocketClient: WebSocketClient { sentMessages.withValue { $0.append(message) } + + if let callback = onCallback.value, let response = callback(message) { + mockReceive(response) + } } private let receiveContinuation = @@ -24,6 +28,11 @@ final class MockWebSocketClient: WebSocketClient { receiveContinuation.value?.yield(message) } + private let onCallback = LockIsolated<((RealtimeMessageV2) -> RealtimeMessageV2?)?>(nil) + func on(_ callback: @escaping (RealtimeMessageV2) -> RealtimeMessageV2?) { + onCallback.setValue(callback) + } + func receive() -> AsyncThrowingStream { let (stream, continuation) = AsyncThrowingStream.makeStream() receiveContinuation.setValue(continuation) @@ -41,5 +50,5 @@ final class MockWebSocketClient: WebSocketClient { return stream } - func cancel() {} + func disconnect(closeCode _: URLSessionWebSocketTask.CloseCode) {} } diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift index ee8925a5..6a145f75 100644 --- a/Tests/RealtimeTests/RealtimeTests.swift +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -8,14 +8,6 @@ import TestHelpers final class RealtimeTests: XCTestCase { let url = URL(string: "https://localhost:54321/realtime/v1")! let apiKey = "anon.api.key" - let accessToken = - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhdXRoZW50aWNhdGVkIiwiZXhwIjoxNzA1Nzc4MTAxLCJpYXQiOjE3MDU3NzQ1MDEsImlzcyI6Imh0dHA6Ly8xMjcuMC4wLjE6NTQzMjEvYXV0aC92MSIsInN1YiI6ImFiZTQ1NjMwLTM0YTAtNDBhNS04Zjg5LTQxY2NkYzJjNjQyNCIsImVtYWlsIjoib2dyc291emErbWFjQGdtYWlsLmNvbSIsInBob25lIjoiIiwiYXBwX21ldGFkYXRhIjp7InByb3ZpZGVyIjoiZW1haWwiLCJwcm92aWRlcnMiOlsiZW1haWwiXX0sInVzZXJfbWV0YWRhdGEiOnt9LCJyb2xlIjoiYXV0aGVudGljYXRlZCIsImFhbCI6ImFhbDEiLCJhbXIiOlt7Im1ldGhvZCI6Im1hZ2ljbGluayIsInRpbWVzdGFtcCI6MTcwNTYwODcxOX1dLCJzZXNzaW9uX2lkIjoiMzFmMmQ4NGQtODZmYi00NWE2LTljMTItODMyYzkwYTgyODJjIn0.RY1y5U7CK97v6buOgJj_jQNDHW_1o0THbNP2UQM1HVE" - - var ref: Int = 0 - func makeRef() -> String { - ref += 1 - return "\(ref)" - } override func invokeTest() { withMainSerialExecutor { @@ -23,46 +15,136 @@ final class RealtimeTests: XCTestCase { } } - func testConnectAndSubscribe() async { - let mock = MockWebSocketClient() - - let realtime = RealtimeClientV2( - config: RealtimeClientV2.Configuration(url: url, apiKey: apiKey), - ws: mock + var ws: MockWebSocketClient! + var sut: RealtimeClientV2! + + override func setUp() { + super.setUp() + + ws = MockWebSocketClient() + sut = RealtimeClientV2( + config: RealtimeClientV2.Configuration( + url: url, + apiKey: apiKey, + heartbeatInterval: 1, + reconnectDelay: 1 + ), + ws: ws ) + } - let channel = await realtime.channel("public:messages") + func testBehavior() async { + let channel = await sut.channel("public:messages") _ = await channel.postgresChange(InsertAction.self, table: "messages") _ = await channel.postgresChange(UpdateAction.self, table: "messages") _ = await channel.postgresChange(DeleteAction.self, table: "messages") - let statusChange = await realtime.statusChange + let statusChange = await sut.statusChange - Task { - await realtime.connect() - } - await Task.megaYield() - mock.mockConnect(.open) - - await realtime.setAuth(accessToken) + await connectSocketAndWait() let status = await statusChange.prefix(3).collect() XCTAssertEqual(status, [.disconnected, .connecting, .connected]) - let messageTask = await realtime.messageTask + let messageTask = await sut.messageTask XCTAssertNotNil(messageTask) - let heartbeatTask = await realtime.heartbeatTask + let heartbeatTask = await sut.heartbeatTask XCTAssertNotNil(heartbeatTask) let subscription = Task { await channel.subscribe() } await Task.megaYield() - mock.mockReceive(.messagesSubscribed) + ws.mockReceive(.messagesSubscribed) + // Wait until channel subscribed await subscription.value - XCTAssertNoDifference(mock.sentMessages.value, [.subscribeToMessages]) + + XCTAssertNoDifference(ws.sentMessages.value, [.subscribeToMessages]) + } + + func testHeartbeat() async throws { + let expectation = expectation(description: "heartbeat") + expectation.expectedFulfillmentCount = 2 + + ws.on { message in + if message.event == "heartbeat" { + expectation.fulfill() + return RealtimeMessageV2( + joinRef: message.joinRef, + ref: message.ref, + topic: "phoenix", + event: "phx_reply", + payload: [ + "response": [:], + "status": "ok", + ] + ) + } + + return nil + } + + await connectSocketAndWait() + + await fulfillment(of: [expectation], timeout: 3) + } + + func testHeartbeat_whenNoResponse_shouldReconnect() async throws { + let sentHeartbeatExpectation = expectation(description: "sentHeartbeat") + + ws.on { + if $0.event == "heartbeat" { + sentHeartbeatExpectation.fulfill() + } + + return nil + } + + let statuses = LockIsolated<[RealtimeClientV2.Status]>([]) + + Task { + for await status in await sut.statusChange { + statuses.withValue { + $0.append(status) + } + } + } + await Task.megaYield() + await connectSocketAndWait() + + await fulfillment(of: [sentHeartbeatExpectation], timeout: 2) + + let pendingHeartbeatRef = await sut.pendingHeartbeatRef + XCTAssertNotNil(pendingHeartbeatRef) + + // Wait until next heartbeat + try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + + // Wait for reconnect delay + try await Task.sleep(nanoseconds: NSEC_PER_SEC * 1) + + XCTAssertEqual( + statuses.value, + [ + .disconnected, + .connecting, + .connected, + .disconnected, + .connecting, + ] + ) + } + + private func connectSocketAndWait() async { + let connection = Task { + await sut.connect() + } + await Task.megaYield() + + ws.mockConnect(.connected) + await connection.value } } @@ -79,7 +161,7 @@ extension RealtimeMessageV2 { topic: "realtime:public:messages", event: "phx_join", payload: [ - "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhdXRoZW50aWNhdGVkIiwiZXhwIjoxNzA1Nzc4MTAxLCJpYXQiOjE3MDU3NzQ1MDEsImlzcyI6Imh0dHA6Ly8xMjcuMC4wLjE6NTQzMjEvYXV0aC92MSIsInN1YiI6ImFiZTQ1NjMwLTM0YTAtNDBhNS04Zjg5LTQxY2NkYzJjNjQyNCIsImVtYWlsIjoib2dyc291emErbWFjQGdtYWlsLmNvbSIsInBob25lIjoiIiwiYXBwX21ldGFkYXRhIjp7InByb3ZpZGVyIjoiZW1haWwiLCJwcm92aWRlcnMiOlsiZW1haWwiXX0sInVzZXJfbWV0YWRhdGEiOnt9LCJyb2xlIjoiYXV0aGVudGljYXRlZCIsImFhbCI6ImFhbDEiLCJhbXIiOlt7Im1ldGhvZCI6Im1hZ2ljbGluayIsInRpbWVzdGFtcCI6MTcwNTYwODcxOX1dLCJzZXNzaW9uX2lkIjoiMzFmMmQ4NGQtODZmYi00NWE2LTljMTItODMyYzkwYTgyODJjIn0.RY1y5U7CK97v6buOgJj_jQNDHW_1o0THbNP2UQM1HVE", + "access_token": "anon.api.key", "config": [ "broadcast": [ "self": false,