Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strict signature validation of aws-chunked streams #130

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ private enum State
private State state = State.FIRST_CHUNK;
private boolean delegateIsDone;
private int bytesRemainingInChunk;
private int bytesAccountedFor;
private final int decodedContentLength;

AwsChunkedInputStream(InputStream delegate, ChunkSigningSession chunkSigningSession)
AwsChunkedInputStream(InputStream delegate, ChunkSigningSession chunkSigningSession, int decodedContentLength)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.chunkSigningSession = requireNonNull(chunkSigningSession, "chunkSigningSession is null");
this.decodedContentLength = decodedContentLength;
}

@Override
Expand Down Expand Up @@ -185,6 +188,7 @@ private void nextChunk()
chunkSigningSession.complete();
state = State.LAST_CHUNK;
}
bytesAccountedFor += chunkSize;

success = true;
}
Expand All @@ -193,6 +197,9 @@ private void nextChunk()
if (!success) {
throw new IOException("Invalid chunk header: " + header);
}
if (bytesAccountedFor > decodedContentLength) {
throw new IllegalStateException("chunked data headers report a larger size than originally declared in the request: declared %s sent %s".formatted(decodedContentLength, bytesAccountedFor));
}
}

private void readEmptyLine()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public void proxyRequest(SigningMetadata signingMetadata, ParsedS3Request reques
private Optional<InputStream> contentInputStream(RequestContent requestContent, SigningMetadata signingMetadata)
{
return switch (requestContent.contentType()) {
case AWS_CHUNKED -> requestContent.inputStream().map(inputStream -> new AwsChunkedInputStream(inputStream, signingMetadata.requiredSigningContext().chunkSigningSession()));
case AWS_CHUNKED -> requestContent.inputStream().map(inputStream -> new AwsChunkedInputStream(inputStream, signingMetadata.requiredSigningContext().chunkSigningSession(), requestContent.contentLength().orElseThrow()));

case STANDARD, W3C_CHUNKED -> requestContent.inputStream().map(inputStream -> {
SigningContext signingContext = signingMetadata.requiredSigningContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import static io.trino.aws.proxy.server.testing.TestingUtil.cleanupBuckets;
import static io.trino.aws.proxy.server.testing.TestingUtil.getFileFromStorage;
import static io.trino.aws.proxy.server.testing.TestingUtil.headObjectInStorage;
import static io.trino.aws.proxy.server.testing.TestingUtil.listFilesInS3Bucket;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -78,6 +79,7 @@ public class TestGenericRestRequests
private final S3Client storageClient;

private static final String TEST_CONTENT_TYPE = "text/plain;charset=utf-8";
private static final String ILLEGAL_CHUNK_SIGNATURE = "0".repeat(AwsS3V4ChunkSigner.getSignatureLength());
private static final String goodContent = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Viverra aliquet eget sit amet tellus cras adipiscing. Viverra mauris in aliquam sem fringilla. Facilisis mauris sit amet massa vitae. Mauris vitae ultricies leo integer malesuada. Sed libero enim sed faucibus turpis in eu mi bibendum. Lorem sed risus ultricies tristique nulla aliquet enim. Quis blandit turpis cursus in hac habitasse platea dictumst quisque. Diam maecenas ultricies mi eget mauris pharetra et ultrices neque. Aliquam sem fringilla ut morbi.";

// first char is different case
Expand Down Expand Up @@ -151,7 +153,7 @@ public void testAwsChunkedUploadValid()
}

@Test
public void testAwsChunkedUploadCornerCases()
public void testAwsChunkedUploadInvalidContent()
throws IOException
{
String bucket = "test-aws-chunked";
Expand All @@ -174,9 +176,8 @@ public void testAwsChunkedUploadCornerCases()

// Final chunk has an invalid size
Function<String, String> changeSizeOfFinalChunk = chunked -> chunked.replaceFirst("\\r\\n0;chunk-signature=(\\w+)", "\r\n1;chunk-signature=$1");
// TODO: this currently will be accepted, we need to add stricter validation - use a different key so it does not interfere with other cases
assertThat(doAwsChunkedUpload(bucket, "final_chunk_invalid_size", goodContent, 2, validCredential, changeSizeOfFinalChunk).getStatusCode()).isEqualTo(500);
// assertFileNotInS3(storageClient, bucket, "final_chunk_invalid_size");
assertThat(doAwsChunkedUpload(bucket, fileKey, goodContent, 2, validCredential, changeSizeOfFinalChunk).getStatusCode()).isEqualTo(500);
assertFileNotInS3(storageClient, bucket, fileKey);

// First chunk has an invalid size
Function<String, String> changeSizeOfFirstChunk = chunked -> {
Expand Down Expand Up @@ -272,6 +273,75 @@ private StatusResponse doAwsChunkedUpload(
return httpClient.execute(requestBuilder.build(), createStatusResponseHandler());
}

@Test
public void testAwsChunkedCornerCases()
throws InterruptedException
{
String bucket = "test-aws-chunked-illegal";
String dummyContent = "hello there";
String longDummyContent = dummyContent.repeat(4096);
storageClient.createBucket(r -> r.bucket(bucket).build());

// Illegal signature and no final chunk
testAwsChunkedIllegalChunks(bucket, "no-final-chunk", buildFakeChunk(longDummyContent, longDummyContent.length()), longDummyContent.length(), 500);
// Illegal signature with a final chunk
testAwsChunkedIllegalChunks(bucket, "with-final-chunk", "%s%s".formatted(buildFakeChunk(longDummyContent, longDummyContent.length()), buildFakeChunk("", 0)), longDummyContent.length(), 401);
// Illegal signature and no final chunk - more chunked data than we report in the x-amz-decoded-content-length header
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-more-data-than-headers-indicate", buildFakeChunk(longDummyContent, longDummyContent.length()), 4096, 500);

// Illegal signature with a final chunk - more chunked data than we report in the x-amz-decoded-content-length header
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-more-data-than-headers-indicate", "%s%s".formatted(buildFakeChunk(longDummyContent, longDummyContent.length()), buildFakeChunk("", 0)), 4096, 500);

// Illegal signature and no final chunk - chunk misreports its size
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-underreports-size", buildFakeChunk(longDummyContent, 4096), 4096, 500);
// Illegal signature with a final chunk - chunk misreports its size
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-underreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 4096), buildFakeChunk("", 0)), 4096, 500);

// Illegal signature and no final chunk - chunk misreports its size
testAwsChunkedIllegalChunks(bucket, "no-final-chunk-chunk-overreports-size", buildFakeChunk(longDummyContent, 9_000_000), 4096, 500);
// Illegal signature with a final chunk - chunk misreports its size
testAwsChunkedIllegalChunks(bucket, "with-final-chunk-chunk-overreports-size", "%s%s".formatted(buildFakeChunk(longDummyContent, 9_000_000), buildFakeChunk("", 0)), 4096, 500);
Thread.sleep(1000);
assertThat(listFilesInS3Bucket(storageClient, bucket)).isEmpty();
}

private static String buildFakeChunk(String dataInChunk, int reportedChunkSize)
{
return "%s;chunk-signature=%s\r\n%s\r\n".formatted(Integer.toString(reportedChunkSize, 16), ILLEGAL_CHUNK_SIGNATURE, dataInChunk);
}

private void testAwsChunkedIllegalChunks(String bucket, String key, String rawContent, int decodedContentLength, int expectedStatusCode)
{
Instant requestDate = Instant.now();
Credential validCredential = new Credential(UUID.randomUUID().toString(), UUID.randomUUID().toString());
credentialsRolesProvider.addCredentials(Credentials.build(validCredential, testingCredentials.requiredRemoteCredential()));

ImmutableMultiMap.Builder requestHeaderBuilder = ImmutableMultiMap.builder(false);
requestHeaderBuilder
.add("Host", "%s:%d".formatted(baseUri.getHost(), baseUri.getPort()))
.add("X-Amz-Date", AwsTimestamp.toRequestFormat(requestDate))
.add("X-Amz-Content-Sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
.add("X-Amz-Decoded-Content-Length", String.valueOf(decodedContentLength))
.add("Content-Length", String.valueOf(rawContent.length()))
.add("Content-Type", TEST_CONTENT_TYPE)
.add("Content-Encoding", "aws-chunked");
InternalSigningController signingController = new InternalSigningController(
new CredentialsController(new TestingRemoteS3Facade(), credentialsRolesProvider),
new SigningControllerConfig().setMaxClockDrift(new Duration(10, TimeUnit.SECONDS)),
new RequestLoggerController(new TrinoAwsProxyConfig()));

URI requestUri = UriBuilder.fromUri(baseUri).path(bucket).path(key).build();
RequestAuthorization requestAuthorization = signingController.signRequest(new SigningMetadata(SigningServiceType.S3, Credentials.build(validCredential, testingCredentials.requiredRemoteCredential()), Optional.empty()),
"us-east-1", requestDate, Optional.empty(), Credentials::emulated, requestUri, requestHeaderBuilder.build(), ImmutableMultiMap.empty(), "PUT").signingAuthorization();
Request.Builder requestBuilder = preparePut().setUri(requestUri);

requestHeaderBuilder.add("Authorization", requestAuthorization.authorization());
requestHeaderBuilder.build().forEachEntry(requestBuilder::addHeader);
requestBuilder.setBodyGenerator(createStaticBodyGenerator(rawContent.getBytes(StandardCharsets.UTF_8)));

assertThat(httpClient.execute(requestBuilder.build(), createStatusResponseHandler()).getStatusCode()).isEqualTo(expectedStatusCode);
}

private StatusResponse doPutObject(String content, String sha256)
{
URI uri = UriBuilder.fromUri(baseUri)
Expand Down Expand Up @@ -306,7 +376,7 @@ private static Function<String, String> getMutatorToBreakSignatureForChunk(int c
for (String part : parts) {
if (part.contains(";chunk-signature=")) {
if (remainingChunks-- == 0) {
resultBuilder.append(part.replaceFirst("([0-9a-f]+;chunk-signature=)(\\w+)", "$1" + "0".repeat(AwsS3V4ChunkSigner.getSignatureLength())));
resultBuilder.append(part.replaceFirst("([0-9a-f]+;chunk-signature=)(\\w+)", "$1" + ILLEGAL_CHUNK_SIGNATURE));
resultBuilder.append("\r\n");
continue;
}
Expand Down
Loading
Loading