Skip to content

Commit

Permalink
feat: implemented persisted queries for GET methods with only SHA-256…
Browse files Browse the repository at this point in the history
… hash of query string
  • Loading branch information
malaquf committed Dec 30, 2024
1 parent 681c70d commit 548482b
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import java.io.IOException
internal const val REQUEST_PARAM_QUERY = "query"
internal const val REQUEST_PARAM_OPERATION_NAME = "operationName"
internal const val REQUEST_PARAM_VARIABLES = "variables"
internal const val REQUEST_PARAM_EXTENSIONS = "extensions"
internal const val REQUEST_PARAM_PERSISTED_QUERY = "persistedQuery"

/**
* GraphQL Ktor [ApplicationRequest] parser.
Expand All @@ -46,8 +48,12 @@ open class KtorGraphQLRequestParser(
else -> null
}

private fun parseGetRequest(request: ApplicationRequest): GraphQLServerRequest? {
val query = request.queryParameters[REQUEST_PARAM_QUERY] ?: throw IllegalStateException("Invalid HTTP request - GET request has to specify query parameter")
private fun parseGetRequest(request: ApplicationRequest): GraphQLServerRequest {
val extensions = request.queryParameters[REQUEST_PARAM_EXTENSIONS]
val query = request.queryParameters[REQUEST_PARAM_QUERY] ?: ""
check(query.isNotEmpty() || extensions?.contains(REQUEST_PARAM_PERSISTED_QUERY) == true) {
"Invalid HTTP request - GET request has to specify either query parameter or persisted query extension"
}
if (query.startsWith("mutation ") || query.startsWith("subscription ")) {
throw UnsupportedOperationException("Invalid GraphQL operation - only queries are supported for GET requests")
}
Expand All @@ -56,7 +62,15 @@ open class KtorGraphQLRequestParser(
val graphQLVariables: Map<String, Any>? = variables?.let {
mapper.readValue(it, mapTypeReference)
}
return GraphQLRequest(query = query, operationName = operationName, variables = graphQLVariables)
val extensionsMap: Map<String, Any>? = request.queryParameters[REQUEST_PARAM_EXTENSIONS]?.let {
mapper.readValue(it, mapTypeReference)
}
return GraphQLRequest(
query = query,
operationName = operationName,
variables = graphQLVariables,
extensions = extensionsMap
)
}

private suspend fun parsePostRequest(request: ApplicationRequest): GraphQLServerRequest? = try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,25 @@ class GraphQLPluginTest {
}

@Test
fun `server should handle valid GET requests`() {
fun `server should handle valid GET requests with persisted query`() {
testApplication {
val response = client.get("/graphql") {
parameter("query", "query HelloQuery(\$name: String){ hello(name: \$name) }")
parameter("operationName", "HelloQuery")
parameter("variables", """{"name":"junit"}""")
parameter("extensions", """{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}""")
}
assertEquals(HttpStatusCode.OK, response.status)
assertEquals("""{"data":{"hello":"Hello junit"}}""", response.bodyAsText().trim())
}
}

@Test
fun `server should return Method Not Allowed for Mutation GET requests`() {
fun `server should return Method Not Allowed for Mutation GET requests persisted query`() {
testApplication {
val response = client.get("/graphql") {
parameter("query", "mutation { foo }")
parameter("extensions", """{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}""")
}
assertEquals(HttpStatusCode.MethodNotAllowed, response.status)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class KtorGraphQLRequestParserTest {
fun `parseRequest should throw IllegalStateException if request method is GET without query`() = runTest {
val request = mockk<ApplicationRequest>(relaxed = true) {
every { queryParameters[REQUEST_PARAM_QUERY] } returns null
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null
every { local.method } returns HttpMethod.Get
}
assertFailsWith<IllegalStateException> {
Expand All @@ -60,6 +61,7 @@ class KtorGraphQLRequestParserTest {
every { queryParameters[REQUEST_PARAM_QUERY] } returns "{ foo }"
every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns null
every { queryParameters[REQUEST_PARAM_VARIABLES] } returns null
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null
every { local.method } returns HttpMethod.Get
}
val graphQLRequest = parser.parseRequest(serverRequest)
Expand All @@ -76,6 +78,7 @@ class KtorGraphQLRequestParserTest {
every { queryParameters[REQUEST_PARAM_QUERY] } returns "query MyFoo { foo }"
every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns "MyFoo"
every { queryParameters[REQUEST_PARAM_VARIABLES] } returns """{"a":1}"""
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns null
every { local.method } returns HttpMethod.Get
}
val graphQLRequest = parser.parseRequest(serverRequest)
Expand All @@ -86,6 +89,23 @@ class KtorGraphQLRequestParserTest {
assertEquals(1, graphQLRequest.variables?.get("a"))
}

@Test
fun `parseRequest should return request if method is GET with hash only`() = runTest {
val serverRequest = mockk<ApplicationRequest>(relaxed = true) {
every { queryParameters[REQUEST_PARAM_QUERY] } returns null
every { queryParameters[REQUEST_PARAM_OPERATION_NAME] } returns "MyFoo"
every { queryParameters[REQUEST_PARAM_VARIABLES] } returns """{"a":1}"""
every { queryParameters[REQUEST_PARAM_EXTENSIONS] } returns """{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}"""
every { local.method } returns HttpMethod.Get
}
val graphQLRequest = parser.parseRequest(serverRequest)
assertNotNull(graphQLRequest)
assertTrue(graphQLRequest is GraphQLRequest)
assertEquals("", graphQLRequest.query)
assertEquals("MyFoo", graphQLRequest.operationName)
assertEquals(1, graphQLRequest.variables?.get("a"))
}

@Test
fun `parseRequest should return request if method is POST`() = runTest {
val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ import org.springframework.web.reactive.function.server.ServerRequest
import org.springframework.web.reactive.function.server.awaitBody
import org.springframework.web.reactive.function.server.bodyToMono
import org.springframework.web.server.ResponseStatusException
import kotlin.jvm.optionals.getOrNull

internal const val REQUEST_PARAM_QUERY = "query"
internal const val REQUEST_PARAM_OPERATION_NAME = "operationName"
internal const val REQUEST_PARAM_VARIABLES = "variables"
internal const val REQUEST_PARAM_EXTENSIONS = "extensions"
internal const val REQUEST_PARAM_PERSISTED_QUERY = "persistedQuery"
internal val graphQLMediaType = MediaType("application", "graphql")

open class SpringGraphQLRequestParser(
Expand All @@ -43,20 +46,32 @@ open class SpringGraphQLRequestParser(
private val mapTypeReference: MapType = TypeFactory.defaultInstance().constructMapType(HashMap::class.java, String::class.java, Any::class.java)

override suspend fun parseRequest(request: ServerRequest): GraphQLServerRequest? = when {
request.queryParam(REQUEST_PARAM_QUERY).isPresent -> { getRequestFromGet(request) }
request.method().equals(HttpMethod.POST) -> { getRequestFromPost(request) }
request.isGetPersistedQuery() || request.hasQueryParam() -> { getRequestFromGet(request) }
request.method() == HttpMethod.POST -> getRequestFromPost(request)
else -> null
}

private fun ServerRequest.hasQueryParam() = queryParam(REQUEST_PARAM_QUERY).isPresent

private fun ServerRequest.isGetPersistedQuery() = queryParam(REQUEST_PARAM_EXTENSIONS).getOrNull()?.contains(REQUEST_PARAM_PERSISTED_QUERY) == true

private fun getRequestFromGet(serverRequest: ServerRequest): GraphQLServerRequest {
val query = serverRequest.queryParam(REQUEST_PARAM_QUERY).get()
val query = serverRequest.queryParam(REQUEST_PARAM_QUERY).orElse("")
val operationName: String? = serverRequest.queryParam(REQUEST_PARAM_OPERATION_NAME).orElseGet { null }
val variables: String? = serverRequest.queryParam(REQUEST_PARAM_VARIABLES).orElseGet { null }
val graphQLVariables: Map<String, Any>? = variables?.let {
objectMapper.readValue(it, mapTypeReference)
}
val extensions: Map<String, Any>? = serverRequest.queryParam(REQUEST_PARAM_EXTENSIONS).takeIf { it.isPresent }?.get()?.let {
objectMapper.readValue(it, mapTypeReference)
}

return GraphQLRequest(query = query, operationName = operationName, variables = graphQLVariables)
return GraphQLRequest(
query = query,
operationName = operationName,
variables = graphQLVariables,
extensions = extensions
)
}

private suspend fun getRequestFromPost(serverRequest: ServerRequest): GraphQLServerRequest? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import io.mockk.every
import io.mockk.mockk
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runBlockingTest
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.springframework.http.HttpHeaders
Expand All @@ -44,18 +43,20 @@ class SpringGraphQLRequestParserTest {
private val parser = SpringGraphQLRequestParser(objectMapper)

@Test
fun `parseRequest should return null if request method is not valid`() = runBlockingTest {
fun `parseRequest should return null if request method is not valid`() = runTest {
val request = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { method() } returns HttpMethod.PUT
}
assertNull(parser.parseRequest(request))
}

@Test
fun `parseRequest should return null if request method is GET without query`() = runBlockingTest {
fun `parseRequest should return null if request method is GET without query`() = runTest {
val request = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { method() } returns HttpMethod.GET
}
assertNull(parser.parseRequest(request))
Expand All @@ -65,6 +66,7 @@ class SpringGraphQLRequestParserTest {
fun `parseRequest should return request if method is GET with simple query`() = runTest {
val serverRequest = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.of("{ foo }")
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.empty()
every { method() } returns HttpMethod.GET
Expand All @@ -82,6 +84,7 @@ class SpringGraphQLRequestParserTest {
val serverRequest = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.of("query MyFoo { foo }")
every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.of("MyFoo")
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.of("""{ "a": 1 }""")
every { method() } returns HttpMethod.GET
}
Expand All @@ -93,6 +96,23 @@ class SpringGraphQLRequestParserTest {
assertEquals(1, graphQLRequest.variables?.get("a"))
}

@Test
fun `parseRequest should return request if method is GET with hash only`() = runTest {
val serverRequest = mockk<ServerRequest>(relaxed = true) {
every { queryParam(REQUEST_PARAM_QUERY) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_EXTENSIONS) } returns Optional.of("""{"persistedQuery":{"version":1,"sha256Hash":"some-hash"}}""")
every { queryParam(REQUEST_PARAM_OPERATION_NAME) } returns Optional.empty()
every { queryParam(REQUEST_PARAM_VARIABLES) } returns Optional.empty()
every { method() } returns HttpMethod.GET
}
val graphQLRequest = parser.parseRequest(serverRequest)
assertNotNull(graphQLRequest)
assertTrue(graphQLRequest is GraphQLRequest)
assertEquals("", graphQLRequest.query)
assertNull(graphQLRequest.operationName)
assertNull(graphQLRequest.variables)
}

@Test
fun `parseRequest should return request if method is POST with no content-type`() = runTest {
val mockRequest = GraphQLRequest("query MyFoo { foo }", "MyFoo", mapOf("a" to 1))
Expand Down

0 comments on commit 548482b

Please sign in to comment.