Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement relation check() function in ZModel #1556

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
"change-case": "^4.1.2",
"colors": "1.4.0",
"commander": "^8.3.0",
"deepmerge": "^4.3.1",
"get-latest-version": "^5.0.1",
"langium": "1.3.1",
"lower-case-first": "^2.0.2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import pluralize from 'pluralize';
import { AstValidator } from '../types';
import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils';

// a registry of function handlers marked with @func
// a registry of function handlers marked with @check
const attributeCheckers = new Map<string, PropertyDescriptor>();

// function handler decorator
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import {
Argument,
DataModel,
DataModelAttribute,
DataModelFieldAttribute,
Expression,
FunctionDecl,
FunctionParam,
InvocationExpr,
isArrayExpr,
isDataModel,
isDataModelAttribute,
isDataModelFieldAttribute,
isLiteralExpr,
Expand All @@ -15,14 +17,29 @@ import {
ExpressionContext,
getDataModelFieldReference,
getFunctionExpressionContext,
getLiteral,
isDataModelFieldReference,
isEnumFieldReference,
isFromStdlib,
} from '@zenstackhq/sdk';
import { AstNode, ValidationAcceptor } from 'langium';
import { P, match } from 'ts-pattern';
import { AstNode, streamAst, ValidationAcceptor } from 'langium';
import { match, P } from 'ts-pattern';
import { isCheckInvocation } from '../../utils/ast-utils';
import { AstValidator } from '../types';
import { typeAssignable } from './utils';

// a registry of function handlers marked with @func
const invocationCheckers = new Map<string, PropertyDescriptor>();

// function handler decorator
function func(name: string) {
return function (_target: unknown, _propertyKey: string, descriptor: PropertyDescriptor) {
if (!invocationCheckers.get(name)) {
invocationCheckers.set(name, descriptor);
}
return descriptor;
};
}
/**
* InvocationExpr validation
*/
Expand Down Expand Up @@ -104,6 +121,12 @@ export default class FunctionInvocationValidator implements AstValidator<Express
}
}
}

// run checkers for specific functions
const checker = invocationCheckers.get(expr.function.$refText);
if (checker) {
checker.value.call(this, expr, accept);
}
}

private validateArgs(funcDecl: FunctionDecl, args: Argument[], accept: ValidationAcceptor) {
Expand Down Expand Up @@ -167,4 +190,76 @@ export default class FunctionInvocationValidator implements AstValidator<Express

return true;
}

