Skip to content
This repository has been archived by the owner on Jan 12, 2024. It is now read-only.

Commit

Permalink
feat: allow roles in AWS China to authenticate
Browse files Browse the repository at this point in the history
* allow roles in AWS China to authenticate
* add feature flag for partition (AWS China/AWS Global) lock for authentication and SDB permissions

* fix: use "enabled" instead of "allowed"; move duplicate into a method; add test

* * Add docs
* Move partition names into constants

* test: Move to using constants for AWS partition in tests

* version bump

Co-authored-by: Todd Underwood <[email protected]>
  • Loading branch information
mayitbeegh and tunderwood authored Aug 11, 2020
1 parent 3a50f28 commit b9132e2
Show file tree
Hide file tree
Showing 13 changed files with 255 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ public enum DefaultApiError implements ApiError {
"Failed to validate factor. Please try again or try a different factor.",
SC_UNAUTHORIZED),

/** AWS China ARNs are not allowed. */
AWS_CHINA_NOT_ALLOWED(
99244,
"The AWS China partition is disabled by the admin. If you're creating or updating an SDB, please remove IAM principal ARNs that start with \"arn:aws-cn:\"",
SC_UNAUTHORIZED),

/** AWS Global ARNs are not allowed. */
AWS_GLOBAL_NOT_ALLOWED(
99245,
"The AWS Global partition is disabled by the admin. If you're creating or updating an SDB, please remove IAM principal ARNs that start with \"arn:aws:\"",
SC_UNAUTHORIZED),

/** Generic not found error. */
ENTITY_NOT_FOUND(99996, "Not found", SC_NOT_FOUND),

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@

