Skip to content

Commit

Permalink
Provide input type correctly (#174)
Browse files Browse the repository at this point in the history
Cleanup the Lambda Handler code a little by providing a much better input type. Since the authorizer is custom, the new type can't solve every ugly type-casting scenario for us, but it is still better than it was.
  • Loading branch information
casewalker authored Sep 10, 2024
1 parent 3b274a3 commit 4704e6b
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 33 deletions.
5 changes: 3 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ repositories {

dependencies {
api(
// From the Serverless generated code
libs.comAmazonaws.awsLambdaJavaCore,
// For getting the Cognito UserPool Group Description, reading YAML, and accessing SSM
libs.softwareAmazonAwssdk.cognitoidentityprovider,
libs.comFasterxmlJacksonDataformat.jacksonDataformatYaml,
Expand All @@ -21,6 +19,9 @@ dependencies {
libs.orgBouncycastle.bcpkixJdk18on
)
implementation(
// From the Serverless generated code
libs.comAmazonaws.awsLambdaJavaCore,
libs.comAmazonaws.awsLambdaJavaEvents,
libs.orgApacheLoggingLog4j.log4jSlf4j2Impl,
libs.orgApacheLoggingLog4j.log4jCore,
libs.orgApacheLoggingLog4j.log4jApi,
Expand Down
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ bouncyCastleVersion = "1.78.1"

[libraries]
comAmazonaws-awsLambdaJavaCore = "com.amazonaws:aws-lambda-java-core:1.2.3"
comAmazonaws-awsLambdaJavaEvents = "com.amazonaws:aws-lambda-java-events:3.13.0"
softwareAmazonAwssdk-cognitoidentityprovider = { group = "software.amazon.awssdk", name = "cognitoidentityprovider", version.ref = "amazonSoftwareVersion" }
comFasterxmlJacksonDataformat-jacksonDataformatYaml = "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:2.17.2"
softwareAmazonAwssdk-ssm = { group = "software.amazon.awssdk", name = "ssm", version.ref = "amazonSoftwareVersion" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import org.jetbrains.annotations.VisibleForTesting;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.ssm.SsmClient;

import static com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent.ProxyRequestContext;

/**
* Handler for getting a generated SAML Response, modified from code generated by Serverless ({@code
* com.serverless.Handler}). This relies on the structure of the handler-input including:
Expand Down Expand Up @@ -47,8 +50,9 @@
* }
* </pre>
*
* If no query string is present, the <code>queryStringParameters</code> entry may be omitted completely, so that must
* be dealt with.
* This shape seems to be most aligned with {@link APIGatewayProxyRequestEvent}, but unfortunately everything inside the
* <code>authorizer</code> object is just a <code>Map</code> of <code>String</code> to <code>Object</code>, so obnoxious
* casting must be done inside.
* <p>
* For security, the "user" passed to {@link SamlGenerator#SamlGenerator(String, String, String, KeysWrapper)} will be
* sourced exclusively from the JWT claim "email", and the "groupName" passed in the pathParameters will be validated
Expand All @@ -60,30 +64,26 @@
*
* @author Case Walker ([email protected])
*/
public class GetSamlResponseHandler implements RequestHandler<Map<String, Object>, Map<String, String>> {
public class GetSamlResponseHandler implements RequestHandler<APIGatewayProxyRequestEvent, Map<String, String>> {

private static final Logger logger = LogManager.getLogger(GetSamlResponseHandler.class);
private static final String DEFAULT_DURATION = System.getenv("DEFAULT_SESSION_DURATION");
private static final String REGION = System.getenv("COGNITO_REGION");
private static final String USER_POOL = System.getenv("COGNITO_USER_POOL");
private static final Pattern DIGITS_PATTERN = Pattern.compile("\\d+");
private static final String INPUT_PATH_PARAMETERS = "pathParameters";
private static final String PATH_PARAMETER_GROUP_NAME = System.getenv("PATH_PARAMETER_GROUP_NAME");
private static final String INPUT_QUERY_STRING_PARAMETERS = "queryStringParameters";
private static final String DURATION_PARAMETER = "duration";
private static final String NESTED_INPUT_REQUEST_CONTEXT = "requestContext";
private static final String NESTED_INPUT_AUTHORIZER = "authorizer";
private static final String NESTED_INPUT_JWT = "jwt";
private static final String NESTED_INPUT_CLAIMS = "claims";
private static final String AUTHORIZER_JWT = "jwt";
private static final String JWT_CLAIMS = "claims";
private static final String EMAIL_CLAIM = "email";
private static final String COGNITO_GROUPS_CLAIM = "cognito:groups";

// Left open on purpose for testing
// Left open (not final) on purpose for testing
@VisibleForTesting
private SsmClient ssmClient = SsmClient.builder().region(Region.of(REGION)).build();

@Override
public Map<String, String> handleRequest(final Map<String, Object> input, final Context context) {
public Map<String, String> handleRequest(final APIGatewayProxyRequestEvent input, final Context context) {
final RequestParameters rp = extractRequestParametersFromInput(input);
if (rp.groupName() == null || rp.groupName().isBlank()) {
return createErrorReturnMap(Status.INPUT_ERROR,
Expand Down Expand Up @@ -138,36 +138,31 @@ public Map<String, String> handleRequest(final Map<String, Object> input, final
}
}

/* Casting the Object results of input.get() leads to unchecked warnings. Ignore them for now. */
@SuppressWarnings("unchecked")
private RequestParameters extractRequestParametersFromInput(final Map<String, Object> input) {
final Map<String, String> pathParams = (Map<String, String>) input.get(INPUT_PATH_PARAMETERS);
private RequestParameters extractRequestParametersFromInput(final APIGatewayProxyRequestEvent input) {
final Map<String, String> pathParams = input.getPathParameters();
final String groupName = pathParams != null && !pathParams.isEmpty() ?
pathParams.get(PATH_PARAMETER_GROUP_NAME) : null;

final Map<String, String> queryStringParams = (Map<String, String>) input.get(INPUT_QUERY_STRING_PARAMETERS);
final Map<String, String> queryStringParams = input.getQueryStringParameters();
final String duration = queryStringParams != null && !queryStringParams.isEmpty() ?
queryStringParams.getOrDefault(DURATION_PARAMETER, DEFAULT_DURATION) : DEFAULT_DURATION;

return new RequestParameters(groupName, duration);
}

/* Casting the Object results of input.get() leads to unchecked warnings. Ignore them for now. */
/* Casting the Object results of authorizer.get() leads to unchecked warnings. Ignore them for now. */
@SuppressWarnings("unchecked")
private AuthorizerContextDetails extractAuthorizerDetailsFromInput(final Map<String, Object> input) {
final Map<String, Object> requestContext = (Map<String, Object>) input.get(NESTED_INPUT_REQUEST_CONTEXT);
final Map<String, Object> authorizer = requestContext != null && !requestContext.isEmpty() ?
(Map<String, Object>) requestContext.get(NESTED_INPUT_AUTHORIZER) : null;
private AuthorizerContextDetails extractAuthorizerDetailsFromInput(final APIGatewayProxyRequestEvent input) {
final ProxyRequestContext requestContext = input.getRequestContext();
final Map<String, Object> authorizer = requestContext != null ? requestContext.getAuthorizer() : null;
final Map<String, Object> jwt = authorizer != null && !authorizer.isEmpty() ?
(Map<String, Object>) authorizer.get(NESTED_INPUT_JWT) : null;
(Map<String, Object>) authorizer.get(AUTHORIZER_JWT) : null;
final Map<String, Object> claims = jwt != null && !jwt.isEmpty() ?
(Map<String, Object>) jwt.get(NESTED_INPUT_CLAIMS) : null;
(Map<String, Object>) jwt.get(JWT_CLAIMS) : null;

final String email = claims != null && !claims.isEmpty() ?
(String) claims.get(EMAIL_CLAIM) : null;
final String email = claims != null && !claims.isEmpty() ? (String) claims.get(EMAIL_CLAIM) : null;
final String usersGroupsString = claims != null && !claims.isEmpty() ?
(String) claims.get(COGNITO_GROUPS_CLAIM) : null;

final List<String> usersGroups = usersGroupsString != null ?
List.of(usersGroupsString.substring(1, usersGroupsString.length() - 1).split(" ")) : null;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package gov.nj.innovation.customAwsIdp.lambda;

import com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent;
import gov.nj.innovation.customAwsIdp.lambda.helpers.CognitoGroupDescriptionMetadataExtractor;
import gov.nj.innovation.customAwsIdp.lambda.helpers.data.CognitoGroupDescriptionMetadata;
import org.junit.jupiter.api.AfterAll;
Expand All @@ -22,6 +23,7 @@
import java.util.Map;
import java.util.stream.Stream;

import static com.amazonaws.services.lambda.runtime.events.APIGatewayProxyRequestEvent.ProxyRequestContext;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mockStatic;

Expand Down Expand Up @@ -181,13 +183,13 @@ void testYamlParserThrows() {
/**
* Create the input map following the layout described in the {@link GetSamlResponseHandler} class.
*/
private Map<String, Object> setupHandlerInput(String groupName, String duration, String email, String groups) {
Map<String, Object> input = new HashMap<>();
private APIGatewayProxyRequestEvent setupHandlerInput(String groupName, String duration, String email, String groups) {
APIGatewayProxyRequestEvent input = new APIGatewayProxyRequestEvent();
Map<String, String> pathParams = new HashMap<>();
pathParams.put("groupName", groupName);
input.put("pathParameters", pathParams);
input.setPathParameters(pathParams);
if (duration != null) {
input.put("queryStringParameters", Map.of("duration", duration));
input.setQueryStringParameters(Map.of("duration", duration));
}

Map<String, String> claims = new HashMap<>();
Expand All @@ -197,7 +199,9 @@ private Map<String, Object> setupHandlerInput(String groupName, String duration,
if (groups != null) {
claims.put("cognito:groups", groups);
}
input.put("requestContext", Map.of("authorizer", Map.of("jwt", Map.of("claims", claims))));
ProxyRequestContext requestContext = new ProxyRequestContext();
requestContext.setAuthorizer(Map.of("jwt", Map.of("claims", claims)));
input.setRequestContext(requestContext);

return input;
}
Expand Down

0 comments on commit 4704e6b

Please sign in to comment.