diff --git a/packages/amplify-graphql-auth-transformer/src/graphql-auth-transformer.ts b/packages/amplify-graphql-auth-transformer/src/graphql-auth-transformer.ts index c45745bfc2..6f7da03dcc 100644 --- a/packages/amplify-graphql-auth-transformer/src/graphql-auth-transformer.ts +++ b/packages/amplify-graphql-auth-transformer/src/graphql-auth-transformer.ts @@ -928,7 +928,7 @@ export class AuthTransformer extends TransformerAuthBase implements TransformerA ? ctx.api.host.getDataSource(RDSLambdaDataSourceLogicalID) : ctx.api.host.getDataSource(`${def.name.value}Table`) ) as DataSourceProvider; - const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression(); + const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression(ctx, def); const authExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthExpressionForUpdate( this.configuredAuthProviders, totalRoles, @@ -952,8 +952,13 @@ export class AuthTransformer extends TransformerAuthBase implements TransformerA const resolver = ctx.resolvers.getResolver(typeName, fieldName) as TransformerResolverProvider; // only roles with full delete on every field can delete const deleteRoles = acm.getRolesPerOperation('delete', true).map((role) => this.roleMap.get(role)!); - const dataSource = ctx.api.host.getDataSource(`${def.name.value}Table`) as DataSourceProvider; - const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression(); + const { RDSLambdaDataSourceLogicalID } = ResourceConstants.RESOURCES; + const dataSource = ( + isRDSModel(ctx, def.name.value) + ? ctx.api.host.getDataSource(RDSLambdaDataSourceLogicalID) + : ctx.api.host.getDataSource(`${def.name.value}Table`) + ) as DataSourceProvider; + const requestExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthRequestExpression(ctx, def); const authExpression = this.getVtlGenerator(ctx, def.name.value).generateAuthExpressionForDelete( this.configuredAuthProviders, deleteRoles, diff --git a/packages/amplify-graphql-auth-transformer/src/vtl-generator/ddb/ddb-vtl-generator.ts b/packages/amplify-graphql-auth-transformer/src/vtl-generator/ddb/ddb-vtl-generator.ts index b7e1298ef4..c62a9022af 100644 --- a/packages/amplify-graphql-auth-transformer/src/vtl-generator/ddb/ddb-vtl-generator.ts +++ b/packages/amplify-graphql-auth-transformer/src/vtl-generator/ddb/ddb-vtl-generator.ts @@ -33,7 +33,10 @@ export class DDBAuthVTLGenerator implements AuthVTLGenerator { fields: ReadonlyArray, ): string => generateAuthExpressionForUpdate(providers, roles, fields); - generateAuthRequestExpression = (): string => generateAuthRequestExpression(); + generateAuthRequestExpression = ( + ctx: TransformerContextProvider, + def: ObjectTypeDefinitionNode, + ): string => generateAuthRequestExpression(); generateAuthExpressionForDelete = ( providers: ConfiguredAuthProviders, diff --git a/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/rds-vtl-generator.ts b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/rds-vtl-generator.ts index ca712160ce..f5927e5a08 100644 --- a/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/rds-vtl-generator.ts +++ b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/rds-vtl-generator.ts @@ -1,8 +1,11 @@ import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; import { FieldDefinitionNode, ObjectTypeDefinitionNode } from 'graphql'; -import { ConfiguredAuthProviders, RoleDefinition, RelationalPrimaryMapConfig } from '../../utils'; +import { ConfiguredAuthProviders, RoleDefinition } from '../../utils'; import { AuthVTLGenerator } from '../vtl-generator'; import { generateDefaultRDSExpression } from './resolvers'; +import { generateAuthExpressionForQueries } from './resolvers/query'; +import { generateAuthExpressionForCreate, generateAuthExpressionForDelete, generateAuthExpressionForUpdate, generateAuthRequestExpression } from './resolvers/mutation'; +import { generateAuthExpressionForSubscriptions } from './resolvers/subscription'; export class RDSAuthVTLGenerator implements AuthVTLGenerator { generateAuthExpressionForCreate = ( @@ -10,21 +13,24 @@ export class RDSAuthVTLGenerator implements AuthVTLGenerator { providers: ConfiguredAuthProviders, roles: Array, fields: ReadonlyArray, - ): string => generateDefaultRDSExpression(); + ): string => generateAuthExpressionForCreate(ctx, providers, roles, fields); generateAuthExpressionForUpdate = ( providers: ConfiguredAuthProviders, roles: Array, fields: ReadonlyArray, - ): string => generateDefaultRDSExpression(); + ): string => generateAuthExpressionForUpdate(providers, roles, fields); - generateAuthRequestExpression = (): string => generateDefaultRDSExpression(); + generateAuthRequestExpression = ( + ctx: TransformerContextProvider, + def: ObjectTypeDefinitionNode, + ): string => generateAuthRequestExpression(ctx, def); generateAuthExpressionForDelete = ( providers: ConfiguredAuthProviders, roles: Array, fields: ReadonlyArray, - ): string => generateDefaultRDSExpression(); + ): string => generateAuthExpressionForDelete(providers, roles, fields); generateAuthExpressionForField = ( providers: ConfiguredAuthProviders, @@ -43,7 +49,7 @@ export class RDSAuthVTLGenerator implements AuthVTLGenerator { fields: ReadonlyArray, def: ObjectTypeDefinitionNode, indexName: string | undefined, - ): string => generateDefaultRDSExpression(); + ): string => generateAuthExpressionForQueries(ctx, providers, roles, fields, def, indexName); generateAuthExpressionForSearchQueries = ( providers: ConfiguredAuthProviders, @@ -53,7 +59,7 @@ export class RDSAuthVTLGenerator implements AuthVTLGenerator { ): string => generateDefaultRDSExpression(); generateAuthExpressionForSubscriptions = (providers: ConfiguredAuthProviders, roles: Array): string => - generateDefaultRDSExpression(); + generateAuthExpressionForSubscriptions(providers, roles); setDeniedFieldFlag = (operation: string, subscriptionsEnabled: boolean): string => generateDefaultRDSExpression(); diff --git a/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/common.ts b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/common.ts index cdbd143ab4..a250c3a316 100644 --- a/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/common.ts +++ b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/common.ts @@ -1,4 +1,6 @@ -import { compoundExpression, methodCall, obj, printBlock, ref, toJson } from 'graphql-mapping-template'; +import { Expression, and, bool, compoundExpression, iff, list, methodCall, not, notEquals, obj, or, parens, printBlock, qref, raw, ref, set, str, toJson } from 'graphql-mapping-template'; +import { FieldDefinitionNode } from 'graphql'; +import { API_KEY_AUTH_TYPE, RoleDefinition } from '../../../utils'; /** * Generates default RDS expression @@ -7,3 +9,178 @@ export const generateDefaultRDSExpression = (): string => { const exp = methodCall(ref('util.unauthorized')); return printBlock('Default RDS Auth Resolver')(compoundExpression([exp, toJson(obj({}))])); }; + +export const generateAuthRulesFromRoles = (roles: Array, fields: Readonly): Expression[] => { + const expressions = []; + expressions.push( + qref(methodCall(ref('ctx.stash.put'), str('hasAuth'), bool(true))), + set(ref('authRules'), list([])), + ); + const fieldNames = fields.map((field) => field.name.value); + roles.forEach((role) => { + expressions.push(convertAuthRoleToVtl(role, fieldNames)); + }); + return expressions; +}; + +const convertAuthRoleToVtl = (role: RoleDefinition, fields: string[]): Expression => { + const allowedFields = getAllowedFields(role, fields).map((field) => str(field)); + + // Api Key + if (role.provider === 'apiKey') { + return qref( + methodCall( + ref('authRules.add'), + obj({ + type: str('public'), + provider: str('apiKey'), + ...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) }, + }) + ) + ); + } + + // Lambda Authorizer + else if (role.provider === 'function') { + return qref( + methodCall( + ref('authRules.add'), + obj({ + type: str('custom'), + provider: str('function'), + ...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) }, + }) + ) + ); + } + + // IAM + else if (role.provider === 'iam') { + return qref( + methodCall( + ref('authRules.add'), + obj({ + type: str(role.strategy), + provider: str('iam'), + roleArn: role.strategy === 'public' ? ref('ctx.stash.unauthRole') : ref('ctx.stash.authRole'), + cognitoIdentityPoolId: ref('ctx.identity.cognitoIdentityPoolId'), + ...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) }, + }) + ) + ); + } + + // User Pools or OIDC + else if (role.provider === 'userPools' || role.provider === 'oidc') { + if (role.strategy === 'private') { + return qref( + methodCall( + ref('authRules.add'), + obj({ + type: str(role.strategy), + provider: str('userPools'), + allowedFields: list(getAllowedFields(role, fields).map((field) => str(field))), + }) + ) + ); + } + else if (role.strategy === 'groups' && role.static) { + return qref( + methodCall( + ref('authRules.add'), + obj({ + type: str(role.strategy), + provider: str('userPools'), + allowedGroups: list([str(role.entity)]), + identityClaim: str(role.claim), + allowedFields: list(getAllowedFields(role, fields).map((field) => str(field))), + }) + ) + ); + } + else if (role.strategy === 'owner') { + return qref( + methodCall( + ref('authRules.add'), + obj({ + type: str(role.strategy), + provider: str('userPools'), + ownerFieldName: str(role.entity), + ownerFieldType: str(role.isEntityList ? 'string[]' : 'string'), + identityClaim: str(role.claim), + ...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) }, + }) + ) + ); + } + else if (role.strategy === 'groups' && !role.static) { + return qref( + methodCall( + ref('authRules.add'), + obj({ + type: str(role.strategy), + provider: str('userPools'), + groupsFieldName: str(role.entity), + groupsFieldType: str(role.isEntityList ? 'string[]' : 'string'), + groupClaim: str(role.claim), + ...(allowedFields && allowedFields.length > 0) && { allowedFields: list(allowedFields) }, + }) + ) + ); + } + } + throw new Error(`Invalid Auth Rule: Unable to process ${JSON.stringify(role)}`); +}; + +const getAllowedFields = (role: RoleDefinition, fields: string[]): string[] => { + if (role.allowedFields && role.allowedFields.length > 1) { + return role.allowedFields; + } + return fields; +}; + +export const validateAuthResult = (): Expression => { + return compoundExpression([ + iff( + or ([ + not(ref('authResult')), + parens( + and([ + ref('authResult'), + not(ref('authResult.authorized')), + ]), + ), + ]), + ref('util.unauthorized'), + ), + ]); +}; + +export const constructAuthFilter = (): Expression => { + return iff( + and([ + ref('authResult'), + not(methodCall(ref('util.isNullOrEmpty'), ref('authResult.authFilter'))), + ]), + set(ref('ctx.stash.authFilter'), ref('authResult.authFilter')), + ); +}; + +export const constructAuthorizedInputStatement = (keyName: string): Expression => +iff( + and([ + ref('authResult'), + not(methodCall(ref('util.isNullOrEmpty'), ref('authResult.authorizedInput'))), + ]), + set(ref(keyName), ref('authResult.authorizedInput')), +); + +/** + * Generates sandbox expression for field + */ +export const generateSandboxExpressionForField = (sandboxEnabled: boolean): string => { + let exp: Expression; + if (sandboxEnabled) exp = iff(notEquals(methodCall(ref('util.authType')), str(API_KEY_AUTH_TYPE)), methodCall(ref('util.unauthorized'))); + else exp = methodCall(ref('util.unauthorized')); + return printBlock(`Sandbox Mode ${sandboxEnabled ? 'Enabled' : 'Disabled'}`)(compoundExpression([exp, toJson(obj({}))])); +}; diff --git a/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/mutation.ts b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/mutation.ts new file mode 100644 index 0000000000..201910a923 --- /dev/null +++ b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/mutation.ts @@ -0,0 +1,110 @@ +import { compoundExpression, list, methodCall, obj, printBlock, qref, ref, set, str } from 'graphql-mapping-template'; +import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import { FieldDefinitionNode, ObjectTypeDefinitionNode } from 'graphql'; +import { ConfiguredAuthProviders, RoleDefinition } from '../../../utils'; +import { constructAuthorizedInputStatement, generateAuthRulesFromRoles, validateAuthResult } from './common'; + +export const generateAuthExpressionForCreate = ( + ctx: TransformerContextProvider, + providers: ConfiguredAuthProviders, + roles: Array, + fields: ReadonlyArray, +): string => { + const expressions = []; + const operation = 'create'; + expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, fields))); + expressions.push( + set( + ref('authResult'), + methodCall( + ref('util.authRules.mutationAuth'), + ref('authRules'), + str(operation), + ref('ctx.args.input'), + ), + ), + ); + expressions.push( + validateAuthResult(), + constructAuthorizedInputStatement('ctx.args.input'), + ); + return printBlock('Authorization rules')(compoundExpression(expressions)); +}; + +export const generateAuthExpressionForUpdate = ( + providers: ConfiguredAuthProviders, + roles: Array, + fields: ReadonlyArray, +): string => { + const expressions = []; + const operation = 'update'; + expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, fields))); + expressions.push( + set( + ref('authResult'), + methodCall( + ref('util.authRules.mutationAuth'), + ref('authRules'), + str(operation), + ref('ctx.args.input'), + ref('ctx.source'), + ), + ), + ); + expressions.push(validateAuthResult()); + return printBlock('Authorization rules')(compoundExpression(expressions)); +}; + +export const generateAuthExpressionForDelete = ( + providers: ConfiguredAuthProviders, + roles: Array, + fields: ReadonlyArray, +): string => { + const expressions = []; + const operation = 'delete'; + expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, fields))); + expressions.push( + set( + ref('authResult'), + methodCall( + ref('util.authRules.mutationAuth'), + ref('authRules'), + str(operation), + ref('ctx.args.input'), + ref('ctx.source'), + ), + ), + ); + expressions.push(validateAuthResult()); + return printBlock('Authorization rules')(compoundExpression(expressions)); +}; + +export const generateAuthRequestExpression = (ctx: TransformerContextProvider, def: ObjectTypeDefinitionNode): string => { + const mappedTableName = ctx.resourceHelper.getModelNameMapping(def.name.value); + const operation = 'GET'; + const operationName = 'GET_EXISTING_RECORD'; + return printBlock('Get existing record')( + compoundExpression([ + set(ref('lambdaInput'), obj({})), + set(ref('lambdaInput.args'), obj({})), + set(ref('lambdaInput.table'), str(mappedTableName)), + set(ref('lambdaInput.operation'), str(operation)), + set(ref('lambdaInput.operationName'), str(operationName)), + set(ref('lambdaInput.args.metadata'), obj({})), + set(ref('lambdaInput.args.metadata.keys'), list([])), + qref( + methodCall(ref('lambdaInput.args.metadata.keys.addAll'), + methodCall(ref('util.defaultIfNull'), ref('ctx.stash.keys'), list([]))), + ), + set( + ref('lambdaInput.args.input'), + methodCall(ref('util.map.copyAndRetainAllKeys'), ref('context.arguments.input'), ref('ctx.stash.keys')), + ), + obj({ + version: str('2018-05-29'), + operation: str('Invoke'), + payload: methodCall(ref('util.toJson'), ref('lambdaInput')), + }), + ]), + ); +}; diff --git a/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/query.ts b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/query.ts new file mode 100644 index 0000000000..979f347634 --- /dev/null +++ b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/query.ts @@ -0,0 +1,28 @@ +import { compoundExpression, methodCall, printBlock, ref, set } from 'graphql-mapping-template'; +import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import { FieldDefinitionNode, ObjectTypeDefinitionNode } from 'graphql'; +import { ConfiguredAuthProviders, RoleDefinition } from '../../../utils'; +import { constructAuthFilter, generateAuthRulesFromRoles, validateAuthResult } from './common'; + +export const generateAuthExpressionForQueries = ( + ctx: TransformerContextProvider, + providers: ConfiguredAuthProviders, + roles: Array, + fields: ReadonlyArray, + def: ObjectTypeDefinitionNode, + indexName: string | undefined, +): string => { + const expressions = []; + expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, fields))); + expressions.push( + set( + ref('authResult'), + methodCall(ref('util.authRules.queryAuth'), ref('authRules')), + ), + ); + expressions.push( + validateAuthResult(), + constructAuthFilter(), + ); + return printBlock('Authorization rules')(compoundExpression(expressions)); +}; diff --git a/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/subscription.ts b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/subscription.ts new file mode 100644 index 0000000000..9311f3dc03 --- /dev/null +++ b/packages/amplify-graphql-auth-transformer/src/vtl-generator/rds/resolvers/subscription.ts @@ -0,0 +1,37 @@ +import { and, compoundExpression, ifElse, iff, list, methodCall, not, obj, printBlock, ref, set } from 'graphql-mapping-template'; +import { ConfiguredAuthProviders, RoleDefinition } from '../../../utils'; +import { generateAuthRulesFromRoles, validateAuthResult } from './common'; + +export const generateAuthExpressionForSubscriptions = (providers: ConfiguredAuthProviders, roles: Array): string => { + const expressions = []; + expressions.push(compoundExpression(generateAuthRulesFromRoles(roles, []))); + expressions.push( + set( + ref('authResult'), + methodCall(ref('util.authRules.subscriptionAuth'), ref('authRules')), + ), + ); + expressions.push(validateAuthResult()); + + // Construct auth filter to set it as runtime filter + expressions.push( + iff( + and([ + ref('authResult'), + not(methodCall(ref('util.isNullOrEmpty'), ref('authResult.authFilter'))), + ]), + ifElse( + methodCall(ref('util.isNullOrEmpty'), ref('ctx.args.filter')), + set(ref('ctx.args.filter'), ref('authResult.authFilter')), + set(ref('ctx.args.filter'), obj({ + and: list([ + ref('authResult.authFilter'), + ref('ctx.args.filter'), + ]), + })), + ), + ), + ); + + return printBlock('Authorization rules')(compoundExpression(expressions)); +}; diff --git a/packages/amplify-graphql-auth-transformer/src/vtl-generator/vtl-generator.ts b/packages/amplify-graphql-auth-transformer/src/vtl-generator/vtl-generator.ts index 2e4e22a7a7..b477d60df1 100644 --- a/packages/amplify-graphql-auth-transformer/src/vtl-generator/vtl-generator.ts +++ b/packages/amplify-graphql-auth-transformer/src/vtl-generator/vtl-generator.ts @@ -17,7 +17,7 @@ export interface AuthVTLGenerator { fields: ReadonlyArray, ) => string; - generateAuthRequestExpression: () => string; + generateAuthRequestExpression: (ctx: TransformerContextProvider, def: ObjectTypeDefinitionNode) => string; generateAuthExpressionForDelete: ( providers: ConfiguredAuthProviders, diff --git a/packages/amplify-graphql-model-transformer/src/resolvers/rds/query.ts b/packages/amplify-graphql-model-transformer/src/resolvers/rds/query.ts index b0d3faf5a6..c5ab2b0865 100644 --- a/packages/amplify-graphql-model-transformer/src/resolvers/rds/query.ts +++ b/packages/amplify-graphql-model-transformer/src/resolvers/rds/query.ts @@ -1,6 +1,6 @@ import { TransformerContextProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { compoundExpression, list, methodCall, obj, printBlock, qref, ref, set, str } from 'graphql-mapping-template'; -import { constructNonScalarFieldsStatement } from './resolver'; +import { compoundExpression, iff, list, methodCall, not, obj, printBlock, qref, ref, set, str } from 'graphql-mapping-template'; +import { constructAuthFilterStatement, constructNonScalarFieldsStatement } from './resolver'; export const generateLambdaListRequestTemplate = ( tableName: string, @@ -18,6 +18,7 @@ export const generateLambdaListRequestTemplate = ( set(ref('lambdaInput.operationName'), str(operationName)), set(ref('lambdaInput.args.metadata'), obj({})), set(ref('lambdaInput.args.metadata.keys'), list([])), + constructAuthFilterStatement('lambdaInput.args.metadata.authFilter'), constructNonScalarFieldsStatement(tableName, ctx), qref( methodCall(ref('lambdaInput.args.metadata.keys.addAll'), methodCall(ref('util.defaultIfNull'), ref('ctx.stash.keys'), list([]))), diff --git a/packages/amplify-graphql-model-transformer/src/resolvers/rds/resolver.ts b/packages/amplify-graphql-model-transformer/src/resolvers/rds/resolver.ts index c0c0f9fbe3..19116b5291 100644 --- a/packages/amplify-graphql-model-transformer/src/resolvers/rds/resolver.ts +++ b/packages/amplify-graphql-model-transformer/src/resolvers/rds/resolver.ts @@ -3,8 +3,10 @@ import { Expression, compoundExpression, ifElse, + iff, list, methodCall, + not, obj, printBlock, qref, @@ -377,6 +379,7 @@ export const generateLambdaRequestTemplate = ( set(ref('lambdaInput.operationName'), str(operationName)), set(ref('lambdaInput.args.metadata'), obj({})), set(ref('lambdaInput.args.metadata.keys'), list([])), + constructAuthFilterStatement('lambdaInput.args.metadata.authFilter'), constructNonScalarFieldsStatement(tableName, ctx), qref( methodCall(ref('lambdaInput.args.metadata.keys.addAll'), methodCall(ref('util.defaultIfNull'), ref('ctx.stash.keys'), list([]))), @@ -451,3 +454,10 @@ export const getNonScalarFields = (object: ObjectTypeDefinitionNode | undefined, export const constructNonScalarFieldsStatement = (tableName: string, ctx: TransformerContextProvider): Expression => set(ref('lambdaInput.args.metadata.nonScalarFields'), list(getNonScalarFields(ctx.output.getObject(tableName), ctx).map(str))); + +export const constructAuthFilterStatement = (keyName: string): Expression => + iff( + not(methodCall(ref('util.isNullOrEmpty'), ref('ctx.stash.authFilter'))), + set(ref(keyName), ref('ctx.stash.authFilter')), + ); + \ No newline at end of file