diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java index 9f768bbc..dfe852e4 100644 --- a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java @@ -24,6 +24,7 @@ import io.trino.aws.proxy.spi.plugin.config.AssumedRoleProviderConfig; import io.trino.aws.proxy.spi.plugin.config.CredentialsProviderConfig; import io.trino.aws.proxy.spi.plugin.config.PluginIdentifierConfig; +import io.trino.aws.proxy.spi.plugin.config.S3RequestRewriterConfig; import io.trino.aws.proxy.spi.plugin.config.S3SecurityFacadeProviderConfig; import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import io.trino.aws.proxy.spi.security.S3SecurityFacadeProvider; @@ -51,6 +52,11 @@ static Module s3SecurityFacadeProviderModule(String identifier, Class implementationClass, Module module) + { + return optionalPluginModule(S3RequestRewriterConfig.class, identifier, S3RequestRewriter.class, implementationClass, module); + } + static void bindIdentityType(Binder binder, Class type) { newOptionalBinder(binder, new TypeLiteral>() {}).setBinding().toProvider(() -> { @@ -59,12 +65,6 @@ static void bindIdentityType(Binder binder, Class type) }); } - static void bindS3RequestRewriter(Binder binder, Class type) - { - log.info("Using %s request rewriter", type.getSimpleName()); - newOptionalBinder(binder, S3RequestRewriter.class).setBinding().to(type); - } - static Module optionalPluginModule( Class configClass, String identifier, diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/config/S3RequestRewriterConfig.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/config/S3RequestRewriterConfig.java new file mode 100644 index 00000000..01e6c6c9 --- /dev/null +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/config/S3RequestRewriterConfig.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.spi.plugin.config; + +import io.airlift.configuration.Config; + +import java.util.Optional; + +public class S3RequestRewriterConfig + implements PluginIdentifierConfig +{ + private Optional identifier = Optional.empty(); + + @Override + public Optional getPluginIdentifier() + { + return identifier; + } + + @Config("s3-request-rewriter.type") + public void setPluginIdentifier(String identifier) + { + this.identifier = Optional.ofNullable(identifier); + } +} diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java index b1f0a4e9..477731e6 100644 --- a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java @@ -21,6 +21,8 @@ public interface S3RequestRewriter { + S3RequestRewriter NOOP = (_, _) -> Optional.empty(); + record S3RewriteResult(String finalRequestBucket, String finalRequestKey) { public S3RewriteResult { @@ -29,8 +31,5 @@ record S3RewriteResult(String finalRequestBucket, String finalRequestKey) } } - default Optional rewrite(Credentials credentials, ParsedS3Request request) - { - return Optional.empty(); - } + Optional rewrite(Credentials credentials, ParsedS3Request request); } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java index f5ad7a85..b0c0b3e8 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java @@ -53,6 +53,7 @@ import io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerPlugin; import io.trino.aws.proxy.spi.plugin.config.AssumedRoleProviderConfig; import io.trino.aws.proxy.spi.plugin.config.CredentialsProviderConfig; +import io.trino.aws.proxy.spi.plugin.config.S3RequestRewriterConfig; import io.trino.aws.proxy.spi.plugin.config.S3SecurityFacadeProviderConfig; import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import io.trino.aws.proxy.spi.security.S3SecurityFacadeProvider; @@ -123,7 +124,11 @@ protected void setup(Binder binder) newSetBinder(binder, com.fasterxml.jackson.databind.Module.class).addBinding().toProvider(JsonIdentityProvider.class).in(Scopes.SINGLETON); // RequestRewriter binder - newOptionalBinder(binder, S3RequestRewriter.class); + configBinder(binder).bindConfig(S3RequestRewriterConfig.class); + newOptionalBinder(binder, S3RequestRewriter.class).setDefault().toProvider(() -> { + log.info("Using default %s NOOP implementation", S3RequestRewriter.class.getSimpleName()); + return S3RequestRewriter.NOOP; + }); // provided implementations install(new FileBasedCredentialsModule()); diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java index 251b65c8..92f5325c 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java @@ -26,6 +26,7 @@ import io.trino.aws.proxy.spi.rest.ParsedS3Request; import io.trino.aws.proxy.spi.rest.RequestContent; import io.trino.aws.proxy.spi.rest.S3RequestRewriter; +import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult; import io.trino.aws.proxy.spi.security.SecurityResponse; import io.trino.aws.proxy.spi.security.SecurityResponse.Failure; import io.trino.aws.proxy.spi.signing.SigningContext; @@ -70,7 +71,7 @@ public class TrinoS3ProxyClient private final S3SecurityController s3SecurityController; private final S3PresignController s3PresignController; private final LimitStreamController limitStreamController; - private final Optional s3RequestRewriter; + private final S3RequestRewriter s3RequestRewriter; private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor(); private final boolean generatePresignedUrlsOnHead; @@ -88,7 +89,7 @@ public TrinoS3ProxyClient( TrinoAwsProxyConfig trinoAwsProxyConfig, S3PresignController s3PresignController, LimitStreamController limitStreamController, - Optional s3RequestRewriter) + S3RequestRewriter s3RequestRewriter) { this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.signingController = requireNonNull(signingController, "signingController is null"); @@ -111,10 +112,10 @@ public void shutDown() public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse, RequestLoggingSession requestLoggingSession) { - Optional rewriteResult = s3RequestRewriter.flatMap(rewriter -> rewriter.rewrite(signingMetadata.credentials(), request)); - String targetBucket = rewriteResult.map(S3RequestRewriter.S3RewriteResult::finalRequestBucket).orElse(request.bucketName()); + Optional rewriteResult = s3RequestRewriter.rewrite(signingMetadata.credentials(), request); + String targetBucket = rewriteResult.map(S3RewriteResult::finalRequestBucket).orElse(request.bucketName()); String targetKey = rewriteResult - .map(S3RequestRewriter.S3RewriteResult::finalRequestKey) + .map(S3RewriteResult::finalRequestKey) .map(SdkHttpUtils::urlEncodeIgnoreSlashes) .orElse(request.rawPath()); URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), targetKey, targetBucket, request.requestAuthorization().region()); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java index 378fed96..2e7c1e0e 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java @@ -14,15 +14,12 @@ package io.trino.aws.proxy.server.testing; import com.google.inject.Inject; -import com.google.inject.Provider; import io.trino.aws.proxy.spi.credentials.Credentials; -import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult; import static java.util.Objects.requireNonNull; public class TestingS3RequestRewriteController - implements Provider { private final TestingS3RequestRewriter s3RequestRewriter; private final Credentials defaultCredentials; @@ -34,12 +31,6 @@ public TestingS3RequestRewriteController(TestingS3RequestRewriter rewriter, @Tes this.defaultCredentials = requireNonNull(defaultCredentials, "defaultCredentials is null"); } - @Override - public S3RequestRewriter get() - { - return s3RequestRewriter; - } - private S3RewriteResult rewriteOrNoop(Credentials credentials, String bucket, String key) { return s3RequestRewriter.testRewrite(credentials, bucket, key).orElseGet(() -> new S3RewriteResult(bucket, key)); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java index 31f0b0e7..bb1c2dd8 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingTrinoAwsProxyServer.java @@ -43,7 +43,6 @@ import io.trino.aws.proxy.server.testing.containers.S3Container; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.spi.credentials.Credentials; -import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import java.io.Closeable; import java.util.Collection; @@ -54,6 +53,7 @@ import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_CREDENTIALS; import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.assumedRoleProviderModule; import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.credentialsProviderModule; +import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.s3RequestRewriterModule; public final class TestingTrinoAwsProxyServer implements Closeable @@ -211,6 +211,11 @@ public Builder withOpaContainer() public TestingTrinoAwsProxyServer buildAndStart() { + addModule(s3RequestRewriterModule("testing", TestingS3RequestRewriter.class, binder -> { + binder.bind(TestingS3RequestRewriteController.class).in(Scopes.SINGLETON); + newOptionalBinder(binder, TestingS3RequestRewriter.class).setDefault().toInstance(TestingS3RequestRewriter.NOP); + })); + withProperty("s3-request-rewriter.type", "testing"); if (addTestingCredentialsRoleProviders) { if (mockS3ContainerAdded) { modules.add(binder -> binder.bind(TestingCredentialsInitializer.class).asEagerSingleton()); @@ -222,11 +227,6 @@ public TestingTrinoAwsProxyServer buildAndStart() withProperty("assumed-role-provider.type", "testing"); modules.add(binder -> binder.bind(Credentials.class).annotatedWith(ForTestingRemoteCredentials.class).toProvider(TestingRemoteCredentialsProvider.class)); - modules.add(binder -> { - binder.bind(TestingS3RequestRewriteController.class).in(Scopes.SINGLETON); - newOptionalBinder(binder, TestingS3RequestRewriter.class).setDefault().toInstance(TestingS3RequestRewriter.NOP); - newOptionalBinder(binder, S3RequestRewriter.class).setBinding().toProvider(TestingS3RequestRewriteController.class).in(Scopes.SINGLETON); - }); } return start(modules.build(), properties.buildKeepingLast());