From df41f2b50005b5bc96975e8e8a43591ca2bc03c6 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 31 May 2024 23:25:07 +0800 Subject: [PATCH 1/4] feat: allow comparing fields from different models in mutation policies - Generate TS checker functions to evaluate rules in JS runtime - Make sure fields needed in the checker are selected when reading entities Only supporting mutation rules (create, update, post-update, delete) because: 1. Evaluating read in JS runtime may result in reading lots of rows and then discard 2. Don't know how to support aggregation without reading all rows --- .../src/enhancements/policy/handler.ts | 205 ++++-- .../src/enhancements/policy/policy-utils.ts | 98 ++- packages/runtime/src/enhancements/types.ts | 18 +- .../validator/expression-validator.ts | 45 +- .../enhancer/policy/policy-guard-generator.ts | 48 +- .../src/plugins/enhancer/policy/utils.ts | 181 ++++- .../src/typescript-expression-transformer.ts | 6 +- .../cross-model-field-comparison.test.ts | 672 ++++++++++++++++++ .../with-policy/field-level-policy.test.ts | 8 + 9 files changed, 1163 insertions(+), 118 deletions(-) create mode 100644 tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 997e727d5..cb71b2bc1 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1,5 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +import deepmerge from 'deepmerge'; import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; import { P, match } from 'ts-pattern'; @@ -23,7 +24,7 @@ import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; -import type { CheckerConstraint } from '../types'; +import type { AdditionalCheckerFunc, CheckerConstraint } from '../types'; import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; import { ConstraintSolver } from './constraint-solver'; import { PolicyUtil } from './policy-utils'; @@ -152,8 +153,7 @@ export class PolicyProxyHandler implements Pr } const result = await this.modelClient[actionName](_args); - this.policyUtils.postProcessForRead(result, this.model, origArgs); - return result; + return this.policyUtils.postProcessForRead(result, this.model, origArgs); } //#endregion @@ -779,10 +779,27 @@ export class PolicyProxyHandler implements Pr } }; - const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => { + const _connectDisconnect = async ( + model: string, + args: any, + context: NestedWriteVisitorContext, + operation: 'connect' | 'disconnect' + ) => { if (context.field?.backLink) { const backLinkField = this.policyUtils.getModelField(model, context.field.backLink); if (backLinkField?.isRelationOwner) { + let uniqueFilter = args; + if (operation === 'disconnect') { + // disconnect filter is not unique, need to build a reversed query to + // locate the entity and use its id fields as unique filter + const reversedQuery = this.policyUtils.buildReversedQuery(context); + const found = await db[model].findUnique({ + where: reversedQuery, + select: this.policyUtils.makeIdSelection(model), + }); + uniqueFilter = found && this.policyUtils.getIdFieldValues(model, found); + } + // update happens on the related model, require updatable, // translate args to foreign keys so field-level policies can be checked const checkArgs: any = {}; @@ -794,10 +811,15 @@ export class PolicyProxyHandler implements Pr } } } - await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs); - // register post-update check - await _registerPostUpdateCheck(model, args, args); + // `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist + if (uniqueFilter) { + // check for update + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, checkArgs); + + // register post-update check + await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); + } } } }; @@ -970,14 +992,14 @@ export class PolicyProxyHandler implements Pr } }, - connect: async (model, args, context) => _connectDisconnect(model, args, context), + connect: async (model, args, context) => _connectDisconnect(model, args, context, 'connect'), connectOrCreate: async (model, args, context) => { // the where condition is already unique, so we can use it to check if the target exists const existing = await this.policyUtils.checkExistence(db, model, args.where); if (existing) { // connect - await _connectDisconnect(model, args.where, context); + await _connectDisconnect(model, args.where, context, 'connect'); return true; } else { // create @@ -997,7 +1019,7 @@ export class PolicyProxyHandler implements Pr } }, - disconnect: async (model, args, context) => _connectDisconnect(model, args, context), + disconnect: async (model, args, context) => _connectDisconnect(model, args, context, 'disconnect'), set: async (model, args, context) => { // find the set of items to be replaced @@ -1012,10 +1034,10 @@ export class PolicyProxyHandler implements Pr const currentSet = await db[model].findMany(findCurrSetArgs); // register current set for update (foreign key) - await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context))); + await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, 'disconnect'))); // proceed with connecting the new set - await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context))); + await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, 'connect'))); }, delete: async (model, args, context) => { @@ -1160,48 +1182,78 @@ export class PolicyProxyHandler implements Pr args.data = this.validateUpdateInputSchema(this.model, args.data); - if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { - // use a transaction to do post-update checks - const postWriteChecks: PostWriteCheckRecord[] = []; - return this.queryUtils.transaction(this.prisma, async (tx) => { - // collect pre-update values - let select = this.policyUtils.makeIdSelection(this.model); - const preValueSelect = this.policyUtils.getPreValueSelect(this.model); - if (preValueSelect) { - select = { ...select, ...preValueSelect }; - } - const currentSetQuery = { select, where: args.where }; - this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); + const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'update'); - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); - } - const currentSet = await tx[this.model].findMany(currentSetQuery); + const canProceedWithoutTransaction = + // no post-update rules + !this.policyUtils.hasAuthGuard(this.model, 'postUpdate') && + // no Zod schema + !this.policyUtils.getZodSchema(this.model) && + // no additional checker + !additionalChecker; - postWriteChecks.push( - ...currentSet.map((preValue) => ({ - model: this.model, - operation: 'postUpdate' as PolicyOperationKind, - uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), - preValue: preValueSelect ? preValue : undefined, - })) - ); - - // proceed with the update - const result = await tx[this.model].updateMany(args); - - // run post-write checks - await this.runPostWriteChecks(postWriteChecks, tx); - - return result; - }); - } else { + if (canProceedWithoutTransaction) { // proceed without a transaction if (this.shouldLogQuery) { this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`); } return this.modelClient.updateMany(args); } + + // collect post-update checks + const postWriteChecks: PostWriteCheckRecord[] = []; + + return this.queryUtils.transaction(this.prisma, async (tx) => { + // collect pre-update values + let select = this.policyUtils.makeIdSelection(this.model); + const preValueSelect = this.policyUtils.getPreValueSelect(this.model); + if (preValueSelect) { + select = { ...select, ...preValueSelect }; + } + + // merge selection required for running additional checker + const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector(this.model, 'update'); + if (additionalCheckerSelector) { + select = deepmerge(select, additionalCheckerSelector); + } + + const currentSetQuery = { select, where: args.where }; + this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'update'); + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); + } + let candidates = await tx[this.model].findMany(currentSetQuery); + + if (additionalChecker) { + // filter candidates with additional checker and build an id filter + const r = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker); + candidates = r.filteredCandidates; + + // merge id filter into update's where clause + args.where = args.where ? { AND: [args.where, r.idFilter] } : r.idFilter; + } + + postWriteChecks.push( + ...candidates.map((preValue) => ({ + model: this.model, + operation: 'postUpdate' as PolicyOperationKind, + uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), + preValue: preValueSelect ? preValue : undefined, + })) + ); + + // proceed with the update + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`updateMany\` in tx for ${this.model}: ${formatObject(args)}`); + } + const result = await tx[this.model].updateMany(args); + + // run post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + + return result; + }); }); } @@ -1328,14 +1380,53 @@ export class PolicyProxyHandler implements Pr this.policyUtils.tryReject(this.prisma, this.model, 'delete'); // inject policy conditions - args = args ?? {}; + args = clone(args); this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); - // conduct the deletion - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'delete'); + if (additionalChecker) { + // additional checker exists, need to run deletion inside a transaction + return this.queryUtils.transaction(this.prisma, async (tx) => { + // find the delete candidates, selecting id fields and fields needed for + // running the additional checker + let candidateSelect = this.policyUtils.makeIdSelection(this.model); + const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector( + this.model, + 'delete' + ); + if (additionalCheckerSelector) { + candidateSelect = deepmerge(candidateSelect, additionalCheckerSelector); + } + + if (this.shouldLogQuery) { + this.logger.info( + `[policy] \`findMany\` ${this.model}: ${formatObject({ + where: args.where, + select: candidateSelect, + })}` + ); + } + const candidates = await tx[this.model].findMany({ where: args.where, select: candidateSelect }); + + // build a ID filter based on id values filtered by the additional checker + const { idFilter } = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker); + + // merge the ID filter into the where clause + args.where = args.where ? { AND: [args.where, idFilter] } : idFilter; + + // finally, conduct the deletion with the combined where clause + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`deleteMany\` in tx for ${this.model}:\n${formatObject(args)}`); + } + return tx[this.model].deleteMany(args); + }); + } else { + // conduct the deletion directly + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.deleteMany(args); } - return this.modelClient.deleteMany(args); }); } @@ -1599,5 +1690,17 @@ export class PolicyProxyHandler implements Pr } } + private buildIdFilterWithAdditionalChecker(candidates: any[], additionalChecker: AdditionalCheckerFunc) { + const filteredCandidates = candidates.filter((value) => additionalChecker({ user: this.context?.user }, value)); + const idFields = this.policyUtils.getIdFields(this.model); + let idFilter: any; + if (idFields.length === 1) { + idFilter = { [idFields[0].name]: { in: filteredCandidates.map((x) => x[idFields[0].name]) } }; + } else { + idFilter = { AND: filteredCandidates.map((x) => this.policyUtils.getIdFieldValues(this.model, x)) }; + } + return { filteredCandidates, idFilter }; + } + //#endregion } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 02bf87ebf..70daa5724 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -726,17 +726,27 @@ export class PolicyUtil extends QueryUtils { // Zod schema is to be checked for "create" and "postUpdate" const schema = ['create', 'postUpdate'].includes(operation) ? this.getZodSchema(model) : undefined; - if (this.isTrue(guard) && !schema) { + const additionalChecker = this.getAdditionalChecker(model, operation); + + if (this.isTrue(guard) && !schema && !additionalChecker) { // unconditionally allowed return; } - const select = schema + const additionalCheckerSelector = this.getAdditionalCheckerSelector(model, operation); + let select = schema ? // need to validate against schema, need to fetch all fields undefined : // only fetch id fields this.makeIdSelection(model); + if (additionalCheckerSelector) { + if (!select) { + select = this.makeAllScalarFieldSelect(model); + } + select = { ...select, ...additionalCheckerSelector }; + } + let where = this.clone(uniqueFilter); // query args may have be of combined-id form, need to flatten it to call findFirst this.flattenGeneratedUniqueField(model, where); @@ -758,6 +768,20 @@ export class PolicyUtil extends QueryUtils { ); } + if (additionalChecker) { + if (this.logger.enabled('info')) { + this.logger.info(`[policy] running additional checker on ${model} for ${operation}`); + } + if (!additionalChecker({ user: this.user, preValue }, result)) { + throw this.deniedByPolicy( + model, + operation, + `entity ${formatObject(uniqueFilter, false)} failed policy check`, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); + } + } + if (schema) { // TODO: push down schema check to the database const parseResult = schema.safeParse(result); @@ -777,6 +801,16 @@ export class PolicyUtil extends QueryUtils { } } + getAdditionalCheckerSelector(model: string, operation: PolicyOperationKind) { + const def = this.getModelPolicyDef(model); + return def.modelLevel[operation].additionalCheckerSelector; + } + + getAdditionalChecker(model: string, operation: PolicyOperationKind) { + const def = this.getModelPolicyDef(model); + return def.modelLevel[operation].additionalChecker; + } + private getFieldReadGuards(db: CrudContract, model: string, args: { select?: any; include?: any }) { const allFields = Object.values(getFields(this.modelMeta, model)); @@ -934,8 +968,8 @@ export class PolicyUtil extends QueryUtils { } /** - * Injects field selection needed for checking field-level read policy into query args. - * @returns + * Injects field selection needed for checking field-level read policy check and evaluating + * additional checker into query args. */ injectReadCheckSelect(model: string, args: any) { // we need to recurse into relation fields before injecting the current level, because @@ -957,6 +991,11 @@ export class PolicyUtil extends QueryUtils { this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); } } + + const additionalCheckerSelector = this.getAdditionalCheckerSelector(model, 'read'); + if (additionalCheckerSelector) { + this.doInjectReadCheckSelect(model, args, { select: additionalCheckerSelector }); + } } private doInjectReadCheckSelect(model: string, args: any, input: any) { @@ -1119,7 +1158,7 @@ export class PolicyUtil extends QueryUtils { // preserve the original data as it may be needed for checking field-level readability, // while the "data" will be manipulated during traversal (deleting unreadable fields) const origData = this.clone(data); - this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); + return this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); } private doPostProcessForRead( @@ -1131,12 +1170,46 @@ export class PolicyUtil extends QueryUtils { path = '' ) { if (data === null || data === undefined) { - return; + return data; } - for (const [entityData, entityFullData] of zip(data, fullData)) { + let filteredData = data; + let filteredFullData = fullData; + + const additionalChecker = this.getAdditionalChecker(model, 'read'); + if (additionalChecker) { + if (Array.isArray(data)) { + filteredData = []; + filteredFullData = []; + for (const [entityData, entityFullData] of zip(data, fullData)) { + if (!additionalChecker({ user: this.user }, entityData)) { + if (this.shouldLogQuery) { + this.logger.info( + `[policy] dropping ${model} entity${ + path ? ' at ' + path : '' + } due to additional checker` + ); + } + } else { + filteredData.push(entityData); + filteredFullData.push(entityFullData); + } + } + } else { + if (!additionalChecker({ user: this.user }, data)) { + if (this.shouldLogQuery) { + this.logger.info( + `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to additional checker` + ); + } + return null; + } + } + } + + for (const [entityData, entityFullData] of zip(filteredData, filteredFullData)) { if (typeof entityData !== 'object' || !entityData) { - return; + continue; } for (const [field, fieldData] of Object.entries(entityData)) { @@ -1192,7 +1265,7 @@ export class PolicyUtil extends QueryUtils { if (fieldInfo.isDataModel) { // recurse into nested fields const nextArgs = (queryArgs?.select ?? queryArgs?.include)?.[field]; - this.doPostProcessForRead( + const nestedResult = this.doPostProcessForRead( fieldData, fieldInfo.type, entityFullData[field], @@ -1200,9 +1273,16 @@ export class PolicyUtil extends QueryUtils { this.hasFieldLevelPolicy(fieldInfo.type), path ? path + '.' + field : field ); + if (nestedResult === undefined) { + delete entityData[field]; + } else { + entityData[field] = nestedResult; + } } } } + + return filteredData; } /** diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index aa14555b8..cdc37d305 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -85,6 +85,11 @@ export type InputCheckFunc = (args: any, context: QueryContext) => boolean; */ export type ReadFieldCheckFunc = (input: any, context: QueryContext) => boolean; +/** + * Additional checker function for checking polices outside of Prisma + */ +export type AdditionalCheckerFunc = (input: any, context: QueryContext) => boolean; + /** * Policy definition */ @@ -137,6 +142,16 @@ type ModelCrudCommon = { */ guard: PolicyFunc | boolean; + /** + * Additional checker function for checking policies outside of Prisma + */ + additionalChecker?: AdditionalCheckerFunc; + + /** + * Field selections for evaluating `additionalChecker` + */ + additionalCheckerSelector?: object; + /** * Permission checker function or a constant condition */ @@ -172,8 +187,7 @@ type ModelDeleteDef = ModelCrudCommon; /** * Policy definition for post-update checking a model */ -type ModelPostUpdateDef = { - guard: PolicyFunc | boolean; +type ModelPostUpdateDef = Exclude & { preUpdateSelector?: object; }; diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index d65e304dc..f9aceaa6a 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -1,11 +1,14 @@ import { AstNode, BinaryExpr, + DataModelAttribute, + DataModelFieldAttribute, Expression, ExpressionType, isDataModel, isDataModelAttribute, isDataModelField, + isDataModelFieldAttribute, isEnum, isLiteralExpr, isMemberAccessExpr, @@ -13,7 +16,12 @@ import { isReferenceExpr, isThisExpr, } from '@zenstackhq/language/ast'; -import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; +import { + getAttributeArgLiteral, + isAuthInvocation, + isDataModelFieldReference, + isEnumFieldReference, +} from '@zenstackhq/sdk'; import { ValidationAcceptor, streamAst } from 'langium'; import { findUpAst, getContainingDataModel } from '../../utils/ast-utils'; import { AstValidator } from '../types'; @@ -151,6 +159,7 @@ export default class ExpressionValidator implements AstValidator { accept('error', 'incompatible operand types', { node: expr }); break; } + // not supported: // - foo.a == bar // - foo.user.id == userId @@ -169,10 +178,26 @@ export default class ExpressionValidator implements AstValidator { // foo.user.id == null // foo.user.id == EnumValue if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { - accept('error', 'comparison between fields of different models are not supported', { - node: expr, - }); - break; + const containingPolicyAttr = findUpAst( + expr, + (node) => + (isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText)) || + (isDataModelFieldAttribute(node) && ['@allow', '@deny'].includes(node.decl.$refText)) + ) as DataModelAttribute | DataModelFieldAttribute | undefined; + + if (containingPolicyAttr) { + const operation = getAttributeArgLiteral(containingPolicyAttr, 'operation'); + if (operation?.split(',').includes('all') || operation?.split(',').includes('all')) { + accept( + 'error', + 'comparison between fields of different models is not supported in "read" rules', + { + node: expr, + } + ); + break; + } + } } } @@ -246,16 +271,6 @@ export default class ExpressionValidator implements AstValidator { accept('error', 'collection predicate can only be used on an array of model type', { node: expr }); return; } - - // TODO: revisit this when we implement lambda inside collection predicate - const thisExpr = streamAst(expr).find(isThisExpr); - if (thisExpr) { - accept( - 'error', - 'using `this` in collection predicate is not supported. To compare entity identity, use id field comparison instead.', - { node: thisExpr } - ); - } } private isInValidationContext(node: AstNode) { diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 619543e44..7a9648543 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -34,6 +34,7 @@ import { generateNormalizedAuthRef, generateQueryGuardFunction, generateSelectForRules, + generateTypeScriptCheckerFunction, getPolicyExpressions, isEnumReferenced, } from './utils'; @@ -171,15 +172,16 @@ export class PolicyGenerator { // writes `inputChecker: [funcName]` for a given model private writeCreateInputChecker(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { - const allows = getPolicyExpressions(model, 'allow', 'create'); - const denies = getPolicyExpressions(model, 'deny', 'create'); - if (this.canCheckCreateBasedOnInput(model, allows, denies)) { - const inputCheckFunc = this.generateCreateInputCheckerFunction(model, allows, denies, sourceFile); + if (this.canCheckCreateBasedOnInput(model)) { + const inputCheckFunc = this.generateCreateInputCheckerFunction(model, sourceFile); writer.write(`inputChecker: ${inputCheckFunc.getName()!},`); } } - private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { + private canCheckCreateBasedOnInput(model: DataModel) { + const allows = getPolicyExpressions(model, 'allow', 'create', false, 'all'); + const denies = getPolicyExpressions(model, 'deny', 'create', false, 'all'); + return [...allows, ...denies].every((rule) => { return streamAst(rule).every((expr) => { if (isThisExpr(expr)) { @@ -216,13 +218,10 @@ export class PolicyGenerator { } // generates a function for checking "create" input - private generateCreateInputCheckerFunction( - model: DataModel, - allows: Expression[], - denies: Expression[], - sourceFile: SourceFile - ) { + private generateCreateInputCheckerFunction(model: DataModel, sourceFile: SourceFile) { const statements: (string | WriterFunction)[] = []; + const allows = getPolicyExpressions(model, 'allow', 'create'); + const denies = getPolicyExpressions(model, 'deny', 'create'); generateNormalizedAuthRef(model, allows, denies, statements); @@ -348,6 +347,30 @@ export class PolicyGenerator { if (kind !== 'postUpdate') { this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile); } + + this.writeAdditionalChecker(model, kind, writer, sourceFile); + } + + private writeAdditionalChecker( + model: DataModel, + kind: PolicyOperationKind, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + const allows = getPolicyExpressions(model, 'allow', kind, false, 'onlyCrossModelComparison'); + const denies = getPolicyExpressions(model, 'deny', kind, false, 'onlyCrossModelComparison'); + + if (allows.length === 0 && denies.length === 0) { + return; + } + + const additionalFunc = generateTypeScriptCheckerFunction(sourceFile, model, kind, allows, denies); + writer.write(`additionalChecker: ${additionalFunc.getName()!},`); + + const additionalSelector = generateSelectForRules([...allows, ...denies], false, kind !== 'postUpdate'); + if (additionalSelector) { + writer.write(`additionalCheckerSelector: ${JSON.stringify(additionalSelector)},`); + } } // writes `guard: ...` for a given policy operation kind @@ -413,11 +436,10 @@ export class PolicyGenerator { // post-update counterpart if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { writer.write(`permissionChecker: false,`); - return; } else { writer.write(`permissionChecker: true,`); - return; } + return; } const guardFunc = this.generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile); diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts index c8b75ffd8..0bba9a763 100644 --- a/packages/schema/src/plugins/enhancer/policy/utils.ts +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -11,6 +11,7 @@ import { getIdFields, getLiteral, isAuthInvocation, + isDataModelFieldReference, isEnumFieldReference, isFromStdlib, isFutureExpr, @@ -19,6 +20,7 @@ import { import { Enum, Model, + isBinaryExpr, isDataModel, isDataModelField, isExpression, @@ -30,10 +32,10 @@ import { type DataModelField, type Expression, } from '@zenstackhq/sdk/ast'; -import { streamAllContents, streamAst, streamContents } from 'langium'; +import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium'; import { SourceFile, WriterFunction } from 'ts-morph'; import { name } from '..'; -import { isCollectionPredicate } from '../../../utils/ast-utils'; +import { isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; /** @@ -43,7 +45,8 @@ export function getPolicyExpressions( target: DataModel | DataModelField, kind: PolicyKind, operation: PolicyOperationKind, - override = false + override = false, + filter: 'all' | 'withoutCrossModelComparison' | 'onlyCrossModelComparison' = 'all' ) { const attributes = target.attributes; const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; @@ -73,6 +76,12 @@ export function getPolicyExpressions( }) .map((attr) => attr.args[1].value); + if (filter === 'onlyCrossModelComparison') { + result = result.filter((expr) => hasCrossModelComparison(expr)); + } else if (filter === 'withoutCrossModelComparison') { + result = result.filter((expr) => !hasCrossModelComparison(expr)); + } + if (operation === 'update') { result = processUpdatePolicies(result, false); } else if (operation === 'postUpdate') { @@ -108,9 +117,18 @@ function processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { * Generates a "select" object that contains (recursively) fields referenced by the * given policy rules */ -export function generateSelectForRules(rules: Expression[], forAuthContext = false): object { +export function generateSelectForRules( + rules: Expression[], + forAuthContext = false, + ignoreFutureReference = true +): object { const result: any = {}; const addPath = (path: string[]) => { + const thisIndex = path.lastIndexOf('$this'); + if (thisIndex >= 0) { + // drop everything before $this + path = path.slice(thisIndex + 1); + } let curr = result; path.forEach((seg, i) => { if (i === path.length - 1) { @@ -128,6 +146,10 @@ export function generateSelectForRules(rules: Expression[], forAuthContext = fal // selection path const visit = (node: Expression): string[] | undefined => { if (isThisExpr(node)) { + return ['$this']; + } + + if (isFutureExpr(node)) { return []; } @@ -144,7 +166,7 @@ export function generateSelectForRules(rules: Expression[], forAuthContext = fal return [node.member.$refText]; } - if (isFutureExpr(node.operand)) { + if (isFutureExpr(node.operand) && ignoreFutureReference) { // future().field is not subject to pre-update select return undefined; } @@ -225,9 +247,12 @@ export function generateQueryGuardFunction( ) { const statements: (string | WriterFunction)[] = []; - generateNormalizedAuthRef(model, allows, denies, statements); + const filteredAllows = allows.filter((rule) => !hasCrossModelComparison(rule)); + const filteredDenies = denies.filter((rule) => !hasCrossModelComparison(rule)); - const hasFieldAccess = [...denies, ...allows].some((rule) => + generateNormalizedAuthRef(model, filteredAllows, filteredDenies, statements); + + const hasFieldAccess = [...filteredDenies, ...filteredAllows].some((rule) => streamAst(rule).some( (child) => // this.??? @@ -248,10 +273,10 @@ export function generateQueryGuardFunction( isPostGuard: kind === 'postUpdate', }); try { - denies.forEach((rule) => { + filteredDenies.forEach((rule) => { writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); }); - allows.forEach((rule) => { + filteredAllows.forEach((rule) => { writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); }); } catch (err) { @@ -267,12 +292,22 @@ export function generateQueryGuardFunction( // if there's no allow rule, for field-level rules, by default we allow writer.write(`return ${TRUE};`); } else { - // if there's any allow rule, we deny unless any allow rule evaluates to true - writer.write(`return ${FALSE};`); + if (filteredAllows.length < allows.length) { + writer.write(`return ${TRUE};`); + } else { + // if there's any allow rule, we deny unless any allow rule evaluates to true + writer.write(`return ${FALSE};`); + } } } else { - // for model-level rules, the default is always deny - writer.write(`return ${FALSE};`); + if (filteredAllows.length < allows.length) { + // some rules are filtered out here and will be generated as additional + // checker functions, so we allow here to avoid a premature denial + writer.write(`return ${TRUE};`); + } else { + // for model-level rules, the default is always deny unless for 'postUpdate' + writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); + } } }); } else { @@ -280,42 +315,42 @@ export function generateQueryGuardFunction( writer.write('return '); const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); const writeDenies = () => { - writer.conditionalWrite(denies.length > 1, '{ AND: ['); - denies.forEach((expr, i) => { + writer.conditionalWrite(filteredDenies.length > 1, '{ AND: ['); + filteredDenies.forEach((expr, i) => { writer.inlineBlock(() => { writer.write('NOT: '); exprWriter.write(expr); }); - writer.conditionalWrite(i !== denies.length - 1, ','); + writer.conditionalWrite(i !== filteredDenies.length - 1, ','); }); - writer.conditionalWrite(denies.length > 1, ']}'); + writer.conditionalWrite(filteredDenies.length > 1, ']}'); }; const writeAllows = () => { - writer.conditionalWrite(allows.length > 1, '{ OR: ['); - allows.forEach((expr, i) => { + writer.conditionalWrite(filteredAllows.length > 1, '{ OR: ['); + filteredAllows.forEach((expr, i) => { exprWriter.write(expr); - writer.conditionalWrite(i !== allows.length - 1, ','); + writer.conditionalWrite(i !== filteredAllows.length - 1, ','); }); - writer.conditionalWrite(allows.length > 1, ']}'); + writer.conditionalWrite(filteredAllows.length > 1, ']}'); }; - if (allows.length > 0 && denies.length > 0) { + if (filteredAllows.length > 0 && filteredDenies.length > 0) { // include both allow and deny rules writer.write('{ AND: ['); writeDenies(); writer.write(','); writeAllows(); writer.write(']}'); - } else if (denies.length > 0) { + } else if (filteredDenies.length > 0) { // only deny rules writeDenies(); - } else if (allows.length > 0) { + } else if (filteredAllows.length > 0) { // only allow rules writeAllows(); } else { - // disallow any operation - writer.write(`{ OR: [] }`); + // disallow any operation unless for 'postUpdate' + writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); } writer.write(';'); }); @@ -341,6 +376,59 @@ export function generateQueryGuardFunction( return func; } +export function generateTypeScriptCheckerFunction( + sourceFile: SourceFile, + model: DataModel, + kind: PolicyOperationKind, + allows: Expression[], + denies: Expression[], + forField?: DataModelField, + fieldOverride = false +) { + const statements: (string | WriterFunction)[] = []; + + generateNormalizedAuthRef(model, allows, denies, statements); + + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + thisExprContext: 'input', + fieldReferenceContext: 'input', + isPostGuard: kind === 'postUpdate', + futureRefContext: 'input', + }); + + denies.forEach((rule) => { + const compiled = transformer.transform(rule); + statements.push(`if (${compiled}) { return false; }`); + }); + + allows.forEach((rule) => { + const compiled = transformer.transform(rule); + statements.push(`if (${compiled}) { return true; }`); + }); + + // default: deny unless for 'postUpdate' + statements.push(kind === 'postUpdate' ? 'return true;' : 'return false;'); + + const func = sourceFile.addFunction({ + name: `$check_${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, + returnType: 'any', + parameters: [ + { + name: 'context', + type: 'QueryContext', + }, + { + name: 'input', + type: 'any', + }, + ], + statements, + }); + + return func; +} + /** * Generates a normalized auth reference for the given policy rules */ @@ -384,3 +472,44 @@ export function isEnumReferenced(model: Model, decl: Enum): unknown { return false; }); } + +function hasCrossModelComparison(expr: Expression) { + return streamAst(expr).some((node) => { + if (isBinaryExpr(node) && ['==', '!=', '>', '<', '>=', '<=', 'in'].includes(node.operator)) { + const leftRoot = getSourceModelOfFieldAccess(node.left); + const rightRoot = getSourceModelOfFieldAccess(node.right); + if (leftRoot && rightRoot && leftRoot !== rightRoot) { + return true; + } + } + return false; + }); +} + +function getSourceModelOfFieldAccess(expr: Expression) { + if (isDataModel(expr.$resolvedType?.decl)) { + return expr.$resolvedType?.decl; + } + + // `this` reference + if (isThisExpr(expr)) { + return getContainerOfType(expr, isDataModel); + } + + // `future()` + if (isFutureInvocation(expr)) { + return getContainerOfType(expr, isDataModel); + } + + // direct field reference + if (isDataModelFieldReference(expr)) { + return (expr.target.ref as DataModelField).$container; + } + + // member access + if (isMemberAccessExpr(expr)) { + return getSourceModelOfFieldAccess(expr.operand); + } + + return undefined; +} diff --git a/packages/sdk/src/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts index 8e33eb4a7..28ce1d345 100644 --- a/packages/sdk/src/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -34,6 +34,7 @@ type Options = { isPostGuard?: boolean; fieldReferenceContext?: string; thisExprContext?: string; + futureRefContext?: string; context: ExpressionContext; }; @@ -116,7 +117,9 @@ export class TypeScriptExpressionTransformer { if (this.options?.isPostGuard !== true) { throw new TypeScriptExpressionTransformerError(`future() is only supported in postUpdate rules`); } - return expr.member.ref.name; + return this.options.futureRefContext + ? `${this.options.futureRefContext}.${expr.member.ref.name}` + : expr.member.ref.name; } else { if (normalizeUndefined) { // normalize field access to null instead of undefined to avoid accidentally use undefined in filter @@ -449,7 +452,6 @@ export class TypeScriptExpressionTransformer { ...this.options, isPostGuard: false, fieldReferenceContext: '_item', - thisExprContext: '_item', }); const predicate = innerTransformer.transform(expr.right, normalizeUndefined); diff --git a/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts new file mode 100644 index 000000000..19886bb5b --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts @@ -0,0 +1,672 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Cross-model field comparison', () => { + it('to-one relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int + + @@allow('read', true) + @@allow('create,update,delete', age == profile.age) + @@deny('update', future().age < future().profile.age && age > 0) + } + + model Profile { + id Int @id @default(autoincrement()) + age Int + user User? + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.user.deleteMany(); + await prisma.profile.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // createMany + await expect( + db.user.createMany({ data: [{ id: 1, age: 18, profile: { create: { id: 1, age: 20 } } }] }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.createMany({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + // await expect(db.user.findMany()).resolves.toHaveLength(1); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveNull(); + // await expect(db.user.findMany()).resolves.toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 20 }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 15 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 18, profile: { create: { age: 25 } } }, + update: { age: 25 }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 25 }); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 25, profile: { create: { age: 25 } } }, + update: { age: 25 }, + }) + ).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(2); + await reset(); + + // updateMany + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + // non updatable + await expect(db.user.updateMany({ data: { age: 18 } })).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } } }); + // one of the two is updatable + await expect(db.user.updateMany({ data: { age: 30 } })).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 18 }); + await expect(prisma.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ age: 30 }); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.delete({ where: { id: 1 } })).toBeRejectedByPolicy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // deleteMany + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } } }); + // one of the two is deletable + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + }); + + it('nested inside to-one relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile? + age Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@allow('read', true) + @@allow('create,update,delete', user == null || age == user.age) + @@deny('update', future().user != null && future().age < future().user.age && age > 0) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.profile.deleteMany(); + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 }, include: { profile: true } })).resolves.toMatchObject({ + // age: 18, + // profile: expect.objectContaining({ age: 18 }), + // }); + // await expect(db.user.findMany({ include: { profile: true } })).resolves.toEqual( + // expect.arrayContaining([ + // expect.objectContaining({ + // age: 18, + // profile: expect.objectContaining({ age: 18 }), + // }), + // ]) + // ); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // let r = await db.user.findUnique({ where: { id: 1 }, include: { profile: true } }); + // expect(r.profile).toBeUndefined(); + // r = await db.user.findMany({ include: { profile: true } }); + // expect(r[0].profile).toBeUndefined(); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 20 } } } }) + ).toResolveTruthy(); + let r = await prisma.user.findUnique({ where: { id: 1 }, include: { profile: true } }); + expect(r.profile).toMatchObject({ age: 20 }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 18 } } } }) + ).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 15 } } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 20 } } } }) + ).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profile: { + upsert: { + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profile: { + upsert: { + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 18 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profile: { + upsert: { + create: { id: 2, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profile: { + upsert: { + create: { id: 2, age: 18 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { delete: true } } })).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { delete: true } } })).toResolveTruthy(); + await expect(await prisma.profile.findMany()).toHaveLength(0); + await reset(); + + // connect/disconnect + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { disconnect: true } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { disconnect: true } } })).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 25 } }); + await expect( + db.user.update({ where: { id: 2 }, data: { profile: { connect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.create({ data: { id: 3, age: 20 } }); + await expect(db.user.update({ where: { id: 3 }, data: { profile: { connect: { id: 1 } } } })).toResolveTruthy(); + await expect(prisma.profile.findFirst()).resolves.toMatchObject({ userId: 3 }); + await reset(); + }); + + it('to-many relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profiles Profile[] + age Int + + @@allow('read', true) + @@allow('create,update,delete', profiles![this.age == age]) + @@deny('update', future().profiles?[this.age < age]) + } + + model Profile { + id Int @id + age Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId Int + + @@allow('all', true) + } + `, + { preserveTsFiles: true } + ); + + const db = enhance(); + + const reset = async () => { + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: [{ id: 1, age: 20 }] } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profiles: { + createMany: { + data: [ + { id: 1, age: 18 }, + { id: 2, age: 20 }, + ], + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: [{ id: 1, age: 20 }] } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profiles: { + createMany: { + data: [ + { id: 1, age: 18 }, + { id: 2, age: 18 }, + ], + }, + }, + }, + }) + ).toResolveTruthy(); + await expect( + db.user.create({ + data: { id: 2, age: 18 }, + }) + ).toResolveTruthy(); + await reset(); + + // createMany + await expect( + db.user.createMany({ + data: [ + { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } }, + { id: 2, age: 18, profiles: { create: { id: 2, age: 20 } } }, + ], + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.createMany({ + data: [ + { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } }, + { id: 2, age: 19, profiles: { create: { id: 2, age: 19 } } }, + ], + }) + ).resolves.toEqual({ count: 2 }); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + // await expect(db.user.findMany()).resolves.toHaveLength(1); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveNull(); + // await expect(db.user.findMany()).resolves.toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 20 }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 15 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 18, profiles: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 25 }); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(2); + await reset(); + + // updateMany + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + // non updatable + await expect(db.user.updateMany({ data: { age: 18 } })).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } } }); + // one of the two is updatable + await expect(db.user.updateMany({ data: { age: 30 } })).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 18 }); + await expect(prisma.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ age: 30 }); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect(db.user.delete({ where: { id: 1 } })).toBeRejectedByPolicy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // deleteMany + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } } }); + // one of the two is deletable + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + }); + + it('nested inside to-many relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profiles Profile[] + age Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@allow('read', true) + @@allow('create,update,delete', user == null || age == user.age) + @@deny('update', future().user != null && future().age < future().user.age && age > 0) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.profile.deleteMany(); + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 }, include: { profiles: true } })).resolves.toMatchObject({ + // age: 18, + // profiles: [expect.objectContaining({ age: 18 })], + // }); + // await expect(db.user.findMany({ include: { profiles: true } })).resolves.toEqual( + // expect.arrayContaining([ + // expect.objectContaining({ + // age: 18, + // profiles: [expect.objectContaining({ age: 18 })], + // }), + // ]) + // ); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // let r = await db.user.findUnique({ where: { id: 1 }, include: { profiles: true } }); + // expect(r.profiles).toHaveLength(0); + // r = await db.user.findMany({ include: { profiles: true } }); + // expect(r[0].profiles).toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 20 } } } }, + }) + ).toResolveTruthy(); + let r = await prisma.user.findUnique({ where: { id: 1 }, include: { profiles: true } }); + expect(r.profiles[0]).toMatchObject({ age: 20 }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 18 } } } }, + }) + ).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 15 } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 20 } } } }, + }) + ).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profiles: { + upsert: { + where: { id: 1 }, + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profiles: { + upsert: { + where: { id: 1 }, + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 18 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profiles: { + upsert: { + where: { id: 2 }, + create: { id: 2, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profiles: { + upsert: { + where: { id: 2 }, + create: { id: 2, age: 18 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { delete: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profiles: { delete: { id: 1 } } } })).toResolveTruthy(); + await expect(await prisma.profile.findMany()).toHaveLength(0); + await reset(); + + // connect/disconnect + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { disconnect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { disconnect: { id: 1 } } } }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 25 } }); + await expect( + db.user.update({ where: { id: 2 }, data: { profiles: { connect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.create({ data: { id: 3, age: 20 } }); + await expect( + db.user.update({ where: { id: 3 }, data: { profiles: { connect: { id: 1 } } } }) + ).toResolveTruthy(); + await expect(prisma.profile.findFirst()).resolves.toMatchObject({ userId: 3 }); + await reset(); + }); + + it('field-level', async () => {}); +}); diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index de778e8e8..0297116a0 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -915,6 +915,10 @@ describe('Policy: field-level policy', () => { data: { models: { connect: { id: 1 } } }, }) ).toBeRejectedByPolicy(); + await prisma.user.update({ + where: { id: 1 }, + data: { models: { connect: { id: 1 } } }, + }); await expect( db.user.update({ where: { id: 1 }, @@ -1015,6 +1019,10 @@ describe('Policy: field-level policy', () => { data: { model: { connect: { id: 1 } } }, }) ).toBeRejectedByPolicy(); + await prisma.user.update({ + where: { id: 1 }, + data: { model: { connect: { id: 1 } } }, + }); await expect( db.user.update({ where: { id: 1 }, From c5ba68c2d18e83e4f622f543daf808c3d7ed6cc0 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 1 Jun 2024 10:09:40 +0800 Subject: [PATCH 2/4] fix tests --- .../validator/expression-validator.ts | 2 +- .../validation/attribute-validation.test.ts | 31 ++++++++++++++++++- .../validation/datamodel-validation.test.ts | 2 +- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index f9aceaa6a..4e2ce1207 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -187,7 +187,7 @@ export default class ExpressionValidator implements AstValidator { if (containingPolicyAttr) { const operation = getAttributeArgLiteral(containingPolicyAttr, 'operation'); - if (operation?.split(',').includes('all') || operation?.split(',').includes('all')) { + if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) { accept( 'error', 'comparison between fields of different models is not supported in "read" rules', diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 380836e21..3956b489e 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -699,7 +699,36 @@ describe('Attribute tests', () => { } `) - ).toContain('comparison between fields of different models are not supported'); + ).toContain('comparison between fields of different models is not supported in "read" rules'); + + expect( + await loadModel(` + ${prelude} + model User { + id Int @id + lists List[] + todos Todo[] + } + + model List { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + todos Todo[] + } + + model Todo { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + list List @relation(fields: [listId], references: [id]) + listId Int + + @@allow('create', list.user.id == userId) + } + + `) + ).toBeTruthy(); expect( await loadModelWithError(` diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index e0778da51..e7dd6bf84 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -88,7 +88,7 @@ describe('Data Model Validation Tests', () => { @@allow('all', members?[this == auth()]) } `) - ).toMatchObject(errorLike('using `this` in collection predicate is not supported')); + ).toBeTruthy(); expect( await loadModel(` From 8b65a517ffd493b74c12359cb74912620abc2528 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 1 Jun 2024 22:36:54 +0800 Subject: [PATCH 3/4] support for field-level policies, fix tests --- .../enhancements/policy/constraint-solver.ts | 10 +- .../src/enhancements/policy/handler.ts | 38 ++- .../src/enhancements/policy/policy-utils.ts | 182 +++++++++---- packages/runtime/src/enhancements/types.ts | 123 +++++---- packages/runtime/src/types.ts | 2 +- .../validator/expression-validator.ts | 10 +- .../enhancer/policy/policy-guard-generator.ts | 251 ++++++++---------- .../src/plugins/enhancer/policy/utils.ts | 66 +++-- .../validation/attribute-validation.test.ts | 2 +- .../cross-model-field-comparison.test.ts | 109 +++++++- 10 files changed, 477 insertions(+), 316 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index c87a528e7..9b792a0fa 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -1,10 +1,10 @@ import Logic from 'logic-solver'; import { match } from 'ts-pattern'; import type { - CheckerConstraint, ComparisonConstraint, ComparisonTerm, LogicalConstraint, + PermissionCheckerConstraint, ValueConstraint, VariableConstraint, } from '../types'; @@ -22,7 +22,7 @@ export class ConstraintSolver { /** * Check the satisfiability of the given constraint. */ - checkSat(constraint: CheckerConstraint): boolean { + checkSat(constraint: PermissionCheckerConstraint): boolean { // reset state this.stringTable = []; this.variables = new Map(); @@ -46,7 +46,7 @@ export class ConstraintSolver { return !!solver.solve(); } - private buildFormula(constraint: CheckerConstraint): Logic.Formula { + private buildFormula(constraint: PermissionCheckerConstraint): Logic.Formula { return match(constraint) .when( (c): c is ValueConstraint => c.kind === 'value', @@ -100,11 +100,11 @@ export class ConstraintSolver { return Logic.not(this.buildFormula(constraint.children[0])); } - private isTrue(constraint: CheckerConstraint): unknown { + private isTrue(constraint: PermissionCheckerConstraint): unknown { return constraint.kind === 'value' && constraint.value === true; } - private isFalse(constraint: CheckerConstraint): unknown { + private isFalse(constraint: PermissionCheckerConstraint): unknown { return constraint.kind === 'value' && constraint.value === false; } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index cb71b2bc1..7ce3a8987 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -24,7 +24,7 @@ import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; -import type { AdditionalCheckerFunc, CheckerConstraint } from '../types'; +import type { EntityCheckerFunc, PermissionCheckerConstraint } from '../types'; import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; import { ConstraintSolver } from './constraint-solver'; import { PolicyUtil } from './policy-utils'; @@ -1182,15 +1182,15 @@ export class PolicyProxyHandler implements Pr args.data = this.validateUpdateInputSchema(this.model, args.data); - const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'update'); + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update'); const canProceedWithoutTransaction = // no post-update rules !this.policyUtils.hasAuthGuard(this.model, 'postUpdate') && // no Zod schema !this.policyUtils.getZodSchema(this.model) && - // no additional checker - !additionalChecker; + // no entity checker + !entityChecker; if (canProceedWithoutTransaction) { // proceed without a transaction @@ -1212,9 +1212,9 @@ export class PolicyProxyHandler implements Pr } // merge selection required for running additional checker - const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector(this.model, 'update'); - if (additionalCheckerSelector) { - select = deepmerge(select, additionalCheckerSelector); + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update'); + if (entityChecker?.selector) { + select = deepmerge(select, entityChecker.selector); } const currentSetQuery = { select, where: args.where }; @@ -1225,9 +1225,9 @@ export class PolicyProxyHandler implements Pr } let candidates = await tx[this.model].findMany(currentSetQuery); - if (additionalChecker) { + if (entityChecker) { // filter candidates with additional checker and build an id filter - const r = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker); + const r = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func); candidates = r.filteredCandidates; // merge id filter into update's where clause @@ -1383,19 +1383,15 @@ export class PolicyProxyHandler implements Pr args = clone(args); this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); - const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'delete'); - if (additionalChecker) { + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'delete'); + if (entityChecker) { // additional checker exists, need to run deletion inside a transaction return this.queryUtils.transaction(this.prisma, async (tx) => { // find the delete candidates, selecting id fields and fields needed for // running the additional checker let candidateSelect = this.policyUtils.makeIdSelection(this.model); - const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector( - this.model, - 'delete' - ); - if (additionalCheckerSelector) { - candidateSelect = deepmerge(candidateSelect, additionalCheckerSelector); + if (entityChecker.selector) { + candidateSelect = deepmerge(candidateSelect, entityChecker.selector); } if (this.shouldLogQuery) { @@ -1409,7 +1405,7 @@ export class PolicyProxyHandler implements Pr const candidates = await tx[this.model].findMany({ where: args.where, select: candidateSelect }); // build a ID filter based on id values filtered by the additional checker - const { idFilter } = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker); + const { idFilter } = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func); // merge the ID filter into the where clause args.where = args.where ? { AND: [args.where, idFilter] } : idFilter; @@ -1560,7 +1556,7 @@ export class PolicyProxyHandler implements Pr if (args.where) { // combine runtime filters with generated constraints - const extraConstraints: CheckerConstraint[] = []; + const extraConstraints: PermissionCheckerConstraint[] = []; for (const [field, value] of Object.entries(args.where)) { if (value === undefined) { continue; @@ -1690,8 +1686,8 @@ export class PolicyProxyHandler implements Pr } } - private buildIdFilterWithAdditionalChecker(candidates: any[], additionalChecker: AdditionalCheckerFunc) { - const filteredCandidates = candidates.filter((value) => additionalChecker({ user: this.context?.user }, value)); + private buildIdFilterWithEntityChecker(candidates: any[], entityChecker: EntityCheckerFunc) { + const filteredCandidates = candidates.filter((value) => entityChecker(value, { user: this.context?.user })); const idFields = this.policyUtils.getIdFields(this.model); let idFilter: any; if (idFields.length === 1) { diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 70daa5724..b76875d28 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -1,18 +1,26 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import deepcopy from 'deepcopy'; +import deepmerge from 'deepmerge'; import { lowerCaseFirst } from 'lower-case-first'; import { upperCaseFirst } from 'upper-case-first'; import { ZodError } from 'zod'; import { fromZodError } from 'zod-validation-error'; import { CrudFailureReason, PrismaErrorCode } from '../../constants'; import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; -import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types'; +import { + AuthUser, + CrudContract, + DbClientContract, + PolicyCrudKind, + PolicyOperationKind, + QueryContext, +} from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; -import type { CheckerFunc, ModelPolicyDef, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; +import type { EntityChecker, ModelPolicyDef, PermissionCheckerFunc, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; import { formatObject, prismaClientKnownRequestError } from '../utils'; /** @@ -272,7 +280,7 @@ export class PolicyUtil extends QueryUtils { */ getFieldOverrideReadAuthGuard(db: CrudContract, model: string, field: string) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.read?.overrideGuard?.[field]; + const guard = def.fieldLevel?.read?.[field]?.overrideGuard; if (guard === undefined) { // field access is denied by default in override mode @@ -292,7 +300,7 @@ export class PolicyUtil extends QueryUtils { */ getFieldUpdateAuthGuard(db: CrudContract, model: string, field: string) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.update?.guard?.[field]; + const guard = def.fieldLevel?.update?.[field]?.guard; if (guard === undefined) { // field access is allowed by default @@ -312,7 +320,7 @@ export class PolicyUtil extends QueryUtils { */ getFieldOverrideUpdateAuthGuard(db: CrudContract, model: string, field: string) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.update?.overrideGuard?.[field]; + const guard = def.fieldLevel?.update?.[field]?.overrideGuard; if (guard === undefined) { // field access is denied by default in override mode @@ -343,8 +351,13 @@ export class PolicyUtil extends QueryUtils { return false; } const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.[operation]?.overrideGuard; - return guard && Object.keys(guard).length > 0; + if (def.fieldLevel?.[operation]) { + return Object.values(def.fieldLevel[operation]).some( + (f) => f.overrideGuard !== undefined || f.overrideEntityChecker !== undefined + ); + } else { + return false; + } } /** @@ -551,7 +564,7 @@ export class PolicyUtil extends QueryUtils { /** * Gets checker constraints for the given model and operation. */ - getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { + getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { if (this.options.kinds && !this.options.kinds.includes('policy')) { // policy enhancement not enabled, return a constant true checker result return true; @@ -697,6 +710,8 @@ export class PolicyUtil extends QueryUtils { ); } + let entityChecker: EntityChecker | undefined; + if (operation === 'update' && args) { // merge field-level policy guards const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); @@ -710,41 +725,45 @@ export class PolicyUtil extends QueryUtils { }"`, CrudFailureReason.ACCESS_POLICY_VIOLATION ); - } else { - if (fieldUpdateGuard.guard) { - // merge field-level guard - guard = this.and(guard, fieldUpdateGuard.guard); - } + } - if (fieldUpdateGuard.overrideGuard) { - // merge field-level override guard - guard = this.or(guard, fieldUpdateGuard.overrideGuard); - } + if (fieldUpdateGuard.guard) { + // merge field-level guard with AND + guard = this.and(guard, fieldUpdateGuard.guard); } + + if (fieldUpdateGuard.overrideGuard) { + // merge field-level override guard with OR + guard = this.or(guard, fieldUpdateGuard.overrideGuard); + } + + // field-level entity checker + entityChecker = fieldUpdateGuard.entityChecker; } // Zod schema is to be checked for "create" and "postUpdate" const schema = ['create', 'postUpdate'].includes(operation) ? this.getZodSchema(model) : undefined; - const additionalChecker = this.getAdditionalChecker(model, operation); + // combine field-level entity checker with model-level + const modelEntityChecker = this.getEntityChecker(model, operation); + entityChecker = this.combineEntityChecker(entityChecker, modelEntityChecker, 'and'); - if (this.isTrue(guard) && !schema && !additionalChecker) { + if (this.isTrue(guard) && !schema && !entityChecker) { // unconditionally allowed return; } - const additionalCheckerSelector = this.getAdditionalCheckerSelector(model, operation); let select = schema ? // need to validate against schema, need to fetch all fields undefined : // only fetch id fields this.makeIdSelection(model); - if (additionalCheckerSelector) { + if (entityChecker?.selector) { if (!select) { select = this.makeAllScalarFieldSelect(model); } - select = { ...select, ...additionalCheckerSelector }; + select = { ...select, ...entityChecker.selector }; } let where = this.clone(uniqueFilter); @@ -768,11 +787,11 @@ export class PolicyUtil extends QueryUtils { ); } - if (additionalChecker) { + if (entityChecker) { if (this.logger.enabled('info')) { - this.logger.info(`[policy] running additional checker on ${model} for ${operation}`); + this.logger.info(`[policy] running entity checker on ${model} for ${operation}`); } - if (!additionalChecker({ user: this.user, preValue }, result)) { + if (!entityChecker.func(result, { user: this.user, preValue })) { throw this.deniedByPolicy( model, operation, @@ -801,14 +820,18 @@ export class PolicyUtil extends QueryUtils { } } - getAdditionalCheckerSelector(model: string, operation: PolicyOperationKind) { + getEntityChecker(model: string, operation: PolicyOperationKind, field?: string) { const def = this.getModelPolicyDef(model); - return def.modelLevel[operation].additionalCheckerSelector; + if (field) { + return def.fieldLevel?.[operation as 'read' | 'update']?.[field]?.entityChecker; + } else { + return def.modelLevel[operation].entityChecker; + } } - getAdditionalChecker(model: string, operation: PolicyOperationKind) { + getUpdateOverrideEntityCheckerForField(model: string, field: string) { const def = this.getModelPolicyDef(model); - return def.modelLevel[operation].additionalChecker; + return def.fieldLevel?.update?.[field]?.overrideEntityChecker; } private getFieldReadGuards(db: CrudContract, model: string, args: { select?: any; include?: any }) { @@ -837,19 +860,20 @@ export class PolicyUtil extends QueryUtils { private getFieldUpdateGuards(db: CrudContract, model: string, args: any) { const allFieldGuards = []; const allOverrideFieldGuards = []; + let entityChecker: EntityChecker | undefined; - for (const [k, v] of Object.entries(args.data ?? args)) { - if (typeof v === 'undefined') { + for (const [field, value] of Object.entries(args.data ?? args)) { + if (typeof value === 'undefined') { continue; } - const field = resolveField(this.modelMeta, model, k); + const fieldInfo = resolveField(this.modelMeta, model, field); - if (field?.isDataModel) { + if (fieldInfo?.isDataModel) { // relation field update should be treated as foreign key update, // fetch and merge all foreign key guards - if (field.isRelationOwner && field.foreignKeyMapping) { - const foreignKeys = Object.values(field.foreignKeyMapping); + if (fieldInfo.isRelationOwner && fieldInfo.foreignKeyMapping) { + const foreignKeys = Object.values(fieldInfo.foreignKeyMapping); for (const fk of foreignKeys) { const fieldGuard = this.getFieldUpdateAuthGuard(db, model, fk); if (this.isFalse(fieldGuard)) { @@ -865,18 +889,26 @@ export class PolicyUtil extends QueryUtils { } } } else { - const fieldGuard = this.getFieldUpdateAuthGuard(db, model, k); + const fieldGuard = this.getFieldUpdateAuthGuard(db, model, field); if (this.isFalse(fieldGuard)) { - return { guard: fieldGuard, rejectedByField: k }; + return { guard: fieldGuard, rejectedByField: field }; } // add field guard allFieldGuards.push(fieldGuard); // add field override guard - const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, k); + const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, field); allOverrideFieldGuards.push(overrideFieldGuard); } + + // merge regular and override entity checkers with OR + let checker = this.getEntityChecker(model, 'update', field); + const overrideChecker = this.getUpdateOverrideEntityCheckerForField(model, field); + checker = this.combineEntityChecker(checker, overrideChecker, 'or'); + + // accumulate entity checker across fields + entityChecker = this.combineEntityChecker(entityChecker, checker, 'and'); } const allFieldsCombined = this.and(...allFieldGuards); @@ -887,6 +919,31 @@ export class PolicyUtil extends QueryUtils { guard: allFieldsCombined, overrideGuard: allOverrideFieldsCombined, rejectedByField: undefined, + entityChecker, + }; + } + + private combineEntityChecker( + left: EntityChecker | undefined, + right: EntityChecker | undefined, + combiner: 'and' | 'or' + ): EntityChecker | undefined { + if (!left) { + return right; + } + + if (!right) { + return left; + } + + const func = + combiner === 'and' + ? (entity: any, context: QueryContext) => left.func(entity, context) && right.func(entity, context) + : (entity: any, context: QueryContext) => left.func(entity, context) || right.func(entity, context); + + return { + func, + selector: deepmerge(left.selector ?? {}, right.selector ?? {}), }; } @@ -969,7 +1026,7 @@ export class PolicyUtil extends QueryUtils { /** * Injects field selection needed for checking field-level read policy check and evaluating - * additional checker into query args. + * entity checker into query args. */ injectReadCheckSelect(model: string, args: any) { // we need to recurse into relation fields before injecting the current level, because @@ -992,9 +1049,9 @@ export class PolicyUtil extends QueryUtils { } } - const additionalCheckerSelector = this.getAdditionalCheckerSelector(model, 'read'); - if (additionalCheckerSelector) { - this.doInjectReadCheckSelect(model, args, { select: additionalCheckerSelector }); + const entityChecker = this.getEntityChecker(model, 'read'); + if (entityChecker?.selector) { + this.doInjectReadCheckSelect(model, args, { select: entityChecker.selector }); } } @@ -1113,19 +1170,36 @@ export class PolicyUtil extends QueryUtils { return def.modelLevel.postUpdate.preUpdateSelector; } + // get a merged selector object for all field-level read policies private getFieldReadCheckSelector(model: string) { const def = this.getModelPolicyDef(model); - return def.fieldLevel?.read?.selector; + let result: any = {}; + const fieldLevel = def.fieldLevel?.read; + if (fieldLevel) { + for (const def of Object.values(fieldLevel)) { + if (def.entityChecker?.selector) { + result = deepmerge(result, def.entityChecker.selector); + } + if (def.overrideEntityChecker?.selector) { + result = deepmerge(result, def.overrideEntityChecker.selector); + } + } + } + return Object.keys(result).length > 0 ? result : undefined; } private checkReadField(model: string, field: string, entity: any) { const def = this.getModelPolicyDef(model); - const guard = def.fieldLevel?.read?.checker?.[field]; - if (guard === undefined) { + // combine regular and override field-level entity checkers with OR + const checker = def.fieldLevel?.read?.[field]?.entityChecker; + const overrideChecker = def.fieldLevel?.read?.[field]?.overrideEntityChecker; + const combinedChecker = this.combineEntityChecker(checker, overrideChecker, 'or'); + + if (combinedChecker === undefined) { return true; } else { - return guard(entity, { user: this.user }); + return combinedChecker.func(entity, { user: this.user }); } } @@ -1135,7 +1209,7 @@ export class PolicyUtil extends QueryUtils { private hasFieldLevelPolicy(model: string) { const def = this.getModelPolicyDef(model); - return !!def.fieldLevel?.read?.checker; + return Object.keys(def.fieldLevel?.read ?? {}).length > 0; } /** @@ -1176,18 +1250,16 @@ export class PolicyUtil extends QueryUtils { let filteredData = data; let filteredFullData = fullData; - const additionalChecker = this.getAdditionalChecker(model, 'read'); - if (additionalChecker) { + const entityChecker = this.getEntityChecker(model, 'read'); + if (entityChecker) { if (Array.isArray(data)) { filteredData = []; filteredFullData = []; for (const [entityData, entityFullData] of zip(data, fullData)) { - if (!additionalChecker({ user: this.user }, entityData)) { + if (!entityChecker.func(entityData, { user: this.user })) { if (this.shouldLogQuery) { this.logger.info( - `[policy] dropping ${model} entity${ - path ? ' at ' + path : '' - } due to additional checker` + `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to entity checker` ); } } else { @@ -1196,10 +1268,10 @@ export class PolicyUtil extends QueryUtils { } } } else { - if (!additionalChecker({ user: this.user }, data)) { + if (!entityChecker.func(data, { user: this.user })) { if (this.shouldLogQuery) { this.logger.info( - `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to additional checker` + `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to entity checker` ); } return null; diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index cdc37d305..8aefcd8ed 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; -import type { CheckerContext, CrudContract, QueryContext } from '../types'; +import type { CrudContract, PermissionCheckerContext, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -24,10 +24,15 @@ export interface CommonEnhancementOptions { */ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; +/** + * Function for checking an entity's data for permission + */ +export type EntityCheckerFunc = (input: any, context: QueryContext) => boolean; + /** * Function for checking if an operation is possibly allowed. */ -export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; +export type PermissionCheckerFunc = (context: PermissionCheckerContext) => PermissionCheckerConstraint; /** * Supported checker constraint checking value types. @@ -67,28 +72,17 @@ export type ComparisonConstraint = { */ export type LogicalConstraint = { kind: 'and' | 'or' | 'not'; - children: CheckerConstraint[]; + children: PermissionCheckerConstraint[]; }; /** * Operation allowability checking constraint */ -export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint; - -/** - * Function for getting policy guard with a given context - */ -export type InputCheckFunc = (args: any, context: QueryContext) => boolean; - -/** - * Function for getting policy guard with a given context - */ -export type ReadFieldCheckFunc = (input: any, context: QueryContext) => boolean; - -/** - * Additional checker function for checking polices outside of Prisma - */ -export type AdditionalCheckerFunc = (input: any, context: QueryContext) => boolean; +export type PermissionCheckerConstraint = + | ValueConstraint + | VariableConstraint + | ComparisonConstraint + | LogicalConstraint; /** * Policy definition @@ -133,6 +127,21 @@ export type ModelCrudDef = { postUpdate: ModelPostUpdateDef; }; +/** + * Information for checking entity data outside of Prisma + */ +export type EntityChecker = { + /** + * Checker function + */ + func: EntityCheckerFunc; + + /** + * Selector for fetching entity data + */ + selector?: object; +}; + /** * Common policy definition for a CRUD operation */ @@ -145,17 +154,15 @@ type ModelCrudCommon = { /** * Additional checker function for checking policies outside of Prisma */ - additionalChecker?: AdditionalCheckerFunc; - /** - * Field selections for evaluating `additionalChecker` + * Additional checker function for checking policies outside of Prisma */ - additionalCheckerSelector?: object; + entityChecker?: EntityChecker; /** * Permission checker function or a constant condition */ - permissionChecker?: CheckerFunc | boolean; + permissionChecker?: PermissionCheckerFunc | boolean; }; /** @@ -171,7 +178,7 @@ type ModelCreateDef = ModelCrudCommon & { * Create input validation function. Only generated when a create * can be approved or denied based on input values. */ - inputChecker?: InputCheckFunc | boolean; + inputChecker?: EntityCheckerFunc | boolean; }; /** @@ -198,37 +205,51 @@ type FieldCrudDef = { /** * Field-level read policy */ - read?: { - /** - * Selector for reading fields needed for evaluating the policy - */ - selector?: object; + read: Record; + + /** + * Field-level update policy + */ + update: Record; +}; - /** - * Field-level Prisma query guard - */ - checker?: Record; +type FieldReadDef = { + /** + * Entity checker + */ + entityChecker?: EntityChecker; - /** - * Field-level read override Prisma query guard - */ - overrideGuard?: Record; - }; + /** + * Field-level read override Prisma query guard + */ + overrideGuard?: PolicyFunc; /** - * Field-level update policy + * Entity checker for override policies + */ + overrideEntityChecker?: EntityChecker; +}; + +type FieldUpdateDef = { + /** + * Field-level update Prisma query guard + */ + guard?: PolicyFunc; + + /** + * Additional entity checker + */ + entityChecker?: EntityChecker; + + /** + * Field-level update override Prisma query guard + */ + overrideGuard?: PolicyFunc; + + /** + * Additional entity checker for override policies */ - update?: { - /** - * Field-level update Prisma query guard - */ - guard?: Record; - - /** - * Field-level update override Prisma query guard - */ - overrideGuard?: Record; - }; + overrideEntityChecker?: EntityChecker; }; /** diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 4c32480ba..b9497b7ee 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -62,7 +62,7 @@ export type QueryContext = { /** * Context for checking operation allowability. */ -export type CheckerContext = { +export type PermissionCheckerContext = { /** * Current user */ diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 4e2ce1207..478db5ff7 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -2,13 +2,11 @@ import { AstNode, BinaryExpr, DataModelAttribute, - DataModelFieldAttribute, Expression, ExpressionType, isDataModel, isDataModelAttribute, isDataModelField, - isDataModelFieldAttribute, isEnum, isLiteralExpr, isMemberAccessExpr, @@ -180,17 +178,15 @@ export default class ExpressionValidator implements AstValidator { if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { const containingPolicyAttr = findUpAst( expr, - (node) => - (isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText)) || - (isDataModelFieldAttribute(node) && ['@allow', '@deny'].includes(node.decl.$refText)) - ) as DataModelAttribute | DataModelFieldAttribute | undefined; + (node) => isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText) + ) as DataModelAttribute | undefined; if (containingPolicyAttr) { const operation = getAttributeArgLiteral(containingPolicyAttr, 'operation'); if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) { accept( 'error', - 'comparison between fields of different models is not supported in "read" rules', + 'comparison between fields of different models is not supported in model-level "read" rules', { node: expr, } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 7a9648543..0f65c76c0 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -31,10 +31,10 @@ import path from 'path'; import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { ConstraintTransformer } from './constraint-transformer'; import { + generateEntityCheckerFunction, generateNormalizedAuthRef, generateQueryGuardFunction, generateSelectForRules, - generateTypeScriptCheckerFunction, getPolicyExpressions, isEnumReferenced, } from './utils'; @@ -86,8 +86,8 @@ export class PolicyGenerator { { name: 'type CrudContract' }, { name: 'allFieldsEqual' }, { name: 'type PolicyDef' }, - { name: 'type CheckerContext' }, - { name: 'type CheckerConstraint' }, + { name: 'type PermissionCheckerContext' }, + { name: 'type PermissionCheckerConstraint' }, ], moduleSpecifier: `${RUNTIME_PACKAGE}`, }); @@ -348,29 +348,51 @@ export class PolicyGenerator { this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile); } - this.writeAdditionalChecker(model, kind, writer, sourceFile); + // write cross-model comparison rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(model, kind, writer, sourceFile, true); } - private writeAdditionalChecker( - model: DataModel, + private writeEntityChecker( + target: DataModel | DataModelField, kind: PolicyOperationKind, writer: CodeBlockWriter, - sourceFile: SourceFile + sourceFile: SourceFile, + onlyCrossModelComparison = false, + forOverride = false ) { - const allows = getPolicyExpressions(model, 'allow', kind, false, 'onlyCrossModelComparison'); - const denies = getPolicyExpressions(model, 'deny', kind, false, 'onlyCrossModelComparison'); + const allows = getPolicyExpressions( + target, + 'allow', + kind, + forOverride, + onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all' + ); + const denies = getPolicyExpressions( + target, + 'deny', + kind, + forOverride, + onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all' + ); if (allows.length === 0 && denies.length === 0) { return; } - const additionalFunc = generateTypeScriptCheckerFunction(sourceFile, model, kind, allows, denies); - writer.write(`additionalChecker: ${additionalFunc.getName()!},`); - - const additionalSelector = generateSelectForRules([...allows, ...denies], false, kind !== 'postUpdate'); - if (additionalSelector) { - writer.write(`additionalCheckerSelector: ${JSON.stringify(additionalSelector)},`); - } + const model = isDataModel(target) ? target : (target.$container as DataModel); + const func = generateEntityCheckerFunction( + sourceFile, + model, + kind, + allows, + denies, + isDataModelField(target) ? target : undefined, + forOverride + ); + const selector = generateSelectForRules([...allows, ...denies], false, kind !== 'postUpdate') ?? {}; + const key = forOverride ? 'overrideEntityChecker' : 'entityChecker'; + writer.write(`${key}: { func: ${func.getName()!}, selector: ${JSON.stringify(selector)} },`); } // writes `guard: ...` for a given policy operation kind @@ -465,11 +487,11 @@ export class PolicyGenerator { const func = sourceFile.addFunction({ name: `${model.name}$checker$${kind}`, - returnType: 'CheckerConstraint', + returnType: 'PermissionCheckerConstraint', parameters: [ { name: 'context', - type: 'CheckerContext', + type: 'PermissionCheckerContext', }, ], statements, @@ -492,132 +514,93 @@ export class PolicyGenerator { } private writeFieldReadDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { - const fieldCheckers: Record = {}; - const overrideGuards: Record = {}; - const allFieldsAllows: Expression[] = []; - const allFieldsDenies: Expression[] = []; - - // generate field read checkers - for (const field of model.fields) { - const allows = getPolicyExpressions(field, 'allow', 'read'); - const denies = getPolicyExpressions(field, 'deny', 'read'); - if (denies.length === 0 && allows.length === 0) { - continue; - } + writer.writeLine('read:'); + writer.block(() => { + for (const field of model.fields) { + const policyAttrs = field.attributes.filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText)); - allFieldsAllows.push(...allows); - allFieldsDenies.push(...denies); - - const guardFunc = this.generateFieldReadCheckerFunction(sourceFile, field, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - fieldCheckers[field.name] = guardFunc.getName()!; - - const overrideAllows = getPolicyExpressions(field, 'allow', 'read', true); - if (overrideAllows.length > 0) { - const denies = getPolicyExpressions(field, 'deny', 'read'); - const overrideGuardFunc = generateQueryGuardFunction( - sourceFile, - model, - 'read', - overrideAllows, - denies, - field, - true - ); - overrideGuards[field.name] = overrideGuardFunc.getName()!; - } - } - - if (Object.keys(fieldCheckers).length > 0 || Object.keys(overrideGuards).length > 0) { - writer.write('read:'); - writer.block(() => { - if (Object.keys(fieldCheckers).length > 0) { - writer.write('checker:'); + if (policyAttrs.length === 0) { + continue; + } - // write checkers - writer.inlineBlock(() => { - Object.entries(fieldCheckers).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); + writer.write(`${field.name}:`); - // write field selector - const readFieldCheckSelect = generateSelectForRules([...allFieldsAllows, ...allFieldsDenies]); - if (readFieldCheckSelect) { - writer.write(`selector: ${JSON.stringify(readFieldCheckSelect)},`); + writer.block(() => { + // checker function + // write all field-level rules as entity checker function + this.writeEntityChecker(field, 'read', writer, sourceFile, false, false); + + const overrideAllows = getPolicyExpressions(field, 'allow', 'read', true); + if (overrideAllows.length > 0) { + // override guard function + const denies = getPolicyExpressions(field, 'deny', 'read'); + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'read', + overrideAllows, + denies, + field, + true + ); + writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + + // additional entity checker for override + this.writeEntityChecker(field, 'read', writer, sourceFile, false, true); } - } - - if (Object.keys(overrideGuards).length > 0) { - // write override guards - writer.write('overrideGuard:'); - writer.inlineBlock(() => { - Object.entries(overrideGuards).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); - } - }); - writer.writeLine(','); - } + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); } private writeFieldUpdateDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { - const guards: Record = {}; - const overrideGuards: Record = {}; - - for (const field of model.fields) { - const allows = getPolicyExpressions(field, 'allow', 'update'); - const denies = getPolicyExpressions(field, 'deny', 'update'); + writer.writeLine('update:'); + writer.block(() => { + for (const field of model.fields) { + const allows = getPolicyExpressions(field, 'allow', 'update'); + const denies = getPolicyExpressions(field, 'deny', 'update'); + const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); + + if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) { + continue; + } - if (denies.length === 0 && allows.length === 0) { - continue; - } + writer.write(`${field.name}:`); - const guardFunc = generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); - guards[field.name] = guardFunc.getName()!; - - const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); - if (overrideAllows.length > 0) { - const overrideGuardFunc = generateQueryGuardFunction( - sourceFile, - model, - 'update', - overrideAllows, - denies, - field, - true - ); - overrideGuards[field.name] = overrideGuardFunc.getName()!; + writer.block(() => { + // guard + const guardFunc = generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); + writer.write(`guard: ${guardFunc.getName()},`); + + // write cross-model comparison rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(field, 'update', writer, sourceFile, true, false); + + const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); + if (overrideAllows.length > 0) { + // override guard + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'update', + overrideAllows, + denies, + field, + true + ); + writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + + // write cross-model comparison override rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(field, 'update', writer, sourceFile, true, true); + } + }); + writer.writeLine(','); } - } - - if (Object.keys(guards).length > 0 || Object.keys(overrideGuards).length > 0) { - writer.write('update:'); - writer.block(() => { - if (Object.keys(guards).length > 0) { - writer.write('guard:'); - writer.inlineBlock(() => { - Object.entries(guards).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); - } - - if (Object.keys(overrideGuards).length > 0) { - writer.write('overrideGuard:'); - writer.inlineBlock(() => { - Object.entries(overrideGuards).forEach(([fieldName, funcName]) => { - writer.write(`${fieldName}: ${funcName},`); - }); - }); - writer.writeLine(','); - } - }); - } + }); + writer.writeLine(','); } private generateFieldReadCheckerFunction( diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts index 0bba9a763..7e25738ce 100644 --- a/packages/schema/src/plugins/enhancer/policy/utils.ts +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -45,7 +45,7 @@ export function getPolicyExpressions( target: DataModel | DataModelField, kind: PolicyKind, operation: PolicyOperationKind, - override = false, + forOverride = false, filter: 'all' | 'withoutCrossModelComparison' | 'onlyCrossModelComparison' = 'all' ) { const attributes = target.attributes; @@ -55,12 +55,10 @@ export function getPolicyExpressions( return false; } - if (override) { - const overrideArg = getAttributeArg(attr, 'override'); - return overrideArg && getLiteral(overrideArg) === true; - } else { - return true; - } + const overrideArg = getAttributeArg(attr, 'override'); + const isOverride = !!overrideArg && getLiteral(overrideArg) === true; + + return (forOverride && isOverride) || (!forOverride && !isOverride); }); const checkOperation = operation === 'postUpdate' ? 'update' : operation; @@ -117,11 +115,7 @@ function processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { * Generates a "select" object that contains (recursively) fields referenced by the * given policy rules */ -export function generateSelectForRules( - rules: Expression[], - forAuthContext = false, - ignoreFutureReference = true -): object { +export function generateSelectForRules(rules: Expression[], forAuthContext = false, ignoreFutureReference = true) { const result: any = {}; const addPath = (path: string[]) => { const thisIndex = path.lastIndexOf('$this'); @@ -247,12 +241,12 @@ export function generateQueryGuardFunction( ) { const statements: (string | WriterFunction)[] = []; - const filteredAllows = allows.filter((rule) => !hasCrossModelComparison(rule)); - const filteredDenies = denies.filter((rule) => !hasCrossModelComparison(rule)); + const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule)); + const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule)); - generateNormalizedAuthRef(model, filteredAllows, filteredDenies, statements); + generateNormalizedAuthRef(model, allowRules, denyRules, statements); - const hasFieldAccess = [...filteredDenies, ...filteredAllows].some((rule) => + const hasFieldAccess = [...denyRules, ...allowRules].some((rule) => streamAst(rule).some( (child) => // this.??? @@ -273,10 +267,10 @@ export function generateQueryGuardFunction( isPostGuard: kind === 'postUpdate', }); try { - filteredDenies.forEach((rule) => { + denyRules.forEach((rule) => { writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); }); - filteredAllows.forEach((rule) => { + allowRules.forEach((rule) => { writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); }); } catch (err) { @@ -292,7 +286,7 @@ export function generateQueryGuardFunction( // if there's no allow rule, for field-level rules, by default we allow writer.write(`return ${TRUE};`); } else { - if (filteredAllows.length < allows.length) { + if (allowRules.length < allows.length) { writer.write(`return ${TRUE};`); } else { // if there's any allow rule, we deny unless any allow rule evaluates to true @@ -300,7 +294,7 @@ export function generateQueryGuardFunction( } } } else { - if (filteredAllows.length < allows.length) { + if (allowRules.length < allows.length) { // some rules are filtered out here and will be generated as additional // checker functions, so we allow here to avoid a premature denial writer.write(`return ${TRUE};`); @@ -315,37 +309,37 @@ export function generateQueryGuardFunction( writer.write('return '); const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); const writeDenies = () => { - writer.conditionalWrite(filteredDenies.length > 1, '{ AND: ['); - filteredDenies.forEach((expr, i) => { + writer.conditionalWrite(denyRules.length > 1, '{ AND: ['); + denyRules.forEach((expr, i) => { writer.inlineBlock(() => { writer.write('NOT: '); exprWriter.write(expr); }); - writer.conditionalWrite(i !== filteredDenies.length - 1, ','); + writer.conditionalWrite(i !== denyRules.length - 1, ','); }); - writer.conditionalWrite(filteredDenies.length > 1, ']}'); + writer.conditionalWrite(denyRules.length > 1, ']}'); }; const writeAllows = () => { - writer.conditionalWrite(filteredAllows.length > 1, '{ OR: ['); - filteredAllows.forEach((expr, i) => { + writer.conditionalWrite(allowRules.length > 1, '{ OR: ['); + allowRules.forEach((expr, i) => { exprWriter.write(expr); - writer.conditionalWrite(i !== filteredAllows.length - 1, ','); + writer.conditionalWrite(i !== allowRules.length - 1, ','); }); - writer.conditionalWrite(filteredAllows.length > 1, ']}'); + writer.conditionalWrite(allowRules.length > 1, ']}'); }; - if (filteredAllows.length > 0 && filteredDenies.length > 0) { + if (allowRules.length > 0 && denyRules.length > 0) { // include both allow and deny rules writer.write('{ AND: ['); writeDenies(); writer.write(','); writeAllows(); writer.write(']}'); - } else if (filteredDenies.length > 0) { + } else if (denyRules.length > 0) { // only deny rules writeDenies(); - } else if (filteredAllows.length > 0) { + } else if (allowRules.length > 0) { // only allow rules writeAllows(); } else { @@ -376,7 +370,7 @@ export function generateQueryGuardFunction( return func; } -export function generateTypeScriptCheckerFunction( +export function generateEntityCheckerFunction( sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind, @@ -414,14 +408,14 @@ export function generateTypeScriptCheckerFunction( name: `$check_${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, returnType: 'any', parameters: [ - { - name: 'context', - type: 'QueryContext', - }, { name: 'input', type: 'any', }, + { + name: 'context', + type: 'QueryContext', + }, ], statements, }); diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 3956b489e..b2ac1544b 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -699,7 +699,7 @@ describe('Attribute tests', () => { } `) - ).toContain('comparison between fields of different models is not supported in "read" rules'); + ).toContain('comparison between fields of different models is not supported in model-level "read" rules'); expect( await loadModel(` diff --git a/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts index 19886bb5b..490cfcaf2 100644 --- a/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts +++ b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts @@ -16,7 +16,7 @@ describe('Cross-model field comparison', () => { } model Profile { - id Int @id @default(autoincrement()) + id Int @id age Int user User? @@ -85,7 +85,7 @@ describe('Cross-model field comparison', () => { await expect( db.user.upsert({ where: { id: 2 }, - create: { id: 2, age: 18, profile: { create: { age: 25 } } }, + create: { id: 2, age: 18, profile: { create: { id: 2, age: 25 } } }, update: { age: 25 }, }) ).toBeRejectedByPolicy(); @@ -97,7 +97,7 @@ describe('Cross-model field comparison', () => { await expect( db.user.upsert({ where: { id: 2 }, - create: { id: 2, age: 25, profile: { create: { age: 25 } } }, + create: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } }, update: { age: 25 }, }) ).toResolveTruthy(); @@ -202,7 +202,7 @@ describe('Cross-model field comparison', () => { await expect( db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 20 } } } }) ).toResolveTruthy(); - let r = await prisma.user.findUnique({ where: { id: 1 }, include: { profile: true } }); + const r = await prisma.user.findUnique({ where: { id: 1 }, include: { profile: true } }); expect(r.profile).toMatchObject({ age: 20 }); await expect( db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 18 } } } }) @@ -668,5 +668,104 @@ describe('Cross-model field comparison', () => { await reset(); }); - it('field-level', async () => {}); + it('field-level simple', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('read', age == profile.age) @allow('update', age > profile.age) + level Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + // read + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + let r = await db.user.findUnique({ where: { id: 1 } }); + expect(r.age).toBeUndefined(); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(r.age).toBeUndefined(); + + // update + await expect(db.user.update({ where: { id: 1 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { level: 2 } })).toResolveTruthy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 21 } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 25 } })).toResolveTruthy(); + }); + + it('field-level read override', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('read', age == profile.age, true) + level Int + } + + model Profile { + id Int @id + age Int + user User? + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + let r = await db.user.findUnique({ where: { id: 1 } }); + expect(r).toBeNull(); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(Object.keys(r).length).toBe(0); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(r).toMatchObject({ age: 20 }); + }); + + it('field-level update override', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('update', age > profile.age, true) + level Int + @@allow('read', true) + } + + model Profile { + id Int @id + age Int + user User? + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { level: 2 } })).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 21 } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 25 } })).toResolveTruthy(); + }); }); From 2ab0b6ff334be08ec307c190634df73ec4bea15b Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 1 Jun 2024 23:16:18 +0800 Subject: [PATCH 4/4] add test with `auth` --- .../src/plugins/enhancer/policy/utils.ts | 8 +-- .../cross-model-field-comparison.test.ts | 52 +++++++++++++++++++ 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts index 7e25738ce..f6f8bd801 100644 --- a/packages/schema/src/plugins/enhancer/policy/utils.ts +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -199,13 +199,15 @@ export function generateSelectForRules(rules: Expression[], forAuthContext = fal } } else if (isCollectionPredicate(expr)) { const path = visit(expr.left); + // recurse into RHS + const rhs = collectReferencePaths(expr.right); if (path) { - // recurse into RHS - const rhs = collectReferencePaths(expr.right); // combine path of LHS and RHS return rhs.map((r) => [...path, ...r]); } else { - return []; + // LHS is not rooted from the current model, + // only keep RHS items that contains '$this' + return rhs.filter((r) => r.includes('$this')); } } else if (isInvocationExpr(expr)) { // recurse into function arguments diff --git a/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts index 490cfcaf2..1ebfaeba6 100644 --- a/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts +++ b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts @@ -768,4 +768,56 @@ describe('Cross-model field comparison', () => { await prisma.user.update({ where: { id: 1 }, data: { age: 21 } }); await expect(db.user.update({ where: { id: 1 }, data: { age: 25 } })).toResolveTruthy(); }); + + it('with auth', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + permissions Permission[] + @@allow('all', true) + } + + model Permission { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + model String + level Int + @@allow('all', true) + } + + model Post { + id Int @id @default(autoincrement()) + title String + permission PostPermission? + + @@allow('read', true) + @@allow("create", auth().permissions?[model == 'Post' && level == this.permission.level]) + } + + model PostPermission { + id Int @id @default(autoincrement()) + post Post @relation(fields: [postId], references: [id]) + postId Int @unique + level Int + @@allow('all', true) + } + `, + { preserveTsFiles: true } + ); + + await expect(enhance().post.create({ data: { title: 'P1' } })).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Foo', level: 1 }] }).post.create({ data: { title: 'P1' } }) + ).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Post', level: 1 }] }).post.create({ data: { title: 'P1' } }) + ).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Post', level: 1 }] }).post.create({ + data: { title: 'P1', permission: { create: { level: 1 } } }, + }) + ).toResolveTruthy(); + }); });