@func('check')
private _checkCheck(expr: InvocationExpr, accept: ValidationAcceptor) {
let valid = true;

const fieldArg = expr.args[0].value;
if (!isDataModelFieldReference(fieldArg) || !isDataModel(fieldArg.$resolvedType?.decl)) {
accept('error', 'argument must be a relation field', { node: expr.args[0] });
valid = false;
}

if (fieldArg.$resolvedType?.array) {
accept('error', 'argument cannot be an array field', { node: expr.args[0] });
valid = false;
}

const opArg = expr.args[1]?.value;
if (opArg) {
const operation = getLiteral<string>(opArg);
if (!operation || !['read', 'create', 'update', 'delete'].includes(operation)) {
accept('error', 'argument must be a "read", "create", "update", or "delete"', { node: expr.args[1] });
valid = false;
}
}

if (!valid) {
return;
}

// check for cyclic relation checking
const start = fieldArg.$resolvedType?.decl as DataModel;
const tasks = [expr];
const seen = new Set<DataModel>();

while (tasks.length > 0) {
const currExpr = tasks.pop()!;
const arg = currExpr.args[0]?.value;

if (!isDataModel(arg?.$resolvedType?.decl)) {
continue;
}

const currModel = arg.$resolvedType.decl;

if (seen.has(currModel)) {
if (currModel === start) {
accept('error', 'cyclic dependency detected when following the `check()` call', { node: expr });
} else {
// a cycle is detected but it doesn't start from the invocation expression we're checking,
// just break here and the cycle will be reported when we validate the start of it
}
break;
} else {
seen.add(currModel);
}

const policyAttrs = currModel.attributes.filter(
(attr) => attr.decl.$refText === '@@allow' || attr.decl.$refText === '@@deny'
);
for (const attr of policyAttrs) {
const rule = attr.args[1];
if (!rule) {
continue;
}
streamAst(rule).forEach((node) => {
if (isCheckInvocation(node)) {
tasks.push(node as InvocationExpr);
}
});
}
}
}
}
77 changes: 63 additions & 14 deletions packages/schema/src/plugins/enhancer/policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ import {
StringLiteral,
UnaryExpr,
} from '@zenstackhq/language/ast';
import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime';
import { DELEGATE_AUX_RELATION_PREFIX, PolicyOperationKind } from '@zenstackhq/runtime';
import {
ExpressionContext,
getFunctionExpressionContext,
getIdFields,
getLiteral,
getQueryGuardFunctionName,
isAuthInvocation,
isDataModelFieldReference,
isDelegateModel,
isFromStdlib,
isFutureExpr,
PluginError,
TypeScriptExpressionTransformer,
Expand All @@ -37,6 +39,7 @@ import { lowerCaseFirst } from 'lower-case-first';
import invariant from 'tiny-invariant';
import { CodeBlockWriter } from 'ts-morph';
import { name } from '..';
import { isCheckInvocation } from '../../../utils/ast-utils';

type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<=';

Expand All @@ -60,6 +63,11 @@ type FilterOperators =
export const TRUE = '{ AND: [] }';
export const FALSE = '{ OR: [] }';

export type ExpressionWriterOptions = {
isPostGuard?: boolean;
operationContext: PolicyOperationKind;
};

/**
* Utility for writing ZModel expression as Prisma query argument objects into a ts-morph writer
*/
Expand All @@ -68,15 +76,14 @@ export class ExpressionWriter {

/**
* Constructs a new ExpressionWriter
*
* @param isPostGuard indicates if we're writing for post-update conditions
*/
constructor(private readonly writer: CodeBlockWriter, private readonly isPostGuard = false) {
constructor(private readonly writer: CodeBlockWriter, private readonly options: ExpressionWriterOptions) {
this.plainExprBuilder = new TypeScriptExpressionTransformer({
context: ExpressionContext.AccessPolicy,
isPostGuard: this.isPostGuard,
isPostGuard: this.options.isPostGuard,
// in post-guard context, `this` references pre-update value
thisExprContext: this.isPostGuard ? 'context.preValue' : undefined,
thisExprContext: this.options.isPostGuard ? 'context.preValue' : undefined,
operationContext: this.options.operationContext,
});
}

Expand Down Expand Up @@ -269,17 +276,20 @@ export class ExpressionWriter {
// expression rooted to `auth()` is always compiled to plain expression
!this.isAuthOrAuthMemberAccess(expr.left) &&
// `future()` in post-update context
((this.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
((this.options.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
// non-`future()` in pre-update context
(!this.isPostGuard && !this.isFutureMemberAccess(expr.left)));
(!this.options.isPostGuard && !this.isFutureMemberAccess(expr.left)));

if (compileToRelationQuery) {
this.block(() => {
this.writeFieldCondition(
expr.left,
() => {
// inner scope of collection expression is always compiled as non-post-guard
const innerWriter = new ExpressionWriter(this.writer, false);
const innerWriter = new ExpressionWriter(this.writer, {
isPostGuard: false,
operationContext: this.options.operationContext,
});
innerWriter.write(expr.right);
},
operator === '?' ? 'some' : operator === '!' ? 'every' : 'none'
Expand All @@ -297,14 +307,14 @@ export class ExpressionWriter {
}

if (isMemberAccessExpr(expr)) {
if (isFutureExpr(expr.operand) && this.isPostGuard) {
if (isFutureExpr(expr.operand) && this.options.isPostGuard) {
// when writing for post-update, future().field.x is a field access
return true;
} else {
return this.isFieldAccess(expr.operand);
}
}
if (isDataModelFieldReference(expr) && !this.isPostGuard) {
if (isDataModelFieldReference(expr) && !this.options.isPostGuard) {
return true;
}
return false;
Expand Down Expand Up @@ -437,7 +447,7 @@ export class ExpressionWriter {
this.writer.write(operator === '!=' ? TRUE : FALSE);
} else {
this.writeOperator(operator, fieldAccess, () => {
if (isDataModelFieldReference(operand) && !this.isPostGuard) {
if (isDataModelFieldReference(operand) && !this.options.isPostGuard) {
// if operand is a field reference and we're not generating for post-update guard,
// we should generate a field reference (comparing fields in the same model)
this.writeFieldReference(operand);
Expand Down Expand Up @@ -735,6 +745,11 @@ export class ExpressionWriter {
functionAllowedContext.includes(ExpressionContext.AccessPolicy) ||
functionAllowedContext.includes(ExpressionContext.ValidationRule)
) {
if (isCheckInvocation(expr)) {
this.writeRelationCheck(expr);
return;
}

if (!expr.args.some((arg) => this.isFieldAccess(arg.value))) {
// filter functions without referencing fields
this.guard(() => this.plain(expr));
Expand All @@ -744,13 +759,13 @@ export class ExpressionWriter {
let valueArg = expr.args[1]?.value;

// isEmpty function is zero arity, it's mapped to a boolean literal
if (funcDecl.name === 'isEmpty') {
if (isFromStdlib(funcDecl) && funcDecl.name === 'isEmpty') {
valueArg = { $type: BooleanLiteral, value: true } as LiteralExpr;
}

// contains function has a 3rd argument that indicates whether the comparison should be case-insensitive
let extraArgs: Record<string, Expression> | undefined = undefined;
if (funcDecl.name === 'contains') {
if (isFromStdlib(funcDecl) && funcDecl.name === 'contains') {
if (getLiteral<boolean>(expr.args[2]?.value) === true) {
extraArgs = { mode: { $type: StringLiteral, value: 'insensitive' } as LiteralExpr };
}
Expand All @@ -770,4 +785,38 @@ export class ExpressionWriter {
throw new PluginError(name, `Unsupported function ${funcDecl.name}`);
}
}

private writeRelationCheck(expr: InvocationExpr) {
if (!isDataModelFieldReference(expr.args[0].value)) {
throw new PluginError(name, `First argument of check() must be a field`);
}
if (!isDataModel(expr.args[0].value.$resolvedType?.decl)) {
throw new PluginError(name, `First argument of check() must be a relation field`);
}

const fieldRef = expr.args[0].value;
const targetModel = fieldRef.$resolvedType?.decl as DataModel;

let operation: string;
if (expr.args[1]) {
const literal = getLiteral<string>(expr.args[1].value);
if (!literal) {
throw new TypeScriptExpressionTransformerError(`Second argument of check() must be a string literal`);
}
if (!['read', 'create', 'update', 'delete'].includes(literal)) {
throw new TypeScriptExpressionTransformerError(`Invalid check() operation "${literal}"`);
}
operation = literal;
} else {
if (!this.options.operationContext) {
throw new TypeScriptExpressionTransformerError('Unable to determine CRUD operation from context');
}
operation = this.options.operationContext;
}

this.block(() => {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
});
}
}
Loading
Loading