Skip to content

Commit

Permalink
Make RequestRewriters into their own module
Browse files Browse the repository at this point in the history
  • Loading branch information
vagaerg committed Aug 20, 2024
1 parent 6d15367 commit 7a173c9
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.aws.proxy.spi.plugin.config.AssumedRoleProviderConfig;
import io.trino.aws.proxy.spi.plugin.config.CredentialsProviderConfig;
import io.trino.aws.proxy.spi.plugin.config.PluginIdentifierConfig;
import io.trino.aws.proxy.spi.plugin.config.S3RequestRewriterConfig;
import io.trino.aws.proxy.spi.plugin.config.S3SecurityFacadeProviderConfig;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter;
import io.trino.aws.proxy.spi.security.S3SecurityFacadeProvider;
Expand Down Expand Up @@ -51,6 +52,11 @@ static Module s3SecurityFacadeProviderModule(String identifier, Class<? extends
return optionalPluginModule(S3SecurityFacadeProviderConfig.class, identifier, S3SecurityFacadeProvider.class, implementationClass, module);
}

static Module s3RequestRewriterModule(String identifier, Class<? extends S3RequestRewriter> implementationClass, Module module)
{
return optionalPluginModule(S3RequestRewriterConfig.class, identifier, S3RequestRewriter.class, implementationClass, module);
}

static <T extends Identity> void bindIdentityType(Binder binder, Class<T> type)
{
newOptionalBinder(binder, new TypeLiteral<Class<? extends Identity>>() {}).setBinding().toProvider(() -> {
Expand All @@ -59,12 +65,6 @@ static <T extends Identity> void bindIdentityType(Binder binder, Class<T> type)
});
}

static void bindS3RequestRewriter(Binder binder, Class<S3RequestRewriter> type)
{
log.info("Using %s request rewriter", type.getSimpleName());
newOptionalBinder(binder, S3RequestRewriter.class).setBinding().to(type);
}

