diff --git a/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/ApolloExtensions.kt b/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/ApolloExtensions.kt index 4b515f4ea..cd1da092a 100644 --- a/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/ApolloExtensions.kt +++ b/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/ApolloExtensions.kt @@ -18,13 +18,16 @@ package com.amplifyframework.apollo.appsync import com.amplifyframework.apollo.appsync.util.UserAgentHeader +import com.amplifyframework.apollo.appsync.util.WebSocketConnectionInterceptor import com.apollographql.apollo.ApolloClient import com.apollographql.apollo.api.ApolloRequest import com.apollographql.apollo.api.NullableAnyAdapter import com.apollographql.apollo.api.Operation import com.apollographql.apollo.api.http.DefaultHttpRequestComposer import com.apollographql.apollo.api.toJsonString +import com.apollographql.apollo.network.ws.DefaultWebSocketEngine import com.apollographql.apollo.network.ws.WebSocketNetworkTransport +import okhttp3.OkHttpClient // Use the requestUuid as the subscriptionId internal val ApolloRequest.subscriptionId: String @@ -38,15 +41,25 @@ internal fun ApolloRequest<*>.toJson() = * Convenience function that configures the [WebSocketNetworkTransport] to connect to AppSync. This function: * 1. Sets the serverUrl * 2. Sets up an [AppSyncProtocol] using the given endpoint and authorizer + * 3. Adds an interceptor to append the authorization payload to the connection request * @param endpoint The [AppSyncEndpoint] to connect to * @param authorizer The [AppSyncAuthorizer] that determines the authorization mode to use when connecting to AppSync * @return The builder instance for chaining */ fun WebSocketNetworkTransport.Builder.appSync(endpoint: AppSyncEndpoint, authorizer: AppSyncAuthorizer) = apply { - serverUrl { endpoint.createWebsocketServerUrl(authorizer) } + // Set the connection URL + serverUrl(endpoint.websocketConnection.toString()) + // Add User-agent header addHeader(UserAgentHeader.NAME, UserAgentHeader.value) + // Add an interceptor that appends the authorization headers + val client = OkHttpClient.Builder() + .addInterceptor(WebSocketConnectionInterceptor(endpoint, authorizer)) + .build() + webSocketEngine(DefaultWebSocketEngine(client)) + + // Set the WebSocket protocol protocol( AppSyncProtocol.Factory( endpoint = endpoint, diff --git a/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/AppSyncEndpoint.kt b/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/AppSyncEndpoint.kt index 534c69b52..ca66eb30c 100644 --- a/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/AppSyncEndpoint.kt +++ b/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/AppSyncEndpoint.kt @@ -62,6 +62,7 @@ class AppSyncEndpoint(serverUrl: String) { * Creates the serverUrl to be used for the WebSocketTransport's serverUrl. For AppSync, this URL has authorization * information appended in query parameters. Set this value as the serverUrl for the WebSocketTransport. */ + @Deprecated("Use HTTP header authorization instead of appending a query parameter") suspend fun createWebsocketServerUrl(authorizer: AppSyncAuthorizer): String { val headers = mapOf("host" to serverUrl.host) + authorizer.getWebsocketConnectionHeaders(this) val authorization = headers.base64() diff --git a/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/util/WebSocketConnectionInterceptor.kt b/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/util/WebSocketConnectionInterceptor.kt new file mode 100644 index 000000000..0708ec818 --- /dev/null +++ b/apollo/apollo-appsync/src/main/java/com/amplifyframework/apollo/appsync/util/WebSocketConnectionInterceptor.kt @@ -0,0 +1,39 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amplifyframework.apollo.appsync.util + +import com.amplifyframework.apollo.appsync.AppSyncAuthorizer +import com.amplifyframework.apollo.appsync.AppSyncEndpoint +import kotlinx.coroutines.runBlocking +import okhttp3.Interceptor +import okhttp3.Response + +/** + * Intercepts the WebSocket connection request to append the authorization headers + */ +internal class WebSocketConnectionInterceptor( + private val endpoint: AppSyncEndpoint, + private val authorizer: AppSyncAuthorizer +) : Interceptor { + override fun intercept(chain: Interceptor.Chain): Response { + // runBlocking is okay because we are on an IO thread when the interceptor is called + val headers = runBlocking { authorizer.getWebsocketConnectionHeaders(endpoint) } + val builder = chain.request().newBuilder() + headers.forEach { header -> builder.header(header.key, header.value) } + builder.header("host", endpoint.serverUrl.host) + return chain.proceed(builder.build()) + } +} diff --git a/apollo/apollo-appsync/src/test/java/com/amplifyframework/apollo/appsync/ApolloExtensionsTest.kt b/apollo/apollo-appsync/src/test/java/com/amplifyframework/apollo/appsync/ApolloExtensionsTest.kt index 80c21a3d8..c68917da5 100644 --- a/apollo/apollo-appsync/src/test/java/com/amplifyframework/apollo/appsync/ApolloExtensionsTest.kt +++ b/apollo/apollo-appsync/src/test/java/com/amplifyframework/apollo/appsync/ApolloExtensionsTest.kt @@ -22,16 +22,12 @@ import com.apollographql.apollo.api.AnyAdapter import com.apollographql.apollo.api.CustomScalarAdapters import com.apollographql.apollo.api.json.BufferedSourceJsonReader import com.apollographql.apollo.network.ws.WebSocketNetworkTransport -import io.kotest.matchers.shouldBe import io.mockk.mockk -import io.mockk.slot import io.mockk.spyk import io.mockk.verify import kotlinx.coroutines.test.runTest -import okhttp3.HttpUrl.Companion.toHttpUrl import okio.Buffer import okio.ByteString -import okio.ByteString.Companion.decodeBase64 import org.junit.Test class ApolloExtensionsTest { @@ -84,23 +80,21 @@ class ApolloExtensionsTest { val transportBuilder = mockk(relaxed = true) builder.appSync(endpoint, authorizer, transportBuilder) - val slot = slot String>() verify { - transportBuilder.serverUrl(capture(slot)) + transportBuilder.serverUrl( + "https://example1234567890123456789.appsync-realtime-api.us-east-1.amazonaws.com/graphql/connect" + ) } + } - val serverUrl = slot.captured().toHttpUrl() - - // Expected URL: - // https://example1234567890123456789.appsync-realtime-api.us-east-1.amazonaws.com/graphql/connect - serverUrl.host shouldBe "example1234567890123456789.appsync-realtime-api.us-east-1.amazonaws.com" - serverUrl.encodedPath shouldBe "/graphql/connect" - - val header = serverUrl.queryParameter("header")?.decodeBase64()!!.toJsonMap() - header["host"] shouldBe "example1234567890123456789.appsync-api.us-east-1.amazonaws.com" - header["x-api-key"] shouldBe "apiKey" + @Test + fun `sets websocket engine`() { + val transportBuilder = mockk(relaxed = true) + builder.appSync(endpoint, authorizer, transportBuilder) - serverUrl.queryParameter("payload") shouldBe "e30=" + verify { + transportBuilder.webSocketEngine(any()) + } } @Suppress("UNCHECKED_CAST") diff --git a/apollo/apollo-appsync/src/test/java/com/amplifyframework/apollo/appsync/util/WebSocketConnectionInterceptorTest.kt b/apollo/apollo-appsync/src/test/java/com/amplifyframework/apollo/appsync/util/WebSocketConnectionInterceptorTest.kt new file mode 100644 index 000000000..800905d53 --- /dev/null +++ b/apollo/apollo-appsync/src/test/java/com/amplifyframework/apollo/appsync/util/WebSocketConnectionInterceptorTest.kt @@ -0,0 +1,53 @@ +/* + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package com.amplifyframework.apollo.appsync.util + +import com.amplifyframework.apollo.appsync.AppSyncAuthorizer +import com.amplifyframework.apollo.appsync.AppSyncEndpoint +import io.mockk.coEvery +import io.mockk.mockk +import io.mockk.verify +import okhttp3.Interceptor +import okhttp3.Request +import org.junit.Test + +/** + * Unit tests for the [WebSocketConnectionInterceptor] class + */ +class WebSocketConnectionInterceptorTest { + private val url = "https://example1234567890123456789.appsync-api.us-east-1.amazonaws.com/graphql" + + @Test + fun `adds expected headers`() { + val endpoint = AppSyncEndpoint(url) + val authorizer = mockk { + coEvery { getWebsocketConnectionHeaders(endpoint) } returns mapOf("test" to "value") + } + val builder = mockk(relaxed = true) + val chain = mockk { + coEvery { request().newBuilder() } returns builder + coEvery { proceed(any()) } returns mockk() + } + + val interceptor = WebSocketConnectionInterceptor(endpoint, authorizer) + interceptor.intercept(chain) + + verify { + builder.header("test", "value") + builder.header("host", "example1234567890123456789.appsync-api.us-east-1.amazonaws.com") + } + } +}