From 5bac14654822d9eccbf4d2dd591fe08d5865d59f Mon Sep 17 00:00:00 2001 From: Pranav Ramachandra Date: Wed, 7 Aug 2024 11:16:14 +0100 Subject: [PATCH] Added http credentials provider --- trino-aws-proxy/pom.xml | 6 + .../server/TrinoAwsProxyServerModule.java | 2 + .../http/ForHttpCredentialsProvider.java | 30 ++++ .../http/HttpCredentialsModule.java | 42 +++++ .../http/HttpCredentialsProvider.java | 68 +++++++ .../http/HttpCredentialsProviderConfig.java | 37 ++++ .../http/TestHttpCredentialsProvider.java | 170 ++++++++++++++++++ .../TestHttpCredentialsProviderConfig.java | 36 ++++ 8 files changed, 391 insertions(+) create mode 100644 trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/ForHttpCredentialsProvider.java create mode 100644 trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsModule.java create mode 100644 trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java create mode 100644 trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java create mode 100644 trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java create mode 100644 trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java diff --git a/trino-aws-proxy/pom.xml b/trino-aws-proxy/pom.xml index 8b1ed98b..9d1fe5bf 100644 --- a/trino-aws-proxy/pom.xml +++ b/trino-aws-proxy/pom.xml @@ -180,6 +180,12 @@ runtime + + jakarta.servlet + jakarta.servlet-api + runtime + + ${project.groupId} trino-aws-proxy-spark3 diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java index bc622cb6..67122736 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java @@ -31,6 +31,7 @@ import io.airlift.log.Logger; import io.trino.aws.proxy.server.credentials.CredentialsController; import io.trino.aws.proxy.server.credentials.file.FileBasedCredentialsModule; +import io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule; import io.trino.aws.proxy.server.remote.RemoteS3Module; import io.trino.aws.proxy.server.rest.RequestFilter; import io.trino.aws.proxy.server.rest.RequestLoggerController; @@ -119,6 +120,7 @@ protected void setup(Binder binder) // provided implementations install(new FileBasedCredentialsModule()); install(new OpaS3SecurityModule()); + install(new HttpCredentialsModule()); // AssumedRoleProvider binder configBinder(binder).bindConfig(AssumedRoleProviderConfig.class); diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/ForHttpCredentialsProvider.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/ForHttpCredentialsProvider.java new file mode 100644 index 00000000..befe76a6 --- /dev/null +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/ForHttpCredentialsProvider.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.credentials.http; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@BindingAnnotation +@Target({FIELD, PARAMETER, METHOD}) +@Retention(RUNTIME) +public @interface ForHttpCredentialsProvider { +} diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsModule.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsModule.java new file mode 100644 index 00000000..bd7b749c --- /dev/null +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsModule.java @@ -0,0 +1,42 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.credentials.http; + +import com.google.inject.Binder; +import io.airlift.configuration.AbstractConfigurationAwareModule; + +import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.http.client.HttpClientBinder.httpClientBinder; +import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.credentialsProviderModule; + +public class HttpCredentialsModule + extends AbstractConfigurationAwareModule +{ + // set as config value for "credentials-provider.type" + public static final String HTTP_CREDENTIALS_PROVIDER_IDENTIFIER = "http"; + public static final String HTTP_CREDENTIALS_PROVIDER_HTTP_CLIENT_NAME = "http-credentials-provider"; + + @Override + protected void setup(Binder binder) + { + install(credentialsProviderModule( + HTTP_CREDENTIALS_PROVIDER_IDENTIFIER, + HttpCredentialsProvider.class, + innerBinder -> { + configBinder(innerBinder).bindConfig(HttpCredentialsProviderConfig.class); + innerBinder.bind(HttpCredentialsProvider.class); + httpClientBinder(innerBinder).bindHttpClient(HTTP_CREDENTIALS_PROVIDER_HTTP_CLIENT_NAME, ForHttpCredentialsProvider.class); + })); + } +} diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java new file mode 100644 index 00000000..3255b9e7 --- /dev/null +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProvider.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.credentials.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.module.SimpleModule; +import com.google.inject.Inject; +import io.airlift.http.client.FullJsonResponseHandler.JsonResponse; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.HttpStatus; +import io.airlift.http.client.Request; +import io.airlift.json.JsonCodec; +import io.airlift.json.JsonCodecFactory; +import io.trino.aws.proxy.spi.credentials.Credentials; +import io.trino.aws.proxy.spi.credentials.CredentialsProvider; +import io.trino.aws.proxy.spi.credentials.Identity; +import jakarta.ws.rs.core.UriBuilder; + +import java.net.URI; +import java.util.Optional; + +import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; +import static io.airlift.http.client.Request.Builder.prepareGet; +import static java.util.Objects.requireNonNull; + +public class HttpCredentialsProvider + implements CredentialsProvider +{ + private final HttpClient httpClient; + private final JsonCodec jsonCodec; + private final URI httpCredentialsProviderEndpoint; + + @Inject + public HttpCredentialsProvider(@ForHttpCredentialsProvider HttpClient httpClient, HttpCredentialsProviderConfig config, ObjectMapper objectMapper, Class identityClass) + { + requireNonNull(objectMapper, "objectMapper is null"); + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.httpCredentialsProviderEndpoint = config.getEndpoint(); + ObjectMapper adjustedObjectMapper = objectMapper.registerModule(new SimpleModule().addAbstractTypeMapping(Identity.class, identityClass)); + this.jsonCodec = new JsonCodecFactory(() -> adjustedObjectMapper).jsonCodec(Credentials.class); + } + + @Override + public Optional credentials(String emulatedAccessKey, Optional session) + { + UriBuilder uriBuilder = UriBuilder.fromUri(httpCredentialsProviderEndpoint).path(emulatedAccessKey); + session.ifPresent(sessionToken -> uriBuilder.queryParam("sessionToken", sessionToken)); + Request request = prepareGet() + .setUri(uriBuilder.build()) + .build(); + JsonResponse response = httpClient.execute(request, createFullJsonResponseHandler(jsonCodec)); + if (response.getStatusCode() == HttpStatus.NOT_FOUND.code() || !response.hasValue()) { + return Optional.empty(); + } + return Optional.of(response.getValue()); + } +} diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java new file mode 100644 index 00000000..1188b61a --- /dev/null +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/credentials/http/HttpCredentialsProviderConfig.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.credentials.http; + +import io.airlift.configuration.Config; +import jakarta.validation.constraints.NotNull; + +import java.net.URI; + +public class HttpCredentialsProviderConfig +{ + private URI endpoint; + + @NotNull + public URI getEndpoint() + { + return endpoint; + } + + @Config("credentials-provider.http.endpoint") + public HttpCredentialsProviderConfig setEndpoint(String endpoint) + { + this.endpoint = URI.create(endpoint); + return this; + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java new file mode 100644 index 00000000..6788a59f --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProvider.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.credentials.http; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Inject; +import io.airlift.http.server.HttpServerConfig; +import io.airlift.http.server.HttpServerInfo; +import io.airlift.http.server.testing.TestingHttpServer; +import io.airlift.json.ObjectMapperProvider; +import io.airlift.node.NodeInfo; +import io.trino.aws.proxy.server.testing.TestingIdentity; +import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; +import io.trino.aws.proxy.server.testing.harness.BuilderFilter; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import io.trino.aws.proxy.spi.credentials.Credential; +import io.trino.aws.proxy.spi.credentials.Credentials; +import io.trino.aws.proxy.spi.credentials.CredentialsProvider; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Optional; + +import static io.trino.aws.proxy.server.credentials.http.HttpCredentialsModule.HTTP_CREDENTIALS_PROVIDER_IDENTIFIER; +import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.bindIdentityType; +import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +@TrinoAwsProxyTest(filters = TestHttpCredentialsProvider.Filter.class) +public class TestHttpCredentialsProvider +{ + private final CredentialsProvider credentialsProvider; + + public static class Filter + implements BuilderFilter + { + @Override + public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) + { + TestingHttpServer httpCredentialsServer; + try { + httpCredentialsServer = createTestingHttpCredentialsServer(); + httpCredentialsServer.start(); + } + catch (Exception e) { + throw new RuntimeException("Failed to start test http credentials provider server", e); + } + return builder.withoutTestingCredentialsRoleProviders() + .addModule(new HttpCredentialsModule()) + .addModule(binder -> bindIdentityType(binder, TestingIdentity.class)) + .withProperty("credentials-provider.type", HTTP_CREDENTIALS_PROVIDER_IDENTIFIER) + .withProperty("credentials-provider.http.endpoint", httpCredentialsServer.getBaseUrl().toString()); + } + } + + @Inject + public TestHttpCredentialsProvider(CredentialsProvider credentialsProvider) + { + this.credentialsProvider = requireNonNull(credentialsProvider, "credentialsProvider is null"); + } + + @Test + public void testValidCredentialsWithEmptySession() + { + Credential emulated = new Credential("test-emulated-access-key", "test-emulated-secret"); + Credential remote = new Credential("test-remote-access-key", "test-remote-secret"); + Credentials expected = new Credentials(emulated, Optional.of(remote), Optional.empty(), Optional.of(new TestingIdentity("test-username", ImmutableList.of(), "xyzpdq"))); + Optional actual = credentialsProvider.credentials("test-emulated-access-key", Optional.empty()); + assertThat(actual).contains(expected); + } + + @Test + public void testValidCredentialsWithValidSession() + { + Credential emulated = new Credential("test-emulated-access-key", "test-emulated-secret"); + Credential remote = new Credential("test-remote-access-key", "test-remote-secret"); + Credentials expected = new Credentials(emulated, Optional.of(remote), Optional.empty(), Optional.of(new TestingIdentity("test-username", ImmutableList.of(), "xyzpdq"))); + Optional actual = credentialsProvider.credentials("test-emulated-access-key", Optional.of("test-emulated-access-key")); + assertThat(actual).contains(expected); + } + + @Test + public void testInvalidCredentialsWithEmptySession() + { + Optional actual = credentialsProvider.credentials("non-existent-key", Optional.empty()); + assertThat(actual).isEmpty(); + } + + @Test + public void testValidCredentialsWithInvalidSession() + { + Optional actual = credentialsProvider.credentials("test-emulated-access-key", Optional.of("sessionToken-not-equals-accessKey")); + assertThat(actual).isEmpty(); + } + + @Test + public void testInvalidCredentialsWithInvalidSession() + { + Optional actual = credentialsProvider.credentials("non-existent-key", Optional.of("sessionToken-not-equals-accessKey")); + assertThat(actual).isEmpty(); + } + + @Test + public void testIncorrectResponseFromServer() + { + Optional actual = credentialsProvider.credentials("incorrect-response", Optional.empty()); + assertThat(actual).isEmpty(); + } + + private static TestingHttpServer createTestingHttpCredentialsServer() + throws IOException + { + NodeInfo nodeInfo = new NodeInfo("test"); + HttpServerConfig config = new HttpServerConfig().setHttpPort(0); + HttpServerInfo httpServerInfo = new HttpServerInfo(config, nodeInfo); + return new TestingHttpServer(httpServerInfo, nodeInfo, config, new HttpCredentialsServlet(), ImmutableMap.of()); + } + + private static class HttpCredentialsServlet + extends HttpServlet + { + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException + { + Optional sessionToken = Optional.ofNullable(request.getParameter("sessionToken")); + String emulatedAccessKey = request.getPathInfo().substring(1); + String credentialsIdentifier = ""; + if (sessionToken.isPresent()) { + // Simulate valid session - When accessKey equals sessionToken + if (emulatedAccessKey.equals(sessionToken.get())) { + credentialsIdentifier = sessionToken.get(); + } + } + else { + credentialsIdentifier = emulatedAccessKey; + } + switch (credentialsIdentifier) { + case "test-emulated-access-key" -> { + Credential emulated = new Credential("test-emulated-access-key", "test-emulated-secret"); + Credential remote = new Credential("test-remote-access-key", "test-remote-secret"); + Credentials credentials = new Credentials(emulated, Optional.of(remote), Optional.empty(), Optional.of(new TestingIdentity("test-username", ImmutableList.of(), "xyzpdq"))); + String jsonCredentials = new ObjectMapperProvider().get().writeValueAsString(credentials); + response.setContentType(APPLICATION_JSON); + response.getWriter().print(jsonCredentials); + } + case "incorrect-response" -> { + response.getWriter().print("incorrect response"); + } + default -> response.setStatus(HttpServletResponse.SC_NOT_FOUND); + } + } + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java new file mode 100644 index 00000000..a07d6079 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/http/TestHttpCredentialsProviderConfig.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.credentials.http; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; + +public class TestHttpCredentialsProviderConfig +{ + @Test + public void testExplicitPropertyMappings() + throws IOException + { + Map properties = ImmutableMap.of( + "credentials-provider.http.endpoint", "http://usersvc:9000/api/v1/users"); + HttpCredentialsProviderConfig expected = new HttpCredentialsProviderConfig() + .setEndpoint("http://usersvc:9000/api/v1/users"); + assertFullMapping(properties, expected); + } +}