Skip to content

Commit

Permalink
aws-powertools#1298 first implem - api gateway events validation will…
Browse files Browse the repository at this point in the history
… be catched and returned as a 400 error
  • Loading branch information
Pascal Romanens committed Oct 24, 2023
1 parent c739383 commit f311722
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import org.slf4j.LoggerFactory;
import software.amazon.lambda.powertools.validation.Validation;
import software.amazon.lambda.powertools.validation.ValidationConfig;
import software.amazon.lambda.powertools.validation.ValidationException;

/**
* Aspect for {@link Validation} annotation
Expand All @@ -73,6 +74,11 @@ public Object around(ProceedingJoinPoint pjp,
if (validation.schemaVersion() != V201909) {
ValidationConfig.get().setSchemaVersion(validation.schemaVersion());
}

// we need this result object to be null at this point as validation of API events, if
// it fails, will catch the ValidationException and generate a 400 API response. This response
// will be stored in the result object to prevent executing the lambda
Object result = null;

if (placedOnRequestHandler(pjp)) {
validationNeeded = true;
Expand All @@ -85,10 +91,10 @@ public Object around(ProceedingJoinPoint pjp,
validate(obj, inboundJsonSchema, validation.envelope());
} else if (obj instanceof APIGatewayProxyRequestEvent) {
APIGatewayProxyRequestEvent event = (APIGatewayProxyRequestEvent) obj;
validate(event.getBody(), inboundJsonSchema);
result = validateAPIGatewayProxyBody(event.getBody(), inboundJsonSchema);
} else if (obj instanceof APIGatewayV2HTTPEvent) {
APIGatewayV2HTTPEvent event = (APIGatewayV2HTTPEvent) obj;
validate(event.getBody(), inboundJsonSchema);
result = validateAPIGatewayV2HTTPBody(event.getBody(), inboundJsonSchema);
} else if (obj instanceof SNSEvent) {
SNSEvent event = (SNSEvent) obj;
event.getRecords().forEach(record -> validate(record.getSNS().getMessage(), inboundJsonSchema));
Expand Down Expand Up @@ -140,33 +146,88 @@ record -> validate(decode(record.getData()), inboundJsonSchema)));
}
}

Object result = pjp.proceed(proceedArgs);

if (validationNeeded && !validation.outboundSchema().isEmpty()) {
JsonSchema outboundJsonSchema = getJsonSchema(validation.outboundSchema(), true);

if (result instanceof APIGatewayProxyResponseEvent) {
APIGatewayProxyResponseEvent response = (APIGatewayProxyResponseEvent) result;
validate(response.getBody(), outboundJsonSchema);
} else if (result instanceof APIGatewayV2HTTPResponse) {
APIGatewayV2HTTPResponse response = (APIGatewayV2HTTPResponse) result;
validate(response.getBody(), outboundJsonSchema);
} else if (result instanceof APIGatewayV2WebSocketResponse) {
APIGatewayV2WebSocketResponse response = (APIGatewayV2WebSocketResponse) result;
validate(response.getBody(), outboundJsonSchema);
} else if (result instanceof ApplicationLoadBalancerResponseEvent) {
ApplicationLoadBalancerResponseEvent response = (ApplicationLoadBalancerResponseEvent) result;
validate(response.getBody(), outboundJsonSchema);
} else if (result instanceof KinesisAnalyticsInputPreprocessingResponse) {
KinesisAnalyticsInputPreprocessingResponse response =
(KinesisAnalyticsInputPreprocessingResponse) result;
response.getRecords().forEach(record -> validate(decode(record.getData()), outboundJsonSchema));
} else {
LOG.warn("Unhandled response type {}, please use the 'envelope' parameter to specify what to validate",
result.getClass().getName());
}
// don't execute the lambda if result was set by previous validation step
// in that case result should already hold a response with validation information
if (result != null) {
LOG.error("Incoming API event's body failed inbound schema validation.");
}
else {
result = pjp.proceed(proceedArgs);

if (validationNeeded && !validation.outboundSchema().isEmpty()) {
JsonSchema outboundJsonSchema = getJsonSchema(validation.outboundSchema(), true);

Object overridenResponse = null;
if (result instanceof APIGatewayProxyResponseEvent) {
APIGatewayProxyResponseEvent response = (APIGatewayProxyResponseEvent) result;
overridenResponse = validateAPIGatewayProxyBody(response.getBody(), outboundJsonSchema);
} else if (result instanceof APIGatewayV2HTTPResponse) {
APIGatewayV2HTTPResponse response = (APIGatewayV2HTTPResponse) result;
overridenResponse = validateAPIGatewayV2HTTPBody(response.getBody(), outboundJsonSchema);
} else if (result instanceof APIGatewayV2WebSocketResponse) {
APIGatewayV2WebSocketResponse response = (APIGatewayV2WebSocketResponse) result;
validate(response.getBody(), outboundJsonSchema);
} else if (result instanceof ApplicationLoadBalancerResponseEvent) {
ApplicationLoadBalancerResponseEvent response = (ApplicationLoadBalancerResponseEvent) result;
validate(response.getBody(), outboundJsonSchema);
} else if (result instanceof KinesisAnalyticsInputPreprocessingResponse) {
KinesisAnalyticsInputPreprocessingResponse response =
(KinesisAnalyticsInputPreprocessingResponse) result;
response.getRecords().forEach(record -> validate(decode(record.getData()), outboundJsonSchema));
} else {
LOG.warn("Unhandled response type {}, please use the 'envelope' parameter to specify what to validate",
result.getClass().getName());
}

if (overridenResponse != null) {
result = overridenResponse;
LOG.error("API response failed outbound schema validation.");
}
}
}

return result;
}

/**
* Validates the given body against the provided JsonSchema. If validation fails the ValidationException
* will be catched and transformed to a 400, bad request, API response
* @param body body of the event to validate
* @param inboundJsonSchema validation schema
* @return null if validation passed, or a 400 response object otherwise
*/
private APIGatewayProxyResponseEvent validateAPIGatewayProxyBody(final String body, final JsonSchema jsonSchema) {
APIGatewayProxyResponseEvent result = null;
try {
validate(body, jsonSchema);
} catch (ValidationException e) {
LOG.error("There were validation errors: {}", e.getMessage());
result = new APIGatewayProxyResponseEvent();
result.setBody(e.getMessage());
result.setStatusCode(400);
result.setIsBase64Encoded(false);
}
return result;
}

