Skip to content

Commit

Permalink
Return XML formatted error messages
Browse files Browse the repository at this point in the history
  • Loading branch information
mosiac1 committed Sep 26, 2024
1 parent 6724a1d commit 7fecc58
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 4 deletions.
5 changes: 5 additions & 0 deletions trino-aws-proxy/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
<artifactId>jackson-dataformat-xml</artifactId>
</dependency>

<dependency>
<groupId>com.fasterxml.jackson.datatype</groupId>
<artifactId>jackson-datatype-jdk8</artifactId>
</dependency>

<dependency>
<groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.fasterxml.jackson.dataformat.xml.XmlMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.google.common.annotations.VisibleForTesting;
import com.google.inject.Binder;
import com.google.inject.Provides;
Expand All @@ -38,6 +39,7 @@
import io.trino.aws.proxy.server.rest.RequestFilter;
import io.trino.aws.proxy.server.rest.RequestLoggerController;
import io.trino.aws.proxy.server.rest.S3PresignController;
import io.trino.aws.proxy.server.rest.ThrowableMapper;
import io.trino.aws.proxy.server.rest.TrinoLogsResource;
import io.trino.aws.proxy.server.rest.TrinoS3ProxyClient;
import io.trino.aws.proxy.server.rest.TrinoS3ProxyClient.ForProxyClient;
Expand Down Expand Up @@ -90,6 +92,7 @@ protected void setup(Binder binder)
MapBinder<Class<?>, SigningServiceType> signingServiceTypesMapBinder = newMapBinder(binder, new TypeLiteral<>() {}, new TypeLiteral<>() {});

