Skip to content

Commit

Permalink
Add tests for rewrites in presigned header responses
Browse files Browse the repository at this point in the history
  • Loading branch information
vagaerg committed Aug 16, 2024
1 parent c3e6a5d commit 2ae84c3
Show file tree
Hide file tree
Showing 8 changed files with 352 additions and 290 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
import static io.airlift.http.client.Request.Builder.preparePut;
import static io.airlift.http.client.StatusResponseHandler.createStatusResponseHandler;
import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.getTargetBucket;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.getTargetKey;
import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE;
import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage;
import static io.trino.aws.proxy.server.testing.TestingUtil.listFilesInS3Bucket;
Expand Down Expand Up @@ -111,23 +113,13 @@ protected AbstractTestPresignedRequests(
this.requestRewriter = requireNonNull(requestRewriter, "requestRewriter is null");
}

private String getTargetBucket(String sourceBucket, String sourceKey)
{
return requestRewriter.testRewrite(Optional.empty(), sourceBucket, sourceKey).finalRequestBucket();
}

private String getTargetKey(String sourceBucket, String sourceKey)
{
return requestRewriter.testRewrite(Optional.empty(), sourceBucket, sourceKey).finalRequestKey();
}

@Test
public void testPresignedGet()
throws IOException
{
String bucketName = "one";
String key = "presignedGet";
uploadFileToStorage(getTargetBucket(bucketName, key), getTargetKey(bucketName, key), TEST_FILE);
uploadFileToStorage(getTargetBucket(requestRewriter, bucketName, key), getTargetKey(requestRewriter, bucketName, key), TEST_FILE);

try (S3Presigner presigner = buildPresigner()) {
GetObjectRequest objectRequest = GetObjectRequest.builder()
Expand Down Expand Up @@ -173,7 +165,7 @@ public void testPresignedPut()
}

assertThat(getFileFromStorage(bucketName, key)).isEqualTo(fileContents);
HeadObjectResponse headObjectResponse = headObjectInStorage(storageClient, getTargetBucket(bucketName, key), getTargetKey(bucketName, key));
HeadObjectResponse headObjectResponse = headObjectInStorage(storageClient, getTargetBucket(requestRewriter, bucketName, key), getTargetKey(requestRewriter, bucketName, key));
assertThat(headObjectResponse.contentType()).isEqualTo("text/plain;charset=UTF-8");
assertThat(headObjectResponse.contentEncoding()).isEqualTo("gzip");
}
Expand All @@ -183,7 +175,7 @@ public void testPresignedDelete()
{
String bucketName = "three";
String key = "fileToDelete";
uploadFileToStorage(getTargetBucket(bucketName, key), getTargetKey(bucketName, key), TEST_FILE);
uploadFileToStorage(getTargetBucket(requestRewriter, bucketName, key), getTargetKey(requestRewriter, bucketName, key), TEST_FILE);

try (S3Presigner presigner = buildPresigner()) {
DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucketName).key(key).build();
Expand All @@ -197,7 +189,7 @@ public void testPresignedDelete()
assertThat(response.getStatusCode()).isEqualTo(204);
}

assertThat(listFilesInS3Bucket(storageClient, getTargetBucket(bucketName, key))).isEmpty();
assertThat(listFilesInS3Bucket(storageClient, getTargetBucket(requestRewriter, bucketName, key))).isEmpty();
}

