diff --git a/trino-aws-proxy/pom.xml b/trino-aws-proxy/pom.xml
index 50b579ea..d5f4c645 100644
--- a/trino-aws-proxy/pom.xml
+++ b/trino-aws-proxy/pom.xml
@@ -44,6 +44,11 @@
jackson-dataformat-xml
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jdk8
+
+
com.github.ben-manes.caffeine
caffeine
diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java
index 59372e8f..afe6de28 100644
--- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java
+++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java
@@ -19,6 +19,7 @@
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.dataformat.xml.XmlMapper;
+import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.google.common.annotations.VisibleForTesting;
import com.google.inject.Binder;
import com.google.inject.Provides;
@@ -38,6 +39,7 @@
import io.trino.aws.proxy.server.rest.RequestFilter;
import io.trino.aws.proxy.server.rest.RequestLoggerController;
import io.trino.aws.proxy.server.rest.S3PresignController;
+import io.trino.aws.proxy.server.rest.ThrowableMapper;
import io.trino.aws.proxy.server.rest.TrinoLogsResource;
import io.trino.aws.proxy.server.rest.TrinoS3ProxyClient;
import io.trino.aws.proxy.server.rest.TrinoS3ProxyClient.ForProxyClient;
@@ -90,6 +92,7 @@ protected void setup(Binder binder)
MapBinder, SigningServiceType> signingServiceTypesMapBinder = newMapBinder(binder, new TypeLiteral<>() {}, new TypeLiteral<>() {});
jaxrsBinder.bind(RequestFilter.class);
+ jaxrsBinder.bind(ThrowableMapper.class);
bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.S3, TrinoS3Resource.class, builtConfig.getS3Path());
bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.STS, TrinoStsResource.class, builtConfig.getStsPath());
bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.LOGS, TrinoLogsResource.class, builtConfig.getLogsPath());
@@ -160,6 +163,7 @@ public XmlMapper newXmlMapper()
{
// NOTE: this is _not_ a singleton on purpose. ObjectMappers/XmlMappers are mutable.
XmlMapper xmlMapper = new XmlMapper();
+ xmlMapper.registerModule(new Jdk8Module());
xmlMapper.setPropertyNamingStrategy(PropertyNamingStrategies.UPPER_CAMEL_CASE);
return xmlMapper;
}
diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ErrorResponse.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ErrorResponse.java
new file mode 100644
index 00000000..35c53a36
--- /dev/null
+++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ErrorResponse.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.aws.proxy.server.rest;
+
+import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlRootElement;
+
+import java.util.Optional;
+
+import static java.util.Objects.requireNonNull;
+
+@JacksonXmlRootElement(localName = "Error")
+public record ErrorResponse(String code, Optional message, String resource, Optional requestId)
+{
+ public ErrorResponse
+ {
+ requireNonNull(code, "code is null");
+ requireNonNull(message, "message is null");
+ requireNonNull(resource, "resource is null");
+ requireNonNull(requestId, "requestId is null");
+ }
+}
diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java
index 42d19468..be13c839 100644
--- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java
+++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java
@@ -87,12 +87,15 @@ public void filter(ContainerRequestContext requestContext)
}
Request request = RequestBuilder.fromRequest(containerRequest);
+ containerRequest.setProperty(Request.class.getName(), request);
+
RequestLoggingSession requestLoggingSession = requestLoggerController.newRequestSession(request, signingServiceType);
containerRequest.setProperty(RequestLoggingSession.class.getName(), requestLoggingSession);
SigningMetadata signingMetadata;
try {
signingMetadata = signingController.validateAndParseAuthorization(request, signingServiceType);
+ containerRequest.setProperty(SigningMetadata.class.getName(), signingMetadata);
}
catch (Exception e) {
requestLoggingSession.logException(e);
@@ -100,12 +103,10 @@ public void filter(ContainerRequestContext requestContext)
switch (Throwables.getRootCause(e)) {
case WebApplicationException webApplicationException -> throw webApplicationException;
case IOException ioException -> throw ioException;
+ case RuntimeException runtimeException -> throw runtimeException;
default -> throw new RuntimeException(e);
}
}
-
- containerRequest.setProperty(Request.class.getName(), request);
- containerRequest.setProperty(SigningMetadata.class.getName(), signingMetadata);
}
else {
log.warn("%s is not a ContainerRequest", requestContext.getRequest().getClass().getName());
diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ThrowableMapper.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ThrowableMapper.java
new file mode 100644
index 00000000..ea3f35cf
--- /dev/null
+++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ThrowableMapper.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.aws.proxy.server.rest;
+
+import com.fasterxml.jackson.dataformat.xml.XmlMapper;
+import com.google.inject.Inject;
+import io.airlift.http.client.HttpStatus;
+import io.airlift.log.Logger;
+import io.trino.aws.proxy.spi.rest.Request;
+import jakarta.ws.rs.WebApplicationException;
+import jakarta.ws.rs.container.ResourceContext;
+import jakarta.ws.rs.core.Context;
+import jakarta.ws.rs.core.Response;
+import jakarta.ws.rs.core.Response.ResponseBuilder;
+import jakarta.ws.rs.ext.ExceptionMapper;
+import org.glassfish.jersey.server.ContainerRequest;
+
+import java.util.Optional;
+import java.util.UUID;
+
+import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE;
+import static jakarta.ws.rs.core.MediaType.APPLICATION_XML_TYPE;
+import static java.util.Objects.requireNonNull;
+
+public class ThrowableMapper
+ implements ExceptionMapper
+{
+ private static final Logger log = Logger.get(ThrowableMapper.class);
+ private static final String X_AMZ_REQUEST_ID = "x-amz-request-id";
+
+ @Context
+ private ResourceContext resourceContext;
+
+ private final XmlMapper xmlMapper;
+
+ @Inject
+ public ThrowableMapper(XmlMapper xmlMapper)
+ {
+ this.xmlMapper = requireNonNull(xmlMapper, "xmlMapper is null");
+ }
+
+ @Override
+ public Response toResponse(Throwable throwable)
+ {
+ ContainerRequest containerRequest = resourceContext.getResource(ContainerRequest.class);
+ Optional requestId = Optional.ofNullable((Request) containerRequest.getProperty(Request.class.getName()))
+ .map(Request::requestId)
+ .map(UUID::toString);
+
+ HttpStatus status = switch (throwable) {
+ case WebApplicationException webApplicationException -> HttpStatus.fromStatusCode(webApplicationException.getResponse().getStatus());
+ default -> {
+ log.error(throwable, "Request failed for %s", containerRequest.getRequestUri());
+ yield HttpStatus.INTERNAL_SERVER_ERROR;
+ }
+ };
+
+ try {
+ ErrorResponse response = new ErrorResponse(
+ status.reason(),
+ Optional.ofNullable(throwable.getMessage()),
+ containerRequest.getRequestUri().getPath(),
+ requestId);
+
+ ResponseBuilder responseBuilder = Response.status(status.code())
+ .header(CONTENT_TYPE, APPLICATION_XML_TYPE);
+ requestId.ifPresent(id -> responseBuilder.header(X_AMZ_REQUEST_ID, id));
+ return responseBuilder.entity(xmlMapper.writeValueAsString(response)).build();
+ }
+ catch (Exception exception) {
+ log.error(exception, "Processing of throwable %s caused an exception", throwable);
+ return Response.status(Response.Status.INTERNAL_SERVER_ERROR).build();
+ }
+ }
+}
diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java
index 3e34463a..f087d2c8 100644
--- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java
+++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java
@@ -109,7 +109,7 @@ public void testDatabaseSecurity()
try {
facadeProvider.disallowGets.set(true);
- clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.contains("Error Code: 401 Unauthorized"));
+ clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.contains("Status Code: 401; Error Code: Unauthorized"));
}
finally {
facadeProvider.disallowGets.set(false);
diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/DelegatingCredentialsProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/DelegatingCredentialsProvider.java
new file mode 100644
index 00000000..4dc12d1d
--- /dev/null
+++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/DelegatingCredentialsProvider.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.aws.proxy.server.credentials;
+
+import io.trino.aws.proxy.spi.credentials.Credentials;
+import io.trino.aws.proxy.spi.credentials.CredentialsProvider;
+
+import java.util.Optional;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static java.util.Objects.requireNonNull;
+
+public class DelegatingCredentialsProvider
+ implements CredentialsProvider
+{
+ private final AtomicReference delegate = new AtomicReference<>();
+
+ public void setDelegate(CredentialsProvider delegate)
+ {
+ this.delegate.set(requireNonNull(delegate, "delegate is null"));
+ }
+
+ @Override
+ public Optional credentials(String emulatedAccessKey, Optional session)
+ {
+ return delegate.get().credentials(emulatedAccessKey, session);
+ }
+}
diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/plugin/exception/TestCredentialsProviderExceptionPropagation.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/plugin/exception/TestCredentialsProviderExceptionPropagation.java
new file mode 100644
index 00000000..e9c8f6b3
--- /dev/null
+++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/plugin/exception/TestCredentialsProviderExceptionPropagation.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.aws.proxy.server.plugin.exception;
+
+import com.google.inject.Inject;
+import com.google.inject.Scopes;
+import io.airlift.http.client.HttpStatus;
+import io.trino.aws.proxy.server.credentials.DelegatingCredentialsProvider;
+import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer.Builder;
+import io.trino.aws.proxy.server.testing.harness.BuilderFilter;
+import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
+import jakarta.ws.rs.WebApplicationException;
+import org.junit.jupiter.api.Test;
+import software.amazon.awssdk.services.s3.S3Client;
+import software.amazon.awssdk.services.s3.model.S3Exception;
+
+import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.credentialsProviderModule;
+import static java.util.Objects.requireNonNull;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.InstanceOfAssertFactories.type;
+
+@TrinoAwsProxyTest(filters = TestCredentialsProviderExceptionPropagation.Filter.class)
+public class TestCredentialsProviderExceptionPropagation
+{
+ private final DelegatingCredentialsProvider delegatingCredentialsProvider;
+ private final S3Client internalClient;
+
+ public static class Filter
+ implements BuilderFilter
+ {
+ @Override
+ public Builder filter(Builder builder)
+ {
+ return builder.withoutTestingCredentialsRoleProviders()
+ .addModule(credentialsProviderModule("testing", DelegatingCredentialsProvider.class, binder -> binder.bind(DelegatingCredentialsProvider.class).in(Scopes.SINGLETON)))
+ .withProperty("credentials-provider.type", "testing");
+ }
+ }
+
+ @Inject
+ public TestCredentialsProviderExceptionPropagation(DelegatingCredentialsProvider delegatingCredentialsProvider, S3Client internalClient)
+ {
+ this.delegatingCredentialsProvider = requireNonNull(delegatingCredentialsProvider, "delegatingCredentialsProvider is null");
+ this.internalClient = requireNonNull(internalClient, "internalClient is null");
+ }
+
+ @Test
+ public void testRuntimeException()
+ {
+ delegatingCredentialsProvider.setDelegate((_, _) -> { throw new RuntimeException("Testing exception"); });
+ assertThatThrownBy(internalClient::listBuckets)
+ .asInstanceOf(type(S3Exception.class))
+ .satisfies(s3Exception -> {
+ assertThat(s3Exception.statusCode()).isEqualTo(500);
+ assertThat(s3Exception.awsErrorDetails().errorMessage()).isEqualTo("Testing exception");
+ });
+ }
+
+ @Test
+ public void testWebApplicationException()
+ {
+ delegatingCredentialsProvider.setDelegate((_, _) -> { throw new WebApplicationException("Testing exception", HttpStatus.IM_A_TEAPOT.code()); });
+ assertThatThrownBy(internalClient::listBuckets)
+ .asInstanceOf(type(S3Exception.class))
+ .satisfies(s3Exception -> {
+ assertThat(s3Exception.statusCode()).isEqualTo(418);
+ assertThat(s3Exception.awsErrorDetails().errorMessage()).isEqualTo("Testing exception");
+ });
+ }
+}