diff --git a/packages/schema/package.json b/packages/schema/package.json index ef84f183c..9fcf2e7b2 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -97,6 +97,7 @@ "change-case": "^4.1.2", "colors": "1.4.0", "commander": "^8.3.0", + "deepmerge": "^4.3.1", "get-latest-version": "^5.0.1", "langium": "1.3.1", "lower-case-first": "^2.0.2", diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index 92c086005..0de3a5854 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -21,7 +21,7 @@ import pluralize from 'pluralize'; import { AstValidator } from '../types'; import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils'; -// a registry of function handlers marked with @func +// a registry of function handlers marked with @check const attributeCheckers = new Map(); // function handler decorator diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index a6af730f2..79b2b1a6b 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -1,5 +1,6 @@ import { Argument, + DataModel, DataModelAttribute, DataModelFieldAttribute, Expression, @@ -7,6 +8,7 @@ import { FunctionParam, InvocationExpr, isArrayExpr, + isDataModel, isDataModelAttribute, isDataModelFieldAttribute, isLiteralExpr, @@ -15,14 +17,29 @@ import { ExpressionContext, getDataModelFieldReference, getFunctionExpressionContext, + getLiteral, + isDataModelFieldReference, isEnumFieldReference, isFromStdlib, } from '@zenstackhq/sdk'; -import { AstNode, ValidationAcceptor } from 'langium'; -import { P, match } from 'ts-pattern'; +import { AstNode, streamAst, ValidationAcceptor } from 'langium'; +import { match, P } from 'ts-pattern'; +import { isCheckInvocation } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; +// a registry of function handlers marked with @func +const invocationCheckers = new Map(); + +// function handler decorator +function func(name: string) { + return function (_target: unknown, _propertyKey: string, descriptor: PropertyDescriptor) { + if (!invocationCheckers.get(name)) { + invocationCheckers.set(name, descriptor); + } + return descriptor; + }; +} /** * InvocationExpr validation */ @@ -104,6 +121,12 @@ export default class FunctionInvocationValidator implements AstValidator(opArg); + if (!operation || !['read', 'create', 'update', 'delete'].includes(operation)) { + accept('error', 'argument must be a "read", "create", "update", or "delete"', { node: expr.args[1] }); + valid = false; + } + } + + if (!valid) { + return; + } + + // check for cyclic relation checking + const start = fieldArg.$resolvedType?.decl as DataModel; + const tasks = [expr]; + const seen = new Set(); + + while (tasks.length > 0) { + const currExpr = tasks.pop()!; + const arg = currExpr.args[0]?.value; + + if (!isDataModel(arg?.$resolvedType?.decl)) { + continue; + } + + const currModel = arg.$resolvedType.decl; + + if (seen.has(currModel)) { + if (currModel === start) { + accept('error', 'cyclic dependency detected when following the `check()` call', { node: expr }); + } else { + // a cycle is detected but it doesn't start from the invocation expression we're checking, + // just break here and the cycle will be reported when we validate the start of it + } + break; + } else { + seen.add(currModel); + } + + const policyAttrs = currModel.attributes.filter( + (attr) => attr.decl.$refText === '@@allow' || attr.decl.$refText === '@@deny' + ); + for (const attr of policyAttrs) { + const rule = attr.args[1]; + if (!rule) { + continue; + } + streamAst(rule).forEach((node) => { + if (isCheckInvocation(node)) { + tasks.push(node as InvocationExpr); + } + }); + } + } + } } diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index f301e5b9d..66e5df73d 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -19,15 +19,17 @@ import { StringLiteral, UnaryExpr, } from '@zenstackhq/language/ast'; -import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; +import { DELEGATE_AUX_RELATION_PREFIX, PolicyOperationKind } from '@zenstackhq/runtime'; import { ExpressionContext, getFunctionExpressionContext, getIdFields, getLiteral, + getQueryGuardFunctionName, isAuthInvocation, isDataModelFieldReference, isDelegateModel, + isFromStdlib, isFutureExpr, PluginError, TypeScriptExpressionTransformer, @@ -37,6 +39,7 @@ import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; import { CodeBlockWriter } from 'ts-morph'; import { name } from '..'; +import { isCheckInvocation } from '../../../utils/ast-utils'; type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<='; @@ -60,6 +63,11 @@ type FilterOperators = export const TRUE = '{ AND: [] }'; export const FALSE = '{ OR: [] }'; +export type ExpressionWriterOptions = { + isPostGuard?: boolean; + operationContext: PolicyOperationKind; +}; + /** * Utility for writing ZModel expression as Prisma query argument objects into a ts-morph writer */ @@ -68,15 +76,14 @@ export class ExpressionWriter { /** * Constructs a new ExpressionWriter - * - * @param isPostGuard indicates if we're writing for post-update conditions */ - constructor(private readonly writer: CodeBlockWriter, private readonly isPostGuard = false) { + constructor(private readonly writer: CodeBlockWriter, private readonly options: ExpressionWriterOptions) { this.plainExprBuilder = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, - isPostGuard: this.isPostGuard, + isPostGuard: this.options.isPostGuard, // in post-guard context, `this` references pre-update value - thisExprContext: this.isPostGuard ? 'context.preValue' : undefined, + thisExprContext: this.options.isPostGuard ? 'context.preValue' : undefined, + operationContext: this.options.operationContext, }); } @@ -269,9 +276,9 @@ export class ExpressionWriter { // expression rooted to `auth()` is always compiled to plain expression !this.isAuthOrAuthMemberAccess(expr.left) && // `future()` in post-update context - ((this.isPostGuard && this.isFutureMemberAccess(expr.left)) || + ((this.options.isPostGuard && this.isFutureMemberAccess(expr.left)) || // non-`future()` in pre-update context - (!this.isPostGuard && !this.isFutureMemberAccess(expr.left))); + (!this.options.isPostGuard && !this.isFutureMemberAccess(expr.left))); if (compileToRelationQuery) { this.block(() => { @@ -279,7 +286,10 @@ export class ExpressionWriter { expr.left, () => { // inner scope of collection expression is always compiled as non-post-guard - const innerWriter = new ExpressionWriter(this.writer, false); + const innerWriter = new ExpressionWriter(this.writer, { + isPostGuard: false, + operationContext: this.options.operationContext, + }); innerWriter.write(expr.right); }, operator === '?' ? 'some' : operator === '!' ? 'every' : 'none' @@ -297,14 +307,14 @@ export class ExpressionWriter { } if (isMemberAccessExpr(expr)) { - if (isFutureExpr(expr.operand) && this.isPostGuard) { + if (isFutureExpr(expr.operand) && this.options.isPostGuard) { // when writing for post-update, future().field.x is a field access return true; } else { return this.isFieldAccess(expr.operand); } } - if (isDataModelFieldReference(expr) && !this.isPostGuard) { + if (isDataModelFieldReference(expr) && !this.options.isPostGuard) { return true; } return false; @@ -437,7 +447,7 @@ export class ExpressionWriter { this.writer.write(operator === '!=' ? TRUE : FALSE); } else { this.writeOperator(operator, fieldAccess, () => { - if (isDataModelFieldReference(operand) && !this.isPostGuard) { + if (isDataModelFieldReference(operand) && !this.options.isPostGuard) { // if operand is a field reference and we're not generating for post-update guard, // we should generate a field reference (comparing fields in the same model) this.writeFieldReference(operand); @@ -735,6 +745,11 @@ export class ExpressionWriter { functionAllowedContext.includes(ExpressionContext.AccessPolicy) || functionAllowedContext.includes(ExpressionContext.ValidationRule) ) { + if (isCheckInvocation(expr)) { + this.writeRelationCheck(expr); + return; + } + if (!expr.args.some((arg) => this.isFieldAccess(arg.value))) { // filter functions without referencing fields this.guard(() => this.plain(expr)); @@ -744,13 +759,13 @@ export class ExpressionWriter { let valueArg = expr.args[1]?.value; // isEmpty function is zero arity, it's mapped to a boolean literal - if (funcDecl.name === 'isEmpty') { + if (isFromStdlib(funcDecl) && funcDecl.name === 'isEmpty') { valueArg = { $type: BooleanLiteral, value: true } as LiteralExpr; } // contains function has a 3rd argument that indicates whether the comparison should be case-insensitive let extraArgs: Record | undefined = undefined; - if (funcDecl.name === 'contains') { + if (isFromStdlib(funcDecl) && funcDecl.name === 'contains') { if (getLiteral(expr.args[2]?.value) === true) { extraArgs = { mode: { $type: StringLiteral, value: 'insensitive' } as LiteralExpr }; } @@ -770,4 +785,38 @@ export class ExpressionWriter { throw new PluginError(name, `Unsupported function ${funcDecl.name}`); } } + + private writeRelationCheck(expr: InvocationExpr) { + if (!isDataModelFieldReference(expr.args[0].value)) { + throw new PluginError(name, `First argument of check() must be a field`); + } + if (!isDataModel(expr.args[0].value.$resolvedType?.decl)) { + throw new PluginError(name, `First argument of check() must be a relation field`); + } + + const fieldRef = expr.args[0].value; + const targetModel = fieldRef.$resolvedType?.decl as DataModel; + + let operation: string; + if (expr.args[1]) { + const literal = getLiteral(expr.args[1].value); + if (!literal) { + throw new TypeScriptExpressionTransformerError(`Second argument of check() must be a string literal`); + } + if (!['read', 'create', 'update', 'delete'].includes(literal)) { + throw new TypeScriptExpressionTransformerError(`Invalid check() operation "${literal}"`); + } + operation = literal; + } else { + if (!this.options.operationContext) { + throw new TypeScriptExpressionTransformerError('Unable to determine CRUD operation from context'); + } + operation = this.options.operationContext; + } + + this.block(() => { + const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation); + this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`); + }); + } } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 0bf949329..2c54949f4 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -2,7 +2,9 @@ import { DataModel, DataModelField, Expression, + InvocationExpr, Model, + ReferenceExpr, isDataModel, isDataModelField, isEnum, @@ -28,9 +30,18 @@ import { getPrismaClientImportSpec } from '@zenstackhq/sdk/prisma'; import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; -import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; +import { + CodeBlockWriter, + FunctionDeclaration, + Project, + SourceFile, + VariableDeclarationKind, + WriterFunction, +} from 'ts-morph'; +import { isCheckInvocation } from '../../../utils/ast-utils'; import { ConstraintTransformer } from './constraint-transformer'; import { + generateConstantQueryGuardFunction, generateEntityCheckerFunction, generateNormalizedAuthRef, generateQueryGuardFunction, @@ -234,6 +245,7 @@ export class PolicyGenerator { const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', + operationContext: 'create', }); let expr = @@ -310,7 +322,7 @@ export class PolicyGenerator { private writePostUpdatePreValueSelector(model: DataModel, writer: CodeBlockWriter) { const allows = getPolicyExpressions(model, 'allow', 'postUpdate'); const denies = getPolicyExpressions(model, 'deny', 'postUpdate'); - const preValueSelect = generateSelectForRules([...allows, ...denies]); + const preValueSelect = generateSelectForRules([...allows, ...denies], 'postUpdate'); if (preValueSelect) { writer.writeLine(`preUpdateSelector: ${JSON.stringify(preValueSelect)},`); } @@ -350,17 +362,19 @@ export class PolicyGenerator { // write cross-model comparison rules as entity checker functions // because they cannot be checked inside Prisma - this.writeEntityChecker(model, kind, writer, sourceFile, true); + const { functionName, selector } = this.writeEntityChecker(model, kind, sourceFile, false); + + if (this.shouldUseEntityChecker(model, kind, true, false)) { + writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`); + } } - private writeEntityChecker( + private shouldUseEntityChecker( target: DataModel | DataModelField, kind: PolicyOperationKind, - writer: CodeBlockWriter, - sourceFile: SourceFile, - onlyCrossModelComparison = false, - forOverride = false - ) { + onlyCrossModelComparison: boolean, + forOverride: boolean + ): boolean { const allows = getPolicyExpressions( target, 'allow', @@ -376,10 +390,37 @@ export class PolicyGenerator { onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all' ); - if (allows.length === 0 && denies.length === 0) { - return; + if (allows.length > 0 || denies.length > 0) { + return true; } + const allRules = [ + ...getPolicyExpressions(target, 'allow', kind, forOverride, 'all'), + ...getPolicyExpressions(target, 'deny', kind, forOverride, 'all'), + ]; + + return allRules.some((rule) => { + return streamAst(rule).some((node) => { + if (isCheckInvocation(node)) { + const expr = node as InvocationExpr; + const fieldRef = expr.args[0].value as ReferenceExpr; + const targetModel = fieldRef.$resolvedType?.decl as DataModel; + return this.shouldUseEntityChecker(targetModel, kind, onlyCrossModelComparison, forOverride); + } + return false; + }); + }); + } + + private writeEntityChecker( + target: DataModel | DataModelField, + kind: PolicyOperationKind, + sourceFile: SourceFile, + forOverride: boolean + ) { + const allows = getPolicyExpressions(target, 'allow', kind, forOverride, 'all'); + const denies = getPolicyExpressions(target, 'deny', kind, forOverride, 'all'); + const model = isDataModel(target) ? target : (target.$container as DataModel); const func = generateEntityCheckerFunction( sourceFile, @@ -390,9 +431,9 @@ export class PolicyGenerator { isDataModelField(target) ? target : undefined, forOverride ); - const selector = generateSelectForRules([...allows, ...denies], false, kind !== 'postUpdate') ?? {}; - const key = forOverride ? 'overrideEntityChecker' : 'entityChecker'; - writer.write(`${key}: { func: ${func.getName()!}, selector: ${JSON.stringify(selector)} },`); + const selector = generateSelectForRules([...allows, ...denies], kind, false, kind !== 'postUpdate') ?? {}; + + return { functionName: func.getName()!, selector }; } // writes `guard: ...` for a given policy operation kind @@ -408,23 +449,32 @@ export class PolicyGenerator { if (kind === 'update' && allows.length === 0) { // no allow rule for 'update', policy is constant based on if there's // post-update counterpart + let func: FunctionDeclaration; if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { - writer.write(`guard: false,`); + func = generateConstantQueryGuardFunction(sourceFile, model, kind, false); } else { - writer.write(`guard: true,`); + func = generateConstantQueryGuardFunction(sourceFile, model, kind, true); } + writer.write(`guard: ${func.getName()!},`); return; } if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) { // no 'postUpdate' rule, always allow - writer.write(`guard: true,`); + const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true); + writer.write(`guard: ${func.getName()},`); return; } if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') { // constant policy - writer.write(`guard: ${policies[kind as keyof typeof policies]},`); + const func = generateConstantQueryGuardFunction( + sourceFile, + model, + kind, + policies[kind as keyof typeof policies] as boolean + ); + writer.write(`guard: ${func.getName()!},`); return; } @@ -534,7 +584,13 @@ export class PolicyGenerator { // checker function // write all field-level rules as entity checker function - this.writeEntityChecker(field, 'read', writer, sourceFile, false, false); + const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, false); + + if (this.shouldUseEntityChecker(field, 'read', false, false)) { + writer.write( + `entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },` + ); + } if (overrideAllows.length > 0) { // override guard function @@ -551,7 +607,14 @@ export class PolicyGenerator { writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); // additional entity checker for override - this.writeEntityChecker(field, 'read', writer, sourceFile, false, true); + const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, true); + if (this.shouldUseEntityChecker(field, 'read', false, true)) { + writer.write( + `overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify( + selector + )} },` + ); + } } }); writer.writeLine(','); @@ -581,7 +644,12 @@ export class PolicyGenerator { // write cross-model comparison rules as entity checker functions // because they cannot be checked inside Prisma - this.writeEntityChecker(field, 'update', writer, sourceFile, true, false); + const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, false); + if (this.shouldUseEntityChecker(field, 'update', true, false)) { + writer.write( + `entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },` + ); + } if (overrideAllows.length > 0) { // override guard @@ -598,7 +666,14 @@ export class PolicyGenerator { // write cross-model comparison override rules as entity checker functions // because they cannot be checked inside Prisma - this.writeEntityChecker(field, 'update', writer, sourceFile, true, true); + const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, true); + if (this.shouldUseEntityChecker(field, 'update', true, true)) { + writer.write( + `overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify( + selector + )} },` + ); + } } }); writer.writeLine(','); @@ -649,7 +724,7 @@ export class PolicyGenerator { }); if (authRules.length > 0) { - return generateSelectForRules(authRules, true); + return generateSelectForRules(authRules, undefined, true); } else { return undefined; } diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts index 3cc43223e..f09263dca 100644 --- a/packages/schema/src/plugins/enhancer/policy/utils.ts +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -8,8 +8,10 @@ import { getAttributeArg, getAuthModel, getDataModels, + getEntityCheckerFunctionName, getIdFields, getLiteral, + getQueryGuardFunctionName, isAuthInvocation, isDataModelFieldReference, isEnumFieldReference, @@ -19,7 +21,9 @@ import { } from '@zenstackhq/sdk'; import { Enum, + InvocationExpr, Model, + ReferenceExpr, isBinaryExpr, isDataModel, isDataModelField, @@ -32,10 +36,11 @@ import { type DataModelField, type Expression, } from '@zenstackhq/sdk/ast'; +import deepmerge from 'deepmerge'; import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium'; import { SourceFile, WriterFunction } from 'ts-morph'; import { name } from '..'; -import { isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils'; +import { isCheckInvocation, isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; /** @@ -115,8 +120,13 @@ function processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { * Generates a "select" object that contains (recursively) fields referenced by the * given policy rules */ -export function generateSelectForRules(rules: Expression[], forAuthContext = false, ignoreFutureReference = true) { - const result: any = {}; +export function generateSelectForRules( + rules: Expression[], + forOperation: PolicyOperationKind | undefined, + forAuthContext = false, + ignoreFutureReference = true +) { + let result: any = {}; const addPath = (path: string[]) => { const thisIndex = path.lastIndexOf('$this'); if (thisIndex >= 0) { @@ -224,11 +234,62 @@ export function generateSelectForRules(rules: Expression[], forAuthContext = fal for (const rule of rules) { const paths = collectReferencePaths(rule); paths.forEach((p) => addPath(p)); + + // merge selectors from models referenced by `check()` calls + streamAst(rule).forEach((node) => { + if (isCheckInvocation(node)) { + const expr = node as InvocationExpr; + const fieldRef = expr.args[0].value as ReferenceExpr; + const targetModel = fieldRef.$resolvedType?.decl as DataModel; + const targetOperation = getLiteral(expr.args[1]?.value) ?? forOperation; + const targetSelector = generateSelectForRules( + [ + ...getPolicyExpressions(targetModel, 'allow', targetOperation as PolicyOperationKind), + ...getPolicyExpressions(targetModel, 'deny', targetOperation as PolicyOperationKind), + ], + targetOperation as PolicyOperationKind, + forAuthContext, + ignoreFutureReference + ); + if (targetSelector) { + result = deepmerge(result, { [fieldRef.target.$refText]: { select: targetSelector } }); + } + } + }); } return Object.keys(result).length === 0 ? undefined : result; } +/** + * Generates a constant query guard function + */ +export function generateConstantQueryGuardFunction( + sourceFile: SourceFile, + model: DataModel, + kind: PolicyOperationKind, + value: boolean +) { + const func = sourceFile.addFunction({ + name: getQueryGuardFunctionName(model, undefined, false, kind), + returnType: 'any', + parameters: [ + { + name: 'context', + type: 'QueryContext', + }, + { + // for generating field references used by field comparison in the same model + name: 'db', + type: 'CrudContract', + }, + ], + statements: [`return ${value ? TRUE : FALSE};`], + }); + + return func; +} + /** * Generates a query guard function that returns a partial Prisma query for the given model or field */ @@ -267,6 +328,7 @@ export function generateQueryGuardFunction( const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, isPostGuard: kind === 'postUpdate', + operationContext: kind, }); try { denyRules.forEach((rule) => { @@ -309,7 +371,10 @@ export function generateQueryGuardFunction( } else { statements.push((writer) => { writer.write('return '); - const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); + const exprWriter = new ExpressionWriter(writer, { + isPostGuard: kind === 'postUpdate', + operationContext: kind, + }); const writeDenies = () => { writer.conditionalWrite(denyRules.length > 1, '{ AND: ['); denyRules.forEach((expr, i) => { @@ -353,7 +418,7 @@ export function generateQueryGuardFunction( } const func = sourceFile.addFunction({ - name: `${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, + name: getQueryGuardFunctionName(model, forField, fieldOverride, kind), returnType: 'any', parameters: [ { @@ -391,6 +456,7 @@ export function generateEntityCheckerFunction( fieldReferenceContext: 'input', isPostGuard: kind === 'postUpdate', futureRefContext: 'input', + operationContext: kind, }); denies.forEach((rule) => { @@ -422,7 +488,7 @@ export function generateEntityCheckerFunction( } const func = sourceFile.addFunction({ - name: `$check_${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, + name: getEntityCheckerFunctionName(model, forField, fieldOverride, kind), returnType: 'any', parameters: [ { diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 696ac8552..1d2a88ba7 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -666,6 +666,16 @@ function datetime(field: String): Boolean { function url(field: String): Boolean { } @@@expressionContext([ValidationRule]) +/** + * Checks if the current user can perform the given operation on the given field. + * + * @param field: The field to check access for + * @param operation: The operation to check access for. Can be "read", "create", "update", or "delete". If the operation is not provided, + * it defaults the operation of the containing policy rule. + */ +function check(field: Any, operation: String?): Boolean { +} @@@expressionContext([AccessPolicy]) + ////////////////////////////////////////////// // End validation attributes and functions ////////////////////////////////////////////// diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 24af8862d..bf935ce20 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -151,6 +151,10 @@ export function isFutureInvocation(node: AstNode) { return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); } +export function isCheckInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref); +} + export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { return expr.target.ref; diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index 5820e92a8..63254047f 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -1369,7 +1369,7 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, declarations: [ { name: 'expr', - initializer: (writer) => new ExpressionWriter(writer).write(expr), + initializer: (writer) => new ExpressionWriter(writer, { operationContext: 'read' }).write(expr), }, ], }); diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 3c89805d9..27198da28 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -1,6 +1,7 @@ export * from './code-gen'; export * from './constants'; export { generate as generateModelMeta } from './model-meta-generator'; +export * from './names'; export * from './policy'; export * from './types'; export * from './typescript-expression-transformer'; diff --git a/packages/sdk/src/names.ts b/packages/sdk/src/names.ts new file mode 100644 index 000000000..be78396a3 --- /dev/null +++ b/packages/sdk/src/names.ts @@ -0,0 +1,25 @@ +import { DataModel, DataModelField } from './ast'; + +/** + * Gets the name of the function that computes a partial Prisma query guard. + */ +export function getQueryGuardFunctionName( + model: DataModel, + forField: DataModelField | undefined, + fieldOverride: boolean, + kind: string +) { + return `${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`; +} + +/** + * Gets the name of the function that checks an entity for access policy rules. + */ +export function getEntityCheckerFunctionName( + model: DataModel, + forField: DataModelField | undefined, + fieldOverride: boolean, + kind: string +) { + return `$check_${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`; +} diff --git a/packages/sdk/src/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts index 28ce1d345..9a884ebdf 100644 --- a/packages/sdk/src/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -22,6 +22,7 @@ import { } from '@zenstackhq/language/ast'; import { P, match } from 'ts-pattern'; import { ExpressionContext } from './constants'; +import { getEntityCheckerFunctionName } from './names'; import { getIdFields, getLiteral, isDataModelFieldReference, isFromStdlib, isFutureExpr } from './utils'; export class TypeScriptExpressionTransformerError extends Error { @@ -36,6 +37,7 @@ type Options = { thisExprContext?: string; futureRefContext?: string; context: ExpressionContext; + operationContext?: 'read' | 'create' | 'update' | 'postUpdate' | 'delete'; }; // a registry of function handlers marked with @func @@ -274,6 +276,39 @@ export class TypeScriptExpressionTransformer { return `(!${field} || ${field}?.length === 0)`; } + @func('check') + private _check(args: Expression[]) { + if (!isDataModelFieldReference(args[0])) { + throw new TypeScriptExpressionTransformerError(`First argument of check() must be a field`); + } + if (!isDataModel(args[0].$resolvedType?.decl)) { + throw new TypeScriptExpressionTransformerError(`First argument of check() must be a relation field`); + } + + const fieldRef = args[0] as ReferenceExpr; + const targetModel = fieldRef.$resolvedType?.decl as DataModel; + + let operation: string; + if (args[1]) { + const literal = getLiteral(args[1]); + if (!literal) { + throw new TypeScriptExpressionTransformerError(`Second argument of check() must be a string literal`); + } + if (!['read', 'create', 'update', 'delete'].includes(literal)) { + throw new TypeScriptExpressionTransformerError(`Invalid check() operation "${literal}"`); + } + operation = literal; + } else { + if (!this.options.operationContext) { + throw new TypeScriptExpressionTransformerError('Unable to determine CRUD operation from context'); + } + operation = this.options.operationContext; + } + + const entityCheckerFunc = getEntityCheckerFunctionName(targetModel, undefined, false, operation); + return `${entityCheckerFunc}(input.${fieldRef.target.$refText}, context)`; + } + private ensureBoolean(expr: string) { if (this.options.context === ExpressionContext.ValidationRule) { // all fields are optional in a validation context, so we treat undefined @@ -452,6 +487,7 @@ export class TypeScriptExpressionTransformer { ...this.options, isPostGuard: false, fieldReferenceContext: '_item', + operationContext: this.options.operationContext, }); const predicate = innerTransformer.transform(expr.right, normalizeUndefined); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 3dce134e2..750f2ef3a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -501,6 +501,9 @@ importers: commander: specifier: ^8.3.0 version: 8.3.0 + deepmerge: + specifier: ^4.3.1 + version: 4.3.1 get-latest-version: specifier: ^5.0.1 version: 5.1.0 diff --git a/tests/integration/tests/enhancements/with-policy/relation-check.test.ts b/tests/integration/tests/enhancements/with-policy/relation-check.test.ts new file mode 100644 index 000000000..c08daa7d2 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/relation-check.test.ts @@ -0,0 +1,703 @@ +import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; + +describe('Relation checker', () => { + it('should work for read', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user, 'read')) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + const db = enhance(); + await expect(db.profile.findMany()).resolves.toHaveLength(1); + }); + + it('should work for simple create', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('create', check(user, 'read')) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + }, + }); + + const db = enhance(); + await expect(db.profile.create({ data: { user: { connect: { id: 1 } }, age: 18 } })).toResolveTruthy(); + await expect(db.profile.create({ data: { user: { connect: { id: 2 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should work for nested create', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('create', age < 30 && check(user, 'read')) + } + ` + ); + + const db = enhance(); + + await expect( + db.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 18 }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }) + ).toBeRejectedByPolicy(); + }); + + it('should work for update', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('update', check(user, 'read') && age < 30) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 3, + public: true, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + const db = enhance(); + await expect(db.profile.update({ where: { id: 1 }, data: { age: 21 } })).toResolveTruthy(); + await expect(db.profile.update({ where: { id: 2 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.profile.update({ where: { id: 3 }, data: { age: 21 } })).toBeRejectedByPolicy(); + }); + + it('should work for delete', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('create', true) + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('delete', check(user, 'read') && age < 30) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 3, + public: true, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + const db = enhance(); + await expect(db.profile.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(db.profile.delete({ where: { id: 2 } })).toBeRejectedByPolicy(); + await expect(db.profile.delete({ where: { id: 3 } })).toBeRejectedByPolicy(); + }); + + it('should work for field-level', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int @allow('read', age < 30 && check(user, 'read')) + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }); + + const db = enhance(); + + const p1 = await db.profile.findUnique({ where: { id: 1 } }); + expect(p1.age).toBe(18); + const p2 = await db.profile.findUnique({ where: { id: 2 } }); + expect(p2.age).toBeUndefined(); + const p3 = await db.profile.findUnique({ where: { id: 3 } }); + expect(p3.age).toBeUndefined(); + }); + + it('should work for field-level with override', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int @allow('read', age < 30 && check(user, 'read'), true) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 3, + public: true, + profile: { + create: { age: 30 }, + }, + }, + }); + + const db = enhance(); + + const p1 = await db.profile.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(p1.age).toBe(18); + const p2 = await db.profile.findUnique({ where: { id: 2 }, select: { age: true } }); + expect(p2).toBeNull(); + const p3 = await db.profile.findUnique({ where: { id: 3 }, select: { age: true } }); + expect(p3).toBeNull(); + }); + + it('should work for cross-model field comparison', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + age Int + @@allow('read', true) + @@allow('update', age == profile.age) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', true) + @@allow('update', check(user, 'update') && age < 30) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + age: 18, + profile: { + create: { id: 1, age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + age: 18, + profile: { + create: { id: 2, age: 20 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 3, + age: 30, + profile: { + create: { id: 3, age: 30 }, + }, + }, + }); + + const db = enhance(); + await expect(db.profile.update({ where: { id: 1 }, data: { age: 21 } })).toResolveTruthy(); + await expect(db.profile.update({ where: { id: 2 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.profile.update({ where: { id: 3 }, data: { age: 21 } })).toBeRejectedByPolicy(); + }); + + it('should work for implicit specific operations', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('read', public) + @@allow('create', true) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('read', check(user)) + @@allow('create', check(user)) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + const db = enhance(); + await expect(db.profile.findMany()).resolves.toHaveLength(1); + + await prisma.user.create({ + data: { + id: 3, + public: true, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 3 } }, age: 18 } })).toResolveTruthy(); + + await prisma.user.create({ + data: { + id: 4, + public: false, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 4 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should work for implicit all operations', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('all', public) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('all', check(user)) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + public: true, + profile: { + create: { age: 18 }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + public: false, + profile: { + create: { age: 20 }, + }, + }, + }); + + const db = enhance(); + await expect(db.profile.findMany()).resolves.toHaveLength(1); + + await prisma.user.create({ + data: { + id: 3, + public: true, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 3 } }, age: 18 } })).toResolveTruthy(); + + await prisma.user.create({ + data: { + id: 4, + public: false, + }, + }); + await expect(db.profile.create({ data: { user: { connect: { id: 4 } }, age: 18 } })).toBeRejectedByPolicy(); + }); + + it('should report error for invalid args', async () => { + await expect( + loadModelWithError( + ` + model User { + id Int @id @default(autoincrement()) + public Boolean + @@allow('read', check(public)) + } + ` + ) + ).resolves.toContain('argument must be a relation field'); + + await expect( + loadModelWithError( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + @@allow('read', check(posts)) + } + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + } + ` + ) + ).resolves.toContain('argument cannot be an array field'); + + await expect( + loadModelWithError( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + @@allow('read', check(profile.details)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + details ProfileDetails? + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int + age Int + } + ` + ) + ).resolves.toContain('argument must be a relation field'); + + await expect( + loadModelWithError( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + @@allow('read', check(posts, 'all')) + } + model Post { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + } + ` + ) + ).resolves.toContain('argument must be a "read", "create", "update", or "delete"'); + }); + + it('should report error for cyclic relation check', async () => { + await expect( + loadModelWithError( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + profileDetails ProfileDetails? + public Boolean + @@allow('all', check(profile)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + details ProfileDetails? + @@allow('all', check(details)) + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + user User @relation(fields: [userId], references: [id]) + userId Int @unique + age Int + @@allow('all', check(user)) + } + ` + ) + ).resolves.toContain('cyclic dependency detected when following the `check()` call'); + }); + + it('should report error for cyclic relation check indirect', async () => { + await expect( + loadModelWithError( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + public Boolean + @@allow('all', check(profile)) + } + + model Profile { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int @unique + details ProfileDetails? + @@allow('all', check(details)) + } + + model ProfileDetails { + id Int @id @default(autoincrement()) + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int + @@allow('all', check(profile)) + } + ` + ) + ).resolves.toContain('cyclic dependency detected when following the `check()` call'); + }); +});