static <Implementation> Module optionalPluginModule(
Class<? extends PluginIdentifierConfig> configClass,
String identifier,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.spi.plugin.config;

import io.airlift.configuration.Config;

import java.util.Optional;

public class S3RequestRewriterConfig
implements PluginIdentifierConfig
{
private Optional<String> identifier = Optional.empty();

@Override
public Optional<String> getPluginIdentifier()
{
return identifier;
}

@Config("s3-request-rewriter.type")
public void setPluginIdentifier(String identifier)
{
this.identifier = Optional.ofNullable(identifier);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

public interface S3RequestRewriter
{
S3RequestRewriter NOOP = (_, _) -> Optional.empty();

record S3RewriteResult(String finalRequestBucket, String finalRequestKey)
{
public S3RewriteResult {
Expand All @@ -29,8 +31,5 @@ record S3RewriteResult(String finalRequestBucket, String finalRequestKey)
}
}

default Optional<S3RewriteResult> rewrite(Credentials credentials, ParsedS3Request request)
{
return Optional.empty();
}
Optional<S3RewriteResult> rewrite(Credentials credentials, ParsedS3Request request);
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerPlugin;
import io.trino.aws.proxy.spi.plugin.config.AssumedRoleProviderConfig;
import io.trino.aws.proxy.spi.plugin.config.CredentialsProviderConfig;
import io.trino.aws.proxy.spi.plugin.config.S3RequestRewriterConfig;
import io.trino.aws.proxy.spi.plugin.config.S3SecurityFacadeProviderConfig;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter;
import io.trino.aws.proxy.spi.security.S3SecurityFacadeProvider;
Expand Down Expand Up @@ -123,7 +124,11 @@ protected void setup(Binder binder)
newSetBinder(binder, com.fasterxml.jackson.databind.Module.class).addBinding().toProvider(JsonIdentityProvider.class).in(Scopes.SINGLETON);

// RequestRewriter binder
newOptionalBinder(binder, S3RequestRewriter.class);
configBinder(binder).bindConfig(S3RequestRewriterConfig.class);
newOptionalBinder(binder, S3RequestRewriter.class).setDefault().toProvider(() -> {
log.info("Using default %s NOOP implementation", S3RequestRewriter.class.getSimpleName());
return S3RequestRewriter.NOOP;
});

// provided implementations
install(new FileBasedCredentialsModule());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.aws.proxy.spi.rest.ParsedS3Request;
import io.trino.aws.proxy.spi.rest.RequestContent;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult;
import io.trino.aws.proxy.spi.security.SecurityResponse;
import io.trino.aws.proxy.spi.security.SecurityResponse.Failure;
import io.trino.aws.proxy.spi.signing.SigningContext;
Expand Down Expand Up @@ -70,7 +71,7 @@ public class TrinoS3ProxyClient
private final S3SecurityController s3SecurityController;
private final S3PresignController s3PresignController;
private final LimitStreamController limitStreamController;
private final Optional<S3RequestRewriter> s3RequestRewriter;
private final S3RequestRewriter s3RequestRewriter;
private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor();
private final boolean generatePresignedUrlsOnHead;

Expand All @@ -88,7 +89,7 @@ public TrinoS3ProxyClient(
TrinoAwsProxyConfig trinoAwsProxyConfig,
S3PresignController s3PresignController,
LimitStreamController limitStreamController,
Optional<S3RequestRewriter> s3RequestRewriter)
S3RequestRewriter s3RequestRewriter)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.signingController = requireNonNull(signingController, "signingController is null");
Expand All @@ -111,10 +112,10 @@ public void shutDown()

public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse, RequestLoggingSession requestLoggingSession)
{
Optional<S3RequestRewriter.S3RewriteResult> rewriteResult = s3RequestRewriter.flatMap(rewriter -> rewriter.rewrite(signingMetadata.credentials(), request));
String targetBucket = rewriteResult.map(S3RequestRewriter.S3RewriteResult::finalRequestBucket).orElse(request.bucketName());
Optional<S3RewriteResult> rewriteResult = s3RequestRewriter.rewrite(signingMetadata.credentials(), request);
String targetBucket = rewriteResult.map(S3RewriteResult::finalRequestBucket).orElse(request.bucketName());
String targetKey = rewriteResult
.map(S3RequestRewriter.S3RewriteResult::finalRequestKey)
.map(S3RewriteResult::finalRequestKey)
.map(SdkHttpUtils::urlEncodeIgnoreSlashes)
.orElse(request.rawPath());
URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), targetKey, targetBucket, request.requestAuthorization().region());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,12 @@
package io.trino.aws.proxy.server.testing;

import com.google.inject.Inject;
import com.google.inject.Provider;
import io.trino.aws.proxy.spi.credentials.Credentials;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult;

import static java.util.Objects.requireNonNull;

public class TestingS3RequestRewriteController
implements Provider<S3RequestRewriter>
{
private final TestingS3RequestRewriter s3RequestRewriter;
private final Credentials defaultCredentials;
Expand All @@ -34,12 +31,6 @@ public TestingS3RequestRewriteController(TestingS3RequestRewriter rewriter, @Tes
this.defaultCredentials = requireNonNull(defaultCredentials, "defaultCredentials is null");
}

@Override
public S3RequestRewriter get()
{
return s3RequestRewriter;
}

private S3RewriteResult rewriteOrNoop(Credentials credentials, String bucket, String key)
{
return s3RequestRewriter.testRewrite(credentials, bucket, key).orElseGet(() -> new S3RewriteResult(bucket, key));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import io.trino.aws.proxy.server.testing.containers.S3Container;
import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container;
import io.trino.aws.proxy.spi.credentials.Credentials;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter;

import java.io.Closeable;
import java.util.Collection;
Expand All @@ -54,6 +53,7 @@
import static io.trino.aws.proxy.server.testing.TestingUtil.TESTING_CREDENTIALS;
import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.assumedRoleProviderModule;
import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.credentialsProviderModule;
import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.s3RequestRewriterModule;

public final class TestingTrinoAwsProxyServer
implements Closeable
Expand Down Expand Up @@ -211,6 +211,11 @@ public Builder withOpaContainer()

public TestingTrinoAwsProxyServer buildAndStart()
{
addModule(s3RequestRewriterModule("testing", TestingS3RequestRewriter.class, binder -> {
binder.bind(TestingS3RequestRewriteController.class).in(Scopes.SINGLETON);
newOptionalBinder(binder, TestingS3RequestRewriter.class).setDefault().toInstance(TestingS3RequestRewriter.NOP);
}));
withProperty("s3-request-rewriter.type", "testing");
if (addTestingCredentialsRoleProviders) {
if (mockS3ContainerAdded) {
modules.add(binder -> binder.bind(TestingCredentialsInitializer.class).asEagerSingleton());
Expand All @@ -222,11 +227,6 @@ public TestingTrinoAwsProxyServer buildAndStart()
withProperty("assumed-role-provider.type", "testing");

modules.add(binder -> binder.bind(Credentials.class).annotatedWith(ForTestingRemoteCredentials.class).toProvider(TestingRemoteCredentialsProvider.class));
modules.add(binder -> {
binder.bind(TestingS3RequestRewriteController.class).in(Scopes.SINGLETON);
newOptionalBinder(binder, TestingS3RequestRewriter.class).setDefault().toInstance(TestingS3RequestRewriter.NOP);
newOptionalBinder(binder, S3RequestRewriter.class).setBinding().toProvider(TestingS3RequestRewriteController.class).in(Scopes.SINGLETON);
});
}

return start(modules.build(), properties.buildKeepingLast());
Expand Down

0 comments on commit 7a173c9

Please sign in to comment.