Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(realtime): pull access token mechanism #615

Merged
merged 5 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions Sources/Realtime/V2/RealtimeChannelV2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ struct Socket: Sendable {
var broadcastURL: @Sendable () -> URL
var status: @Sendable () -> RealtimeClientStatus
var options: @Sendable () -> RealtimeClientOptions
var accessToken: @Sendable () -> String?
var accessToken: @Sendable () async -> String?
var apiKey: @Sendable () -> String?
var makeRef: @Sendable () -> Int

Expand All @@ -46,7 +46,12 @@ extension Socket {
broadcastURL: { [weak client] in client?.broadcastURL ?? URL(string: "http://localhost")! },
status: { [weak client] in client?.status ?? .disconnected },
options: { [weak client] in client?.options ?? .init() },
accessToken: { [weak client] in client?.mutableState.accessToken },
accessToken: { [weak client] in
if let accessToken = try? await client?.options.accessToken?() {
return accessToken
}
return client?.mutableState.accessToken
},
apiKey: { [weak client] in client?.apikey },
makeRef: { [weak client] in client?.makeRef() ?? 0 },
connect: { [weak client] in await client?.connect() },
Expand Down Expand Up @@ -139,7 +144,7 @@ public final class RealtimeChannelV2: Sendable {

let payload = RealtimeJoinPayload(
config: joinConfig,
accessToken: socket.accessToken()
accessToken: await socket.accessToken()
)

let joinRef = socket.makeRef().description
Expand Down Expand Up @@ -213,7 +218,7 @@ public final class RealtimeChannelV2: Sendable {
if let apiKey = socket.apiKey() {
headers[.apiKey] = apiKey
}
if let accessToken = socket.accessToken() {
if let accessToken = await socket.accessToken() {
headers[.authorization] = "Bearer \(accessToken)"
}

Expand Down
26 changes: 22 additions & 4 deletions Sources/Realtime/V2/RealtimeClientV2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,11 @@ public final class RealtimeClientV2: Sendable {
apikey = options.apikey

mutableState.withValue {
$0.accessToken = options.accessToken ?? options.apikey
if let accessToken = options.headers[.authorization]?.split(separator: " ").last {
$0.accessToken = String(accessToken)
} else {
$0.accessToken = options.apikey
}
}
}

Expand Down Expand Up @@ -361,8 +365,22 @@ public final class RealtimeClientV2: Sendable {
}

/// Sets the JWT access token used for channel subscription authorization and Realtime RLS.
/// - Parameter token: A JWT string.
public func setAuth(_ token: String?) async {
///
/// If `token` is nil it will use the ``RealtimeClientOptions/accessToken`` callback function or the token set on the client.
///
/// On callback used, it will set the value of the token internal to the client.
/// - Parameter token: A JWT string to override the token set on the client.
public func setAuth(_ token: String? = nil) async {
var token = token

if token == nil {
token = try? await options.accessToken?()
}

if token == nil {
token = mutableState.accessToken
}

if let token, let payload = JWT.decodePayload(token),
let exp = payload["exp"] as? TimeInterval, exp < Date().timeIntervalSince1970
{
Expand All @@ -371,7 +389,7 @@ public final class RealtimeClientV2: Sendable {
return
}

mutableState.withValue {
mutableState.withValue { [token] in
$0.accessToken = token
}

Expand Down
12 changes: 4 additions & 8 deletions Sources/Realtime/V2/Types.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//

import Foundation
import Helpers
import HTTPTypes
import Helpers

#if canImport(FoundationNetworking)
import FoundationNetworking
Expand All @@ -22,6 +22,7 @@ public struct RealtimeClientOptions: Sendable {
var disconnectOnSessionLoss: Bool
var connectOnSubscribe: Bool
var fetch: (@Sendable (_ request: URLRequest) async throws -> (Data, URLResponse))?
package var accessToken: (@Sendable () async throws -> String)?
package var logger: (any SupabaseLogger)?

public static let defaultHeartbeatInterval: TimeInterval = 15
Expand All @@ -38,6 +39,7 @@ public struct RealtimeClientOptions: Sendable {
disconnectOnSessionLoss: Bool = Self.defaultDisconnectOnSessionLoss,
connectOnSubscribe: Bool = Self.defaultConnectOnSubscribe,
fetch: (@Sendable (_ request: URLRequest) async throws -> (Data, URLResponse))? = nil,
accessToken: (@Sendable () async throws -> String)? = nil,
logger: (any SupabaseLogger)? = nil
) {
self.headers = HTTPFields(headers)
Expand All @@ -47,19 +49,13 @@ public struct RealtimeClientOptions: Sendable {
self.disconnectOnSessionLoss = disconnectOnSessionLoss
self.connectOnSubscribe = connectOnSubscribe
self.fetch = fetch
self.accessToken = accessToken
self.logger = logger
}

var apikey: String? {
headers[.apiKey]
}

var accessToken: String? {
guard let accessToken = headers[.authorization]?.split(separator: " ").last else {
return nil
}
return String(accessToken)
}
}

public typealias RealtimeSubscription = ObservationToken
Expand Down
80 changes: 56 additions & 24 deletions Sources/Supabase/SupabaseClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import ConcurrencyExtras
import Foundation
@_exported import Functions
import HTTPTypes
import Helpers
import IssueReporting
@_exported import PostgREST
@_exported import Realtime
@_exported import Storage
import HTTPTypes

#if canImport(FoundationNetworking)
import FoundationNetworking
Expand All @@ -33,10 +33,11 @@ public final class SupabaseClient: Sendable {
/// Supabase Auth allows you to create and manage user sessions for access to data that is secured by access policies.
public var auth: AuthClient {
if options.auth.accessToken != nil {
reportIssue("""
Supabase Client is configured with the auth.accessToken option,
accessing supabase.auth is not possible.
""")
reportIssue(
"""
Supabase Client is configured with the auth.accessToken option,
accessing supabase.auth is not possible.
""")
}
return _auth
}
Expand Down Expand Up @@ -80,7 +81,14 @@ public final class SupabaseClient: Sendable {
let _realtime: UncheckedSendable<RealtimeClient>

/// Realtime client for Supabase
public let realtimeV2: RealtimeClientV2
public var realtimeV2: RealtimeClientV2 {
mutableState.withValue {
if $0.realtime == nil {
$0.realtime = _initRealtimeClient()
}
return $0.realtime!
}
}

/// Supabase Functions allows you to deploy and invoke edge functions.
public var functions: FunctionsClient {
Expand Down Expand Up @@ -112,6 +120,7 @@ public final class SupabaseClient: Sendable {
var storage: SupabaseStorageClient?
var rest: PostgrestClient?
var functions: FunctionsClient?
var realtime: RealtimeClientV2?

var changedAccessToken: String?
}
Expand Down Expand Up @@ -189,18 +198,6 @@ public final class SupabaseClient: Sendable {
)
)

var realtimeOptions = options.realtime
realtimeOptions.headers.merge(with: _headers)

if realtimeOptions.logger == nil {
realtimeOptions.logger = options.global.logger
}

realtimeV2 = RealtimeClientV2(
url: supabaseURL.appendingPathComponent("/realtime/v1"),
options: realtimeOptions
)

if options.auth.accessToken == nil {
listenForAuthEvents()
}
Expand Down Expand Up @@ -351,11 +348,7 @@ public final class SupabaseClient: Sendable {
}

private func adapt(request: URLRequest) async -> URLRequest {
let token: String? = if let accessToken = options.auth.accessToken {
try? await accessToken()
} else {
try? await auth.session.accessToken
}
let token = try? await _getAccessToken()

var request = request
if let token {
Expand All @@ -364,6 +357,14 @@ public final class SupabaseClient: Sendable {
return request
}

private func _getAccessToken() async throws -> String {
if let accessToken = options.auth.accessToken {
try await accessToken()
} else {
try await auth.session.accessToken
}
}

private func listenForAuthEvents() {
let task = Task {
for await (event, session) in auth.authStateChanges {
Expand All @@ -377,7 +378,9 @@ public final class SupabaseClient: Sendable {

private func handleTokenChanged(event: AuthChangeEvent, session: Session?) async {
let accessToken: String? = mutableState.withValue {
if [.initialSession, .signedIn, .tokenRefreshed].contains(event), $0.changedAccessToken != session?.accessToken {
if [.initialSession, .signedIn, .tokenRefreshed].contains(event),
$0.changedAccessToken != session?.accessToken
{
$0.changedAccessToken = session?.accessToken
return session?.accessToken ?? supabaseKey
}
Expand All @@ -393,4 +396,33 @@ public final class SupabaseClient: Sendable {
realtime.setAuth(accessToken)
await realtimeV2.setAuth(accessToken)
}

private func _initRealtimeClient() -> RealtimeClientV2 {
var realtimeOptions = options.realtime
realtimeOptions.headers.merge(with: _headers)

if realtimeOptions.logger == nil {
realtimeOptions.logger = options.global.logger
}

if realtimeOptions.accessToken == nil {
realtimeOptions.accessToken = { [weak self] in
try await self?._getAccessToken() ?? ""
}
} else {
reportIssue(
"""
You assigned a custom `accessToken` closure to the RealtimeClientV2. This might not work as you expect
as SupabaseClient uses Auth for pulling an access token to send on the realtime channels.

Please make sure you know what you're doing.
"""
)
}

return RealtimeClientV2(
url: supabaseURL.appendingPathComponent("/realtime/v1"),
options: realtimeOptions
)
}
}
13 changes: 8 additions & 5 deletions Tests/RealtimeTests/RealtimeTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ final class RealtimeTests: XCTestCase {
headers: ["apikey": apiKey],
heartbeatInterval: 1,
reconnectDelay: 1,
timeoutInterval: 2
timeoutInterval: 2,
accessToken: {
"custom.access.token"
}
),
ws: ws,
http: http
Expand Down Expand Up @@ -100,7 +103,7 @@ final class RealtimeTests: XCTestCase {
"event" : "phx_join",
"join_ref" : "1",
"payload" : {
"access_token" : "anon.api.key",
"access_token" : "custom.access.token",
"config" : {
"broadcast" : {
"ack" : false,
Expand Down Expand Up @@ -179,7 +182,7 @@ final class RealtimeTests: XCTestCase {
"event" : "phx_join",
"join_ref" : "1",
"payload" : {
"access_token" : "anon.api.key",
"access_token" : "custom.access.token",
"config" : {
"broadcast" : {
"ack" : false,
Expand All @@ -201,7 +204,7 @@ final class RealtimeTests: XCTestCase {
"event" : "phx_join",
"join_ref" : "2",
"payload" : {
"access_token" : "anon.api.key",
"access_token" : "custom.access.token",
"config" : {
"broadcast" : {
"ack" : false,
Expand Down Expand Up @@ -322,7 +325,7 @@ final class RealtimeTests: XCTestCase {
assertInlineSnapshot(of: request?.urlRequest, as: .raw(pretty: true)) {
"""
POST https://localhost:54321/realtime/v1/api/broadcast
Authorization: Bearer anon.api.key
Authorization: Bearer custom.access.token
Content-Type: application/json
apiKey: anon.api.key

Expand Down
Loading