Skip to content

Commit

Permalink
feat: implement relation check() function in ZModel
Browse files Browse the repository at this point in the history
Fixes #276
  • Loading branch information
ymc9 committed Jul 5, 2024
1 parent 6f92323 commit 4981197
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import pluralize from 'pluralize';
import { AstValidator } from '../types';
import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils';

// a registry of function handlers marked with @func
// a registry of function handlers marked with @check
const attributeCheckers = new Map<string, PropertyDescriptor>();

// function handler decorator
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import {
Argument,
DataModel,
DataModelAttribute,
DataModelFieldAttribute,
Expression,
FunctionDecl,
FunctionParam,
InvocationExpr,
isArrayExpr,
isDataModel,
isDataModelAttribute,
isDataModelFieldAttribute,
isLiteralExpr,
Expand All @@ -15,14 +17,29 @@ import {
ExpressionContext,
getDataModelFieldReference,
getFunctionExpressionContext,
getLiteral,
isDataModelFieldReference,
isEnumFieldReference,
isFromStdlib,
} from '@zenstackhq/sdk';
import { AstNode, ValidationAcceptor } from 'langium';
import { P, match } from 'ts-pattern';
import { AstNode, streamAst, ValidationAcceptor } from 'langium';
import { match, P } from 'ts-pattern';
import { isCheckInvocation } from '../../utils/ast-utils';
import { AstValidator } from '../types';
import { typeAssignable } from './utils';

// a registry of function handlers marked with @func
const invocationCheckers = new Map<string, PropertyDescriptor>();

// function handler decorator
function func(name: string) {
return function (_target: unknown, _propertyKey: string, descriptor: PropertyDescriptor) {
if (!invocationCheckers.get(name)) {
invocationCheckers.set(name, descriptor);
}
return descriptor;
};
}
/**
* InvocationExpr validation
*/
Expand Down Expand Up @@ -104,6 +121,12 @@ export default class FunctionInvocationValidator implements AstValidator<Express
}
}
}

// run checkers for specific functions
const checker = invocationCheckers.get(expr.function.$refText);
if (checker) {
checker.value.call(this, expr, accept);
}
}

private validateArgs(funcDecl: FunctionDecl, args: Argument[], accept: ValidationAcceptor) {
Expand Down Expand Up @@ -167,4 +190,76 @@ export default class FunctionInvocationValidator implements AstValidator<Express

return true;
}

@func('check')
private _checkCheck(expr: InvocationExpr, accept: ValidationAcceptor) {
let valid = true;

const fieldArg = expr.args[0].value;
if (!isDataModelFieldReference(fieldArg) || !isDataModel(fieldArg.$resolvedType?.decl)) {
accept('error', 'argument must be a relation field', { node: expr.args[0] });
valid = false;
}

if (fieldArg.$resolvedType?.array) {
accept('error', 'argument cannot be an array field', { node: expr.args[0] });
valid = false;
}

const opArg = expr.args[1]?.value;
if (opArg) {
const operation = getLiteral<string>(opArg);
if (!operation || !['read', 'create', 'update', 'delete'].includes(operation)) {
accept('error', 'argument must be a "read", "create", "update", or "delete"', { node: expr.args[1] });
valid = false;
}
}

if (!valid) {
return;
}

// check for cyclic relation checking
const start = fieldArg.$resolvedType?.decl as DataModel;
const tasks = [expr];
const seen = new Set<DataModel>();

while (tasks.length > 0) {
const currExpr = tasks.pop()!;
const arg = currExpr.args[0]?.value;

if (!isDataModel(arg?.$resolvedType?.decl)) {
continue;
}

const currModel = arg.$resolvedType.decl;

if (seen.has(currModel)) {
if (currModel === start) {
accept('error', 'cyclic dependency detected when following the `check()` call', { node: expr });
} else {
// a cycle is detected but it doesn't start from the invocation expression we're checking,
// just break here and the cycle will be reported when we validate the start of it
}
break;
} else {
seen.add(currModel);
}

const policyAttrs = currModel.attributes.filter(
(attr) => attr.decl.$refText === '@@allow' || attr.decl.$refText === '@@deny'
);
for (const attr of policyAttrs) {
const rule = attr.args[1];
if (!rule) {
continue;
}
streamAst(rule).forEach((node) => {
if (isCheckInvocation(node)) {
tasks.push(node as InvocationExpr);
}
});
}
}
}
}
2 changes: 1 addition & 1 deletion packages/sdk/src/typescript-expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type Options = {
thisExprContext?: string;
futureRefContext?: string;
context: ExpressionContext;
operationContext?: 'read' | 'create' | 'update' | 'delete';
operationContext?: 'read' | 'create' | 'update' | 'postUpdate' | 'delete';
};