jaxrsBinder.bind(RequestFilter.class);
jaxrsBinder.bind(ThrowableMapper.class);
bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.S3, TrinoS3Resource.class, builtConfig.getS3Path());
bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.STS, TrinoStsResource.class, builtConfig.getStsPath());
bindResourceAtPath(jaxrsBinder, signingServiceTypesMapBinder, SigningServiceType.LOGS, TrinoLogsResource.class, builtConfig.getLogsPath());
Expand Down Expand Up @@ -160,6 +163,7 @@ public XmlMapper newXmlMapper()
{
// NOTE: this is _not_ a singleton on purpose. ObjectMappers/XmlMappers are mutable.
XmlMapper xmlMapper = new XmlMapper();
xmlMapper.registerModule(new Jdk8Module());
xmlMapper.setPropertyNamingStrategy(PropertyNamingStrategies.UPPER_CAMEL_CASE);
return xmlMapper;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.server.rest;

import com.fasterxml.jackson.dataformat.xml.annotation.JacksonXmlRootElement;

import java.util.Optional;

import static java.util.Objects.requireNonNull;

@JacksonXmlRootElement(localName = "Error")
public record ErrorResponse(String code, Optional<String> message, String resource, Optional<String> requestId)
{
public ErrorResponse
{
requireNonNull(code, "code is null");
requireNonNull(message, "message is null");
requireNonNull(resource, "resource is null");
requireNonNull(requestId, "requestId is null");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,26 @@ public void filter(ContainerRequestContext requestContext)
}

Request request = RequestBuilder.fromRequest(containerRequest);
containerRequest.setProperty(Request.class.getName(), request);

RequestLoggingSession requestLoggingSession = requestLoggerController.newRequestSession(request, signingServiceType);
containerRequest.setProperty(RequestLoggingSession.class.getName(), requestLoggingSession);

SigningMetadata signingMetadata;
try {
signingMetadata = signingController.validateAndParseAuthorization(request, signingServiceType);
containerRequest.setProperty(SigningMetadata.class.getName(), signingMetadata);
}
catch (Exception e) {
requestLoggingSession.logException(e);

switch (Throwables.getRootCause(e)) {
case WebApplicationException webApplicationException -> throw webApplicationException;
case IOException ioException -> throw ioException;
case RuntimeException runtimeException -> throw runtimeException;
default -> throw new RuntimeException(e);
}
}

containerRequest.setProperty(Request.class.getName(), request);
containerRequest.setProperty(SigningMetadata.class.getName(), signingMetadata);
}
else {
log.warn("%s is not a ContainerRequest", requestContext.getRequest().getClass().getName());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.server.rest;

import com.fasterxml.jackson.dataformat.xml.XmlMapper;
import com.google.inject.Inject;
import io.airlift.http.client.HttpStatus;
import io.airlift.log.Logger;
import io.trino.aws.proxy.spi.rest.Request;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.container.ResourceContext;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.Response;
import jakarta.ws.rs.core.Response.ResponseBuilder;
import jakarta.ws.rs.ext.ExceptionMapper;
import org.glassfish.jersey.server.ContainerRequest;

import java.util.Optional;
import java.util.UUID;

import static jakarta.ws.rs.core.HttpHeaders.CONTENT_TYPE;
import static jakarta.ws.rs.core.MediaType.APPLICATION_XML_TYPE;
import static java.util.Objects.requireNonNull;

public class ThrowableMapper
implements ExceptionMapper<Throwable>
{
private static final Logger log = Logger.get(ThrowableMapper.class);
private static final String X_AMZ_REQUEST_ID = "x-amz-request-id";

@Context
private ResourceContext resourceContext;

private final XmlMapper xmlMapper;

@Inject
public ThrowableMapper(XmlMapper xmlMapper)
{
this.xmlMapper = requireNonNull(xmlMapper, "xmlMapper is null");
}

@Override
public Response toResponse(Throwable throwable)
{
ContainerRequest containerRequest = resourceContext.getResource(ContainerRequest.class);
Optional<String> requestId = Optional.ofNullable((Request) containerRequest.getProperty(Request.class.getName()))
.map(Request::requestId)
.map(UUID::toString);

HttpStatus status = switch (throwable) {
case WebApplicationException webApplicationException -> HttpStatus.fromStatusCode(webApplicationException.getResponse().getStatus());
default -> {
log.error(throwable, "Request failed for %s", containerRequest.getRequestUri());
yield HttpStatus.INTERNAL_SERVER_ERROR;
}
};

try {
ErrorResponse response = new ErrorResponse(
status.reason(),
Optional.ofNullable(throwable.getMessage()),
containerRequest.getRequestUri().getPath(),
requestId);

ResponseBuilder responseBuilder = Response.status(status.code())
.header(CONTENT_TYPE, APPLICATION_XML_TYPE);
requestId.ifPresent(id -> responseBuilder.header(X_AMZ_REQUEST_ID, id));
return responseBuilder.entity(xmlMapper.writeValueAsString(response)).build();
}
catch (Exception exception) {
log.error(exception, "Processing of throwable %s caused an exception", throwable);
return Response.status(Response.Status.INTERNAL_SERVER_ERROR).build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public void testDatabaseSecurity()

try {
facadeProvider.disallowGets.set(true);
clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.contains("Error Code: 401 Unauthorized"));
clearInputStreamAndClose(inputToContainerStdin(pySparkContainer.containerId(), "spark.sql(\"select * from %s.%s\").show()".formatted(DATABASE_NAME, TABLE_NAME)), line -> line.contains("Status Code: 401; Error Code: Unauthorized"));
}
finally {
facadeProvider.disallowGets.set(false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.server.credentials;

import io.trino.aws.proxy.spi.credentials.Credentials;
import io.trino.aws.proxy.spi.credentials.CredentialsProvider;

import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.Objects.requireNonNull;

public class DelegatingCredentialsProvider
implements CredentialsProvider
{
private final AtomicReference<CredentialsProvider> delegate = new AtomicReference<>();

public void setDelegate(CredentialsProvider delegate)
{
this.delegate.set(requireNonNull(delegate, "delegate is null"));
}

@Override
public Optional<Credentials> credentials(String emulatedAccessKey, Optional<String> session)
{
return delegate.get().credentials(emulatedAccessKey, session);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.aws.proxy.server.plugin.exception;

import com.google.inject.Inject;
import com.google.inject.Scopes;
import io.airlift.http.client.HttpStatus;
import io.trino.aws.proxy.server.credentials.DelegatingCredentialsProvider;
import io.trino.aws.proxy.server.testing.TestingTrinoAwsProxyServer.Builder;
import io.trino.aws.proxy.server.testing.harness.BuilderFilter;
import io.trino.aws.proxy.server.testing.harness.TrinoAwsProxyTest;
import jakarta.ws.rs.WebApplicationException;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.S3Exception;

import static io.trino.aws.proxy.spi.plugin.TrinoAwsProxyServerBinding.credentialsProviderModule;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.InstanceOfAssertFactories.type;

@TrinoAwsProxyTest(filters = TestCredentialsProviderExceptionPropagation.Filter.class)
public class TestCredentialsProviderExceptionPropagation
{
private final DelegatingCredentialsProvider delegatingCredentialsProvider;
private final S3Client internalClient;

public static class Filter
implements BuilderFilter
{
@Override
public Builder filter(Builder builder)
{
return builder.withoutTestingCredentialsRoleProviders()
.addModule(credentialsProviderModule("testing", DelegatingCredentialsProvider.class, binder -> binder.bind(DelegatingCredentialsProvider.class).in(Scopes.SINGLETON)))
.withProperty("credentials-provider.type", "testing");
}
}

@Inject
public TestCredentialsProviderExceptionPropagation(DelegatingCredentialsProvider delegatingCredentialsProvider, S3Client internalClient)
{
this.delegatingCredentialsProvider = requireNonNull(delegatingCredentialsProvider, "delegatingCredentialsProvider is null");
this.internalClient = requireNonNull(internalClient, "internalClient is null");
}

@Test
public void testRuntimeException()
{
delegatingCredentialsProvider.setDelegate((_, _) -> { throw new RuntimeException("Testing exception"); });
assertThatThrownBy(internalClient::listBuckets)
.asInstanceOf(type(S3Exception.class))
.satisfies(s3Exception -> {
assertThat(s3Exception.statusCode()).isEqualTo(500);
assertThat(s3Exception.awsErrorDetails().errorMessage()).isEqualTo("Testing exception");
});
}

@Test
public void testWebApplicationException()
{
delegatingCredentialsProvider.setDelegate((_, _) -> { throw new WebApplicationException("Testing exception", HttpStatus.IM_A_TEAPOT.code()); });
assertThatThrownBy(internalClient::listBuckets)
.asInstanceOf(type(S3Exception.class))
.satisfies(s3Exception -> {
assertThat(s3Exception.statusCode()).isEqualTo(418);
assertThat(s3Exception.awsErrorDetails().errorMessage()).isEqualTo("Testing exception");
});
}
}

0 comments on commit 7fecc58

Please sign in to comment.