public class DomainConstants {

public static final String AWS_IAM_ROLE_ARN_TEMPLATE = "arn:aws:iam::%s:role/%s";
public static final String AWS_GLOBAL_PARTITION_NAME = "aws";

public static final String AWS_CHINA_PARTITION_NAME = "aws-cn";

public static final String AWS_IAM_ROLE_ARN_TEMPLATE = "arn:%s:iam::%s:role/%s";

/**
* Pattern used to determine if an ARN should be allowed in DB.
*
* <p>This is also the list of ARN types that are allowed in KMS key policies.
*/
public static final String AWS_IAM_PRINCIPAL_ARN_REGEX_ALLOWED =
"^arn:aws:(iam|sts)::(?<accountId>\\d+?):(role|user|federated-user|assumed-role).*/.+(?<!\\s)$";
"^arn:(?<partition>aws|aws-cn):(iam|sts)::(?<accountId>\\d+?):(role|user|federated-user|assumed-role).*/.+(?<!\\s)$";

/**
* Pattern used to determine if an ARN should be allowed in DB
Expand All @@ -46,7 +50,7 @@ public class DomainConstants {
* because they can't go in KMS key policies.
*/
public static final String AWS_IAM_PRINCIPAL_ARN_REGEX_ROLE_GENERATION =
"^arn:aws:(iam|sts)::(?<accountId>\\d+?):(?!group).+?/(?<roleName>.+)$";
"^arn:(?<partition>aws|aws-cn):(iam|sts)::(?<accountId>\\d+?):(?!group).+?/(?<roleName>.+)$";

/**
* Pattern used for generating a role from another ARN.
Expand All @@ -58,20 +62,21 @@ public class DomainConstants {
public static final Pattern IAM_PRINCIPAL_ARN_PATTERN_ROLE_GENERATION =
Pattern.compile(AWS_IAM_PRINCIPAL_ARN_REGEX_ROLE_GENERATION);

public static final String AWS_ACCOUNT_ROOT_ARN_REGEX = "^arn:aws:iam::(?<accountId>\\d+?):root$";
public static final String AWS_ACCOUNT_ROOT_ARN_REGEX =
"^arn:(?<partition>aws|aws-cn):iam::(?<accountId>\\d+?):root$";
public static final Pattern AWS_ACCOUNT_ROOT_ARN_PATTERN =
Pattern.compile(AWS_ACCOUNT_ROOT_ARN_REGEX);
private static final String AWS_IAM_ROLE_ARN_REGEX =
"^arn:aws:iam::(?<accountId>\\d+?):role/(?<roleName>.+)$";
"^arn:(?<partition>aws|aws-cn):iam::(?<accountId>\\d+?):role/(?<roleName>.+)$";
public static final Pattern IAM_ROLE_ARN_PATTERN = Pattern.compile(AWS_IAM_ROLE_ARN_REGEX);
private static final String AWS_IAM_ASSUMED_ROLE_ARN_REGEX =
"^arn:aws:sts::(?<accountId>\\d+?):assumed-role/(?<roleName>.+)/.+$";
"^arn:(?<partition>aws|aws-cn):sts::(?<accountId>\\d+?):assumed-role/(?<roleName>.+)/.+$";
public static final Pattern IAM_ASSUMED_ROLE_ARN_PATTERN =
Pattern.compile(AWS_IAM_ASSUMED_ROLE_ARN_REGEX);
private static final String GENERIC_ASSUMED_ROLE_REGEX =
"^arn:aws:sts::(?<accountId>\\d+?):assumed-role/.+$";
"^arn:(?<partition>aws|aws-cn):sts::(?<accountId>\\d+?):assumed-role/.+$";
public static final Pattern GENERIC_ASSUMED_ROLE_PATTERN =
Pattern.compile(GENERIC_ASSUMED_ROLE_REGEX);
private static final Pattern AWS_IAM_ARN_ACCOUNT_ID_PATTERN =
Pattern.compile("arn:aws:(iam|sts)::(?<accountId>\\d+?):.+");
Pattern.compile("^arn:(?<partition>aws|aws-cn):(iam|sts)::(?<accountId>\\d+?):.+");
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class AwsStsHttpClient {
MediaType.parse("application/x-www-form-urlencoded");
private static final MediaType DEFAULT_ACCEPTED_MEDIA_TYPE = MediaType.parse("application/json");
private static final String AWS_STS_ENDPOINT_TEMPLATE = "https://sts.%s.amazonaws.com";
private static final String AWS_CN_STS_ENDPOINT_TEMPLATE = "https://sts.%s.amazonaws.com.cn";
private static final String DEFAULT_GET_CALLER_IDENTITY_ACTION =
"Action=GetCallerIdentity&Version=2011-06-15";
private static final String DEFAULT_METHOD = "POST";
Expand Down Expand Up @@ -115,9 +116,15 @@ public <M> M execute(

/** Build the request */
protected Request buildRequest(String region, Map<String, String> headers) {
String stsEndpointUrl;
if (region.startsWith("cn-")) {
stsEndpointUrl = String.format(AWS_CN_STS_ENDPOINT_TEMPLATE, region);
} else {
stsEndpointUrl = String.format(AWS_STS_ENDPOINT_TEMPLATE, region);
}
Request.Builder requestBuilder =
new Request.Builder()
.url(String.format(AWS_STS_ENDPOINT_TEMPLATE, region))
.url(stsEndpointUrl)
.addHeader("Accept", DEFAULT_ACCEPTED_MEDIA_TYPE.toString());

if (headers != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.nike.cerberus.service;

import static com.nike.cerberus.domain.DomainConstants.AWS_GLOBAL_PARTITION_NAME;
import static com.nike.cerberus.domain.DomainConstants.AWS_IAM_ROLE_ARN_TEMPLATE;
import static com.nike.cerberus.security.CerberusPrincipal.*;

Expand Down Expand Up @@ -221,7 +222,10 @@ public EncryptedAuthDataWrapper authenticate(IamRoleCredentials credentials) {

final String iamPrincipalArn =
String.format(
AWS_IAM_ROLE_ARN_TEMPLATE, credentials.getAccountId(), credentials.getRoleName());
AWS_IAM_ROLE_ARN_TEMPLATE,
AWS_GLOBAL_PARTITION_NAME, // hardcoding this to AWS Global for backwards compatibility
credentials.getAccountId(),
credentials.getRoleName());
final String region = credentials.getRegion();

final AwsIamKmsAuthRequest awsIamKmsAuthRequest = new AwsIamKmsAuthRequest();
Expand All @@ -243,6 +247,7 @@ public EncryptedAuthDataWrapper authenticate(IamRoleCredentials credentials) {
public EncryptedAuthDataWrapper authenticate(AwsIamKmsAuthRequest awsIamKmsAuthRequest) {

final String iamPrincipalArn = awsIamKmsAuthRequest.getIamPrincipalArn();
awsIamRoleArnParser.iamPrincipalPartitionCheck(iamPrincipalArn);
final Map<String, String> authPrincipalMetadata =
generateCommonIamPrincipalAuthMetadata(iamPrincipalArn, awsIamKmsAuthRequest.getRegion());
authPrincipalMetadata.put(
Expand All @@ -258,6 +263,7 @@ public EncryptedAuthDataWrapper authenticate(AwsIamKmsAuthRequest awsIamKmsAuthR
* @return Unencrypted auth response
*/
public AuthTokenResponse stsAuthenticate(final String iamPrincipalArn) {
awsIamRoleArnParser.iamPrincipalPartitionCheck(iamPrincipalArn);
final Map<String, String> authPrincipalMetadata =
generateCommonIamPrincipalAuthMetadata(iamPrincipalArn);
authPrincipalMetadata.put(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ public SafeDepositBoxV2 createSafeDepositBoxV2(

final Set<IamPrincipalPermission> iamRolePermissionSet =
safeDepositBox.getIamPrincipalPermissions();
partitionCheck(iamRolePermissionSet);

final boolean isSlugUnique = safeDepositBoxDao.isSlugUnique(boxRecordToStore.getSdbNameSlug());

Expand Down Expand Up @@ -309,6 +310,8 @@ public SafeDepositBoxV2 updateSafeDepositBoxV2(
final Set<IamPrincipalPermission> iamRolePermissionSet =
safeDepositBox.getIamPrincipalPermissions();

partitionCheck(iamRolePermissionSet);

if (!StringUtils.equals(currentBox.getDescription(), boxToUpdate.getDescription())) {
safeDepositBoxDao.updateSafeDepositBox(boxToUpdate);
}
Expand Down Expand Up @@ -636,6 +639,9 @@ protected SafeDepositBoxV2 convertSafeDepositBoxV1ToV2(SafeDepositBoxV1 safeDepo
.withIamPrincipalArn(
String.format(
DomainConstants.AWS_IAM_ROLE_ARN_TEMPLATE,
DomainConstants
.AWS_GLOBAL_PARTITION_NAME, // hardcoding this to AWS Global for
// backwards compatibility
iamRolePermission.getAccountId(),
iamRolePermission.getIamRoleName()))
.withRoleId(iamRolePermission.getRoleId()))
Expand Down Expand Up @@ -793,4 +799,12 @@ public SafeDepositBoxV2 getSafeDepositBoxDangerouslyWithoutPermissionValidation(
() ->
ApiException.newBuilder().withApiErrors(DefaultApiError.ENTITY_NOT_FOUND).build()));
}

private void partitionCheck(Set<IamPrincipalPermission> iamRolePermissionSet) {
iamRolePermissionSet.stream()
.forEach(
iamRolePermission ->
awsIamRoleArnParser.iamPrincipalPartitionCheck(
iamRolePermission.getIamPrincipalArn()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,25 @@

import com.nike.backstopper.exception.ApiException;
import com.nike.cerberus.domain.DomainConstants;
import com.nike.cerberus.error.DefaultApiError;
import com.nike.cerberus.error.InvalidIamRoleArnApiError;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

/** Utility class for concatenating and parsing AWS IAM role ARNs. */
@Component
public class AwsIamRoleArnParser {
private final boolean awsChinaEnabled;
private final boolean awsGlobalEnabled;

public AwsIamRoleArnParser(
@Value("${cerberus.partitions.awsGlobal.enabled}") boolean awsGlobalEnabled,
@Value("${cerberus.partitions.awsChina.enabled}") boolean awsChinaEnabled) {
this.awsGlobalEnabled = awsGlobalEnabled;
this.awsChinaEnabled = awsChinaEnabled;
}

/**
* Gets account ID from a 'role' ARN
Expand Down Expand Up @@ -119,8 +130,10 @@ public String convertPrincipalArnToRoleArn(final String principalArn) {
final String accountId =
getNamedGroupFromRegexPattern(patternToMatch, "accountId", principalArn);
final String roleName = getNamedGroupFromRegexPattern(patternToMatch, "roleName", principalArn);
final String partition =
getNamedGroupFromRegexPattern(patternToMatch, "partition", principalArn);

return String.format(DomainConstants.AWS_IAM_ROLE_ARN_TEMPLATE, accountId, roleName);
return String.format(DomainConstants.AWS_IAM_ROLE_ARN_TEMPLATE, partition, accountId, roleName);
}

public String convertPrincipalArnToRootArn(final String principalArn) {
Expand All @@ -133,7 +146,11 @@ public String convertPrincipalArnToRootArn(final String principalArn) {
getNamedGroupFromRegexPattern(
DomainConstants.IAM_PRINCIPAL_ARN_PATTERN_ALLOWED, "accountId", principalArn);

return String.format("arn:aws:iam::%s:root", accountId);
final String partition =
getNamedGroupFromRegexPattern(
DomainConstants.IAM_PRINCIPAL_ARN_PATTERN_ALLOWED, "partition", principalArn);

return String.format("arn:%s:iam::%s:root", partition, accountId);
}

/**
Expand All @@ -150,6 +167,17 @@ public String stripOutDescription(final String principalArn) {
}
}

/**
* Checks if the partition of an IAM principal ARN is enabled
*
* @param iamPrincipalArn The IAM principal ARN to be checked
* @throws ApiException Throws an exception if the partition of the IAM principal isn't enabled
*/
public void iamPrincipalPartitionCheck(String iamPrincipalArn) {
getNamedGroupFromRegexPattern(
DomainConstants.IAM_PRINCIPAL_ARN_PATTERN_ALLOWED, "partition", iamPrincipalArn);
}

private String getNamedGroupFromRegexPattern(
final Pattern pattern, final String groupName, final String input) {
final Matcher iamRoleArnMatcher = pattern.matcher(input);
Expand All @@ -160,7 +188,17 @@ private String getNamedGroupFromRegexPattern(
.withExceptionMessage("ARN does not match pattern: " + pattern.toString())
.build();
}
partitionCheck(iamRoleArnMatcher.group("partition"));

return iamRoleArnMatcher.group(groupName);
}

private void partitionCheck(String partition) {
if (DomainConstants.AWS_GLOBAL_PARTITION_NAME.equals(partition) && !awsGlobalEnabled) {
throw ApiException.newBuilder().withApiErrors(DefaultApiError.AWS_GLOBAL_NOT_ALLOWED).build();
}
if (DomainConstants.AWS_CHINA_PARTITION_NAME.equals(partition) && !awsChinaEnabled) {
throw ApiException.newBuilder().withApiErrors(DefaultApiError.AWS_CHINA_NOT_ALLOWED).build();
}
}
}
5 changes: 5 additions & 0 deletions cerberus-web/src/main/resources/cerberus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ c3p0:
preferredTestQuery: SELECT 1

cerberus:
partitions:
awsGlobal:
enabled: true
awsChina:
enabled: false
environmentName: TODO
admin:
# These are aws principal that you want to allow to use the admin API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,13 @@ public void test_getRegion_returns_region_as_expected() {
"AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/us-east-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf");

assertEquals("us-east-1", header.getRegion());

header =
new AwsStsHttpHeader(
"20180904T205115Z",
"FQoGZXIvYXdzEFYaDEYceadsfLKJLKlkj908098oB/rJIdxdo57fx3Ef2wW8WhFbSpLGg3hwNqhuepdkf/c0F7OXJutqM2yjgnZCiO7SPAdnMSJhoEgH7SJlkPaPfiRzZAf0yxxD6e4z0VJU74uQfbgfZpn5RL+JyDpgoYkUrjuyL8zRB1knGSOCi32Q75+asdfasd+7bWxMyJIKEb/HF2Le8xM/9F4WRqa5P0+asdfasdfasdf+MGlDlNG0KTzg1JT6QXf95ozWR5bBFSz5DbrFhXhMegMQ7+7Kvx+asdfasdl.jlkj++5NpRRlE54cct7+aG3HQskow9y73AU=",
"AWS4-HMAC-SHA256 Credential=ASIA5S2FQS2GYQLK5FFF/20180904/cn-northwest-1/sts/aws4_request, SignedHeaders=host;x-amz-date, Signature=ddb9417d2b9bfe6f8b03e31a8f5d8ab98e0f4alkj12312098asdf");

assertEquals("cn-northwest-1", header.getRegion());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nike.cerberus.service;

import static com.nike.cerberus.domain.DomainConstants.AWS_IAM_ROLE_ARN_TEMPLATE;
import static com.nike.cerberus.domain.DomainConstants.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotEquals;
Expand Down Expand Up @@ -266,7 +266,8 @@ public void test_that_getKeyId_only_validates_kms_policy_one_time_within_interva
String accountId = "0000000000";
String roleName = "role/path";
String principalArn = String.format("arn:aws:iam::%s:instance-profile/%s", accountId, roleName);
String roleArn = String.format(AWS_IAM_ROLE_ARN_TEMPLATE, accountId, roleName);
String roleArn =
String.format(AWS_IAM_ROLE_ARN_TEMPLATE, AWS_GLOBAL_PARTITION_NAME, accountId, roleName);

AwsIamRoleRecord awsIamRoleRecord = mock(AwsIamRoleRecord.class);
when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.empty());
Expand Down Expand Up @@ -311,7 +312,8 @@ public void test_that_findIamRoleAssociatedWithSdb_returns_empty_optional_when_r
String accountId = "0000000000";
String roleName = "role/path";
String principalArn = String.format("arn:aws:iam::%s:instance-profile/%s", accountId, roleName);
String roleArn = String.format(AWS_IAM_ROLE_ARN_TEMPLATE, accountId, roleName);
String roleArn =
String.format(AWS_IAM_ROLE_ARN_TEMPLATE, AWS_GLOBAL_PARTITION_NAME, accountId, roleName);
String rootArn = String.format("arn:aws:iam::%s:root", accountId);

AwsIamRoleRecord rootRecord = mock(AwsIamRoleRecord.class);
Expand All @@ -332,6 +334,36 @@ public void test_that_findIamRoleAssociatedWithSdb_returns_empty_optional_when_r
assertEquals(roleRecord, result.get());
}

@Test
public void
test_that_findIamRoleAssociatedWithSdb_returns_generic_role_when_iam_principal_not_found_and_root_found_for_aws_china() {

String accountId = "0000000000";
String roleName = "role/path";
String principalArn =
String.format("arn:aws-cn:iam::%s:instance-profile/%s", accountId, roleName);
String roleArn =
String.format(AWS_IAM_ROLE_ARN_TEMPLATE, AWS_CHINA_PARTITION_NAME, accountId, roleName);
String rootArn = String.format("arn:aws-cn:iam::%s:root", accountId);

AwsIamRoleRecord rootRecord = mock(AwsIamRoleRecord.class);
AwsIamRoleRecord roleRecord = mock(AwsIamRoleRecord.class);
when(awsIamRoleDao.getIamRole(principalArn)).thenReturn(Optional.empty());
when(awsIamRoleDao.getIamRole(roleArn)).thenReturn(Optional.empty());
when(awsIamRoleDao.getIamRole(rootArn)).thenReturn(Optional.of(rootRecord));

when(awsIamRoleArnParser.isRoleArn(principalArn)).thenReturn(false);
when(awsIamRoleArnParser.convertPrincipalArnToRoleArn(principalArn)).thenReturn(roleArn);
when(awsIamRoleArnParser.convertPrincipalArnToRootArn(roleArn)).thenReturn(rootArn);

when(awsIamRoleService.createIamRole(roleArn)).thenReturn(roleRecord);

Optional<AwsIamRoleRecord> result =
authenticationService.findIamRoleAssociatedWithSdb(principalArn);

assertEquals(roleRecord, result.get());
}

@Test
public void
tests_that_validateAuthPayloadSizeAndTruncateIfLargerThanMaxKmsSupportedSize_returns_the_original_payload_if_the_size_can_be_encrypted_by_kms()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,29 @@ public void before() {
String cmsRoleArn = "arn:aws:iam::1111111111:role/cms-iam-role";
kmsPolicyService =
new KmsPolicyService(
true, rootUserArn, adminRoleArn, cmsRoleArn, new AwsIamRoleArnParser());
true, rootUserArn, adminRoleArn, cmsRoleArn, new AwsIamRoleArnParser(true, false));
objectMapper = new ObjectMapper();
}

@Test(expected = NullPointerException.class)
public void test_that_KmsPolicyService_throws_error_when_required_field_null_rootUserArn() {
new KmsPolicyService(true, null, "foo", "bar", new AwsIamRoleArnParser());
new KmsPolicyService(true, null, "foo", "bar", new AwsIamRoleArnParser(true, false));
}

@Test(expected = NullPointerException.class)
public void test_that_KmsPolicyService_throws_error_when_required_field_null_adminRoleArn() {
new KmsPolicyService(true, "foo", null, "bar", new AwsIamRoleArnParser());
new KmsPolicyService(true, "foo", null, "bar", new AwsIamRoleArnParser(true, false));
}

@Test(expected = NullPointerException.class)
public void test_that_KmsPolicyService_throws_error_when_required_field_null_cmsRoleArn() {
new KmsPolicyService(true, "foo", "bar", null, new AwsIamRoleArnParser());
new KmsPolicyService(true, "foo", "bar", null, new AwsIamRoleArnParser(true, false));
}

@Test()
public void
test_that_KmsPolicyService_throws_no_error_when_required_fields_are_null_but_kms_auth_disabled() {
new KmsPolicyService(false, null, null, null, new AwsIamRoleArnParser());
new KmsPolicyService(false, null, null, null, new AwsIamRoleArnParser(true, false));
}

@Test
Expand Down
Loading

0 comments on commit b9132e2

Please sign in to comment.