From 7f36c79c8e3e018f30febcfad6c843e5afa7d09a Mon Sep 17 00:00:00 2001 From: Pablo Arteaga Date: Wed, 31 Jul 2024 12:09:59 +0100 Subject: [PATCH] Add strict signature validation of aws-chunked streams Closes #102 --- .../server/rest/AwsChunkedInputStream.java | 9 +- .../proxy/server/rest/TrinoS3ProxyClient.java | 2 +- .../proxy/server/TestGenericRestRequests.java | 80 ++++- .../rest/TestAwsChunkedInputStream.java | 288 +++++++++++++++--- .../signing/TestingChunkSigningSession.java | 12 +- .../aws/proxy/server/testing/TestingUtil.java | 6 + 6 files changed, 346 insertions(+), 51 deletions(-) diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java index 72e9276c..5abcc74b 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/AwsChunkedInputStream.java @@ -41,11 +41,14 @@ private enum State private State state = State.FIRST_CHUNK; private boolean delegateIsDone; private int bytesRemainingInChunk; + private int bytesAccountedFor; + private final int decodedContentLength; - AwsChunkedInputStream(InputStream delegate, ChunkSigningSession chunkSigningSession) + AwsChunkedInputStream(InputStream delegate, ChunkSigningSession chunkSigningSession, int decodedContentLength) { this.delegate = requireNonNull(delegate, "delegate is null"); this.chunkSigningSession = requireNonNull(chunkSigningSession, "chunkSigningSession is null"); + this.decodedContentLength = decodedContentLength; } @Override @@ -185,6 +188,7 @@ private void nextChunk() chunkSigningSession.complete(); state = State.LAST_CHUNK; } + bytesAccountedFor += chunkSize; success = true; } @@ -193,6 +197,9 @@ private void nextChunk() if (!success) { throw new IOException("Invalid chunk header: " + header); } + if (bytesAccountedFor > decodedContentLength) { + throw new IllegalStateException("chunked data headers report a larger size than originally declared in the request: declared %s sent %s".formatted(decodedContentLength, bytesAccountedFor)); + } } private void readEmptyLine() diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java index 9545f85f..d25a9856 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/TrinoS3ProxyClient.java @@ -185,7 +185,7 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques private Optional contentInputStream(RequestContent requestContent, SigningMetadata signingMetadata) { return switch (requestContent.contentType()) { - case AWS_CHUNKED -> requestContent.inputStream().map(inputStream -> new AwsChunkedInputStream(inputStream, signingMetadata.requiredSigningContext().chunkSigningSession())); + case AWS_CHUNKED -> requestContent.inputStream().map(inputStream -> new AwsChunkedInputStream(inputStream, signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow())); case STANDARD, W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> { SigningContext signingContext = signingMetadata.requiredSigningContext(); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java index 9be724a5..e841053b 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestGenericRestRequests.java @@ -65,6 +65,7 @@ import static io.trino.aws.proxy.server.testing.TestingUtil.cleanupBuckets; import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage; import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage; +import static io.trino.aws.proxy.server.testing.TestingUtil.listFilesInS3Bucket; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -78,6 +79,7 @@ public class TestGenericRestRequests private final S3Client storageClient; private static final String TEST_CONTENT_TYPE = "text/plain;charset=utf-8"; + private static final String ILLEGAL_CHUNK_SIGNATURE = "0".repeat(AwsS3V4ChunkSigner.getSignatureLength()); private static final String goodContent = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Viverra aliquet eget sit amet tellus cras adipiscing. Viverra mauris in aliquam sem fringilla. Facilisis mauris sit amet massa vitae. Mauris vitae ultricies leo integer malesuada. Sed libero enim sed faucibus turpis in eu mi bibendum. Lorem sed risus ultricies tristique nulla aliquet enim. Quis blandit turpis cursus in hac habitasse platea dictumst quisque. Diam maecenas ultricies mi eget mauris pharetra et ultrices neque. Aliquam sem fringilla ut morbi."; // first char is different case @@ -151,7 +153,7 @@ public void testAwsChunkedUploadValid() } @Test - public void testAwsChunkedUploadCornerCases() + public void testAwsChunkedUploadInvalidContent() throws IOException { String bucket = "test-aws-chunked"; @@ -174,9 +176,8 @@ public void testAwsChunkedUploadCornerCases() // Final chunk has an invalid size Function changeSizeOfFinalChunk = chunked -> chunked.replaceFirst("\\r\\n0;chunk-signature=(\\w+)", "\r\n1;chunk-signature=$1"); - // TODO: this currently will be accepted, we need to add stricter validation - use a different key so it does not interfere with other cases - assertThat(doAwsChunkedUpload(bucket, "final_chunk_invalid_size", goodContent, 2, validCredential, changeSizeOfFinalChunk).getStatusCode()).isEqualTo(500); - // assertFileNotInS3(storageClient, bucket, "final_chunk_invalid_size"); + assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 2, validCredential, changeSizeOfFinalChunk).getStatusCode()).isEqualTo(500); + assertFileNotInS3(storageClient, bucket, fileKey); // First chunk has an invalid size Function changeSizeOfFirstChunk = chunked -> { @@ -272,6 +273,75 @@ private StatusResponse doAwsChunkedUpload( return httpClient.execute(requestBuilder.build(), createStatusResponseHandler()); } + @Test + public void testAwsChunkedCornerCases() + throws InterruptedException + { + String bucket = "test-aws-chunked-illegal"; + String dummyContent = "hello there"; + String longDummyContent = dummyContent.repeat(4096); + storageClient.createBucket(r -> r.bucket(bucket).build()); + + // Illegal signature and no final chunk + testAwsChunkedIllegalChunks(bucket, "no-final-chunk", buildFakeChunk(longDummyContent, longDummyContent.length()), longDummyContent.length(), 500); + // Illegal signature with a final chunk + testAwsChunkedIllegalChunks(bucket, "with-final-chunk", "%s%s".formatted(buildFakeChunk(longDummyContent, longDummyContent.length()), buildFakeChunk("", 0)), longDummyContent.length(), 401); + // Illegal signature and no final chunk - more chunked data than we report in the x-amz-decoded-content-length header + testAwsChunkedIllegalChunks(bucket, "no-final-chunk-more-data-than-headers-indicate", buildFakeChunk(longDummyContent, longDummyContent.length()), 4096, 500); + + // Illegal signature with a final chunk - more chunked data than we report in the x-amz-decoded-content-length header + testAwsChunkedIllegalChunks(bucket, "with-final-chunk-more-data-than-headers-indicate", "%s%s".formatted(buildFakeChunk(longDummyContent, longDummyContent.length()), buildFakeChunk("", 0)), 4096, 500); + + // Illegal signature and no final chunk - chunk misreports its size + testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-underreports-size", buildFakeChunk(longDummyContent, 4096), 4096, 500); + // Illegal signature with a final chunk - chunk misreports its size + testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-underreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 4096), buildFakeChunk("", 0)), 4096, 500); + + // Illegal signature and no final chunk - chunk misreports its size + testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-overreports-size", buildFakeChunk(longDummyContent, 9_000_000), 4096, 500); + // Illegal signature with a final chunk - chunk misreports its size + testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-overreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 9_000_000), buildFakeChunk("", 0)), 4096, 500); + Thread.sleep(1000); + assertThat(listFilesInS3Bucket(storageClient, bucket)).isEmpty(); + } + + private static String buildFakeChunk(String dataInChunk, int reportedChunkSize) + { + return "%s;chunk-signature=%s\r\n%s\r\n".formatted(Integer.toString(reportedChunkSize, 16), ILLEGAL_CHUNK_SIGNATURE, dataInChunk); + } + + private void testAwsChunkedIllegalChunks(String bucket, String key, String rawContent, int decodedContentLength, int expectedStatusCode) + { + Instant requestDate = Instant.now(); + Credential validCredential = new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + credentialsRolesProvider.addCredentials(Credentials.build(validCredential, testingCredentials.requiredRemoteCredential())); + + ImmutableMultiMap.Builder requestHeaderBuilder = ImmutableMultiMap.builder(false); + requestHeaderBuilder + .add("Host", "%s:%d".formatted(baseUri.getHost(), baseUri.getPort())) + .add("X-Amz-Date", AwsTimestamp.toRequestFormat(requestDate)) + .add("X-Amz-Content-Sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD") + .add("X-Amz-Decoded-Content-Length", String.valueOf(decodedContentLength)) + .add("Content-Length", String.valueOf(rawContent.length())) + .add("Content-Type", TEST_CONTENT_TYPE) + .add("Content-Encoding", "aws-chunked"); + InternalSigningController signingController = new InternalSigningController( + new CredentialsController(new TestingRemoteS3Facade(), credentialsRolesProvider), + new SigningControllerConfig().setMaxClockDrift(new Duration(10, TimeUnit.SECONDS)), + new RequestLoggerController(new TrinoAwsProxyConfig())); + + URI requestUri = UriBuilder.fromUri(baseUri).path(bucket).path(key).build(); + RequestAuthorization requestAuthorization = signingController.signRequest(new SigningMetadata(SigningServiceType.S3, Credentials.build(validCredential, testingCredentials.requiredRemoteCredential()), Optional.empty()), + "us-east-1", requestDate, Optional.empty(), Credentials::emulated, requestUri, requestHeaderBuilder.build(), ImmutableMultiMap.empty(), "PUT").signingAuthorization(); + Request.Builder requestBuilder = preparePut().setUri(requestUri); + + requestHeaderBuilder.add("Authorization", requestAuthorization.authorization()); + requestHeaderBuilder.build().forEachEntry(requestBuilder::addHeader); + requestBuilder.setBodyGenerator(createStaticBodyGenerator(rawContent.getBytes(StandardCharsets.UTF_8))); + + assertThat(httpClient.execute(requestBuilder.build(), createStatusResponseHandler()).getStatusCode()).isEqualTo(expectedStatusCode); + } + private StatusResponse doPutObject(String content, String sha256) { URI uri = UriBuilder.fromUri(baseUri) @@ -306,7 +376,7 @@ private static Function getMutatorToBreakSignatureForChunk(int c for (String part : parts) { if (part.contains(";chunk-signature=")) { if (remainingChunks-- == 0) { - resultBuilder.append(part.replaceFirst("([0-9a-f]+;chunk-signature=)(\\w+)", "$1" + "0".repeat(AwsS3V4ChunkSigner.getSignatureLength()))); + resultBuilder.append(part.replaceFirst("([0-9a-f]+;chunk-signature=)(\\w+)", "$1" + ILLEGAL_CHUNK_SIGNATURE)); resultBuilder.append("\r\n"); continue; } diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java index 0f51effd..5a57c67d 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/rest/TestAwsChunkedInputStream.java @@ -13,15 +13,19 @@ */ package io.trino.aws.proxy.server.rest; +import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import io.trino.aws.proxy.server.signing.TestingChunkSigningSession; import io.trino.aws.proxy.spi.credentials.Credential; import io.trino.aws.proxy.spi.signing.ChunkSigningSession; +import io.trino.aws.proxy.spi.util.AwsTimestamp; import jakarta.ws.rs.WebApplicationException; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.signer.internal.chunkedencoding.AwsS3V4ChunkSigner; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; @@ -41,32 +45,47 @@ public class TestAwsChunkedInputStream private static final Credential BAD_CREDENTIAL = new Credential("BAD_TEST_ACCESS_KEY", "BAD_TEST_SECRET_KEY"); private static final String BAD_SEED = "THIS IS A FAKE BAD SEED"; + private static final String ILLEGAL_CHUNK_SIGNATURE = "0".repeat(AwsS3V4ChunkSigner.getSignatureLength()); + + private interface ChunkReader + { + void read(String chunkedData, int decodedContentLength, TestingChunkSigningSession signingSession, ByteArrayOutputStream output) + throws IOException; + } + + private static TestingChunkSigningSession goodTestSigningSession() + { + return TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED); + } + + private static TestingChunkSigningSession fixedTimeGoodTestSigningSession() + { + return TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED, AwsTimestamp.fromRequestTimestamp("20240801T010203Z")); + } @Test public void testGood() throws IOException { - TestingChunkSigningSession signingSession = TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED); - String chunkedStream = signingSession.generateChunkedStream(GOOD_CONTENT, 3); + TestingChunkSigningSession session = goodTestSigningSession(); + String chunkedStream = session.generateChunkedStream(GOOD_CONTENT, 3); - assertThat(readChunked(chunkedStream, signingSession)).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + assertThat(readChunked(chunkedStream, session)).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); } @Test public void testRecreateSessionValidatesGoodPayload() throws IOException { - TestingChunkSigningSession signingSession = TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED); - String chunkedStream = signingSession.generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); - assertThat(readChunked(chunkedStream, TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED))).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + assertThat(readChunked(chunkedStream, goodTestSigningSession())).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); } @Test public void testBadSeed() { - TestingChunkSigningSession signingSession = TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED); - String chunkedStream = signingSession.generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); assertThatThrownBy(() -> readChunked(chunkedStream, TestingChunkSigningSession.build(GOOD_CREDENTIAL, BAD_SEED))) .isInstanceOf(WebApplicationException.class); @@ -75,8 +94,7 @@ public void testBadSeed() @Test public void testBadCredential() { - TestingChunkSigningSession signingSession = TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED); - String chunkedStream = signingSession.generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); assertThatThrownBy(() -> readChunked(chunkedStream, TestingChunkSigningSession.build(BAD_CREDENTIAL, BAD_SEED))) .isInstanceOf(WebApplicationException.class); @@ -86,17 +104,192 @@ public void testBadCredential() public void testMultipleExtensions() throws IOException { - TestingChunkSigningSession signingSession = TestingChunkSigningSession.build(GOOD_CREDENTIAL, GOOD_SEED); - String chunkedStream = signingSession.generateChunkedStream(GOOD_CONTENT, 3); + String chunkedStream = goodTestSigningSession().generateChunkedStream(GOOD_CONTENT, 3); chunkedStream = chunkedStream.replace(";chunk-signature=", ";foo=bar;chunk-signature="); - assertThat(readChunked(chunkedStream, signingSession)).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + assertThat(readChunked(chunkedStream, goodTestSigningSession())).isEqualTo(GOOD_CONTENT.getBytes(UTF_8)); + } + + @Test + public void testAwsChunkedCornerCases() + throws IOException + { + // Ensure that the input stream always validates the data prior to returning x-amz-decoded-content-length bytes + // Otherwise we could be skipping signature verification if a chunk reports a larger size than x-amz-decoded-content-length + // We need to ensure this is the case regardless of the amount of bytes we read at a time - which is controlled by Jetty + for (ChunkReader readerMethod : ImmutableList.of( + // Read 1 byte at a time using the read() method + TestAwsChunkedInputStream::tryReadAwsChunkedData, + // Read using the read(byte[], off, len) method + // Read 1 byte at a time + buildBatchedAwsChunkedReader(1), + // Read an even number of bytes + buildBatchedAwsChunkedReader(2), + // Read an odd number of bytes + buildBatchedAwsChunkedReader(3), + // Read a very large number of bytes + buildBatchedAwsChunkedReader(4096))) { + // Both chunks are the same size + testAwsChunkedCornerCases("abcdef", "ghijkl", readerMethod); + // Chunks are different sizes + testAwsChunkedCornerCases("abcdef", "ghi", readerMethod); + // Potential tricky case: the last chunk is a single byte (meaning a single read from the chunk should force signature verification) + testAwsChunkedCornerCases("abcdef", "g", readerMethod); + } + } + + private void testAwsChunkedCornerCases(String firstChunkContent, String secondChunkContent, ChunkReader readMethod) + throws IOException + { + final int totalContentLength = firstChunkContent.length() + secondChunkContent.length(); + + String firstChunkSignature = fixedTimeGoodTestSigningSession().getChunkSignature(firstChunkContent, GOOD_SEED); + String secondChunkSignature = fixedTimeGoodTestSigningSession().getChunkSignature(secondChunkContent, firstChunkSignature); + String finalChunkSignature = fixedTimeGoodTestSigningSession().getChunkSignature("", secondChunkSignature); + + String validFirstChunk = buildChunk(firstChunkContent.length(), firstChunkSignature, firstChunkContent); + String validSecondChunk = buildChunk(secondChunkContent.length(), secondChunkSignature, secondChunkContent); + String validFinalChunk = buildChunk(0, finalChunkSignature, ""); + + // Sanity check - the below should be read correctly + ByteArrayOutputStream testOutput = new ByteArrayOutputStream(); + String correctChunkedData = String.join("", validFirstChunk, validSecondChunk, validFinalChunk); + readMethod.read(correctChunkedData, totalContentLength, fixedTimeGoodTestSigningSession(), testOutput); + assertThat(testOutput.toByteArray()).hasSize(totalContentLength); + assertThat(testOutput.toString(UTF_8)).isEqualTo(firstChunkContent + secondChunkContent); + + // Data is correctly chunked, but the decoded content length is underreported + testIllegalAwsChunkedData( + correctChunkedData, + totalContentLength - 1, + fixedTimeGoodTestSigningSession(), + readMethod); + // Data is correctly chunked, but the decoded content length is overreported + testIllegalAwsChunkedData( + correctChunkedData, + totalContentLength + 1, + fixedTimeGoodTestSigningSession(), + readMethod); + + // Missing final chunk + testIllegalAwsChunkedData( + String.join("", validFirstChunk, validSecondChunk), + totalContentLength, + fixedTimeGoodTestSigningSession(), + readMethod); + + // First chunk has an invalid signature + testIllegalAwsChunkedData( + String.join("", buildChunk(firstChunkContent.length(), ILLEGAL_CHUNK_SIGNATURE, firstChunkContent), validSecondChunk, validFinalChunk), + totalContentLength, + fixedTimeGoodTestSigningSession(), + readMethod); + // Second chunk has an invalid signature + testIllegalAwsChunkedData( + String.join("", validFirstChunk, buildChunk(secondChunkContent.length(), ILLEGAL_CHUNK_SIGNATURE, secondChunkContent), validFinalChunk), + totalContentLength, + fixedTimeGoodTestSigningSession(), + readMethod); + // Final chunk has an invalid signature + testIllegalAwsChunkedData( + String.join("", validFirstChunk, validSecondChunk, buildChunk(0, ILLEGAL_CHUNK_SIGNATURE, "")), + totalContentLength, + fixedTimeGoodTestSigningSession(), + readMethod); + + // First chunk overreports its size - total content length unchanged + testIllegalAwsChunkedData( + String.join("", buildChunk(firstChunkContent.length() + 1, firstChunkSignature, firstChunkContent), validSecondChunk, validFinalChunk), + totalContentLength, + fixedTimeGoodTestSigningSession(), + readMethod); + // First chunk overreports its size - total content length increased to match + testIllegalAwsChunkedData( + String.join("", buildChunk(firstChunkContent.length() + 1, firstChunkSignature, firstChunkContent), validSecondChunk, validFinalChunk), + totalContentLength + 1, + fixedTimeGoodTestSigningSession(), + readMethod); + + // Second chunk overreports its size - total content length unchanged + testIllegalAwsChunkedData( + String.join("", validFirstChunk, buildChunk(secondChunkContent.length() + 1, secondChunkSignature, secondChunkContent), validFinalChunk), + totalContentLength, + fixedTimeGoodTestSigningSession(), + readMethod); + // Second chunk overreports its size - total content length increased to match + testIllegalAwsChunkedData( + String.join("", validFirstChunk, buildChunk(secondChunkContent.length() + 1, secondChunkSignature, secondChunkContent), validFinalChunk), + totalContentLength + 1, + fixedTimeGoodTestSigningSession(), + readMethod); + + // Final chunk has invalid size - total content length unchanged + testIllegalAwsChunkedData( + String.join("", validFirstChunk, validSecondChunk, buildChunk(1, finalChunkSignature, "")), + totalContentLength, + fixedTimeGoodTestSigningSession(), + readMethod); + // Final chunk has invalid size - total content length increased to match + testIllegalAwsChunkedData( + String.join("", validFirstChunk, validSecondChunk, buildChunk(1, finalChunkSignature, "")), + totalContentLength + 1, + fixedTimeGoodTestSigningSession(), + readMethod); + } + + private static String buildChunk(int reportedChunkSize, String chunkSignature, String chunkContent) + { + return "%s;chunk-signature=%s\r\n%s\r\n".formatted(Integer.toString(reportedChunkSize, 16), chunkSignature, chunkContent); + } + + private static ChunkReader buildBatchedAwsChunkedReader(int bytesToReadAtATime) + { + return (chunkedData, decodedContentLength, signingSession, output) -> tryReadAwsChunkedDataBatch(chunkedData, decodedContentLength, signingSession, output, bytesToReadAtATime); + } + + private static void tryReadAwsChunkedDataBatch(String chunkedData, int decodedContentLength, TestingChunkSigningSession signingSession, ByteArrayOutputStream output, int bytesToReadAtATime) + throws IOException + { + int remainingBytes = decodedContentLength; + try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedData.getBytes(UTF_8)), signingSession, decodedContentLength)) { + while (remainingBytes > 0) { + byte[] readBytes = new byte[bytesToReadAtATime]; + int count = in.read(readBytes, 0, bytesToReadAtATime); + if (count < 0) { + throw new EOFException("Unexpected EOF"); + } + remainingBytes -= count; + output.write(readBytes, 0, count); + } + } + } + + private static void tryReadAwsChunkedData(String chunkedData, int decodedContentLength, TestingChunkSigningSession signingSession, ByteArrayOutputStream output) + throws IOException + { + int remainingBytes = decodedContentLength; + try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedData.getBytes(UTF_8)), signingSession, decodedContentLength)) { + while (remainingBytes-- > 0) { + int readByte = in.read(); + if (readByte == -1) { + throw new EOFException("Unexpected EOF"); + } + output.write(readByte); + } + } + } + + private static void testIllegalAwsChunkedData(String chunkedData, int decodedContentLength, TestingChunkSigningSession signingSession, ChunkReader readerMethod) + { + ByteArrayOutputStream testOutput = new ByteArrayOutputStream(); + assertThatThrownBy(() -> readerMethod.read(chunkedData, decodedContentLength, signingSession, testOutput)).isInstanceOfAny(IllegalStateException.class, WebApplicationException.class, IOException.class); + assertThat(testOutput.toByteArray().length).isLessThan(decodedContentLength); } - private byte[] readChunked(String chunkedStream, TestingChunkSigningSession signingSession) + private static byte[] readChunked(String chunkedStream, TestingChunkSigningSession signingSession) throws IOException { - try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedStream.getBytes(UTF_8)), signingSession)) { + try (InputStream in = new AwsChunkedInputStream(new ByteArrayInputStream(chunkedStream.getBytes(UTF_8)), signingSession, chunkedStream.length())) { return ByteStreams.toByteArray(in); } } @@ -113,8 +306,9 @@ private byte[] readChunked(String chunkedStream, TestingChunkSigningSession sign public void testChunkedInputStreamLargeBuffer() throws IOException { - ByteArrayInputStream inputStream = new ByteArrayInputStream(CHUNKED_INPUT.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = CHUNKED_INPUT.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); byte[] buffer = new byte[300]; ByteArrayOutputStream out = new ByteArrayOutputStream(); int len; @@ -135,8 +329,9 @@ public void testChunkedInputStreamLargeBuffer() public void testChunkedInputStreamSmallBuffer() throws IOException { - ByteArrayInputStream inputStream = new ByteArrayInputStream(CHUNKED_INPUT.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = CHUNKED_INPUT.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); byte[] buffer = new byte[7]; ByteArrayOutputStream out = new ByteArrayOutputStream(); @@ -157,8 +352,9 @@ public void testChunkedInputStreamOneByteRead() throws IOException { String s = "5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); int ch; int i = '0'; while ((ch = in.read()) != -1) { @@ -176,8 +372,9 @@ public void testChunkedInputStreamOneByteRead() public void testChunkedInputStreamNoClosingChunk() { String s = "5;chunk-signature=0\r\n01234\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); byte[] tmp = new byte[5]; // altered from original test. Our AwsChunkedInputStream is improved and throws when the final chunk is missing or bad assertThrows(IOException.class, () -> in.read(tmp)); @@ -191,8 +388,9 @@ public void testCorruptChunkedInputStreamTruncatedCRLF() // altered to add a few more bad stings Stream.of("5;chunk-signature=0\r\n01234", ";chunk-signature=0\r\n01234\r\n", "5;chunk-signature=0\r\n012340;chunk-signature=0\r\n\r\n") .forEach(s -> { - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); byte[] tmp = new byte[5]; // altered from original test. Our AwsChunkedInputStream is improved and throws when the final chunk is missing or bad assertThrows(IOException.class, () -> in.read(tmp)); @@ -210,8 +408,9 @@ public void testCorruptChunkedInputStreamMissingCRLF() throws IOException { String s = "5;chunk-signature=0\r\n012345\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); byte[] buffer = new byte[300]; ByteArrayOutputStream out = new ByteArrayOutputStream(); assertThrows(IOException.class, () -> { @@ -229,8 +428,9 @@ public void testCorruptChunkedInputStreamMissingLF() throws IOException { String s = "5;chunk-signature=0\r01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); assertThrows(IOException.class, in::read); in.close(); } @@ -241,8 +441,9 @@ public void testCorruptChunkedInputStreamInvalidSize() throws IOException { String s = "whatever;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); assertThrows(IOException.class, in::read); in.close(); } @@ -253,8 +454,9 @@ public void testCorruptChunkedInputStreamNegativeSize() throws IOException { String s = "-5;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); assertThrows(IOException.class, in::read); in.close(); } @@ -265,8 +467,9 @@ public void testCorruptChunkedInputStreamTruncatedChunk() throws IOException { String s = "3;chunk-signature=0\r\n12"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); byte[] buffer = new byte[300]; assertEquals(2, in.read(buffer)); assertThrows(IOException.class, () -> in.read(buffer)); @@ -278,8 +481,9 @@ public void testCorruptChunkedInputStreamClose() throws IOException { String s = "whatever;chunk-signature=0\r\n01234\r\n5;chunk-signature=0\r\n56789\r\n0;chunk-signature=0\r\n\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - try (InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession())) { + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + try (InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length)) { assertThrows(IOException.class, in::read); } } @@ -289,8 +493,9 @@ public void testEmptyChunkedInputStream() throws IOException { String s = "0;chunk-signature=0\r\n\r\n"; - ByteArrayInputStream inputStream = new ByteArrayInputStream(s.getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = s.getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), rawBytes.length); byte[] buffer = new byte[300]; ByteArrayOutputStream out = new ByteArrayOutputStream(); int len; @@ -306,8 +511,9 @@ public void testEmptyChunkedInputStream() public void testHugeChunk() throws IOException { - ByteArrayInputStream inputStream = new ByteArrayInputStream("499602D2;chunk-signature=0\r\n01234567".getBytes(UTF_8)); - InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession()); + byte[] rawBytes = "499602D2;chunk-signature=0\r\n01234567".getBytes(UTF_8); + ByteArrayInputStream inputStream = new ByteArrayInputStream(rawBytes); + InputStream in = new AwsChunkedInputStream(inputStream, new DummyChunkSigningSession(), 1234567890); ByteArrayOutputStream out = new ByteArrayOutputStream(); for (int i = 0; i < 8; ++i) { diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java index 67552229..55bdd9ba 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/signing/TestingChunkSigningSession.java @@ -103,9 +103,7 @@ public String generateChunkedStream(String content, int partitions) int thisLength = Math.min(chunkSize, content.length() - index); String thisChunk = content.substring(index, index + thisLength); - Hasher hasher = Hashing.sha256().newHasher(); - hasher.putString(thisChunk, UTF_8); - String thisSignature = chunkSigner.signChunk(hasher.hash(), previousSignature); + String thisSignature = getChunkSignature(thisChunk, previousSignature); chunkedStream.append(Integer.toHexString(thisLength)).append(";chunk-signature=").append(thisSignature).append("\r\n"); chunkedStream.append(thisChunk).append("\r\n"); previousSignature = thisSignature; @@ -119,6 +117,14 @@ public String generateChunkedStream(String content, int partitions) return chunkedStream.toString(); } + @SuppressWarnings("UnstableApiUsage") + public String getChunkSignature(String chunkContent, String previousSignature) + { + Hasher hasher = Hashing.sha256().newHasher(); + hasher.putString(chunkContent, UTF_8); + return chunkSigner.signChunk(hasher.hash(), previousSignature); + } + private TestingChunkSigningSession(String seed, Instant instant, byte[] signingKey, String keyPath) { super(new ChunkSigner(instant, keyPath, signingKey), seed); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java index ca899c15..229fd4ca 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/testing/TestingUtil.java @@ -26,6 +26,7 @@ import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import software.amazon.awssdk.services.s3.model.ObjectIdentifier; import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.model.S3Object; import java.io.ByteArrayOutputStream; import java.io.File; @@ -130,4 +131,9 @@ public static void assertFileNotInS3(S3Client storageClient, String bucket, Stri .extracting(S3Exception::statusCode) .isEqualTo(404); } + + public static List listFilesInS3Bucket(S3Client storageClient, String bucket) + { + return storageClient.listObjects(request -> request.bucket(bucket)).contents().stream().map(S3Object::key).collect(toImmutableList()); + } }