diff --git a/aws-lambda-scorer/lambda-template/build.gradle b/aws-lambda-scorer/lambda-template/build.gradle index ba1f5311..989d417b 100644 --- a/aws-lambda-scorer/lambda-template/build.gradle +++ b/aws-lambda-scorer/lambda-template/build.gradle @@ -18,6 +18,7 @@ dependencies { compile group: 'ai.h2o', name: 'mojo2-runtime-api', version: '0.13.7' compile group: 'ai.h2o', name: 'mojo2-runtime-impl', version: '0.13.7' compile group: 'com.amazonaws', name: 'aws-lambda-java-core', version: '1.2.0' + compile group: 'com.amazonaws', name: 'aws-lambda-java-events', version: '2.2.3' compile group: 'com.amazonaws', name: 'aws-java-sdk-s3', version: '1.11.445' compile group: 'com.google.code.gson', name: 'gson', version: '2.3.1' compile group: 'io.swagger.core.v3', name: 'swagger-annotations', version: '2.0.5' diff --git a/aws-lambda-scorer/lambda-template/src/main/java/ai/h2o/dia/deploy/aws/lambda/ApiGatewayWrapper.java b/aws-lambda-scorer/lambda-template/src/main/java/ai/h2o/dia/deploy/aws/lambda/ApiGatewayWrapper.java new file mode 100644 index 00000000..061fccbf --- /dev/null +++ b/aws-lambda-scorer/lambda-template/src/main/java/ai/h2o/dia/deploy/aws/lambda/ApiGatewayWrapper.java @@ -0,0 +1,53 @@ +package ai.h2o.dia.deploy.aws.lambda; + +import ai.h2o.dai.deploy.aws.lambda.model.ScoreRequest; +import ai.h2o.dai.deploy.aws.lambda.model.ScoreResponse; +import ai.h2o.mojos.runtime.lic.LicenseException; +import com.amazonaws.services.lambda.runtime.Context; +import com.amazonaws.services.lambda.runtime.RequestHandler; +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent; +import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyResponseEvent; +import com.google.gson.Gson; +import com.google.gson.JsonSyntaxException; +import org.apache.http.HttpStatus; + +import java.io.IOException; + +/** + * Wraps {@link MojoScorer} so that it can be called via the AWS API gateway. + * + *

Note that AWS API Gateway strictly mandates the input and output formats that we need to adhere to here. + * In addition, this class does error handling. + * + * @see + * Build an API Gateway API with Lambda Proxy Integration + */ +public class ApiGatewayWrapper implements RequestHandler { + private final Gson gson = new Gson(); + private final MojoScorer mojoScorer = new MojoScorer(); + + @Override + public APIGatewayProxyResponseEvent handleRequest(APIGatewayProxyRequestEvent input, Context context) { + try { + ScoreRequest scoreRequest = gson.fromJson(input.getBody(), ScoreRequest.class); + ScoreResponse scoreResponse = mojoScorer.score(scoreRequest, context); + return toSuccessResponse(gson.toJson(scoreResponse)); + } catch (JsonSyntaxException e) { + return toErrorResponse(HttpStatus.SC_BAD_REQUEST, e); + } catch (IOException e) { + return toErrorResponse(HttpStatus.SC_NOT_FOUND, e); + } catch (LicenseException e) { + return toErrorResponse(HttpStatus.SC_SERVICE_UNAVAILABLE, e); + } catch (Exception e) { + return toErrorResponse(HttpStatus.SC_INTERNAL_SERVER_ERROR, e); + } + } + + private static APIGatewayProxyResponseEvent toErrorResponse(int statusCode, Exception exception) { + return new APIGatewayProxyResponseEvent().withStatusCode(statusCode).withBody(exception.toString()); + } + + private static APIGatewayProxyResponseEvent toSuccessResponse(String body) { + return new APIGatewayProxyResponseEvent().withStatusCode(HttpStatus.SC_OK).withBody(body); + } +} diff --git a/aws-lambda-scorer/terraform-recipe/api_gateway.tf b/aws-lambda-scorer/terraform-recipe/api_gateway.tf new file mode 100644 index 00000000..4b3b9155 --- /dev/null +++ b/aws-lambda-scorer/terraform-recipe/api_gateway.tf @@ -0,0 +1,52 @@ +resource "aws_api_gateway_rest_api" "scorer_api" { + name = "${var.lambda_id}_api" + description = "H2O Driverless AI Mojo Scorer API (${var.lambda_id})" +} + +resource "aws_api_gateway_resource" "proxy_resource" { + rest_api_id = "${aws_api_gateway_rest_api.scorer_api.id}" + parent_id = "${aws_api_gateway_rest_api.scorer_api.root_resource_id}" + path_part = "score" +} + +resource "aws_api_gateway_method" "proxy_method" { + rest_api_id = "${aws_api_gateway_rest_api.scorer_api.id}" + resource_id = "${aws_api_gateway_resource.proxy_resource.id}" + http_method = "POST" + authorization = "NONE" +} + +resource "aws_api_gateway_integration" "scorer_integration" { + rest_api_id = "${aws_api_gateway_rest_api.scorer_api.id}" + resource_id = "${aws_api_gateway_method.proxy_method.resource_id}" + http_method = "${aws_api_gateway_method.proxy_method.http_method}" + + integration_http_method = "POST" + type = "AWS_PROXY" + uri = "${aws_lambda_function.scorer_lambda.invoke_arn}" +} + +resource "aws_api_gateway_deployment" "scorer_api_deployment" { + depends_on = [ + "aws_api_gateway_integration.scorer_integration" + ] + + rest_api_id = "${aws_api_gateway_rest_api.scorer_api.id}" + stage_name = "test" +} + +resource "aws_lambda_permission" "apigw" { + statement_id = "AllowExecutionFromAPIGateway" + action = "lambda:InvokeFunction" + function_name = "${aws_lambda_function.scorer_lambda.arn}" + principal = "apigateway.amazonaws.com" + + # The /*/*/* part allows invocation from any stage, method and resource path + # within API Gateway REST API. + source_arn = "${aws_api_gateway_rest_api.scorer_api.execution_arn}/*/*/*" +} + +output "base_url" { + value = "${aws_api_gateway_deployment.scorer_api_deployment.invoke_url}" +} + diff --git a/aws-lambda-scorer/terraform-recipe/main.tf b/aws-lambda-scorer/terraform-recipe/main.tf index 6b61cd29..9c2626f6 100644 --- a/aws-lambda-scorer/terraform-recipe/main.tf +++ b/aws-lambda-scorer/terraform-recipe/main.tf @@ -36,9 +36,9 @@ resource "aws_s3_bucket_object" "mojo" { // AWS Lambda function with a Java implementation of the Mojo scorer. resource "aws_lambda_function" "scorer_lambda" { function_name = "${var.lambda_id}_function" - description = "H2O Driverless AI Mojo Scorer" + description = "H2O Driverless AI Mojo Scorer (${var.lambda_id})" filename = "${var.lambda_zip_path}" - handler = "ai.h2o.dia.deploy.aws.lambda.MojoScorer::score" + handler = "ai.h2o.dia.deploy.aws.lambda.ApiGatewayWrapper::handleRequest" source_code_hash = "${base64sha256(file(var.lambda_zip_path))}" role = "${aws_iam_role.scorer_lambda_iam_role.arn}" runtime = "java8"