Skip to content

Commit

Permalink
feat(api): rds auth model level rules
Browse files Browse the repository at this point in the history
  • Loading branch information
sundersc committed Sep 28, 2023
1 parent eb3c85d commit 1600579
Show file tree
Hide file tree
Showing 10 changed files with 392 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ export class AuthTransformer extends TransformerAuthBase implements TransformerA
? ctx.api.host.getDataSource(RDSLambdaDataSourceLogicalID)
: ctx.api.host.getDataSource(`${def.name.value}Table`)
) as DataSourceProvider;
const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression();
const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression(ctx, def);
const authExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthExpressionForUpdate(
this.configuredAuthProviders,
totalRoles,
Expand All @@ -952,8 +952,13 @@ export class AuthTransformer extends TransformerAuthBase implements TransformerA
const resolver = ctx.resolvers.getResolver(typeName, fieldName) as TransformerResolverProvider;
// only roles with full delete on every field can delete
const deleteRoles = acm.getRolesPerOperation('delete', true).map((role) => this.roleMap.get(role)!);
const dataSource = ctx.api.host.getDataSource(`${def.name.value}Table`) as DataSourceProvider;
const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression();
const { RDSLambdaDataSourceLogicalID } = ResourceConstants.RESOURCES;
const dataSource = (
isRDSModel(ctx, def.name.value)
? ctx.api.host.getDataSource(RDSLambdaDataSourceLogicalID)
: ctx.api.host.getDataSource(`${def.name.value}Table`)
) as DataSourceProvider;
const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression(ctx, def);
const authExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthExpressionForDelete(
this.configuredAuthProviders,
deleteRoles,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ export class DDBAuthVTLGenerator implements AuthVTLGenerator {
fields: ReadonlyArray<FieldDefinitionNode>,
): string => generateAuthExpressionForUpdate(providers, roles, fields);

generateAuthRequestExpression = (): string => generateAuthRequestExpression();
generateAuthRequestExpression = (
ctx: TransformerContextProvider,
def: ObjectTypeDefinitionNode,
): string => generateAuthRequestExpression();

generateAuthExpressionForDelete = (
providers: ConfiguredAuthProviders,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces';
import { FieldDefinitionNode, ObjectTypeDefinitionNode } from 'graphql';
import { ConfiguredAuthProviders, RoleDefinition, RelationalPrimaryMapConfig } from '../../utils';
import { ConfiguredAuthProviders, RoleDefinition } from '../../utils';
import { AuthVTLGenerator } from '../vtl-generator';
import { generateDefaultRDSExpression } from './resolvers';
import { generateAuthExpressionForQueries } from './resolvers/query';
import { generateAuthExpressionForCreate, generateAuthExpressionForDelete, generateAuthExpressionForUpdate, generateAuthRequestExpression } from './resolvers/mutation';
import { generateAuthExpressionForSubscriptions } from './resolvers/subscription';

export class RDSAuthVTLGenerator implements AuthVTLGenerator {
generateAuthExpressionForCreate = (
ctx: TransformerContextProvider,
providers: ConfiguredAuthProviders,
roles: Array<RoleDefinition>,
fields: ReadonlyArray<FieldDefinitionNode>,
): string => generateDefaultRDSExpression();
): string => generateAuthExpressionForCreate(ctx, providers, roles, fields);

generateAuthExpressionForUpdate = (
providers: ConfiguredAuthProviders,
roles: Array<RoleDefinition>,
fields: ReadonlyArray<FieldDefinitionNode>,
): string => generateDefaultRDSExpression();
): string => generateAuthExpressionForUpdate(providers, roles, fields);

generateAuthRequestExpression = (): string => generateDefaultRDSExpression();
generateAuthRequestExpression = (
ctx: TransformerContextProvider,
def: ObjectTypeDefinitionNode,
): string => generateAuthRequestExpression(ctx, def);

generateAuthExpressionForDelete = (
providers: ConfiguredAuthProviders,
roles: Array<RoleDefinition>,
fields: ReadonlyArray<FieldDefinitionNode>,
): string => generateDefaultRDSExpression();
): string => generateAuthExpressionForDelete(providers, roles, fields);

generateAuthExpressionForField = (
providers: ConfiguredAuthProviders,
Expand All @@ -43,7 +49,7 @@ export class RDSAuthVTLGenerator implements AuthVTLGenerator {
fields: ReadonlyArray<FieldDefinitionNode>,
def: ObjectTypeDefinitionNode,
indexName: string | undefined,
): string => generateDefaultRDSExpression();
): string => generateAuthExpressionForQueries(ctx, providers, roles, fields, def, indexName);

generateAuthExpressionForSearchQueries = (
providers: ConfiguredAuthProviders,
Expand All @@ -53,7 +59,7 @@ export class RDSAuthVTLGenerator implements AuthVTLGenerator {
): string => generateDefaultRDSExpression();

generateAuthExpressionForSubscriptions = (providers: ConfiguredAuthProviders, roles: Array<RoleDefinition>): string =>
generateDefaultRDSExpression();
generateAuthExpressionForSubscriptions(providers, roles);

setDeniedFieldFlag = (operation: string, subscriptionsEnabled: boolean): string => generateDefaultRDSExpression();

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { compoundExpression, methodCall, obj, printBlock, ref, toJson } from 'graphql-mapping-template';
import { Expression, and, bool, compoundExpression, iff, list, methodCall, not, notEquals, obj, or, parens, printBlock, qref, raw, ref, set, str, toJson } from 'graphql-mapping-template';
import { FieldDefinitionNode } from 'graphql';
import { API_KEY_AUTH_TYPE, RoleDefinition } from '../../../utils';

/**
* Generates default RDS expression
Expand All @@ -7,3 +9,178 @@ export const generateDefaultRDSExpression = (): string => {
const exp = methodCall(ref('util.unauthorized'));
return printBlock('Default RDS Auth Resolver')(compoundExpression([exp, toJson(obj({}))]));
};

export const generateAuthRulesFromRoles = (roles: Array<RoleDefinition>, fields: Readonly<FieldDefinitionNode[]>): Expression[] => {
const expressions = [];
expressions.push(
qref(methodCall(ref('ctx.stash.put'), str('hasAuth'), bool(true))),
set(ref('authRules'), list([])),
);
const fieldNames = fields.map((field) => field.name.value);
roles.forEach((role) => {
expressions.push(convertAuthRoleToVtl(role, fieldNames));
});
return expressions;
};

const convertAuthRoleToVtl = (role: RoleDefinition, fields: string[]): Expression => {
const allowedFields = getAllowedFields(role, fields).map((field) => str(field));

// Api Key
if (role.provider === 'apiKey') {
return qref(
methodCall(
ref('authRules.add'),
obj({
type: str('public'),
provider: str('apiKey'),
...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) },
})
)
);
}

// Lambda Authorizer
else if (role.provider === 'function') {
return qref(
methodCall(
ref('authRules.add'),
obj({
type: str('custom'),
provider: str('function'),
...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) },
})
)
);
}

// IAM
else if (role.provider === 'iam') {
return qref(
methodCall(
ref('authRules.add'),
obj({
type: str(role.strategy),
provider: str('iam'),
roleArn: role.strategy === 'public' ? ref('ctx.stash.unauthRole') : ref('ctx.stash.authRole'),
cognitoIdentityPoolId: ref('ctx.identity.cognitoIdentityPoolId'),
...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) },
})
)
);
}

// User Pools or OIDC
else if (role.provider === 'userPools' || role.provider === 'oidc') {
if (role.strategy === 'private') {
return qref(
methodCall(
ref('authRules.add'),
obj({
type: str(role.strategy),
provider: str('userPools'),
allowedFields: list(getAllowedFields(role, fields).map((field) => str(field))),
})
)
);
}
else if (role.strategy === 'groups' && role.static) {
return qref(
methodCall(
ref('authRules.add'),
obj({
type: str(role.strategy),
provider: str('userPools'),
allowedGroups: list([str(role.entity)]),
identityClaim: str(role.claim),
allowedFields: list(getAllowedFields(role, fields).map((field) => str(field))),
})
)
);
}
else if (role.strategy === 'owner') {
return qref(
methodCall(
ref('authRules.add'),
obj({
type: str(role.strategy),
provider: str('userPools'),
ownerFieldName: str(role.entity),
ownerFieldType: str(role.isEntityList ? 'string[]' : 'string'),
identityClaim: str(role.claim),
...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) },
})
)
);
}
else if (role.strategy === 'groups' && !role.static) {
return qref(
methodCall(
ref('authRules.add'),
obj({
type: str(role.strategy),
provider: str('userPools'),
groupsFieldName: str(role.entity),
groupsFieldType: str(role.isEntityList ? 'string[]' : 'string'),
groupClaim: str(role.claim),
...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) },
})
)
);
}
}
throw new Error(`Invalid Auth Rule: Unable to process ${JSON.stringify(role)}`);
};

const getAllowedFields = (role: RoleDefinition, fields: string[]): string[] => {
if (role.allowedFields && role.allowedFields.length > 1) {
return role.allowedFields;
}
return fields;
};

export const validateAuthResult = (): Expression => {
return compoundExpression([
iff(
or ([
not(ref('authResult')),
parens(
and([
ref('authResult'),
not(ref('authResult.authorized')),
]),
),
]),
ref('util.unauthorized'),
),
]);
};

export const constructAuthFilter = (): Expression => {
return iff(
and([
ref('authResult'),
not(methodCall(ref('util.isNullOrEmpty'), ref('authResult.authFilter'))),
]),
set(ref('ctx.stash.authFilter'), ref('authResult.authFilter')),
);
};

export const constructAuthorizedInputStatement = (keyName: string): Expression =>
iff(
and([
ref('authResult'),
not(methodCall(ref('util.isNullOrEmpty'), ref('authResult.authorizedInput'))),
]),
set(ref(keyName), ref('authResult.authorizedInput')),
);

/**
* Generates sandbox expression for field
*/
export const generateSandboxExpressionForField = (sandboxEnabled: boolean): string => {
let exp: Expression;
if (sandboxEnabled) exp = iff(notEquals(methodCall(ref('util.authType')), str(API_KEY_AUTH_TYPE)), methodCall(ref('util.unauthorized')));
else exp = methodCall(ref('util.unauthorized'));
return printBlock(`Sandbox Mode ${sandboxEnabled ? 'Enabled' : 'Disabled'}`)(compoundExpression([exp, toJson(obj({}))]));
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import { compoundExpression, list, methodCall, obj, printBlock, qref, ref, set, str } from 'graphql-mapping-template';
import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces';
import { FieldDefinitionNode, ObjectTypeDefinitionNode } from 'graphql';
import { ConfiguredAuthProviders, RoleDefinition } from '../../../utils';
import { constructAuthorizedInputStatement, generateAuthRulesFromRoles, validateAuthResult } from './common';

export const generateAuthExpressionForCreate = (
ctx: TransformerContextProvider,
providers: ConfiguredAuthProviders,
roles: Array<RoleDefinition>,
fields: ReadonlyArray<FieldDefinitionNode>,
): string => {
const expressions = [];
const operation = 'create';
expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, fields)));
expressions.push(
set(
ref('authResult'),
methodCall(
ref('util.authRules.mutationAuth'),
ref('authRules'),
str(operation),
ref('ctx.args.input'),
),
),
);
expressions.push(
validateAuthResult(),
constructAuthorizedInputStatement('ctx.args.input'),
);
return printBlock('Authorization rules')(compoundExpression(expressions));
};