// a registry of function handlers marked with @func
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { loadSchema } from '@zenstackhq/testtools';
import { loadModelWithError, loadSchema } from '@zenstackhq/testtools';

describe('Relation checker', () => {
it('should work for read', async () => {
Expand Down Expand Up @@ -562,5 +562,142 @@ describe('Relation checker', () => {
await expect(db.profile.create({ data: { user: { connect: { id: 4 } }, age: 18 } })).toBeRejectedByPolicy();
});

it('should report error for cyclic relation check', async () => {});
it('should report error for invalid args', async () => {
await expect(
loadModelWithError(
`
model User {
id Int @id @default(autoincrement())
public Boolean
@@allow('read', check(public))
}
`
)
).resolves.toContain('argument must be a relation field');

await expect(
loadModelWithError(
`
model User {
id Int @id @default(autoincrement())
posts Post[]
@@allow('read', check(posts))
}
model Post {
id Int @id @default(autoincrement())
user User @relation(fields: [userId], references: [id])
userId Int
}
`
)
).resolves.toContain('argument cannot be an array field');

await expect(
loadModelWithError(
`
model User {
id Int @id @default(autoincrement())
profile Profile?
@@allow('read', check(profile.details))
}
model Profile {
id Int @id @default(autoincrement())
user User @relation(fields: [userId], references: [id])
userId Int
details ProfileDetails?
}
model ProfileDetails {
id Int @id @default(autoincrement())
profile Profile @relation(fields: [profileId], references: [id])
profileId Int
age Int
}
`
)
).resolves.toContain('argument must be a relation field');

await expect(
loadModelWithError(
`
model User {
id Int @id @default(autoincrement())
posts Post[]
@@allow('read', check(posts, 'all'))
}
model Post {
id Int @id @default(autoincrement())
user User @relation(fields: [userId], references: [id])
userId Int
}
`
)
).resolves.toContain('argument must be a "read", "create", "update", or "delete"');
});

it('should report error for cyclic relation check', async () => {
await expect(
loadModelWithError(
`
model User {
id Int @id @default(autoincrement())
profile Profile?
profileDetails ProfileDetails?
public Boolean
@@allow('all', check(profile))
}
model Profile {
id Int @id @default(autoincrement())
user User @relation(fields: [userId], references: [id])
userId Int @unique
details ProfileDetails?
@@allow('all', check(details))
}
model ProfileDetails {
id Int @id @default(autoincrement())
profile Profile @relation(fields: [profileId], references: [id])
profileId Int @unique
user User @relation(fields: [userId], references: [id])
userId Int @unique
age Int
@@allow('all', check(user))
}
`
)
).resolves.toContain('cyclic dependency detected when following the `check()` call');
});

it('should report error for cyclic relation check indirect', async () => {
await expect(
loadModelWithError(
`
model User {
id Int @id @default(autoincrement())
profile Profile?
public Boolean
@@allow('all', check(profile))
}
model Profile {
id Int @id @default(autoincrement())
user User @relation(fields: [userId], references: [id])
userId Int @unique
details ProfileDetails?
@@allow('all', check(details))
}
model ProfileDetails {
id Int @id @default(autoincrement())
profile Profile @relation(fields: [profileId], references: [id])
profileId Int @unique
age Int
@@allow('all', check(profile))
}
`
)
).resolves.toContain('cyclic dependency detected when following the `check()` call');
});
});

0 comments on commit 4981197

Please sign in to comment.