Skip to content

Commit

Permalink
feat: field-level checks
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 committed Jul 2, 2024
1 parent 6a67626 commit 6f92323
Show file tree
Hide file tree
Showing 10 changed files with 477 additions and 55 deletions.
1 change: 1 addition & 0 deletions packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
"change-case": "^4.1.2",
"colors": "1.4.0",
"commander": "^8.3.0",
"deepmerge": "^4.3.1",
"get-latest-version": "^5.0.1",
"langium": "1.3.1",
"lower-case-first": "^2.0.2",
Expand Down
55 changes: 36 additions & 19 deletions packages/schema/src/plugins/enhancer/policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
StringLiteral,
UnaryExpr,
} from '@zenstackhq/language/ast';
import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime';
import { DELEGATE_AUX_RELATION_PREFIX, PolicyOperationKind } from '@zenstackhq/runtime';
import {
ExpressionContext,
getFunctionExpressionContext,
Expand All @@ -39,6 +39,7 @@ import { lowerCaseFirst } from 'lower-case-first';
import invariant from 'tiny-invariant';
import { CodeBlockWriter } from 'ts-morph';
import { name } from '..';
import { isCheckInvocation } from '../../../utils/ast-utils';

type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<=';

Expand All @@ -62,6 +63,11 @@ type FilterOperators =
export const TRUE = '{ AND: [] }';
export const FALSE = '{ OR: [] }';

export type ExpressionWriterOptions = {
isPostGuard?: boolean;
operationContext: PolicyOperationKind;
};

/**
* Utility for writing ZModel expression as Prisma query argument objects into a ts-morph writer
*/
Expand All @@ -70,15 +76,14 @@ export class ExpressionWriter {

/**
* Constructs a new ExpressionWriter
*
* @param isPostGuard indicates if we're writing for post-update conditions
*/
constructor(private readonly writer: CodeBlockWriter, private readonly isPostGuard = false) {
constructor(private readonly writer: CodeBlockWriter, private readonly options: ExpressionWriterOptions) {
this.plainExprBuilder = new TypeScriptExpressionTransformer({
context: ExpressionContext.AccessPolicy,
isPostGuard: this.isPostGuard,
isPostGuard: this.options.isPostGuard,
// in post-guard context, `this` references pre-update value
thisExprContext: this.isPostGuard ? 'context.preValue' : undefined,
thisExprContext: this.options.isPostGuard ? 'context.preValue' : undefined,
operationContext: this.options.operationContext,
});
}

Expand Down Expand Up @@ -271,17 +276,20 @@ export class ExpressionWriter {
// expression rooted to `auth()` is always compiled to plain expression
!this.isAuthOrAuthMemberAccess(expr.left) &&
// `future()` in post-update context
((this.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
((this.options.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
// non-`future()` in pre-update context
(!this.isPostGuard && !this.isFutureMemberAccess(expr.left)));
(!this.options.isPostGuard && !this.isFutureMemberAccess(expr.left)));

if (compileToRelationQuery) {
this.block(() => {
this.writeFieldCondition(
expr.left,
() => {
// inner scope of collection expression is always compiled as non-post-guard
const innerWriter = new ExpressionWriter(this.writer, false);
const innerWriter = new ExpressionWriter(this.writer, {
isPostGuard: false,
operationContext: this.options.operationContext,
});
innerWriter.write(expr.right);
},
operator === '?' ? 'some' : operator === '!' ? 'every' : 'none'
Expand All @@ -299,14 +307,14 @@ export class ExpressionWriter {
}

if (isMemberAccessExpr(expr)) {
if (isFutureExpr(expr.operand) && this.isPostGuard) {
if (isFutureExpr(expr.operand) && this.options.isPostGuard) {
// when writing for post-update, future().field.x is a field access
return true;
} else {
return this.isFieldAccess(expr.operand);
}
}
if (isDataModelFieldReference(expr) && !this.isPostGuard) {
if (isDataModelFieldReference(expr) && !this.options.isPostGuard) {
return true;
}
return false;
Expand Down Expand Up @@ -439,7 +447,7 @@ export class ExpressionWriter {
this.writer.write(operator === '!=' ? TRUE : FALSE);
} else {
this.writeOperator(operator, fieldAccess, () => {
if (isDataModelFieldReference(operand) && !this.isPostGuard) {
if (isDataModelFieldReference(operand) && !this.options.isPostGuard) {
// if operand is a field reference and we're not generating for post-update guard,
// we should generate a field reference (comparing fields in the same model)
this.writeFieldReference(operand);
Expand Down Expand Up @@ -737,7 +745,7 @@ export class ExpressionWriter {
functionAllowedContext.includes(ExpressionContext.AccessPolicy) ||
functionAllowedContext.includes(ExpressionContext.ValidationRule)
) {
if (isFromStdlib(funcDecl) && funcDecl.name === 'check') {
if (isCheckInvocation(expr)) {
this.writeRelationCheck(expr);
return;
}
Expand Down Expand Up @@ -789,12 +797,21 @@ export class ExpressionWriter {
const fieldRef = expr.args[0].value;
const targetModel = fieldRef.$resolvedType?.decl as DataModel;

const operation = getLiteral<string>(expr.args[1].value);
if (!operation) {
throw new PluginError(name, `Second argument of check() must be a string literal`);
}
if (!['read', 'create', 'update', 'delete'].includes(operation)) {
throw new PluginError(name, `Invalid check() operation "${operation}"`);
let operation: string;
if (expr.args[1]) {
const literal = getLiteral<string>(expr.args[1].value);
if (!literal) {
throw new TypeScriptExpressionTransformerError(`Second argument of check() must be a string literal`);
}
if (!['read', 'create', 'update', 'delete'].includes(literal)) {
throw new TypeScriptExpressionTransformerError(`Invalid check() operation "${literal}"`);
}
operation = literal;
} else {
if (!this.options.operationContext) {
throw new TypeScriptExpressionTransformerError('Unable to determine CRUD operation from context');
}
operation = this.options.operationContext;
}

this.block(() => {
Expand Down
121 changes: 98 additions & 23 deletions packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ import {
DataModel,
DataModelField,
Expression,
InvocationExpr,
Model,
ReferenceExpr,
isDataModel,
isDataModelField,
isEnum,
Expand All @@ -28,9 +30,18 @@ import { getPrismaClientImportSpec } from '@zenstackhq/sdk/prisma';
import { streamAst } from 'langium';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph';
import {
CodeBlockWriter,
FunctionDeclaration,
Project,
SourceFile,
VariableDeclarationKind,
WriterFunction,
} from 'ts-morph';
import { isCheckInvocation } from '../../../utils/ast-utils';
import { ConstraintTransformer } from './constraint-transformer';
import {
generateConstantQueryGuardFunction,
generateEntityCheckerFunction,
generateNormalizedAuthRef,
generateQueryGuardFunction,
Expand Down Expand Up @@ -234,6 +245,7 @@ export class PolicyGenerator {
const transformer = new TypeScriptExpressionTransformer({
context: ExpressionContext.AccessPolicy,
fieldReferenceContext: 'input',
operationContext: 'create',
});

let expr =
Expand Down Expand Up @@ -310,7 +322,7 @@ export class PolicyGenerator {
private writePostUpdatePreValueSelector(model: DataModel, writer: CodeBlockWriter) {
const allows = getPolicyExpressions(model, 'allow', 'postUpdate');
const denies = getPolicyExpressions(model, 'deny', 'postUpdate');
const preValueSelect = generateSelectForRules([...allows, ...denies]);
const preValueSelect = generateSelectForRules([...allows, ...denies], 'postUpdate');
if (preValueSelect) {
writer.writeLine(`preUpdateSelector: ${JSON.stringify(preValueSelect)},`);
}
Expand Down Expand Up @@ -350,17 +362,19 @@ export class PolicyGenerator {

// write cross-model comparison rules as entity checker functions
// because they cannot be checked inside Prisma
this.writeEntityChecker(model, kind, writer, sourceFile, true);
const { functionName, selector } = this.writeEntityChecker(model, kind, sourceFile, false);

if (this.shouldUseEntityChecker(model, kind, true, false)) {
writer.write(`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`);
}
}

private writeEntityChecker(
private shouldUseEntityChecker(
target: DataModel | DataModelField,
kind: PolicyOperationKind,
writer: CodeBlockWriter,
sourceFile: SourceFile,
onlyCrossModelComparison = false,
forOverride = false
) {
onlyCrossModelComparison: boolean,
forOverride: boolean
): boolean {
const allows = getPolicyExpressions(
target,
'allow',
Expand All @@ -376,10 +390,37 @@ export class PolicyGenerator {
onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all'
);

if (allows.length === 0 && denies.length === 0) {
return;
if (allows.length > 0 || denies.length > 0) {
return true;
}

const allRules = [
...getPolicyExpressions(target, 'allow', kind, forOverride, 'all'),
...getPolicyExpressions(target, 'deny', kind, forOverride, 'all'),
];

return allRules.some((rule) => {
return streamAst(rule).some((node) => {
if (isCheckInvocation(node)) {
const expr = node as InvocationExpr;
const fieldRef = expr.args[0].value as ReferenceExpr;
const targetModel = fieldRef.$resolvedType?.decl as DataModel;
return this.shouldUseEntityChecker(targetModel, kind, onlyCrossModelComparison, forOverride);
}
return false;
});
});
}

private writeEntityChecker(
target: DataModel | DataModelField,
kind: PolicyOperationKind,
sourceFile: SourceFile,
forOverride: boolean
) {
const allows = getPolicyExpressions(target, 'allow', kind, forOverride, 'all');
const denies = getPolicyExpressions(target, 'deny', kind, forOverride, 'all');

const model = isDataModel(target) ? target : (target.$container as DataModel);
const func = generateEntityCheckerFunction(
sourceFile,
Expand All @@ -390,9 +431,9 @@ export class PolicyGenerator {
isDataModelField(target) ? target : undefined,
forOverride
);
const selector = generateSelectForRules([...allows, ...denies], false, kind !== 'postUpdate') ?? {};
const key = forOverride ? 'overrideEntityChecker' : 'entityChecker';
writer.write(`${key}: { func: ${func.getName()!}, selector: ${JSON.stringify(selector)} },`);
const selector = generateSelectForRules([...allows, ...denies], kind, false, kind !== 'postUpdate') ?? {};

return { functionName: func.getName()!, selector };
}

// writes `guard: ...` for a given policy operation kind
Expand All @@ -408,23 +449,32 @@ export class PolicyGenerator {
if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func: FunctionDeclaration;
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
writer.write(`guard: false,`);
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
} else {
writer.write(`guard: true,`);
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
}
writer.write(`guard: ${func.getName()!},`);
return;
}

if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
writer.write(`guard: true,`);
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
}

if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
// constant policy
writer.write(`guard: ${policies[kind as keyof typeof policies]},`);
const func = generateConstantQueryGuardFunction(
sourceFile,
model,
kind,
policies[kind as keyof typeof policies] as boolean
);
writer.write(`guard: ${func.getName()!},`);
return;
}

Expand Down Expand Up @@ -534,7 +584,13 @@ export class PolicyGenerator {

// checker function
// write all field-level rules as entity checker function
this.writeEntityChecker(field, 'read', writer, sourceFile, false, false);
const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, false);

if (this.shouldUseEntityChecker(field, 'read', false, false)) {
writer.write(
`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`
);
}

if (overrideAllows.length > 0) {
// override guard function
Expand All @@ -551,7 +607,14 @@ export class PolicyGenerator {
writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`);

// additional entity checker for override
this.writeEntityChecker(field, 'read', writer, sourceFile, false, true);
const { functionName, selector } = this.writeEntityChecker(field, 'read', sourceFile, true);
if (this.shouldUseEntityChecker(field, 'read', false, true)) {
writer.write(
`overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify(
selector
)} },`
);
}
}
});
writer.writeLine(',');
Expand Down Expand Up @@ -581,7 +644,12 @@ export class PolicyGenerator {

// write cross-model comparison rules as entity checker functions
// because they cannot be checked inside Prisma
this.writeEntityChecker(field, 'update', writer, sourceFile, true, false);
const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, false);
if (this.shouldUseEntityChecker(field, 'update', true, false)) {
writer.write(
`entityChecker: { func: ${functionName}, selector: ${JSON.stringify(selector)} },`
);
}

if (overrideAllows.length > 0) {
// override guard
Expand All @@ -598,7 +666,14 @@ export class PolicyGenerator {

// write cross-model comparison override rules as entity checker functions
// because they cannot be checked inside Prisma
this.writeEntityChecker(field, 'update', writer, sourceFile, true, true);
const { functionName, selector } = this.writeEntityChecker(field, 'update', sourceFile, true);
if (this.shouldUseEntityChecker(field, 'update', true, true)) {
writer.write(
`overrideEntityChecker: { func: ${functionName}, selector: ${JSON.stringify(
selector
)} },`
);
}
}
});
writer.writeLine(',');
Expand Down Expand Up @@ -649,7 +724,7 @@ export class PolicyGenerator {
});

if (authRules.length > 0) {
return generateSelectForRules(authRules, true);
return generateSelectForRules(authRules, undefined, true);
} else {
return undefined;
}
Expand Down
Loading

0 comments on commit 6f92323

Please sign in to comment.