export const generateAuthExpressionForUpdate = (
providers: ConfiguredAuthProviders,
roles: Array<RoleDefinition>,
fields: ReadonlyArray<FieldDefinitionNode>,
): string => {
const expressions = [];
const operation = 'update';
expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, fields)));
expressions.push(
set(
ref('authResult'),
methodCall(
ref('util.authRules.mutationAuth'),
ref('authRules'),
str(operation),
ref('ctx.args.input'),
ref('ctx.source'),
),
),
);
expressions.push(validateAuthResult());
return printBlock('Authorization rules')(compoundExpression(expressions));
};

export const generateAuthExpressionForDelete = (
providers: ConfiguredAuthProviders,
roles: Array<RoleDefinition>,
fields: ReadonlyArray<FieldDefinitionNode>,
): string => {
const expressions = [];
const operation = 'delete';
expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, fields)));
expressions.push(
set(
ref('authResult'),
methodCall(
ref('util.authRules.mutationAuth'),
ref('authRules'),
str(operation),
ref('ctx.args.input'),
ref('ctx.source'),
),
),
);
expressions.push(validateAuthResult());
return printBlock('Authorization rules')(compoundExpression(expressions));
};

export const generateAuthRequestExpression = (ctx: TransformerContextProvider, def: ObjectTypeDefinitionNode): string => {
const mappedTableName = ctx.resourceHelper.getModelNameMapping(def.name.value);
const operation = 'GET';
const operationName = 'GET_EXISTING_RECORD';
return printBlock('Get existing record')(
compoundExpression([
set(ref('lambdaInput'), obj({})),
set(ref('lambdaInput.args'), obj({})),
set(ref('lambdaInput.table'), str(mappedTableName)),
set(ref('lambdaInput.operation'), str(operation)),
set(ref('lambdaInput.operationName'), str(operationName)),
set(ref('lambdaInput.args.metadata'), obj({})),
set(ref('lambdaInput.args.metadata.keys'), list([])),
qref(
methodCall(ref('lambdaInput.args.metadata.keys.addAll'),
methodCall(ref('util.defaultIfNull'), ref('ctx.stash.keys'), list([]))),
),
set(
ref('lambdaInput.args.input'),
methodCall(ref('util.map.copyAndRetainAllKeys'), ref('context.arguments.input'), ref('ctx.stash.keys')),
),
obj({
version: str('2018-05-29'),
operation: str('Invoke'),
payload: methodCall(ref('util.toJson'), ref('lambdaInput')),
}),
]),
);
};
Loading

0 comments on commit 1600579

Please sign in to comment.