diff --git a/trino-aws-proxy/pom.xml b/trino-aws-proxy/pom.xml index 50b579ea..d5f4c645 100644 --- a/trino-aws-proxy/pom.xml +++ b/trino-aws-proxy/pom.xml @@ -44,6 +44,11 @@ jackson-dataformat-xml + + com.fasterxml.jackson.datatype + jackson-datatype-jdk8 + + com.github.ben-manes.caffeine caffeine diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java index 59372e8f..afe6de28 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/TrinoAwsProxyServerModule.java @@ -19,6 +19,7 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.module.SimpleModule; import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; import com.google.common.annotations.VisibleForTesting; import com.google.inject.Binder; import com.google.inject.Provides; @@ -38,6 +39,7 @@ import io.trino.aws.proxy.server.rest.RequestFilter; import io.trino.aws.proxy.server.rest.RequestLoggerController; import io.trino.aws.proxy.server.rest.S3PresignController; +import io.trino.aws.proxy.server.rest.ThrowableMapper; import io.trino.aws.proxy.server.rest.TrinoLogsResource; import io.trino.aws.proxy.server.rest.TrinoS3ProxyClient; import io.trino.aws.proxy.server.rest.TrinoS3ProxyClient.ForProxyClient; @@ -90,6 +92,7 @@ protected void setup(Binder binder) MapBinder, SigningServiceType> signingServiceTypesMapBinder = newMapBinder(binder, new TypeLiteral<>() {}, new TypeLiteral<>() {}); jaxrsBinder.bind(RequestFilter.class); + jaxrsBinder.bind(ThrowableMapper.class); bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.S3, TrinoS3Resource.class, builtConfig.getS3Path()); bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.STS, TrinoStsResource.class, builtConfig.getStsPath()); bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.LOGS, TrinoLogsResource.class, builtConfig.getLogsPath()); @@ -160,6 +163,7 @@ public XmlMapper newXmlMapper() { // NOTE: this is _not_ a singleton on purpose. ObjectMappers/XmlMappers are mutable. XmlMapper xmlMapper = new XmlMapper(); + xmlMapper.registerModule(new Jdk8Module()); xmlMapper.setPropertyNamingStrategy(PropertyNamingStrategies.UPPER_CAMEL_CASE); return xmlMapper; } diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ErrorResponse.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ErrorResponse.java new file mode 100644 index 00000000..35c53a36 --- /dev/null +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ErrorResponse.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.rest; + +import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlRootElement; + +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +@JacksonXmlRootElement(localName = "Error") +public record ErrorResponse(String code, Optional message, String resource, Optional requestId) +{ + public ErrorResponse + { + requireNonNull(code, "code is null"); + requireNonNull(message, "message is null"); + requireNonNull(resource, "resource is null"); + requireNonNull(requestId, "requestId is null"); + } +} diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java index 42d19468..be13c839 100644 --- a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/RequestFilter.java @@ -87,12 +87,15 @@ public void filter(ContainerRequestContext requestContext) } Request request = RequestBuilder.fromRequest(containerRequest); + containerRequest.setProperty(Request.class.getName(), request); + RequestLoggingSession requestLoggingSession = requestLoggerController.newRequestSession(request, signingServiceType); containerRequest.setProperty(RequestLoggingSession.class.getName(), requestLoggingSession); SigningMetadata signingMetadata; try { signingMetadata = signingController.validateAndParseAuthorization(request, signingServiceType); + containerRequest.setProperty(SigningMetadata.class.getName(), signingMetadata); } catch (Exception e) { requestLoggingSession.logException(e); @@ -100,12 +103,10 @@ public void filter(ContainerRequestContext requestContext) switch (Throwables.getRootCause(e)) { case WebApplicationException webApplicationException -> throw webApplicationException; case IOException ioException -> throw ioException; + case RuntimeException runtimeException -> throw runtimeException; default -> throw new RuntimeException(e); } } - - containerRequest.setProperty(Request.class.getName(), request); - containerRequest.setProperty(SigningMetadata.class.getName(), signingMetadata); } else { log.warn("%s is not a ContainerRequest", requestContext.getRequest().getClass().getName()); diff --git a/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ThrowableMapper.java b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ThrowableMapper.java new file mode 100644 index 00000000..ea3f35cf --- /dev/null +++ b/trino-aws-proxy/src/main/java/io/trino/aws/proxy/server/rest/ThrowableMapper.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.rest; + +import com.fasterxml.jackson.dataformat.xml.XmlMapper; +import com.google.inject.Inject; +import io.airlift.http.client.HttpStatus; +import io.airlift.log.Logger; +import io.trino.aws.proxy.spi.rest.Request; +import jakarta.ws.rs.WebApplicationException; +import jakarta.ws.rs.container.ResourceContext; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.Response.ResponseBuilder; +import jakarta.ws.rs.ext.ExceptionMapper; +import org.glassfish.jersey.server.ContainerRequest; + +import java.util.Optional; +import java.util.UUID; + +import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE; +import static jakarta.ws.rs.core.MediaType.APPLICATION_XML_TYPE; +import static java.util.Objects.requireNonNull; + +public class ThrowableMapper + implements ExceptionMapper +{ + private static final Logger log = Logger.get(ThrowableMapper.class); + private static final String X_AMZ_REQUEST_ID = "x-amz-request-id"; + + @Context + private ResourceContext resourceContext; + + private final XmlMapper xmlMapper; + + @Inject + public ThrowableMapper(XmlMapper xmlMapper) + { + this.xmlMapper = requireNonNull(xmlMapper, "xmlMapper is null"); + } + + @Override + public Response toResponse(Throwable throwable) + { + ContainerRequest containerRequest = resourceContext.getResource(ContainerRequest.class); + Optional requestId = Optional.ofNullable((Request) containerRequest.getProperty(Request.class.getName())) + .map(Request::requestId) + .map(UUID::toString); + + HttpStatus status = switch (throwable) { + case WebApplicationException webApplicationException -> HttpStatus.fromStatusCode(webApplicationException.getResponse().getStatus()); + default -> { + log.error(throwable, "Request failed for %s", containerRequest.getRequestUri()); + yield HttpStatus.INTERNAL_SERVER_ERROR; + } + }; + + try { + ErrorResponse response = new ErrorResponse( + status.reason(), + Optional.ofNullable(throwable.getMessage()), + containerRequest.getRequestUri().getPath(), + requestId); + + ResponseBuilder responseBuilder = Response.status(status.code()) + .header(CONTENT_TYPE, APPLICATION_XML_TYPE); + requestId.ifPresent(id -> responseBuilder.header(X_AMZ_REQUEST_ID, id)); + return responseBuilder.entity(xmlMapper.writeValueAsString(response)).build(); + } + catch (Exception exception) { + log.error(exception, "Processing of throwable %s caused an exception", throwable); + return Response.status(Response.Status.INTERNAL_SERVER_ERROR).build(); + } + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java index 3e34463a..f087d2c8 100644 --- a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/TestDatabaseSecurity.java @@ -109,7 +109,7 @@ public void testDatabaseSecurity() try { facadeProvider.disallowGets.set(true); - clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.contains("Error Code: 401 Unauthorized")); + clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.contains("Status Code: 401; Error Code: Unauthorized")); } finally { facadeProvider.disallowGets.set(false); diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/DelegatingCredentialsProvider.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/DelegatingCredentialsProvider.java new file mode 100644 index 00000000..4dc12d1d --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/credentials/DelegatingCredentialsProvider.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.credentials; + +import io.trino.aws.proxy.spi.credentials.Credentials; +import io.trino.aws.proxy.spi.credentials.CredentialsProvider; + +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Objects.requireNonNull; + +public class DelegatingCredentialsProvider + implements CredentialsProvider +{ + private final AtomicReference delegate = new AtomicReference<>(); + + public void setDelegate(CredentialsProvider delegate) + { + this.delegate.set(requireNonNull(delegate, "delegate is null")); + } + + @Override + public Optional credentials(String emulatedAccessKey, Optional session) + { + return delegate.get().credentials(emulatedAccessKey, session); + } +} diff --git a/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/plugin/exception/TestCredentialsProviderExceptionPropagation.java b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/plugin/exception/TestCredentialsProviderExceptionPropagation.java new file mode 100644 index 00000000..e9c8f6b3 --- /dev/null +++ b/trino-aws-proxy/src/test/java/io/trino/aws/proxy/server/plugin/exception/TestCredentialsProviderExceptionPropagation.java @@ -0,0 +1,82 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.aws.proxy.server.plugin.exception; + +import com.google.inject.Inject; +import com.google.inject.Scopes; +import io.airlift.http.client.HttpStatus; +import io.trino.aws.proxy.server.credentials.DelegatingCredentialsProvider; +import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer.Builder; +import io.trino.aws.proxy.server.testing.harness.BuilderFilter; +import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest; +import jakarta.ws.rs.WebApplicationException; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.S3Exception; + +import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.credentialsProviderModule; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; + +@TrinoAwsProxyTest(filters = TestCredentialsProviderExceptionPropagation.Filter.class) +public class TestCredentialsProviderExceptionPropagation +{ + private final DelegatingCredentialsProvider delegatingCredentialsProvider; + private final S3Client internalClient; + + public static class Filter + implements BuilderFilter + { + @Override + public Builder filter(Builder builder) + { + return builder.withoutTestingCredentialsRoleProviders() + .addModule(credentialsProviderModule("testing", DelegatingCredentialsProvider.class, binder -> binder.bind(DelegatingCredentialsProvider.class).in(Scopes.SINGLETON))) + .withProperty("credentials-provider.type", "testing"); + } + } + + @Inject + public TestCredentialsProviderExceptionPropagation(DelegatingCredentialsProvider delegatingCredentialsProvider, S3Client internalClient) + { + this.delegatingCredentialsProvider = requireNonNull(delegatingCredentialsProvider, "delegatingCredentialsProvider is null"); + this.internalClient = requireNonNull(internalClient, "internalClient is null"); + } + + @Test + public void testRuntimeException() + { + delegatingCredentialsProvider.setDelegate((_, _) -> { throw new RuntimeException("Testing exception"); }); + assertThatThrownBy(internalClient::listBuckets) + .asInstanceOf(type(S3Exception.class)) + .satisfies(s3Exception -> { + assertThat(s3Exception.statusCode()).isEqualTo(500); + assertThat(s3Exception.awsErrorDetails().errorMessage()).isEqualTo("Testing exception"); + }); + } + + @Test + public void testWebApplicationException() + { + delegatingCredentialsProvider.setDelegate((_, _) -> { throw new WebApplicationException("Testing exception", HttpStatus.IM_A_TEAPOT.code()); }); + assertThatThrownBy(internalClient::listBuckets) + .asInstanceOf(type(S3Exception.class)) + .satisfies(s3Exception -> { + assertThat(s3Exception.statusCode()).isEqualTo(418); + assertThat(s3Exception.awsErrorDetails().errorMessage()).isEqualTo("Testing exception"); + }); + } +}