diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java index f8f8bf1b..dfe852e4 100644 --- a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/TrinoAwsProxyServerBinding.java @@ -24,7 +24,9 @@ import io.trino.aws.proxy.spi.plugin.config.AssumedRoleProviderConfig; import io.trino.aws.proxy.spi.plugin.config.CredentialsProviderConfig; import io.trino.aws.proxy.spi.plugin.config.PluginIdentifierConfig; +import io.trino.aws.proxy.spi.plugin.config.S3RequestRewriterConfig; import io.trino.aws.proxy.spi.plugin.config.S3SecurityFacadeProviderConfig; +import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import io.trino.aws.proxy.spi.security.S3SecurityFacadeProvider; import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; @@ -50,6 +52,11 @@ static Module s3SecurityFacadeProviderModule(String identifier, Class implementationClass, Module module) + { + return optionalPluginModule(S3RequestRewriterConfig.class, identifier, S3RequestRewriter.class, implementationClass, module); + } + static void bindIdentityType(Binder binder, Class type) { newOptionalBinder(binder, new TypeLiteral>() {}).setBinding().toProvider(() -> { diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/config/S3RequestRewriterConfig.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/config/S3RequestRewriterConfig.java new file mode 100644 index 00000000..01e6c6c9 --- /dev/null +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/plugin/config/S3RequestRewriterConfig.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.spi.plugin.config; + +import io.airlift.configuration.Config; + +import java.util.Optional; + +public class S3RequestRewriterConfig + implements PluginIdentifierConfig +{ + private Optional identifier = Optional.empty(); + + @Override + public Optional getPluginIdentifier() + { + return identifier; + } + + @Config("s3-request-rewriter.type") + public void setPluginIdentifier(String identifier) + { + this.identifier = Optional.ofNullable(identifier); + } +} diff --git a/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java new file mode 100644 index 00000000..477731e6 --- /dev/null +++ b/trino-aws-proxy-spi/src/main/java/io/trino/aws/proxy/spi/rest/S3RequestRewriter.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.spi.rest; + +import io.trino.aws.proxy.spi.credentials.Credentials; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public interface S3RequestRewriter +{ + S3RequestRewriter NOOP = (_, _) -> Optional.empty(); + + record S3RewriteResult(String finalRequestBucket, String finalRequestKey) + { + public S3RewriteResult { + requireNonNull(finalRequestBucket, "finalRequestBucket is null"); + requireNonNull(finalRequestKey, "finalRequestKey is null"); + } + } + + Optional rewrite(Credentials credentials, ParsedS3Request request); +} diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java index 622a15d1..b0c0b3e8 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java @@ -53,7 +53,9 @@ import io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerPlugin; import io.trino.aws.proxy.spi.plugin.config.AssumedRoleProviderConfig; import io.trino.aws.proxy.spi.plugin.config.CredentialsProviderConfig; +import io.trino.aws.proxy.spi.plugin.config.S3RequestRewriterConfig; import io.trino.aws.proxy.spi.plugin.config.S3SecurityFacadeProviderConfig; +import io.trino.aws.proxy.spi.rest.S3RequestRewriter; import io.trino.aws.proxy.spi.security.S3SecurityFacadeProvider; import io.trino.aws.proxy.spi.signing.SigningServiceType; import org.glassfish.jersey.server.model.Resource; @@ -121,6 +123,13 @@ protected void setup(Binder binder) }); newSetBinder(binder, com.fasterxml.jackson.databind.Module.class).addBinding().toProvider(JsonIdentityProvider.class).in(Scopes.SINGLETON); + // RequestRewriter binder + configBinder(binder).bindConfig(S3RequestRewriterConfig.class); + newOptionalBinder(binder, S3RequestRewriter.class).setDefault().toProvider(() -> { + log.info("Using default %s NOOP implementation", S3RequestRewriter.class.getSimpleName()); + return S3RequestRewriter.NOOP; + }); + // provided implementations install(new FileBasedCredentialsModule()); install(new OpaS3SecurityModule()); diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java index 531068bc..92f5325c 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java @@ -25,6 +25,8 @@ import io.trino.aws.proxy.spi.credentials.Credentials; import io.trino.aws.proxy.spi.rest.ParsedS3Request; import io.trino.aws.proxy.spi.rest.RequestContent; +import io.trino.aws.proxy.spi.rest.S3RequestRewriter; +import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult; import io.trino.aws.proxy.spi.security.SecurityResponse; import io.trino.aws.proxy.spi.security.SecurityResponse.Failure; import io.trino.aws.proxy.spi.signing.SigningContext; @@ -38,6 +40,7 @@ import jakarta.ws.rs.container.AsyncResponse; import jakarta.ws.rs.core.Response; import jakarta.ws.rs.core.UriBuilder; +import software.amazon.awssdk.utils.http.SdkHttpUtils; import java.io.InputStream; import java.lang.annotation.Retention; @@ -68,6 +71,7 @@ public class TrinoS3ProxyClient private final S3SecurityController s3SecurityController; private final S3PresignController s3PresignController; private final LimitStreamController limitStreamController; + private final S3RequestRewriter s3RequestRewriter; private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor(); private final boolean generatePresignedUrlsOnHead; @@ -84,7 +88,8 @@ public TrinoS3ProxyClient( S3SecurityController s3SecurityController, TrinoAwsProxyConfig trinoAwsProxyConfig, S3PresignController s3PresignController, - LimitStreamController limitStreamController) + LimitStreamController limitStreamController, + S3RequestRewriter s3RequestRewriter) { this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.signingController = requireNonNull(signingController, "signingController is null"); @@ -92,6 +97,7 @@ public TrinoS3ProxyClient( this.s3SecurityController = requireNonNull(s3SecurityController, "securityController is null"); this.s3PresignController = requireNonNull(s3PresignController, "presignController is null"); this.limitStreamController = requireNonNull(limitStreamController, "quotaStreamController is null"); + this.s3RequestRewriter = requireNonNull(s3RequestRewriter, "s3RequestRewriter is null"); generatePresignedUrlsOnHead = trinoAwsProxyConfig.isGeneratePresignedUrlsOnHead(); } @@ -106,7 +112,13 @@ public void shutDown() public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request request, AsyncResponse asyncResponse, RequestLoggingSession requestLoggingSession) { - URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), request.rawPath(), request.bucketName(), request.requestAuthorization().region()); + Optional rewriteResult = s3RequestRewriter.rewrite(signingMetadata.credentials(), request); + String targetBucket = rewriteResult.map(S3RewriteResult::finalRequestBucket).orElse(request.bucketName()); + String targetKey = rewriteResult + .map(S3RewriteResult::finalRequestKey) + .map(SdkHttpUtils::urlEncodeIgnoreSlashes) + .orElse(request.rawPath()); + URI remoteUri = remoteS3Facade.buildEndpoint(uriBuilder(request.queryParameters()), targetKey, targetBucket, request.requestAuthorization().region()); SecurityResponse securityResponse = s3SecurityController.apply(request, signingMetadata.credentials().identity()); if (securityResponse instanceof Failure(var error)) { diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestPresignedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestPresignedRequests.java new file mode 100644 index 00000000..9e784bd9 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestPresignedRequests.java @@ -0,0 +1,347 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.node.TextNode; +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.Request; +import io.airlift.http.client.ResponseHandler; +import io.airlift.http.client.StaticBodyGenerator; +import io.airlift.http.client.StatusResponseHandler.StatusResponse; +import io.airlift.http.client.StringResponseHandler.StringResponse; +import io.airlift.http.server.testing.TestingHttpServer; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; +import io.trino.aws.proxy.server.testing.TestingUtil; +import io.trino.aws.proxy.spi.credentials.Credential; +import io.trino.aws.proxy.spi.credentials.Credentials; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.UploadPartRequest; +import software.amazon.awssdk.services.s3.presigner.S3Presigner; +import software.amazon.awssdk.services.s3.presigner.model.CompleteMultipartUploadPresignRequest; +import software.amazon.awssdk.services.s3.presigner.model.CreateMultipartUploadPresignRequest; +import software.amazon.awssdk.services.s3.presigner.model.DeleteObjectPresignRequest; +import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedCompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedCreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedDeleteObjectRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedGetObjectRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedPutObjectRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedUploadPartRequest; +import software.amazon.awssdk.services.s3.presigner.model.PutObjectPresignRequest; +import software.amazon.awssdk.services.s3.presigner.model.UploadPartPresignRequest; + +import java.io.IOException; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.Optional; + +import static io.airlift.http.client.Request.Builder.prepareDelete; +import static io.airlift.http.client.Request.Builder.prepareGet; +import static io.airlift.http.client.Request.Builder.prepareHead; +import static io.airlift.http.client.Request.Builder.preparePost; +import static io.airlift.http.client.Request.Builder.preparePut; +import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; +import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler; +import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; +import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage; +import static io.trino.aws.proxy.server.testing.TestingUtil.listFilesInS3Bucket; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class AbstractTestPresignedRequests +{ + private final HttpClient httpClient; + private final S3Client internalClient; + private final S3Client storageClient; + private final Credentials testingCredentials; + private final URI s3ProxyUrl; + private final XmlMapper xmlMapper; + private final TestingS3RequestRewriteController requestRewriteController; + + private static final Duration TEST_SIGNATURE_DURATION = Duration.ofMinutes(10); + + protected AbstractTestPresignedRequests( + HttpClient httpClient, + S3Client internalClient, + S3Client storageClient, + Credentials testingCredentials, + TestingHttpServer httpServer, + TrinoAwsProxyConfig s3ProxyConfig, + XmlMapper xmlMapper, + TestingS3RequestRewriteController requestRewriteController) + { + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.internalClient = requireNonNull(internalClient, "internalClient is null"); + this.storageClient = requireNonNull(storageClient, "storageClient is null"); + this.testingCredentials = requireNonNull(testingCredentials, "testingCredentials is null"); + this.s3ProxyUrl = httpServer.getBaseUrl().resolve(s3ProxyConfig.getS3Path()); + this.xmlMapper = requireNonNull(xmlMapper, "xmlMapper is null"); + this.requestRewriteController = requireNonNull(requestRewriteController, "requestRewriteController is null"); + } + + @Test + public void testPresignedGet() + throws IOException + { + String bucketName = "one"; + String key = "presignedGet"; + uploadFileToStorage(requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key), TEST_FILE); + + try (S3Presigner presigner = buildPresigner()) { + GetObjectRequest objectRequest = GetObjectRequest.builder() + .bucket(bucketName) + .key(key) + .build(); + + GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder() + .signatureDuration(TEST_SIGNATURE_DURATION) + .getObjectRequest(objectRequest) + .build(); + + PresignedGetObjectRequest presignedRequest = presigner.presignGetObject(presignRequest); + StringResponse response = executeHttpRequest(presignedRequest.httpRequest(), createStringResponseHandler()); + assertThat(response.getStatusCode()).isEqualTo(200); + assertThat(response.getBody()).isEqualTo(Files.readString(TEST_FILE)); + } + } + + @Test + public void testPresignedPut() + throws IOException + { + String bucketName = "two"; + String key = "presignedPut"; + String fileContents = Files.readString(TEST_FILE, StandardCharsets.UTF_8); + + try (S3Presigner presigner = buildPresigner()) { + PutObjectRequest putObjectRequest = PutObjectRequest.builder() + .bucket(bucketName) + .key(key) + .contentEncoding("gzip") + .contentType("text/plain;charset=UTF-8") + .build(); + PutObjectPresignRequest presignRequest = PutObjectPresignRequest.builder() + .signatureDuration(TEST_SIGNATURE_DURATION) + .putObjectRequest(putObjectRequest) + .build(); + PresignedPutObjectRequest presignedRequest = presigner.presignPutObject(presignRequest); + + StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), fileContents, createStatusResponseHandler()); + assertThat(response.getStatusCode()).isEqualTo(200); + } + + assertThat(getFileFromStorage(bucketName, key)).isEqualTo(fileContents); + HeadObjectResponse headObjectResponse = headObjectInStorage(storageClient, requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key)); + assertThat(headObjectResponse.contentType()).isEqualTo("text/plain;charset=UTF-8"); + assertThat(headObjectResponse.contentEncoding()).isEqualTo("gzip"); + } + + @Test + public void testPresignedDelete() + { + String bucketName = "three"; + String key = "fileToDelete"; + uploadFileToStorage(requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key), TEST_FILE); + + try (S3Presigner presigner = buildPresigner()) { + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucketName).key(key).build(); + DeleteObjectPresignRequest presignRequest = DeleteObjectPresignRequest.builder() + .signatureDuration(TEST_SIGNATURE_DURATION) + .deleteObjectRequest(deleteObjectRequest) + .build(); + PresignedDeleteObjectRequest presignedRequest = presigner.presignDeleteObject(presignRequest); + + StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), createStatusResponseHandler()); + assertThat(response.getStatusCode()).isEqualTo(204); + } + + assertThat(listFilesInS3Bucket(storageClient, requestRewriteController.getTargetBucket(bucketName, key))).isEmpty(); + } + + @Test + public void testExpiredSignature() + throws InterruptedException, IOException + { + String bucketName = "three"; + String key = "fileToDeleteExpired"; + uploadFileToStorage(requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key), TEST_FILE); + + try (S3Presigner presigner = buildPresigner()) { + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucketName).key(key).build(); + DeleteObjectPresignRequest presignRequest = DeleteObjectPresignRequest.builder() + .signatureDuration(Duration.ofSeconds(1)) + .deleteObjectRequest(deleteObjectRequest) + .build(); + PresignedDeleteObjectRequest presignedRequest = presigner.presignDeleteObject(presignRequest); + + Thread.sleep(Duration.ofSeconds(2)); + + StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), createStatusResponseHandler()); + assertThat(response.getStatusCode()).isEqualTo(401); + } + + assertThat(getFileFromStorage(bucketName, key)).isEqualTo(Files.readString(TEST_FILE)); + } + + @Test + public void testMultipart() + throws IOException + { + String bucketName = "one"; + String key = "multipart-upload"; + String dummyPayload = "foo bar baz"; + try (S3Presigner presigner = buildPresigner()) { + CreateMultipartUploadRequest createMultipartUploadRequest = CreateMultipartUploadRequest.builder() + .bucket(bucketName) + .contentType("text/plain;charset=UTF-8") + .contentEncoding("gzip") + .metadata(ImmutableMap.of("some-metadata-key", "some-metadata-value")) + .key(key) + .build(); + CreateMultipartUploadPresignRequest presignCreateMultipartUploadRequest = CreateMultipartUploadPresignRequest.builder() + .signatureDuration(TEST_SIGNATURE_DURATION) + .createMultipartUploadRequest(createMultipartUploadRequest) + .build(); + PresignedCreateMultipartUploadRequest presignedCreateMultipartUploadRequest = presigner.presignCreateMultipartUpload(presignCreateMultipartUploadRequest); + StringResponse startMultipartResponse = executeHttpRequest(presignedCreateMultipartUploadRequest.httpRequest(), createStringResponseHandler()); + assertThat(startMultipartResponse.getStatusCode()).isEqualTo(200); + + String uploadId; + try (JsonParser objectMapper = xmlMapper.createParser(startMultipartResponse.getBody())) { + uploadId = ((TextNode) objectMapper.readValueAsTree().get("UploadId")).textValue(); + } + + UploadPartRequest uploadPartRequest = UploadPartRequest.builder() + .bucket(bucketName) + .key(key) + .uploadId(uploadId) + .partNumber(1) + .build(); + UploadPartPresignRequest presignUploadPartRequest = UploadPartPresignRequest.builder() + .signatureDuration(TEST_SIGNATURE_DURATION) + .uploadPartRequest(uploadPartRequest) + .build(); + PresignedUploadPartRequest presignedUploadPartRequest = presigner.presignUploadPart(presignUploadPartRequest); + StatusResponse uploadPartResponse = executeHttpRequest(presignedUploadPartRequest.httpRequest(), dummyPayload, createStatusResponseHandler()); + assertThat(uploadPartResponse.getStatusCode()).isEqualTo(200); + + String eTag = uploadPartResponse.getHeader("etag"); + + // If we provide a body for this request here, the AWS SDK will sign the contents even though it should not + // That results in Minio rejecting the request + CompleteMultipartUploadRequest.Builder completeMultipartUploadRequestBuilder = CompleteMultipartUploadRequest.builder() + .bucket(bucketName) + .key(key) + .uploadId(uploadId); + + // This is the signature for the request without a signed payload, just like all other presigned requests + PresignedCompleteMultipartUploadRequest presignedCompleteMultipartUploadRequest = presignCompleteMultipartUpload(presigner, completeMultipartUploadRequestBuilder); + + String completeMultipartUploadPayload = presignCompleteMultipartUpload( + presigner, + completeMultipartUploadRequestBuilder.multipartUpload(CompletedMultipartUpload.builder().parts(ImmutableList.of(CompletedPart.builder().partNumber(1).eTag(eTag).build())).build())) + .signedPayload().orElseThrow().asUtf8String(); + + StatusResponse completeMultipartResponse = executeHttpRequest(presignedCompleteMultipartUploadRequest.httpRequest(), completeMultipartUploadPayload, createStatusResponseHandler()); + assertThat(completeMultipartResponse.getStatusCode()).isEqualTo(200); + } + assertThat(getFileFromStorage(bucketName, key)).isEqualTo(dummyPayload); + HeadObjectResponse headResult = TestingUtil.headObjectInStorage(storageClient, requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key)); + assertThat(headResult.contentEncoding()).isEqualTo("gzip"); + assertThat(headResult.contentType()).isEqualTo("text/plain;charset=UTF-8"); + assertThat(headResult.metadata()).containsEntry("some-metadata-key", "some-metadata-value"); + } + + private PresignedCompleteMultipartUploadRequest presignCompleteMultipartUpload(S3Presigner presigner, CompleteMultipartUploadRequest.Builder completeMultipartUploadRequestBuilder) + { + return presigner.presignCompleteMultipartUpload(CompleteMultipartUploadPresignRequest.builder() + .completeMultipartUploadRequest(completeMultipartUploadRequestBuilder.build()) + .signatureDuration(TEST_SIGNATURE_DURATION) + .build()); + } + + void uploadFileToStorage(String bucketName, String key, Path filePath) + { + PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(bucketName).key(key).build(); + PutObjectResponse putObjectResponse = storageClient.putObject(putObjectRequest, filePath); + assertThat(putObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(200); + } + + String getFileFromStorage(String bucketName, String key) + throws IOException + { + String dataFromProxy = TestingUtil.getFileFromStorage(internalClient, bucketName, key); + String dataFromStorage = TestingUtil.getFileFromStorage(storageClient, requestRewriteController.getTargetBucket(bucketName, key), requestRewriteController.getTargetKey(bucketName, key)); + assertThat(dataFromProxy).isEqualTo(dataFromStorage); + return dataFromStorage; + } + + T executeHttpRequest(SdkHttpRequest sdkRequest, ResponseHandler responseHandler) + { + return executeHttpRequest(sdkRequest, Optional.empty(), responseHandler); + } + + T executeHttpRequest(SdkHttpRequest sdkRequest, String body, ResponseHandler responseHandler) + { + return executeHttpRequest(sdkRequest, Optional.of(body), responseHandler); + } + + T executeHttpRequest(SdkHttpRequest sdkRequest, Optional body, ResponseHandler responseHandler) + { + Request.Builder requestBuilder = switch (sdkRequest.method()) { + case POST -> preparePost(); + case PUT -> preparePut(); + case GET -> prepareGet(); + case HEAD -> prepareHead(); + case DELETE -> prepareDelete(); + default -> throw new IllegalStateException("Unexpected HTTP method"); + }; + requestBuilder.setUri(sdkRequest.getUri()); + body.ifPresent(actualBody -> requestBuilder.setBodyGenerator(StaticBodyGenerator.createStaticBodyGenerator(actualBody, StandardCharsets.UTF_8))); + sdkRequest.forEachHeader((headerName, headerValues) -> headerValues.forEach(headerValue -> requestBuilder.addHeader(headerName, headerValue))); + return httpClient.execute(requestBuilder.build(), responseHandler); + } + + S3Presigner buildPresigner() + { + return buildPresigner(testingCredentials.emulated()); + } + + S3Presigner buildPresigner(Credential credential) + { + AwsBasicCredentials proxyCredentials = AwsBasicCredentials.create(credential.accessKey(), credential.secretKey()); + return S3Presigner.builder().region(Region.US_EAST_1).endpointOverride(s3ProxyUrl).credentialsProvider(StaticCredentialsProvider.create(proxyCredentials)).build(); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java index f1d7cb41..990f1e54 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/AbstractTestProxiedRequests.java @@ -13,7 +13,9 @@ */ package io.trino.aws.proxy.server; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingUtil; import jakarta.annotation.PreDestroy; import org.junit.jupiter.api.AfterEach; @@ -25,21 +27,18 @@ import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateBucketResponse; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; import software.amazon.awssdk.services.s3.model.DeleteObjectResponse; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import software.amazon.awssdk.services.s3.model.ListBucketsResponse; -import software.amazon.awssdk.services.s3.model.ListObjectsResponse; import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; -import software.amazon.awssdk.services.s3.model.S3Object; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.file.Files; import java.time.Duration; @@ -55,22 +54,24 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination; import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; +import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage; +import static io.trino.aws.proxy.server.testing.TestingUtil.listFilesInS3Bucket; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; public abstract class AbstractTestProxiedRequests { - private final S3Client internalClient; - private final S3Client remoteClient; - private final List configuredBuckets; + final S3Client internalClient; + final S3Client remoteClient; + private final TestingS3RequestRewriteController requestRewriteController; private final ExecutorService executorService = Executors.newVirtualThreadPerTaskExecutor(); - protected AbstractTestProxiedRequests(S3Client internalClient, S3Client remoteClient, List configuredBuckets) + protected AbstractTestProxiedRequests(S3Client internalClient, S3Client remoteClient, TestingS3RequestRewriteController requestRewriteController) { this.internalClient = requireNonNull(internalClient, "internalClient is null"); this.remoteClient = requireNonNull(remoteClient, "remoteClient is null"); - this.configuredBuckets = requireNonNull(configuredBuckets, "configuredBuckets is null"); + this.requestRewriteController = requireNonNull(requestRewriteController, "requestRewriteController is null"); } @PreDestroy @@ -85,69 +86,71 @@ public void cleanupBuckets() TestingUtil.cleanupBuckets(remoteClient); } + @Test + public void testCreateBucket() + { + String newBucketName = "new-bucket"; + CreateBucketResponse createBucketResponse = internalClient.createBucket(r -> r.bucket(newBucketName)); + assertThat(createBucketResponse.sdkHttpResponse().statusCode()).isEqualTo(200); + + ListBucketsResponse listBucketsResponse = remoteClient.listBuckets(); + assertThat(listBucketsResponse.buckets()).extracting(Bucket::name).contains(requestRewriteController.getTargetBucket(newBucketName, "")); + } + @Test public void testListBuckets() { - ListBucketsResponse listBucketsResponse = internalClient.listBuckets(); - assertThat(listBucketsResponse.buckets()) - .extracting(Bucket::name) - .containsExactlyInAnyOrderElementsOf(configuredBuckets); - - assertThat(configuredBuckets.stream().map(bucketName -> internalClient.listObjects(request -> request.bucket(bucketName)).contents().size())) - .containsOnly(0) - .hasSize(configuredBuckets.size()); + List actualBuckets = remoteClient.listBuckets().buckets().stream().map(Bucket::name).collect(toImmutableList()); + List bucketsReportedByProxy = internalClient.listBuckets().buckets().stream().map(Bucket::name).collect(toImmutableList()); + + assertThat(bucketsReportedByProxy).containsExactlyElementsOf(actualBuckets); } @Test public void testListBucketsWithContents() { - String bucketToTest = configuredBuckets.getFirst(); + String bucketToTest = "one"; String testKey = "some-key"; - assertThat(internalClient.listObjects(request -> request.bucket(bucketToTest)).contents()).isEmpty(); + assertThat(listFilesInS3Bucket(internalClient, bucketToTest)).isEmpty(); - remoteClient.putObject(request -> request.bucket(bucketToTest).key(testKey), RequestBody.fromString("some-contents")); + remoteClient.putObject(request -> request.bucket(requestRewriteController.getTargetBucket(bucketToTest, testKey)).key(requestRewriteController.getTargetKey(bucketToTest, testKey)), RequestBody.fromString("some-contents")); - assertThat(internalClient.listObjects(request -> request.bucket(bucketToTest)).contents()) - .extracting(S3Object::key) - .containsOnly(testKey); + assertThat(listFilesInS3Bucket(internalClient, bucketToTest)).containsExactlyInAnyOrder(requestRewriteController.getTargetKey(bucketToTest, testKey)); } @Test public void testUploadAndDelete() throws IOException { - PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket("two").key("test").build(); + String bucket = "two"; + String key = "test"; + PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(bucket).key(key).build(); PutObjectResponse putObjectResponse = internalClient.putObject(putObjectRequest, TEST_FILE); assertThat(putObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(200); - GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket("two").key("test").build(); - ByteArrayOutputStream readContents = new ByteArrayOutputStream(); - internalClient.getObject(getObjectRequest).transferTo(readContents); - String expectedContents = Files.readString(TEST_FILE); + assertThat(getFileFromStorage(internalClient, bucket, key)).isEqualTo(expectedContents); + assertThat(getFileFromStorage(remoteClient, requestRewriteController.getTargetBucket(bucket, key), requestRewriteController.getTargetKey(bucket, key))).isEqualTo(expectedContents); - assertThat(readContents.toString()).isEqualTo(expectedContents); - - ListObjectsResponse listObjectsResponse = internalClient.listObjects(request -> request.bucket("two")); - assertThat(listObjectsResponse.contents()) - .hasSize(1) - .first() - .extracting(S3Object::key, S3Object::size) - .containsExactlyInAnyOrder("test", Files.size(TEST_FILE)); + assertThat(listFilesInS3Bucket(internalClient, bucket)).containsExactlyInAnyOrder(requestRewriteController.getTargetKey(bucket, key)); + assertThat(listFilesInS3Bucket(remoteClient, requestRewriteController.getTargetBucket(bucket, key))).containsExactlyInAnyOrder(requestRewriteController.getTargetKey(bucket, key)); - DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket("two").key("test").build(); + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucket).key(key).build(); internalClient.deleteObject(deleteObjectRequest); - listObjectsResponse = internalClient.listObjects(request -> request.bucket("two")); - assertThat(listObjectsResponse.contents()).isEmpty(); + assertThat(listFilesInS3Bucket(internalClient, bucket)).isEmpty(); + assertThat(listFilesInS3Bucket(remoteClient, requestRewriteController.getTargetBucket(bucket, key))).isEmpty(); } @Test public void testUploadWithContentTypeAndMetadata() + throws IOException { + String bucket = "two"; + String key = "testWithMetadata"; PutObjectRequest putObjectRequest = PutObjectRequest.builder() - .bucket("two") - .key("testWithMetadata") + .bucket(bucket) + .key(key) .contentType("text/plain;charset=utf-8") .contentEncoding("gzip,compress") .metadata(ImmutableMap.of("metadata-key", "metadata-value")) @@ -155,7 +158,9 @@ public void testUploadWithContentTypeAndMetadata() PutObjectResponse putObjectResponse = internalClient.putObject(putObjectRequest, TEST_FILE); assertThat(putObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(200); - HeadObjectResponse headObjectResponse = headObjectInStorage(internalClient, "two", "testWithMetadata"); + assertThat(getFileFromStorage(internalClient, bucket, key)).isEqualTo(Files.readString(TEST_FILE)); + assertThat(getFileFromStorage(remoteClient, requestRewriteController.getTargetBucket(bucket, key), requestRewriteController.getTargetKey(bucket, key))).isEqualTo(Files.readString(TEST_FILE)); + HeadObjectResponse headObjectResponse = headObjectInStorage(internalClient, bucket, key); assertThat(headObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(200); assertThat(headObjectResponse.contentType()).isEqualTo("text/plain;charset=utf-8"); @@ -167,7 +172,9 @@ public void testUploadWithContentTypeAndMetadata() public void testMultipartUpload() throws IOException { - CreateMultipartUploadRequest multipartUploadRequest = CreateMultipartUploadRequest.builder().bucket("three").key("multi").build(); + String bucket = "three"; + String key = "multi"; + CreateMultipartUploadRequest multipartUploadRequest = CreateMultipartUploadRequest.builder().bucket(bucket).key(key).build(); CreateMultipartUploadResponse multipartUploadResponse = internalClient.createMultipartUpload(multipartUploadRequest); String uploadId = multipartUploadResponse.uploadId(); @@ -177,7 +184,7 @@ public void testMultipartUpload() List> futures = IntStream.rangeClosed(1, 5) .mapToObj(partNumber -> executorService.submit(() -> { String content = buildLine(partNumber); - UploadPartRequest part = UploadPartRequest.builder().bucket("three").key("multi").uploadId(uploadId).partNumber(partNumber).contentLength((long) content.length()).build(); + UploadPartRequest part = UploadPartRequest.builder().bucket(bucket).key(key).uploadId(uploadId).partNumber(partNumber).contentLength((long) content.length()).build(); UploadPartResponse uploadPartResponse = internalClient.uploadPart(part, RequestBody.fromString(content)); assertThat(uploadPartResponse.sdkHttpResponse().statusCode()).isEqualTo(200); @@ -203,8 +210,8 @@ public void testMultipartUpload() .build(); CompleteMultipartUploadRequest completeRequest = CompleteMultipartUploadRequest.builder() - .bucket("three") - .key("multi") + .bucket(bucket) + .key(key) .uploadId(uploadId) .multipartUpload(completedUpload) .build(); @@ -212,17 +219,14 @@ public void testMultipartUpload() CompleteMultipartUploadResponse completeResponse = internalClient.completeMultipartUpload(completeRequest); assertThat(completeResponse.sdkHttpResponse().statusCode()).isEqualTo(200); - GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket("three").key("multi").build(); - ByteArrayOutputStream readContents = new ByteArrayOutputStream(); - internalClient.getObject(getObjectRequest).transferTo(readContents); - String expected = IntStream.rangeClosed(1, 5) .mapToObj(AbstractTestProxiedRequests::buildLine) .collect(Collectors.joining()); - assertThat(readContents.toString()).isEqualTo(expected); + assertThat(getFileFromStorage(internalClient, bucket, key)).isEqualTo(expected); + assertThat(getFileFromStorage(remoteClient, requestRewriteController.getTargetBucket(bucket, key), requestRewriteController.getTargetKey(bucket, key))).isEqualTo(expected); - DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket("three").key("multi").build(); + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucket).key(key).build(); DeleteObjectResponse deleteObjectResponse = internalClient.deleteObject(deleteObjectRequest); assertThat(deleteObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(204); } @@ -230,16 +234,19 @@ public void testMultipartUpload() @Test public void testPathsNeedingEscaping() { - internalClient.createBucket(r -> r.bucket("escapes")); - internalClient.putObject(r -> r.bucket("escapes").key("a=1/b=2"), RequestBody.fromString("something")); - internalClient.putObject(r -> r.bucket("escapes").key("a=1%2Fb=2"), RequestBody.fromString("else")); - - ListObjectsResponse listObjectsResponse = internalClient.listObjects(request -> request.bucket("escapes")); - assertThat(listObjectsResponse.contents()).extracting(S3Object::key).containsExactlyInAnyOrder("a=1/b=2", "a=1%2Fb=2"); - - internalClient.deleteObject(r -> r.bucket("escapes").key("a=1/b=2")); - internalClient.deleteObject(r -> r.bucket("escapes").key("a=1%2Fb=2")); - internalClient.deleteBucket(r -> r.bucket("escapes")); + String bucket = "escapes"; + remoteClient.createBucket(r -> r.bucket(requestRewriteController.getTargetBucket(bucket, ""))); + internalClient.putObject(r -> r.bucket(bucket).key("a=1/b=2"), RequestBody.fromString("something")); + internalClient.putObject(r -> r.bucket(bucket).key("a=1%2Fb=2"), RequestBody.fromString("else")); + + List expectedKeys = ImmutableList.of(requestRewriteController.getTargetKey(bucket, "a=1/b=2"), requestRewriteController.getTargetKey(bucket, "a=1%2Fb=2")); + assertThat(listFilesInS3Bucket(internalClient, bucket)).containsExactlyInAnyOrderElementsOf(expectedKeys); + assertThat(listFilesInS3Bucket(remoteClient, requestRewriteController.getTargetBucket(bucket, ""))).containsExactlyInAnyOrderElementsOf(expectedKeys); + + internalClient.deleteObject(r -> r.bucket(bucket).key("a=1/b=2")); + internalClient.deleteObject(r -> r.bucket(bucket).key("a=1%2Fb=2")); + assertThat(listFilesInS3Bucket(internalClient, bucket)).isEmpty(); + internalClient.deleteBucket(r -> r.bucket(bucket)); } private static String buildLine(int partNumber) diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java index 190dfe74..1f593486 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java @@ -45,7 +45,6 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.auth.signer.internal.chunkedencoding.AwsS3V4ChunkSigner; import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.DeleteBucketRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import java.io.IOException; @@ -64,7 +63,7 @@ import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; import static io.trino.aws.proxy.server.testing.TestingUtil.LOREM_IPSUM; import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3; -import static io.trino.aws.proxy.server.testing.TestingUtil.cleanupBuckets; +import static io.trino.aws.proxy.server.testing.TestingUtil.deleteAllBuckets; import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage; import static io.trino.aws.proxy.server.testing.TestingUtil.listFilesInS3Bucket; @@ -125,8 +124,7 @@ public TestGenericRestRequests( @AfterEach public void cleanupStorage() { - cleanupBuckets(storageClient); - storageClient.listBuckets().buckets().forEach(bucket -> storageClient.deleteBucket(DeleteBucketRequest.builder().bucket(bucket.name()).build())); + deleteAllBuckets(storageClient); } @Test diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java index bddb79f9..853f6760 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestHttpChunked.java @@ -41,7 +41,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.DeleteBucketRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import java.io.ByteArrayInputStream; @@ -63,7 +62,7 @@ import static io.airlift.http.client.StreamingBodyGenerator.streamingBodyGenerator; import static io.trino.aws.proxy.server.testing.TestingUtil.LOREM_IPSUM; import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3; -import static io.trino.aws.proxy.server.testing.TestingUtil.cleanupBuckets; +import static io.trino.aws.proxy.server.testing.TestingUtil.deleteAllBuckets; import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage; import static io.trino.aws.proxy.server.testing.TestingUtil.sha256; @@ -107,8 +106,7 @@ public TestHttpChunked( @AfterEach public void cleanupStorage() { - cleanupBuckets(storageClient); - storageClient.listBuckets().buckets().forEach(bucket -> storageClient.deleteBucket(DeleteBucketRequest.builder().bucket(bucket.name()).build())); + deleteAllBuckets(storageClient); } private class ForceChunkInputStream diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java index b15b3a38..f7dfaf23 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequests.java @@ -13,89 +13,23 @@ */ package io.trino.aws.proxy.server; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.node.TextNode; import com.fasterxml.jackson.dataformat.xml.XmlMapper; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import io.airlift.http.client.HttpClient; -import io.airlift.http.client.Request; -import io.airlift.http.client.ResponseHandler; -import io.airlift.http.client.StaticBodyGenerator; -import io.airlift.http.client.StatusResponseHandler.StatusResponse; -import io.airlift.http.client.StringResponseHandler.StringResponse; import io.airlift.http.server.testing.TestingHttpServer; -import io.trino.aws.proxy.server.testing.TestingUtil; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithTestingHttpClient; import io.trino.aws.proxy.spi.credentials.Credentials; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; -import software.amazon.awssdk.http.SdkHttpRequest; -import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; -import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.HeadObjectResponse; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; -import software.amazon.awssdk.services.s3.model.PutObjectResponse; -import software.amazon.awssdk.services.s3.model.UploadPartRequest; -import software.amazon.awssdk.services.s3.presigner.S3Presigner; -import software.amazon.awssdk.services.s3.presigner.model.CompleteMultipartUploadPresignRequest; -import software.amazon.awssdk.services.s3.presigner.model.CreateMultipartUploadPresignRequest; -import software.amazon.awssdk.services.s3.presigner.model.DeleteObjectPresignRequest; -import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest; -import software.amazon.awssdk.services.s3.presigner.model.PresignedCompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.presigner.model.PresignedCreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.presigner.model.PresignedDeleteObjectRequest; -import software.amazon.awssdk.services.s3.presigner.model.PresignedGetObjectRequest; -import software.amazon.awssdk.services.s3.presigner.model.PresignedPutObjectRequest; -import software.amazon.awssdk.services.s3.presigner.model.PresignedUploadPartRequest; -import software.amazon.awssdk.services.s3.presigner.model.PutObjectPresignRequest; -import software.amazon.awssdk.services.s3.presigner.model.UploadPartPresignRequest; - -import java.io.IOException; -import java.net.URI; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.time.Duration; -import java.util.Optional; - -import static io.airlift.http.client.Request.Builder.prepareDelete; -import static io.airlift.http.client.Request.Builder.prepareGet; -import static io.airlift.http.client.Request.Builder.prepareHead; -import static io.airlift.http.client.Request.Builder.preparePost; -import static io.airlift.http.client.Request.Builder.preparePut; -import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler; -import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler; -import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; -import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3; -import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage; -import static java.util.Objects.requireNonNull; -import static org.assertj.core.api.Assertions.assertThat; @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithTestingHttpClient.class, TestProxiedRequests.Filter.class}) public class TestPresignedRequests + extends AbstractTestPresignedRequests { - private final HttpClient httpClient; - private final S3Client internalClient; - private final S3Client storageClient; - private final Credentials testingCredentials; - private final URI s3ProxyUrl; - private final XmlMapper xmlMapper; - - private static final Duration TEST_SIGNATURE_DURATION = Duration.ofMinutes(10); - @Inject public TestPresignedRequests( @ForTesting HttpClient httpClient, @@ -104,231 +38,9 @@ public TestPresignedRequests( @ForTesting Credentials testingCredentials, TestingHttpServer httpServer, TrinoAwsProxyConfig s3ProxyConfig, - XmlMapper xmlMapper) - { - this.httpClient = requireNonNull(httpClient, "httpClient is null"); - this.internalClient = requireNonNull(internalClient, "internalClient is null"); - this.storageClient = requireNonNull(storageClient, "storageClient is null"); - this.testingCredentials = requireNonNull(testingCredentials, "testingCredentials is null"); - this.s3ProxyUrl = httpServer.getBaseUrl().resolve(s3ProxyConfig.getS3Path()); - this.xmlMapper = requireNonNull(xmlMapper, "xmlMapper is null"); - } - - @Test - public void testPresignedGet() - throws IOException - { - uploadFileToStorage("one", "presignedGet", TEST_FILE); - - try (S3Presigner presigner = buildPresigner()) { - GetObjectRequest objectRequest = GetObjectRequest.builder() - .bucket("one") - .key("presignedGet") - .build(); - - GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder() - .signatureDuration(TEST_SIGNATURE_DURATION) - .getObjectRequest(objectRequest) - .build(); - - PresignedGetObjectRequest presignedRequest = presigner.presignGetObject(presignRequest); - StringResponse response = executeHttpRequest(presignedRequest.httpRequest(), createStringResponseHandler()); - assertThat(response.getStatusCode()).isEqualTo(200); - assertThat(response.getBody()).isEqualTo(Files.readString(TEST_FILE)); - } - } - - @Test - public void testPresignedPut() - throws IOException - { - String fileContents = Files.readString(TEST_FILE, StandardCharsets.UTF_8); - - try (S3Presigner presigner = buildPresigner()) { - PutObjectRequest putObjectRequest = PutObjectRequest.builder() - .bucket("two") - .key("presignedPut") - .contentEncoding("gzip") - .contentType("text/plain;charset=UTF-8") - .build(); - PutObjectPresignRequest presignRequest = PutObjectPresignRequest.builder() - .signatureDuration(TEST_SIGNATURE_DURATION) - .putObjectRequest(putObjectRequest) - .build(); - PresignedPutObjectRequest presignedRequest = presigner.presignPutObject(presignRequest); - - StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), fileContents, createStatusResponseHandler()); - assertThat(response.getStatusCode()).isEqualTo(200); - } - - assertThat(getFileFromStorage("two", "presignedPut")).isEqualTo(fileContents); - HeadObjectResponse headObjectResponse = headObjectInStorage(storageClient, "two", "presignedPut"); - assertThat(headObjectResponse.contentType()).isEqualTo("text/plain;charset=UTF-8"); - assertThat(headObjectResponse.contentEncoding()).isEqualTo("gzip"); - } - - @Test - public void testPresignedDelete() - { - uploadFileToStorage("three", "fileToDelete", TEST_FILE); - - try (S3Presigner presigner = buildPresigner()) { - DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket("three").key("fileToDelete").build(); - DeleteObjectPresignRequest presignRequest = DeleteObjectPresignRequest.builder() - .signatureDuration(TEST_SIGNATURE_DURATION) - .deleteObjectRequest(deleteObjectRequest) - .build(); - PresignedDeleteObjectRequest presignedRequest = presigner.presignDeleteObject(presignRequest); - - StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), createStatusResponseHandler()); - assertThat(response.getStatusCode()).isEqualTo(204); - } - - assertFileNotInS3(storageClient, "three", "fileToDelete"); - } - - @Test - public void testExpiredSignature() - throws InterruptedException, IOException - { - uploadFileToStorage("three", "fileToDeleteExpired", TEST_FILE); - - try (S3Presigner presigner = buildPresigner()) { - DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket("three").key("fileToDeleteExpired").build(); - DeleteObjectPresignRequest presignRequest = DeleteObjectPresignRequest.builder() - .signatureDuration(Duration.ofSeconds(1)) - .deleteObjectRequest(deleteObjectRequest) - .build(); - PresignedDeleteObjectRequest presignedRequest = presigner.presignDeleteObject(presignRequest); - - Thread.sleep(Duration.ofSeconds(2)); - - StatusResponse response = executeHttpRequest(presignedRequest.httpRequest(), createStatusResponseHandler()); - assertThat(response.getStatusCode()).isEqualTo(401); - } - - assertThat(getFileFromStorage("three", "fileToDeleteExpired")).isEqualTo(Files.readString(TEST_FILE)); - } - - @Test - public void testMultipart() - throws IOException - { - String bucketName = "one"; - String key = "multipart-upload"; - String dummyPayload = "foo bar baz"; - try (S3Presigner presigner = buildPresigner()) { - CreateMultipartUploadRequest createMultipartUploadRequest = CreateMultipartUploadRequest.builder() - .bucket(bucketName) - .contentType("text/plain;charset=UTF-8") - .contentEncoding("gzip") - .metadata(ImmutableMap.of("some-metadata-key", "some-metadata-value")) - .key(key) - .build(); - CreateMultipartUploadPresignRequest presignCreateMultipartUploadRequest = CreateMultipartUploadPresignRequest.builder() - .signatureDuration(TEST_SIGNATURE_DURATION) - .createMultipartUploadRequest(createMultipartUploadRequest) - .build(); - PresignedCreateMultipartUploadRequest presignedCreateMultipartUploadRequest = presigner.presignCreateMultipartUpload(presignCreateMultipartUploadRequest); - StringResponse startMultipartResponse = executeHttpRequest(presignedCreateMultipartUploadRequest.httpRequest(), createStringResponseHandler()); - assertThat(startMultipartResponse.getStatusCode()).isEqualTo(200); - - String uploadId; - try (JsonParser objectMapper = xmlMapper.createParser(startMultipartResponse.getBody())) { - uploadId = ((TextNode) objectMapper.readValueAsTree().get("UploadId")).textValue(); - } - - UploadPartRequest uploadPartRequest = UploadPartRequest.builder() - .bucket(bucketName) - .key(key) - .uploadId(uploadId) - .partNumber(1) - .build(); - UploadPartPresignRequest presignUploadPartRequest = UploadPartPresignRequest.builder() - .signatureDuration(TEST_SIGNATURE_DURATION) - .uploadPartRequest(uploadPartRequest) - .build(); - PresignedUploadPartRequest presignedUploadPartRequest = presigner.presignUploadPart(presignUploadPartRequest); - StatusResponse uploadPartResponse = executeHttpRequest(presignedUploadPartRequest.httpRequest(), dummyPayload, createStatusResponseHandler()); - assertThat(uploadPartResponse.getStatusCode()).isEqualTo(200); - - String eTag = uploadPartResponse.getHeader("etag"); - - // If we provide a body for this request here, the AWS SDK will sign the contents even though it should not - // That results in Minio rejecting the request - CompleteMultipartUploadRequest.Builder completeMultipartUploadRequestBuilder = CompleteMultipartUploadRequest.builder() - .bucket(bucketName) - .key(key) - .uploadId(uploadId); - - // This is the signature for the request without a signed payload, just like all other presigned requests - PresignedCompleteMultipartUploadRequest presignedCompleteMultipartUploadRequest = presignCompleteMultipartUpload(presigner, completeMultipartUploadRequestBuilder); - - String completeMultipartUploadPayload = presignCompleteMultipartUpload( - presigner, - completeMultipartUploadRequestBuilder.multipartUpload(CompletedMultipartUpload.builder().parts(ImmutableList.of(CompletedPart.builder().partNumber(1).eTag(eTag).build())).build())) - .signedPayload().orElseThrow().asUtf8String(); - - StatusResponse completeMultipartResponse = executeHttpRequest(presignedCompleteMultipartUploadRequest.httpRequest(), completeMultipartUploadPayload, createStatusResponseHandler()); - assertThat(completeMultipartResponse.getStatusCode()).isEqualTo(200); - } - assertThat(getFileFromStorage(bucketName, key)).isEqualTo(dummyPayload); - HeadObjectResponse headResult = TestingUtil.headObjectInStorage(storageClient, bucketName, key); - assertThat(headResult.contentEncoding()).isEqualTo("gzip"); - assertThat(headResult.contentType()).isEqualTo("text/plain;charset=UTF-8"); - assertThat(headResult.metadata()).containsEntry("some-metadata-key", "some-metadata-value"); - } - - private PresignedCompleteMultipartUploadRequest presignCompleteMultipartUpload(S3Presigner presigner, CompleteMultipartUploadRequest.Builder completeMultipartUploadRequestBuilder) - { - return presigner.presignCompleteMultipartUpload(CompleteMultipartUploadPresignRequest.builder() - .completeMultipartUploadRequest(completeMultipartUploadRequestBuilder.build()) - .signatureDuration(TEST_SIGNATURE_DURATION) - .build()); - } - - private void uploadFileToStorage(String bucketName, String key, Path filePath) - { - PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(bucketName).key(key).build(); - PutObjectResponse putObjectResponse = storageClient.putObject(putObjectRequest, filePath); - assertThat(putObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(200); - } - - private String getFileFromStorage(String bucketName, String key) - throws IOException - { - return TestingUtil.getFileFromStorage(internalClient, bucketName, key); - } - - private T executeHttpRequest(SdkHttpRequest sdkRequest, ResponseHandler responseHandler) - { - return executeHttpRequest(sdkRequest, Optional.empty(), responseHandler); - } - - private T executeHttpRequest(SdkHttpRequest sdkRequest, String body, ResponseHandler responseHandler) - { - return executeHttpRequest(sdkRequest, Optional.of(body), responseHandler); - } - - private T executeHttpRequest(SdkHttpRequest sdkRequest, Optional body, ResponseHandler responseHandler) - { - Request.Builder requestBuilder = switch (sdkRequest.method()) { - case POST -> preparePost(); - case PUT -> preparePut(); - case GET -> prepareGet(); - case HEAD -> prepareHead(); - case DELETE -> prepareDelete(); - default -> throw new IllegalStateException("Unexpected HTTP method"); - }; - requestBuilder.setUri(sdkRequest.getUri()); - body.ifPresent(actualBody -> requestBuilder.setBodyGenerator(StaticBodyGenerator.createStaticBodyGenerator(actualBody, StandardCharsets.UTF_8))); - sdkRequest.forEachHeader((headerName, headerValues) -> headerValues.forEach(headerValue -> requestBuilder.addHeader(headerName, headerValue))); - return httpClient.execute(requestBuilder.build(), responseHandler); - } - - private S3Presigner buildPresigner() + XmlMapper xmlMapper, + TestingS3RequestRewriteController requestRewriteController) { - AwsBasicCredentials proxyCredentials = AwsBasicCredentials.create(testingCredentials.emulated().accessKey(), testingCredentials.emulated().secretKey()); - return S3Presigner.builder().region(Region.US_EAST_1).endpointOverride(s3ProxyUrl).credentialsProvider(StaticCredentialsProvider.create(proxyCredentials)).build(); + super(httpClient, internalClient, storageClient, testingCredentials, httpServer, s3ProxyConfig, xmlMapper, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequestsWithRewrite.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequestsWithRewrite.java new file mode 100644 index 00000000..c5d9e74d --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestPresignedRequestsWithRewrite.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server; + +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import com.google.inject.Inject; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.StringResponseHandler; +import io.airlift.http.server.testing.TestingHttpServer; +import io.trino.aws.proxy.server.testing.RequestRewriteUtil; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; +import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; +import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithTestingHttpClient; +import io.trino.aws.proxy.spi.credentials.Credentials; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.presigner.S3Presigner; +import software.amazon.awssdk.services.s3.presigner.model.GetObjectPresignRequest; +import software.amazon.awssdk.services.s3.presigner.model.PresignedGetObjectRequest; + +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; + +import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.CREDENTIAL_TO_REDIRECT; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_BUCKET; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_KEY; +import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; +import static org.assertj.core.api.Assertions.assertThat; + +@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithTestingHttpClient.class, RequestRewriteUtil.Filter.class}) +public class TestPresignedRequestsWithRewrite + extends AbstractTestPresignedRequests +{ + @Inject + public TestPresignedRequestsWithRewrite( + @ForTesting HttpClient httpClient, + S3Client internalClient, + @ForS3Container S3Client storageClient, + @ForTesting Credentials testingCredentials, + TestingHttpServer httpServer, + TrinoAwsProxyConfig s3ProxyConfig, + XmlMapper xmlMapper, + TestingS3RequestRewriteController requestRewriteController) + { + super(httpClient, internalClient, storageClient, testingCredentials, httpServer, s3ProxyConfig, xmlMapper, requestRewriteController); + } + + @Test + public void testPresignedRedirectBasedOnIdentity() + throws IOException + { + uploadFileToStorage(TEST_CREDENTIAL_REDIRECT_BUCKET, TEST_CREDENTIAL_REDIRECT_KEY, TEST_FILE); + + try (S3Presigner presigner = buildPresigner(CREDENTIAL_TO_REDIRECT)) { + GetObjectRequest objectRequest = GetObjectRequest.builder() + .bucket("foo") + .key("does-not-matter") + .build(); + + GetObjectPresignRequest presignRequest = GetObjectPresignRequest.builder() + .signatureDuration(Duration.ofDays(1)) + .getObjectRequest(objectRequest) + .build(); + + PresignedGetObjectRequest presignedRequest = presigner.presignGetObject(presignRequest); + StringResponseHandler.StringResponse response = executeHttpRequest(presignedRequest.httpRequest(), createStringResponseHandler()); + assertThat(response.getStatusCode()).isEqualTo(200); + assertThat(response.getBody()).isEqualTo(Files.readString(TEST_FILE)); + } + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedAssumedRoleRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedAssumedRoleRequests.java index 03dace19..82867800 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedAssumedRoleRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedAssumedRoleRequests.java @@ -16,6 +16,7 @@ import com.google.inject.Inject; import io.airlift.http.server.testing.TestingHttpServer; import io.trino.aws.proxy.server.testing.TestingCredentialsRolesProvider; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; @@ -31,7 +32,6 @@ import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider; import java.net.URI; -import java.util.List; import java.util.concurrent.CompletableFuture; import static java.util.Objects.requireNonNull; @@ -50,20 +50,19 @@ public TestProxiedAssumedRoleRequests( @ForTesting Credentials testingCredentials, TestingCredentialsRolesProvider credentialsController, @ForS3Container S3Client storageClient, - @ForS3Container List configuredBuckets, - TrinoAwsProxyConfig trinoAwsProxyConfig) + TrinoAwsProxyConfig trinoAwsProxyConfig, + TestingS3RequestRewriteController requestRewriteController) { - this(buildClient(httpServer, testingCredentials, trinoAwsProxyConfig.getS3Path(), trinoAwsProxyConfig.getStsPath()), testingCredentials, credentialsController, storageClient, configuredBuckets); + this(buildClient(httpServer, testingCredentials, trinoAwsProxyConfig.getS3Path(), trinoAwsProxyConfig.getStsPath()), credentialsController, storageClient, requestRewriteController); } protected TestProxiedAssumedRoleRequests( S3Client internalClient, - Credentials testingCredentials, TestingCredentialsRolesProvider credentialsController, S3Client storageClient, - List configuredBuckets) + TestingS3RequestRewriteController requestRewriteController) { - super(internalClient, storageClient, configuredBuckets); + super(internalClient, storageClient, requestRewriteController); this.credentialsController = requireNonNull(credentialsController, "credentialsController is null"); } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedEmulatedAndRemoteAssumedRoleRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedEmulatedAndRemoteAssumedRoleRequests.java index 51e3895a..89610761 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedEmulatedAndRemoteAssumedRoleRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedEmulatedAndRemoteAssumedRoleRequests.java @@ -16,27 +16,24 @@ import com.google.inject.Inject; import io.airlift.http.server.testing.TestingHttpServer; import io.trino.aws.proxy.server.testing.TestingCredentialsRolesProvider; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServerModule.ForTestingRemoteCredentials; -import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.spi.credentials.Credentials; import software.amazon.awssdk.services.s3.S3Client; -import java.util.List; - public class TestProxiedEmulatedAndRemoteAssumedRoleRequests extends TestProxiedAssumedRoleRequests { @Inject public TestProxiedEmulatedAndRemoteAssumedRoleRequests( TestingHttpServer httpServer, - @ForTesting Credentials testingCredentials, TestingCredentialsRolesProvider credentialsController, @ForS3Container S3Client storageClient, - @ForS3Container List configuredBuckets, @ForTestingRemoteCredentials Credentials remoteCredentials, - TrinoAwsProxyConfig trinoAwsProxyConfig) + TrinoAwsProxyConfig trinoAwsProxyConfig, + TestingS3RequestRewriteController requestRewriteController) { - super(buildClient(httpServer, remoteCredentials, trinoAwsProxyConfig.getS3Path(), trinoAwsProxyConfig.getStsPath()), testingCredentials, credentialsController, storageClient, configuredBuckets); + super(buildClient(httpServer, remoteCredentials, trinoAwsProxyConfig.getS3Path(), trinoAwsProxyConfig.getStsPath()), credentialsController, storageClient, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequests.java index 63b5407e..8ae87f40 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequests.java @@ -14,6 +14,7 @@ package io.trino.aws.proxy.server; import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.BuilderFilter; @@ -21,8 +22,6 @@ import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; import software.amazon.awssdk.services.s3.S3Client; -import java.util.List; - @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, TestProxiedRequests.Filter.class}) public class TestProxiedRequests extends AbstractTestProxiedRequests @@ -38,8 +37,8 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil } @Inject - public TestProxiedRequests(S3Client s3Client, @ForS3Container S3Client storageClient, @ForS3Container List configuredBuckets) + public TestProxiedRequests(S3Client s3Client, @ForS3Container S3Client storageClient, TestingS3RequestRewriteController requestRewriteController) { - super(s3Client, storageClient, configuredBuckets); + super(s3Client, storageClient, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java index f5fb81f1..b76229df 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsToVirtualHostEndpoint.java @@ -14,21 +14,20 @@ package io.trino.aws.proxy.server; import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithVirtualHostAddressing; import software.amazon.awssdk.services.s3.S3Client; -import java.util.List; - @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithVirtualHostAddressing.class}) public class TestProxiedRequestsToVirtualHostEndpoint extends AbstractTestProxiedRequests { @Inject - public TestProxiedRequestsToVirtualHostEndpoint(S3Client s3Client, @ForS3Container S3Client storageClient, @ForS3Container List configuredBuckets) + public TestProxiedRequestsToVirtualHostEndpoint(S3Client s3Client, @ForS3Container S3Client storageClient, TestingS3RequestRewriteController requestRewriteController) { - super(s3Client, storageClient, configuredBuckets); + super(s3Client, storageClient, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithEmptyPath.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithEmptyPath.java index a9f24d36..73b16d1b 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithEmptyPath.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithEmptyPath.java @@ -14,6 +14,7 @@ package io.trino.aws.proxy.server; import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.BuilderFilter; @@ -21,8 +22,6 @@ import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; import software.amazon.awssdk.services.s3.S3Client; -import java.util.List; - @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, TestProxiedRequestsWithEmptyPath.Filter.class}) public class TestProxiedRequestsWithEmptyPath extends AbstractTestProxiedRequests @@ -38,8 +37,8 @@ public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Buil } @Inject - public TestProxiedRequestsWithEmptyPath(S3Client s3Client, @ForS3Container S3Client storageClient, @ForS3Container List configuredBuckets) + public TestProxiedRequestsWithEmptyPath(S3Client s3Client, @ForS3Container S3Client storageClient, TestingS3RequestRewriteController requestRewriteController) { - super(s3Client, storageClient, configuredBuckets); + super(s3Client, storageClient, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithPathStyleOnVirtualHostProxy.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithPathStyleOnVirtualHostProxy.java index a1e71400..453a3441 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithPathStyleOnVirtualHostProxy.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithPathStyleOnVirtualHostProxy.java @@ -14,21 +14,20 @@ package io.trino.aws.proxy.server; import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithVirtualHostEnabledProxy; import software.amazon.awssdk.services.s3.S3Client; -import java.util.List; - @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithVirtualHostEnabledProxy.class}) public class TestProxiedRequestsWithPathStyleOnVirtualHostProxy extends AbstractTestProxiedRequests { @Inject - public TestProxiedRequestsWithPathStyleOnVirtualHostProxy(S3Client s3Client, @ForS3Container S3Client storageClient, @ForS3Container List configuredBuckets) + public TestProxiedRequestsWithPathStyleOnVirtualHostProxy(S3Client s3Client, @ForS3Container S3Client storageClient, TestingS3RequestRewriteController requestRewriteController) { - super(s3Client, storageClient, configuredBuckets); + super(s3Client, storageClient, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithRewrite.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithRewrite.java new file mode 100644 index 00000000..96b25a33 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithRewrite.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server; + +import com.google.inject.Inject; +import io.airlift.http.server.testing.TestingHttpServer; +import io.trino.aws.proxy.server.testing.RequestRewriteUtil; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; +import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +import java.io.IOException; +import java.net.URI; +import java.nio.file.Files; +import java.util.Optional; + +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.CREDENTIAL_TO_REDIRECT; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_BUCKET; +import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.TEST_CREDENTIAL_REDIRECT_KEY; +import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; +import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3; +import static io.trino.aws.proxy.server.testing.TestingUtil.clientBuilder; +import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; +import static org.assertj.core.api.Assertions.assertThat; + +@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, RequestRewriteUtil.Filter.class}) +public class TestProxiedRequestsWithRewrite + extends AbstractTestProxiedRequests +{ + private final URI baseUri; + private final String relativePath; + + @Inject + public TestProxiedRequestsWithRewrite( + S3Client s3Client, + @ForS3Container S3Client storageClient, + TestingS3RequestRewriteController s3RequestRewriteController, + TestingHttpServer testingHttpServer, + TrinoAwsProxyConfig config) + { + super(s3Client, storageClient, s3RequestRewriteController); + this.baseUri = testingHttpServer.getBaseUrl(); + this.relativePath = config.getS3Path(); + } + + @Test + public void testRewriteBasedOnIdentity() + throws IOException + { + String testBucket = "dummy"; + String testKey = "dummy-key"; + try (S3Client testS3Client = clientBuilder(baseUri, Optional.of(relativePath)) + .credentialsProvider(() -> AwsBasicCredentials.create(CREDENTIAL_TO_REDIRECT.accessKey(), CREDENTIAL_TO_REDIRECT.secretKey())) + .build()) { + PutObjectRequest uploadRequest = PutObjectRequest.builder().bucket(testBucket).key(testKey).build(); + PutObjectResponse uploadResponse = testS3Client.putObject(uploadRequest, TEST_FILE); + assertThat(uploadResponse.sdkHttpResponse().statusCode()).isEqualTo(200); + + assertThat(getFileFromStorage(testS3Client, testBucket, testKey)).isEqualTo(Files.readString(TEST_FILE)); + } + assertThat(getFileFromStorage(remoteClient, TEST_CREDENTIAL_REDIRECT_BUCKET, TEST_CREDENTIAL_REDIRECT_KEY)).isEqualTo(Files.readString(TEST_FILE)); + assertFileNotInS3(remoteClient, testBucket, testKey); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java index 9df4fce4..376c4847 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestProxiedRequestsWithVirtualHostProxy.java @@ -15,21 +15,20 @@ import com.google.inject.Inject; import io.trino.aws.proxy.server.testing.TestingS3ClientModule.ForVirtualHostProxy; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithVirtualHostEnabledProxy; import software.amazon.awssdk.services.s3.S3Client; -import java.util.List; - @TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithVirtualHostEnabledProxy.class}) public class TestProxiedRequestsWithVirtualHostProxy extends AbstractTestProxiedRequests { @Inject - public TestProxiedRequestsWithVirtualHostProxy(@ForVirtualHostProxy S3Client s3Client, @ForS3Container S3Client storageClient, @ForS3Container List configuredBuckets) + public TestProxiedRequestsWithVirtualHostProxy(@ForVirtualHostProxy S3Client s3Client, @ForS3Container S3Client storageClient, TestingS3RequestRewriteController requestRewriteController) { - super(s3Client, storageClient, configuredBuckets); + super(s3Client, storageClient, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestRemoteSessionProxiedRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestRemoteSessionProxiedRequests.java index a048e6b7..831175d5 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestRemoteSessionProxiedRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestRemoteSessionProxiedRequests.java @@ -15,6 +15,7 @@ import com.google.inject.Inject; import io.airlift.http.server.testing.TestingHttpServer; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServerModule.ForTestingRemoteCredentials; import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; @@ -23,7 +24,6 @@ import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.services.s3.S3Client; -import java.util.List; import java.util.Optional; import static io.trino.aws.proxy.server.testing.TestingUtil.clientBuilder; @@ -33,9 +33,9 @@ public class TestRemoteSessionProxiedRequests extends AbstractTestProxiedRequests { @Inject - public TestRemoteSessionProxiedRequests(@ForS3Container S3Client storageClient, @ForTestingRemoteCredentials Credentials remoteCredentials, TestingHttpServer httpServer, @ForS3Container List configuredBuckets, TrinoAwsProxyConfig trinoAwsProxyConfig) + public TestRemoteSessionProxiedRequests(@ForS3Container S3Client storageClient, @ForTestingRemoteCredentials Credentials remoteCredentials, TestingHttpServer httpServer, TrinoAwsProxyConfig trinoAwsProxyConfig, TestingS3RequestRewriteController requestRewriteController) { - super(buildInternalClient(remoteCredentials, httpServer, trinoAwsProxyConfig.getS3Path()), storageClient, configuredBuckets); + super(buildInternalClient(remoteCredentials, httpServer, trinoAwsProxyConfig.getS3Path()), storageClient, requestRewriteController); } private static S3Client buildInternalClient(Credentials credentials, TestingHttpServer httpServer, String s3Path) diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/AbstractTestPresigningHeaders.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/AbstractTestPresigningHeaders.java new file mode 100644 index 00000000..153922ec --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/AbstractTestPresigningHeaders.java @@ -0,0 +1,269 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.rest; + +import com.google.common.io.ByteStreams; +import io.trino.aws.proxy.server.testing.TestingS3PresignController; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; +import io.trino.aws.proxy.server.testing.TestingS3SecurityController; +import jakarta.ws.rs.core.MediaType; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; +import software.amazon.awssdk.services.s3.model.CompletedPart; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; +import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; +import static io.trino.aws.proxy.spi.security.SecurityResponse.FAILURE; +import static io.trino.aws.proxy.spi.security.SecurityResponse.SUCCESS; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +public abstract class AbstractTestPresigningHeaders +{ + private final S3Client storageClient; + private final S3Client internalClient; + private final TestingS3SecurityController securityController; + private final TestingS3RequestRewriteController s3RequestRewriteController; + + public AbstractTestPresigningHeaders(S3Client storageClient, S3Client internalClient, TestingS3PresignController presignController, TestingS3SecurityController securityController, TestingS3RequestRewriteController s3RequestRewriteController) + { + this.storageClient = requireNonNull(storageClient, "storageClient is null"); + this.internalClient = requireNonNull(internalClient, "internalClient is null"); + this.securityController = requireNonNull(securityController, "securityController is null"); + this.s3RequestRewriteController = requireNonNull(s3RequestRewriteController, "s3RequestRewriteController is null"); + + presignController.setRewriteUrisForContainers(false); + } + + @AfterEach + public void reset() + { + securityController.clear(); + } + + @Test + public void testPresignHeaderGet() + throws Exception + { + String bucketName = "one"; + String key = "getTest"; + PutObjectRequest putObjectRequest = PutObjectRequest.builder() + .bucket(s3RequestRewriteController.getTargetBucket(bucketName, key)) + .key(s3RequestRewriteController.getTargetKey(bucketName, key)) + .build(); + storageClient.putObject(putObjectRequest, TEST_FILE); + + URI uri = getPresigned("get", bucketName, key).uri; + + // test the pre-signed URL by using it directly without any additional headers, signing, etc. + try (InputStream inputStream = uri.toURL().openStream()) { + String readContents = new String(ByteStreams.toByteArray(inputStream), StandardCharsets.UTF_8); + String expectedContents = Files.readString(TEST_FILE); + assertThat(readContents).isEqualTo(expectedContents); + } + } + + @Test + public void testPresignHeaderSecurity() + { + String bucketName = "one"; + String key = "getTest"; + PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket(bucketName).key(key).build(); + internalClient.putObject(putObjectRequest, TEST_FILE); + + Presigned presigned = getPresigned("get", bucketName, key); + assertThat(presigned.presignedHeaderMethods).containsExactlyInAnyOrder("GET", "PUT", "POST", "DELETE"); + + securityController.setDelegate((request, _) -> lowercaseAction -> { + // Authorization happens with the original bucket and key, not the rewritten one + if (!request.bucketName().equals(bucketName) || !request.keyInBucket().equals(key)) { + return FAILURE; + } + return request.httpVerb().equalsIgnoreCase("DELETE") ? FAILURE : SUCCESS; + }); + + presigned = getPresigned("get", bucketName, key); + assertThat(presigned.presignedHeaderMethods).containsExactlyInAnyOrder("GET", "PUT", "POST"); + } + + @Test + public void testPresignHeaderMultiPart() + throws Exception + { + CreateMultipartUploadRequest multipartUploadRequest = CreateMultipartUploadRequest.builder().bucket("three").key("multi").build(); + CreateMultipartUploadResponse multipartUploadResponse = internalClient.createMultipartUpload(multipartUploadRequest); + + String uploadId = multipartUploadResponse.uploadId(); + + record Part(URI presignedUri, String content, int partNumber) {} + + List parts = IntStream.rangeClosed(1, 5).mapToObj(partNumber -> { + HeadObjectRequest request = HeadObjectRequest.builder() + .bucket("three") + .key("multi") + .partNumber(partNumber) + .overrideConfiguration(c -> c.putRawQueryParameter("uploadId", uploadId)) + .build(); + URI uri = getPresigned("PUT", request).uri; + String content = buildLine(partNumber); + return new Part(uri, content, partNumber); + }).collect(toImmutableList()); + + List completedParts = parts.stream().map(part -> { + try { + String eTag = upload(part.presignedUri, part.content); + return CompletedPart.builder().eTag(eTag).partNumber(part.partNumber).build(); + } + catch (IOException e) { + throw new RuntimeException(e); + } + }).collect(toImmutableList()); + + CompletedMultipartUpload completedUpload = CompletedMultipartUpload.builder() + .parts(completedParts) + .build(); + + CompleteMultipartUploadRequest completeRequest = CompleteMultipartUploadRequest.builder() + .bucket("three") + .key("multi") + .uploadId(uploadId) + .multipartUpload(completedUpload) + .build(); + + CompleteMultipartUploadResponse completeResponse = internalClient.completeMultipartUpload(completeRequest); + assertThat(completeResponse.sdkHttpResponse().statusCode()).isEqualTo(200); + + GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket("three").key("multi").build(); + ByteArrayOutputStream readContents = new ByteArrayOutputStream(); + internalClient.getObject(getObjectRequest).transferTo(readContents); + + String expected = IntStream.rangeClosed(1, 5) + .mapToObj(AbstractTestPresigningHeaders::buildLine) + .collect(Collectors.joining()); + + assertThat(readContents.toString()).isEqualTo(expected); + } + + @Test + public void testPresignHeaderPut() + throws Exception + { + String bucketName = "two"; + String key = "puttest"; + URI uri = getPresigned("put", bucketName, key).uri; + + String fileContents = Files.readString(TEST_FILE); + + // test the pre-signed URL by using it directly without any additional headers, signing, etc. + upload(uri, fileContents); + + // check that the file was uploaded correctly + assertThat(getFileFromStorage(internalClient, bucketName, key)).isEqualTo(fileContents); + assertThat(getFileFromStorage(storageClient, s3RequestRewriteController.getTargetBucket(bucketName, key), s3RequestRewriteController.getTargetKey(bucketName, key))).isEqualTo(fileContents); + } + + private static String upload(URI uri, String contents) + throws IOException + { + HttpURLConnection urlConnection = (HttpURLConnection) uri.toURL().openConnection(); + try { + urlConnection.setFixedLengthStreamingMode(contents.length()); + urlConnection.setRequestProperty("Content-Type", MediaType.APPLICATION_OCTET_STREAM); + urlConnection.setDoOutput(true); + urlConnection.setDoInput(true); + urlConnection.setRequestMethod("PUT"); + urlConnection.connect(); + try (OutputStream outputStream = urlConnection.getOutputStream()) { + ByteStreams.copy(new ByteArrayInputStream(contents.getBytes(StandardCharsets.UTF_8)), outputStream); + outputStream.flush(); + } + + return urlConnection.getHeaderField("eTag"); + } + finally { + urlConnection.disconnect(); + } + } + + private record Presigned(URI uri, Set presignedHeaderMethods) {} + + private Presigned getPresigned(String type, String bucket, String key) + { + HeadObjectRequest request = HeadObjectRequest.builder() + .bucket(bucket) + .key(key) + .build(); + return getPresigned(type, request); + } + + private Presigned getPresigned(String type, HeadObjectRequest request) + { + SdkHttpResponse sdkHttpResponse; + try { + HeadObjectResponse response = internalClient.headObject(request); + sdkHttpResponse = response.sdkHttpResponse(); + } + catch (AwsServiceException e) { + // when the bucket isn't found an exception is thrown - but response headers still have pre-signed URLs + sdkHttpResponse = e.awsErrorDetails().sdkHttpResponse(); + } + + Set presignedHeaderMethods = sdkHttpResponse.headers() + .keySet() + .stream() + .filter(header -> header.toLowerCase(Locale.ROOT).startsWith("x-trino-pre-signed-url-")) + .map(header -> header.substring("x-trino-pre-signed-url-".length())) + .collect(toImmutableSet()); + + // use an odd case for the header name on purpose + String header = "x-TRINO-pre-SIGNED-uRl-" + type; + String uri = sdkHttpResponse.firstMatchingHeader(header).orElseThrow(); + return new Presigned(URI.create(uri), presignedHeaderMethods); + } + + private static String buildLine(int partNumber) + { + // min multi-part is 5MB + return Character.toString('a' + (partNumber - 1)).repeat(1024 * 1024 * 5); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestPresigningHeaders.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestPresigningHeaders.java index 7c61dfcc..7f4e2a4f 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestPresigningHeaders.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestPresigningHeaders.java @@ -13,248 +13,22 @@ */ package io.trino.aws.proxy.server.rest; -import com.google.common.io.ByteStreams; import com.google.inject.Inject; import io.trino.aws.proxy.server.testing.TestingS3PresignController; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; import io.trino.aws.proxy.server.testing.TestingS3SecurityController; +import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; -import jakarta.ws.rs.core.MediaType; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Test; -import software.amazon.awssdk.awscore.exception.AwsServiceException; -import software.amazon.awssdk.core.ResponseInputStream; -import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.services.s3.S3Client; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; -import software.amazon.awssdk.services.s3.model.CompletedPart; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CreateMultipartUploadResponse; -import software.amazon.awssdk.services.s3.model.GetObjectRequest; -import software.amazon.awssdk.services.s3.model.GetObjectResponse; -import software.amazon.awssdk.services.s3.model.HeadObjectRequest; -import software.amazon.awssdk.services.s3.model.HeadObjectResponse; -import software.amazon.awssdk.services.s3.model.PutObjectRequest; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.net.HttpURLConnection; -import java.net.URI; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.util.List; -import java.util.Locale; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE; -import static io.trino.aws.proxy.spi.security.SecurityResponse.FAILURE; -import static io.trino.aws.proxy.spi.security.SecurityResponse.SUCCESS; -import static java.util.Objects.requireNonNull; -import static org.assertj.core.api.Assertions.assertThat; @TrinoAwsProxyTest(filters = WithConfiguredBuckets.class) public class TestPresigningHeaders + extends AbstractTestPresigningHeaders { - private final S3Client s3Client; - private final TestingS3SecurityController securityController; - @Inject - public TestPresigningHeaders(S3Client s3Client, TestingS3PresignController presignController, TestingS3SecurityController securityController) - { - this.s3Client = requireNonNull(s3Client, "s3Client is null"); - this.securityController = requireNonNull(securityController, "securityController is null"); - - presignController.setRewriteUrisForContainers(false); - } - - @AfterEach - public void reset() - { - securityController.clear(); - } - - @Test - public void testPresignHeaderGet() - throws Exception - { - PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket("one").key("gettest").build(); - s3Client.putObject(putObjectRequest, TEST_FILE); - - URI uri = getPresigned("get", "one", "gettest").uri; - - // test the pre-signed URL by using it directly without any additional headers, signing, etc. - try (InputStream inputStream = uri.toURL().openStream()) { - String readContents = new String(ByteStreams.toByteArray(inputStream), StandardCharsets.UTF_8); - String expectedContents = Files.readString(TEST_FILE); - assertThat(readContents).isEqualTo(expectedContents); - } - } - - @Test - public void testPresignHeaderSecurity() - { - PutObjectRequest putObjectRequest = PutObjectRequest.builder().bucket("one").key("gettest").build(); - s3Client.putObject(putObjectRequest, TEST_FILE); - - Presigned presigned = getPresigned("get", "one", "gettest"); - assertThat(presigned.presignedHeaderMethods).containsExactlyInAnyOrder("GET", "PUT", "POST", "DELETE"); - - securityController.setDelegate((request, _) -> lowercaseAction -> request.httpVerb().equalsIgnoreCase("DELETE") ? FAILURE : SUCCESS); - - presigned = getPresigned("get", "one", "gettest"); - assertThat(presigned.presignedHeaderMethods).containsExactlyInAnyOrder("GET", "PUT", "POST"); - } - - @Test - public void testPresignHeaderMultiPart() - throws Exception - { - CreateMultipartUploadRequest multipartUploadRequest = CreateMultipartUploadRequest.builder().bucket("three").key("multi").build(); - CreateMultipartUploadResponse multipartUploadResponse = s3Client.createMultipartUpload(multipartUploadRequest); - - String uploadId = multipartUploadResponse.uploadId(); - - record Part(URI presignedUri, String content, int partNumber) {} - - List parts = IntStream.rangeClosed(1, 5).mapToObj(partNumber -> { - HeadObjectRequest request = HeadObjectRequest.builder() - .bucket("three") - .key("multi") - .partNumber(partNumber) - .overrideConfiguration(c -> c.putRawQueryParameter("uploadId", uploadId)) - .build(); - URI uri = getPresigned("PUT", request).uri; - String content = buildLine(partNumber); - return new Part(uri, content, partNumber); - }).collect(toImmutableList()); - - List completedParts = parts.stream().map(part -> { - try { - String eTag = upload(part.presignedUri, part.content); - return CompletedPart.builder().eTag(eTag).partNumber(part.partNumber).build(); - } - catch (IOException e) { - throw new RuntimeException(e); - } - }).collect(toImmutableList()); - - CompletedMultipartUpload completedUpload = CompletedMultipartUpload.builder() - .parts(completedParts) - .build(); - - CompleteMultipartUploadRequest completeRequest = CompleteMultipartUploadRequest.builder() - .bucket("three") - .key("multi") - .uploadId(uploadId) - .multipartUpload(completedUpload) - .build(); - - CompleteMultipartUploadResponse completeResponse = s3Client.completeMultipartUpload(completeRequest); - assertThat(completeResponse.sdkHttpResponse().statusCode()).isEqualTo(200); - - GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket("three").key("multi").build(); - ByteArrayOutputStream readContents = new ByteArrayOutputStream(); - s3Client.getObject(getObjectRequest).transferTo(readContents); - - String expected = IntStream.rangeClosed(1, 5) - .mapToObj(TestPresigningHeaders::buildLine) - .collect(Collectors.joining()); - - assertThat(readContents.toString()).isEqualTo(expected); - } - - @Test - public void testPresignHeaderPut() - throws Exception - { - URI uri = getPresigned("put", "two", "puttest").uri; - - String fileContents = Files.readString(TEST_FILE); - - // test the pre-signed URL by using it directly without any additional headers, signing, etc. - upload(uri, fileContents); - - // check that the file was uploaded correctly - GetObjectRequest getObjectRequest = GetObjectRequest.builder() - .bucket("two") - .key("puttest") - .build(); - ResponseInputStream responseInputStream = s3Client.getObject(getObjectRequest); - String readContents = new String(responseInputStream.readAllBytes(), StandardCharsets.UTF_8); - assertThat(readContents).isEqualTo(fileContents); - } - - private static String upload(URI uri, String contents) - throws IOException - { - HttpURLConnection urlConnection = (HttpURLConnection) uri.toURL().openConnection(); - try { - urlConnection.setFixedLengthStreamingMode(contents.length()); - urlConnection.setRequestProperty("Content-Type", MediaType.APPLICATION_OCTET_STREAM); - urlConnection.setDoOutput(true); - urlConnection.setDoInput(true); - urlConnection.setRequestMethod("PUT"); - urlConnection.connect(); - try (OutputStream outputStream = urlConnection.getOutputStream()) { - ByteStreams.copy(new ByteArrayInputStream(contents.getBytes(StandardCharsets.UTF_8)), outputStream); - outputStream.flush(); - } - - return urlConnection.getHeaderField("eTag"); - } - finally { - urlConnection.disconnect(); - } - } - - private record Presigned(URI uri, Set presignedHeaderMethods) {} - - private Presigned getPresigned(String type, String bucket, String key) - { - HeadObjectRequest request = HeadObjectRequest.builder() - .bucket(bucket) - .key(key) - .build(); - return getPresigned(type, request); - } - - private Presigned getPresigned(String type, HeadObjectRequest request) - { - SdkHttpResponse sdkHttpResponse; - try { - HeadObjectResponse response = s3Client.headObject(request); - sdkHttpResponse = response.sdkHttpResponse(); - } - catch (AwsServiceException e) { - // when the bucket isn't found an exception is thrown - but response headers still have pre-signed URLs - sdkHttpResponse = e.awsErrorDetails().sdkHttpResponse(); - } - - Set presignedHeaderMethods = sdkHttpResponse.headers() - .keySet() - .stream() - .filter(header -> header.toLowerCase(Locale.ROOT).startsWith("x-trino-pre-signed-url-")) - .map(header -> header.substring("x-trino-pre-signed-url-".length())) - .collect(toImmutableSet()); - - // use an odd case for the header name on purpose - String header = "x-TRINO-pre-SIGNED-uRl-" + type; - String uri = sdkHttpResponse.firstMatchingHeader(header).orElseThrow(); - return new Presigned(URI.create(uri), presignedHeaderMethods); - } - - private static String buildLine(int partNumber) + public TestPresigningHeaders(@ForS3Container S3Client storageClient, S3Client internalClient, TestingS3PresignController presignController, TestingS3SecurityController securityController, TestingS3RequestRewriteController requestRewriteController) { - // min multi-part is 5MB - return Character.toString('a' + (partNumber - 1)).repeat(1024 * 1024 * 5); + super(storageClient, internalClient, presignController, securityController, requestRewriteController); } } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestPresigningHeadersWithRewrite.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestPresigningHeadersWithRewrite.java new file mode 100644 index 00000000..23ef341d --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestPresigningHeadersWithRewrite.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.rest; + +import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.RequestRewriteUtil; +import io.trino.aws.proxy.server.testing.TestingS3PresignController; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; +import io.trino.aws.proxy.server.testing.TestingS3SecurityController; +import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets; +import software.amazon.awssdk.services.s3.S3Client; + +@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, RequestRewriteUtil.Filter.class}) +public class TestPresigningHeadersWithRewrite + extends AbstractTestPresigningHeaders +{ + @Inject + public TestPresigningHeadersWithRewrite(@ForS3Container S3Client storageClient, S3Client internalClient, TestingS3PresignController presignController, TestingS3SecurityController securityController, TestingS3RequestRewriteController requestRewriteController) + { + super(storageClient, internalClient, presignController, securityController, requestRewriteController); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/ContainerS3Facade.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/ContainerS3Facade.java index 3361acc1..351abc77 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/ContainerS3Facade.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/ContainerS3Facade.java @@ -30,7 +30,6 @@ public static class PathStyleContainerS3Facade @Inject public PathStyleContainerS3Facade(S3Container s3Container, TestingRemoteS3Facade delegatingFacade) { - //super((ignored1, ignored2) -> "127.0.0.1", false, Optional.of(5432)); super((ignored1, ignored2) -> s3Container.containerHost().getHost(), false, Optional.of(s3Container.containerHost().getPort())); delegatingFacade.setDelegate(this); } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/RequestRewriteUtil.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/RequestRewriteUtil.java new file mode 100644 index 00000000..a29cf1e9 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/RequestRewriteUtil.java @@ -0,0 +1,88 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.testing; + +import com.google.inject.Inject; +import com.google.inject.Scopes; +import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; +import io.trino.aws.proxy.server.testing.containers.S3Container.ForS3Container; +import io.trino.aws.proxy.server.testing.harness.BuilderFilter; +import io.trino.aws.proxy.spi.credentials.Credential; +import io.trino.aws.proxy.spi.credentials.Credentials; +import software.amazon.awssdk.services.s3.S3Client; + +import java.util.List; +import java.util.Optional; +import java.util.UUID; + +import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder; +import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.s3RequestRewriterModule; + +public final class RequestRewriteUtil +{ + public static final String TEST_CREDENTIAL_REDIRECT_BUCKET = "redirected-bucket-for-credential"; + public static final String TEST_CREDENTIAL_REDIRECT_KEY = "redirected-key-for-credential"; + public static final Credential CREDENTIAL_TO_REDIRECT = new Credential("credential-to-redirect", UUID.randomUUID().toString()); + + private RequestRewriteUtil() {} + + public static class SetupRequestRewrites + { + @Inject + public SetupRequestRewrites( + TestingCredentialsRolesProvider credentialsRolesProvider, + @ForTesting Credentials testingCredentials, + @ForS3Container List configuredBuckets, + @ForS3Container S3Client storageClient) + { + credentialsRolesProvider.addCredentials(Credentials.build(CREDENTIAL_TO_REDIRECT, testingCredentials.requiredRemoteCredential())); + configuredBuckets.forEach(bucket -> storageClient.createBucket(r -> r.bucket(getTargetName(bucket)))); + storageClient.createBucket(r -> r.bucket(TEST_CREDENTIAL_REDIRECT_BUCKET)); + } + } + + public static class Filter + implements BuilderFilter + { + @Override + public TestingTrinoAwsProxyServer.Builder filter(TestingTrinoAwsProxyServer.Builder builder) + { + return builder + .addModule(s3RequestRewriterModule("testing", TestingS3RequestRewriter.class, binder -> { + newOptionalBinder(binder, TestingS3RequestRewriter.class).setBinding().to(Rewriter.class).in(Scopes.SINGLETON); + binder.bind(SetupRequestRewrites.class).asEagerSingleton(); + })) + .withProperty("s3-request-rewriter.type", "testing"); + } + } + + public static class Rewriter + implements TestingS3RequestRewriter + { + @Override + public Optional testRewrite(Credentials credentials, String bucketName, String keyName) + { + boolean redirectForTestCredential = credentials.emulated().accessKey().equalsIgnoreCase(CREDENTIAL_TO_REDIRECT.accessKey()); + if (redirectForTestCredential) { + return Optional.of(new S3RewriteResult(TEST_CREDENTIAL_REDIRECT_BUCKET, keyName.isEmpty() ? "" : TEST_CREDENTIAL_REDIRECT_KEY)); + } + return Optional.of(new S3RewriteResult(getTargetName(bucketName), getTargetName(keyName))); + } + } + + private static String getTargetName(String name) + { + return name.isEmpty() ? "" : "redirected-%s".formatted(name); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java new file mode 100644 index 00000000..0c79816e --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriteController.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.testing; + +import com.google.inject.Inject; +import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; +import io.trino.aws.proxy.spi.credentials.Credentials; +import io.trino.aws.proxy.spi.rest.S3RequestRewriter.S3RewriteResult; + +import static java.util.Objects.requireNonNull; + +public class TestingS3RequestRewriteController +{ + private final TestingS3RequestRewriter s3RequestRewriter; + private final Credentials defaultCredentials; + + @Inject + public TestingS3RequestRewriteController(TestingS3RequestRewriter rewriter, @ForTesting Credentials defaultCredentials) + { + this.s3RequestRewriter = requireNonNull(rewriter, "rewriter is null"); + this.defaultCredentials = requireNonNull(defaultCredentials, "defaultCredentials is null"); + } + + private S3RewriteResult rewriteOrNoop(Credentials credentials, String bucket, String key) + { + return s3RequestRewriter.testRewrite(credentials, bucket, key).orElseGet(() -> new S3RewriteResult(bucket, key)); + } + + public String getTargetBucket(Credentials credentials, String bucket, String key) + { + return rewriteOrNoop(credentials, bucket, key).finalRequestBucket(); + } + + public String getTargetBucket(String bucket, String key) + { + return getTargetBucket(defaultCredentials, bucket, key); + } + + public String getTargetKey(Credentials credentials, String bucket, String key) + { + return rewriteOrNoop(credentials, bucket, key).finalRequestKey(); + } + + public String getTargetKey(String bucket, String key) + { + return getTargetKey(defaultCredentials, bucket, key); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriter.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriter.java new file mode 100644 index 00000000..441a2f91 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingS3RequestRewriter.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.testing; + +import io.trino.aws.proxy.spi.credentials.Credentials; +import io.trino.aws.proxy.spi.rest.ParsedS3Request; +import io.trino.aws.proxy.spi.rest.S3RequestRewriter; + +import java.util.Optional; + +@FunctionalInterface +public interface TestingS3RequestRewriter + extends S3RequestRewriter +{ + TestingS3RequestRewriter NOOP = (_, _, _) -> Optional.empty(); + + Optional testRewrite(Credentials credentials, String bucketName, String keyName); + + @Override + default Optional rewrite(Credentials credentials, ParsedS3Request request) + { + return testRewrite(credentials, request.bucketName(), request.keyInBucket()); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java index f608028d..d41c5ed1 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java @@ -22,6 +22,7 @@ import software.amazon.awssdk.services.s3.S3Client; import software.amazon.awssdk.services.s3.S3ClientBuilder; import software.amazon.awssdk.services.s3.model.Delete; +import software.amazon.awssdk.services.s3.model.DeleteBucketRequest; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; @@ -123,6 +124,12 @@ public static void cleanupBuckets(S3Client storageClient) })); } + public static void deleteAllBuckets(S3Client storageClient) + { + cleanupBuckets(storageClient); + storageClient.listBuckets().buckets().forEach(bucket -> storageClient.deleteBucket(DeleteBucketRequest.builder().bucket(bucket.name()).build())); + } + public static void assertFileNotInS3(S3Client storageClient, String bucket, String key) { assertThatExceptionOfType(S3Exception.class) diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/harness/TrinoAwsProxyTestExtension.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/harness/TrinoAwsProxyTestExtension.java index d13eb824..cd5e2e97 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/harness/TrinoAwsProxyTestExtension.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/harness/TrinoAwsProxyTestExtension.java @@ -19,6 +19,8 @@ import io.trino.aws.proxy.server.remote.RemoteS3Facade; import io.trino.aws.proxy.server.testing.ContainerS3Facade; import io.trino.aws.proxy.server.testing.TestingS3ClientModule; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriteController; +import io.trino.aws.proxy.server.testing.TestingS3RequestRewriter; import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer; import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting; import io.trino.aws.proxy.server.testing.containers.S3Container; @@ -66,6 +68,8 @@ public Object createTestInstance(TestInstanceFactoryContext factoryContext, Exte .setDefault() .to(ContainerS3Facade.PathStyleContainerS3Facade.class) .asEagerSingleton(); + newOptionalBinder(binder, TestingS3RequestRewriter.class).setDefault().toInstance(TestingS3RequestRewriter.NOOP); + binder.bind(TestingS3RequestRewriteController.class).in(Scopes.SINGLETON); }) .buildAndStart(); testingServersRegistry.put(extensionContext.getUniqueId(), trinoS3ProxyServer);