@Test
Expand All @@ -206,7 +198,7 @@ public void testExpiredSignature()
{
String bucketName = "three";
String key = "fileToDeleteExpired";
uploadFileToStorage(getTargetBucket(bucketName, key), getTargetKey(bucketName, key), TEST_FILE);
uploadFileToStorage(getTargetBucket(requestRewriter, bucketName, key), getTargetKey(requestRewriter, bucketName, key), TEST_FILE);

try (S3Presigner presigner = buildPresigner()) {
DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucketName).key(key).build();
Expand Down Expand Up @@ -288,7 +280,7 @@ public void testMultipart()
assertThat(completeMultipartResponse.getStatusCode()).isEqualTo(200);
}
assertThat(getFileFromStorage(bucketName, key)).isEqualTo(dummyPayload);
HeadObjectResponse headResult = TestingUtil.headObjectInStorage(storageClient, getTargetBucket(bucketName, key), getTargetKey(bucketName, key));
HeadObjectResponse headResult = TestingUtil.headObjectInStorage(storageClient, getTargetBucket(requestRewriter, bucketName, key), getTargetKey(requestRewriter, bucketName, key));
assertThat(headResult.contentEncoding()).isEqualTo("gzip");
assertThat(headResult.contentType()).isEqualTo("text/plain;charset=UTF-8");
assertThat(headResult.metadata()).containsEntry("some-metadata-key", "some-metadata-value");
Expand All @@ -313,7 +305,7 @@ String getFileFromStorage(String bucketName, String key)
throws IOException
{
String dataFromProxy = TestingUtil.getFileFromStorage(internalClient, bucketName, key);
String dataFromStorage = TestingUtil.getFileFromStorage(storageClient, getTargetBucket(bucketName, key), getTargetKey(bucketName, key));
String dataFromStorage = TestingUtil.getFileFromStorage(storageClient, getTargetBucket(requestRewriter, bucketName, key), getTargetKey(requestRewriter, bucketName, key));
assertThat(dataFromProxy).isEqualTo(dataFromStorage);
return dataFromStorage;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import java.time.Duration;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -54,6 +53,8 @@

import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.util.concurrent.MoreExecutors.shutdownAndAwaitTermination;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.getTargetBucket;
import static io.trino.aws.proxy.server.testing.RequestRewriteUtil.getTargetKey;
import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE;
import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage;
import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage;
Expand All @@ -80,16 +81,6 @@ protected AbstractTestProxiedRequests(S3Client internalClient, S3Client remoteCl
this(internalClient, remoteClient, TestingS3RequestRewriter.NOP);
}

private String getTargetBucket(String sourceBucket, String sourceKey)
{
return requestRewriter.testRewrite(Optional.empty(), sourceBucket, sourceKey).finalRequestBucket();
}

private String getTargetKey(String sourceBucket, String sourceKey)
{
return requestRewriter.testRewrite(Optional.empty(), sourceBucket, sourceKey).finalRequestKey();
}

@PreDestroy
public void shutdown()
{
Expand All @@ -110,7 +101,7 @@ public void testCreateBucket()
assertThat(createBucketResponse.sdkHttpResponse().statusCode()).isEqualTo(200);

ListBucketsResponse listBucketsResponse = remoteClient.listBuckets();
assertThat(listBucketsResponse.buckets()).extracting(Bucket::name).contains(getTargetBucket(newBucketName, ""));
assertThat(listBucketsResponse.buckets()).extracting(Bucket::name).contains(getTargetBucket(requestRewriter, newBucketName, ""));
}

@Test
Expand All @@ -129,9 +120,9 @@ public void testListBucketsWithContents()
String testKey = "some-key";
assertThat(listFilesInS3Bucket(internalClient, bucketToTest)).isEmpty();

remoteClient.putObject(request -> request.bucket(getTargetBucket(bucketToTest, testKey)).key(getTargetKey(bucketToTest, testKey)), RequestBody.fromString("some-contents"));
remoteClient.putObject(request -> request.bucket(getTargetBucket(requestRewriter, bucketToTest, testKey)).key(getTargetKey(requestRewriter, bucketToTest, testKey)), RequestBody.fromString("some-contents"));

assertThat(listFilesInS3Bucket(internalClient, bucketToTest)).containsExactlyInAnyOrder(getTargetKey(bucketToTest, testKey));
assertThat(listFilesInS3Bucket(internalClient, bucketToTest)).containsExactlyInAnyOrder(getTargetKey(requestRewriter, bucketToTest, testKey));
}

@Test
Expand All @@ -146,16 +137,16 @@ public void testUploadAndDelete()

String expectedContents = Files.readString(TEST_FILE);
assertThat(getFileFromStorage(internalClient, bucket, key)).isEqualTo(expectedContents);
assertThat(getFileFromStorage(remoteClient, getTargetBucket(bucket, key), getTargetKey(bucket, key))).isEqualTo(expectedContents);
assertThat(getFileFromStorage(remoteClient, getTargetBucket(requestRewriter, bucket, key), getTargetKey(requestRewriter, bucket, key))).isEqualTo(expectedContents);

assertThat(listFilesInS3Bucket(internalClient, bucket)).containsExactlyInAnyOrder(getTargetKey(bucket, key));
assertThat(listFilesInS3Bucket(remoteClient, getTargetBucket(bucket, key))).containsExactlyInAnyOrder(getTargetKey(bucket, key));
assertThat(listFilesInS3Bucket(internalClient, bucket)).containsExactlyInAnyOrder(getTargetKey(requestRewriter, bucket, key));
assertThat(listFilesInS3Bucket(remoteClient, getTargetBucket(requestRewriter, bucket, key))).containsExactlyInAnyOrder(getTargetKey(requestRewriter, bucket, key));

DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucket).key(key).build();
internalClient.deleteObject(deleteObjectRequest);

assertThat(listFilesInS3Bucket(internalClient, bucket)).isEmpty();
assertThat(listFilesInS3Bucket(remoteClient, getTargetBucket(bucket, key))).isEmpty();
assertThat(listFilesInS3Bucket(remoteClient, getTargetBucket(requestRewriter, bucket, key))).isEmpty();
}