/**
* Validates the given body against the provided JsonSchema. If validation fails the ValidationException
* will be catched and transformed to a 400, bad request, API response
* @param body body of the event to validate
* @param inboundJsonSchema validation schema
* @return null if validation passed, or a 400 response object otherwise
*/
private APIGatewayV2HTTPResponse validateAPIGatewayV2HTTPBody(final String body, final JsonSchema jsonSchema) {
APIGatewayV2HTTPResponse result = null;
try {
validate(body, jsonSchema);
} catch (ValidationException e) {
LOG.error("There were validation errors: {}", e.getMessage());
result = new APIGatewayV2HTTPResponse();
result.setBody(e.getMessage());
result.setStatusCode(400);
result.setIsBase64Encoded(false);
}
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package software.amazon.lambda.powertools.validation.handlers;

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;

import software.amazon.lambda.powertools.validation.Validation;

public class GenericSchemaV7APIGatewayProxyRequestEventHandler implements RequestHandler<APIGatewayProxyRequestEvent, APIGatewayProxyResponseEvent> {

@Validation(inboundSchema = "classpath:/schema_v7.json")
@Override
public APIGatewayProxyResponseEvent handleRequest(APIGatewayProxyRequestEvent input, Context context) {
APIGatewayProxyResponseEvent response = new APIGatewayProxyResponseEvent();
response.setBody("valid-test");
response.setStatusCode(200);
return response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;

import software.amazon.lambda.powertools.validation.Validation;

public class GenericSchemaV7Handler<T> implements RequestHandler<T, String> {
public class GenericSchemaV7StringHandler<T> implements RequestHandler<T, String> {

@Validation(inboundSchema = "classpath:/schema_v7.json")
@Override
public String handleRequest(T input, Context context) {
return "OK";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPResponse;

import software.amazon.lambda.powertools.validation.Validation;


public class ValidationInboundStringHandler implements RequestHandler<APIGatewayV2HTTPEvent, String> {
public class ValidationInboundAPIGatewayV2HTTPEventHandler implements RequestHandler<APIGatewayV2HTTPEvent, APIGatewayV2HTTPResponse> {

private static final String schema = "{\n" +
" \"$schema\": \"http://json-schema.org/draft-07/schema\",\n" +
Expand Down Expand Up @@ -80,7 +82,10 @@ public class ValidationInboundStringHandler implements RequestHandler<APIGateway

@Override
@Validation(inboundSchema = schema)
public String handleRequest(APIGatewayV2HTTPEvent input, Context context) {
return "OK";
public APIGatewayV2HTTPResponse handleRequest(APIGatewayV2HTTPEvent input, Context context) {
APIGatewayV2HTTPResponse response = new APIGatewayV2HTTPResponse();
response.setBody("valid-test");
response.setStatusCode(200);
return response;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates.
* Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package software.amazon.lambda.powertools.validation.internal;

import java.util.stream.Stream;

import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;

import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2HTTPResponse;
import com.amazonaws.services.lambda.runtime.events.APIGatewayV2WebSocketResponse;
import com.amazonaws.services.lambda.runtime.events.ApplicationLoadBalancerResponseEvent;

public class HandledResponseEventsArgumentsProvider implements ArgumentsProvider {

@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {

String body = "{id";

final APIGatewayProxyResponseEvent apiGWProxyResponseEvent = new APIGatewayProxyResponseEvent().withBody(body);

APIGatewayV2HTTPResponse apiGWV2HTTPResponse = new APIGatewayV2HTTPResponse();
apiGWV2HTTPResponse.setBody(body);

APIGatewayV2WebSocketResponse apiGWV2WebSocketResponse = new APIGatewayV2WebSocketResponse();
apiGWV2WebSocketResponse.setBody(body);

ApplicationLoadBalancerResponseEvent albResponseEvent = new ApplicationLoadBalancerResponseEvent();
albResponseEvent.setBody(body);

return Stream.of(apiGWProxyResponseEvent, apiGWV2HTTPResponse).map(Arguments::of);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ public Stream<? extends Arguments> provideArguments(ExtensionContext context) {

String body = "{id";

final APIGatewayProxyResponseEvent apiGWProxyResponseEvent = new APIGatewayProxyResponseEvent().withBody(body);

APIGatewayV2HTTPResponse apiGWV2HTTPResponse = new APIGatewayV2HTTPResponse();
apiGWV2HTTPResponse.setBody(body);

APIGatewayV2WebSocketResponse apiGWV2WebSocketResponse = new APIGatewayV2WebSocketResponse();
apiGWV2WebSocketResponse.setBody(body);

Expand All @@ -53,7 +48,7 @@ public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
KinesisAnalyticsInputPreprocessingResponse.Result.Ok, buffer));
kaipResponse.setRecords(records);

return Stream.of(apiGWProxyResponseEvent, apiGWV2HTTPResponse, apiGWV2WebSocketResponse, albResponseEvent,
return Stream.of(apiGWV2WebSocketResponse, albResponseEvent,
kaipResponse).map(Arguments::of);
}
}
Loading

0 comments on commit f311722

Please sign in to comment.