@Test
Expand All @@ -175,7 +166,7 @@ public void testUploadWithContentTypeAndMetadata()
assertThat(putObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(200);

assertThat(getFileFromStorage(internalClient, bucket, key)).isEqualTo(Files.readString(TEST_FILE));
assertThat(getFileFromStorage(remoteClient, getTargetBucket(bucket, key), getTargetKey(bucket, key))).isEqualTo(Files.readString(TEST_FILE));
assertThat(getFileFromStorage(remoteClient, getTargetBucket(requestRewriter, bucket, key), getTargetKey(requestRewriter, bucket, key))).isEqualTo(Files.readString(TEST_FILE));
HeadObjectResponse headObjectResponse = headObjectInStorage(internalClient, bucket, key);
assertThat(headObjectResponse.sdkHttpResponse().statusCode()).isEqualTo(200);

Expand Down Expand Up @@ -240,7 +231,7 @@ public void testMultipartUpload()
.collect(Collectors.joining());

assertThat(getFileFromStorage(internalClient, bucket, key)).isEqualTo(expected);
assertThat(getFileFromStorage(remoteClient, getTargetBucket(bucket, key), getTargetKey(bucket, key))).isEqualTo(expected);
assertThat(getFileFromStorage(remoteClient, getTargetBucket(requestRewriter, bucket, key), getTargetKey(requestRewriter, bucket, key))).isEqualTo(expected);

DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(bucket).key(key).build();
DeleteObjectResponse deleteObjectResponse = internalClient.deleteObject(deleteObjectRequest);
Expand All @@ -251,13 +242,13 @@ public void testMultipartUpload()
public void testPathsNeedingEscaping()
{
String bucket = "escapes";
remoteClient.createBucket(r -> r.bucket(getTargetBucket(bucket, "")));
remoteClient.createBucket(r -> r.bucket(getTargetBucket(requestRewriter, bucket, "")));
internalClient.putObject(r -> r.bucket(bucket).key("a=1/b=2"), RequestBody.fromString("something"));
internalClient.putObject(r -> r.bucket(bucket).key("a=1%2Fb=2"), RequestBody.fromString("else"));

List<String> expectedKeys = ImmutableList.of(getTargetKey(bucket, "a=1/b=2"), getTargetKey(bucket, "a=1%2Fb=2"));
List<String> expectedKeys = ImmutableList.of(getTargetKey(requestRewriter, bucket, "a=1/b=2"), getTargetKey(requestRewriter, bucket, "a=1%2Fb=2"));
assertThat(listFilesInS3Bucket(internalClient, bucket)).containsExactlyInAnyOrderElementsOf(expectedKeys);
assertThat(listFilesInS3Bucket(remoteClient, getTargetBucket(bucket, ""))).containsExactlyInAnyOrderElementsOf(expectedKeys);
assertThat(listFilesInS3Bucket(remoteClient, getTargetBucket(requestRewriter, bucket, ""))).containsExactlyInAnyOrderElementsOf(expectedKeys);

internalClient.deleteObject(r -> r.bucket(bucket).key("a=1/b=2"));
internalClient.deleteObject(r -> r.bucket(bucket).key("a=1%2Fb=2"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,24 @@
import io.airlift.http.client.HttpClient;
import io.airlift.http.server.testing.TestingHttpServer;
import io.trino.aws.proxy.server.testing.TestingS3RequestRewriter;
import io.trino.aws.proxy.server.testing.TestingUtil;
import io.trino.aws.proxy.server.testing.TestingUtil.ForTesting;
import io.trino.aws.proxy.server.testing.containers.S3Container;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithTestingHttpClient;
import io.trino.aws.proxy.spi.credentials.Credentials;
import software.amazon.awssdk.services.s3.S3Client;

@TrinoAwsProxyTest(filters = {TrinoAwsProxyTestCommonModules.WithConfiguredBuckets.class, TrinoAwsProxyTestCommonModules.WithTestingHttpClient.class, TestProxiedRequests.Filter.class})
@TrinoAwsProxyTest(filters = {WithConfiguredBuckets.class, WithTestingHttpClient.class, TestProxiedRequests.Filter.class})
public class TestPresignedRequests
extends AbstractTestPresignedRequests
{
@Inject
public TestPresignedRequests(
@TestingUtil.ForTesting HttpClient httpClient,
@ForTesting HttpClient httpClient,
S3Client internalClient,
@S3Container.ForS3Container S3Client storageClient,
@TestingUtil.ForTesting Credentials testingCredentials,
@ForTesting Credentials testingCredentials,
TestingHttpServer httpServer,
TrinoAwsProxyConfig s3ProxyConfig,
XmlMapper xmlMapper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTestCommonModules.WithConfiguredBuckets;
import io.trino.aws.proxy.spi.rest.S3RequestRewriter;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.services.s3.S3Client;
Expand All @@ -39,7 +38,6 @@
import static io.trino.aws.proxy.server.testing.TestingUtil.TEST_FILE;
import static io.trino.aws.proxy.server.testing.TestingUtil.assertFileNotInS3;
import static io.trino.aws.proxy.server.testing.TestingUtil.clientBuilder;
import static io.trino.aws.proxy.server.testing.TestingUtil.deleteAllBuckets;
import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage;
import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -50,12 +48,6 @@ public class TestProxiedRequestsWithRewrite
private final URI baseUri;
private final String relativePath;

@AfterAll
public void cleanupBuckets()
{
deleteAllBuckets(remoteClient);
}

@Inject
public TestProxiedRequestsWithRewrite(
S3Client s3Client,
Expand Down
Loading

0 comments on commit 2ae84c3

Please sign in to comment.