From 594bb7b5e14a981e259dfd16817110beb484f906 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 25 Dec 2024 22:11:06 +0800 Subject: [PATCH] feat(zmodel): add new functions `currentModel` and `currentOperation` --- .../function-invocation-validator.ts | 12 +- packages/schema/src/res/stdlib.zmodel | 23 +++ packages/sdk/src/code-gen.ts | 5 + .../src/typescript-expression-transformer.ts | 77 ++++++-- .../with-policy/currentModel.test.ts | 185 ++++++++++++++++++ .../with-policy/currentOperation.test.ts | 154 +++++++++++++++ 6 files changed, 441 insertions(+), 15 deletions(-) create mode 100644 tests/integration/tests/enhancements/with-policy/currentModel.test.ts create mode 100644 tests/integration/tests/enhancements/with-policy/currentOperation.test.ts diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index 8c11a2a72..343c75cad 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -87,7 +87,17 @@ export default class FunctionInvocationValidator implements AstValidator(expr.args[0]?.value); + if (arg && !allCasing.includes(arg)) { + accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, { + node: expr.args[0], + }); + } + } else if ( funcAllowedContext.includes(ExpressionContext.AccessPolicy) || funcAllowedContext.includes(ExpressionContext.ValidationRule) ) { diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 3316a90a9..483993d92 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -171,6 +171,29 @@ function hasSome(field: Any[], search: Any[]): Boolean { function isEmpty(field: Any[]): Boolean { } @@@expressionContext([AccessPolicy, ValidationRule]) +/** + * The name of the model for which the policy rule is defined. If the rule is + * inherited to a sub model, this function returns the name of the sub model. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentModel(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + +/** + * The operation for which the policy rule is defined for. Note that a rule with + * "all" operation is expanded to "create", "read", "update", and "delete" rules, + * and the function returns corresponding value for each expanded version. + * + * @param optional parameter to control the casing of the returned value. Valid + * values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults + * to "original". + */ +function currentOperation(casing: String?): String { +} @@@expressionContext([AccessPolicy]) + /** * Marks an attribute to be only applicable to certain field types. */ diff --git a/packages/sdk/src/code-gen.ts b/packages/sdk/src/code-gen.ts index 7b26cc0c4..67833b788 100644 --- a/packages/sdk/src/code-gen.ts +++ b/packages/sdk/src/code-gen.ts @@ -47,6 +47,11 @@ export async function saveProject(project: Project) { * Emit a TS project to JS files. */ export async function emitProject(project: Project) { + // ignore type checking for all source files + for (const sf of project.getSourceFiles()) { + sf.insertStatements(0, '// @ts-nocheck'); + } + const errors = project.getPreEmitDiagnostics().filter((d) => d.getCategory() === DiagnosticCategory.Error); if (errors.length > 0) { console.error('Error compiling generated code:'); diff --git a/packages/sdk/src/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts index 9a884ebdf..801db4d4f 100644 --- a/packages/sdk/src/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -20,6 +20,7 @@ import { isNullExpr, isThisExpr, } from '@zenstackhq/language/ast'; +import { getContainerOfType } from 'langium'; import { P, match } from 'ts-pattern'; import { ExpressionContext } from './constants'; import { getEntityCheckerFunctionName } from './names'; @@ -40,6 +41,8 @@ type Options = { operationContext?: 'read' | 'create' | 'update' | 'postUpdate' | 'delete'; }; +type Casing = 'original' | 'upper' | 'lower' | 'capitalize' | 'uncapitalize'; + // a registry of function handlers marked with @func const functionHandlers = new Map(); @@ -150,7 +153,7 @@ export class TypeScriptExpressionTransformer { } const args = expr.args.map((arg) => arg.value); - return handler.value.call(this, args, normalizeUndefined); + return handler.value.call(this, expr, args, normalizeUndefined); } // #region function invocation handlers @@ -168,7 +171,7 @@ export class TypeScriptExpressionTransformer { } @func('length') - private _length(args: Expression[]) { + private _length(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); const min = getLiteral(args[1]); const max = getLiteral(args[2]); @@ -188,7 +191,7 @@ export class TypeScriptExpressionTransformer { } @func('contains') - private _contains(args: Expression[], normalizeUndefined: boolean) { + private _contains(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const caseInsensitive = getLiteral(args[2]) === true; let result: string; @@ -201,34 +204,34 @@ export class TypeScriptExpressionTransformer { } @func('startsWith') - private _startsWith(args: Expression[], normalizeUndefined: boolean) { + private _startsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.startsWith(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('endsWith') - private _endsWith(args: Expression[], normalizeUndefined: boolean) { + private _endsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.endsWith(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('regex') - private _regex(args: Expression[]) { + private _regex(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); const pattern = getLiteral(args[1]); return this.ensureBooleanTernary(args[0], field, `new RegExp(${JSON.stringify(pattern)}).test(${field})`); } @func('email') - private _email(args: Expression[]) { + private _email(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary(args[0], field, `z.string().email().safeParse(${field}).success`); } @func('datetime') - private _datetime(args: Expression[]) { + private _datetime(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -238,20 +241,20 @@ export class TypeScriptExpressionTransformer { } @func('url') - private _url(args: Expression[]) { + private _url(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return this.ensureBooleanTernary(args[0], field, `z.string().url().safeParse(${field}).success`); } @func('has') - private _has(args: Expression[], normalizeUndefined: boolean) { + private _has(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); const result = `${field}?.includes(${this.transform(args[1], normalizeUndefined)})`; return this.ensureBoolean(result); } @func('hasEvery') - private _hasEvery(args: Expression[], normalizeUndefined: boolean) { + private _hasEvery(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -261,7 +264,7 @@ export class TypeScriptExpressionTransformer { } @func('hasSome') - private _hasSome(args: Expression[], normalizeUndefined: boolean) { + private _hasSome(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( args[0], @@ -271,13 +274,13 @@ export class TypeScriptExpressionTransformer { } @func('isEmpty') - private _isEmpty(args: Expression[]) { + private _isEmpty(_invocation: InvocationExpr, args: Expression[]) { const field = this.transform(args[0], false); return `(!${field} || ${field}?.length === 0)`; } @func('check') - private _check(args: Expression[]) { + private _check(_invocation: InvocationExpr, args: Expression[]) { if (!isDataModelFieldReference(args[0])) { throw new TypeScriptExpressionTransformerError(`First argument of check() must be a field`); } @@ -309,6 +312,52 @@ export class TypeScriptExpressionTransformer { return `${entityCheckerFunc}(input.${fieldRef.target.$refText}, context)`; } + private toStringWithCaseChange(value: string, casing: Casing) { + if (!value) { + return "''"; + } + return match(casing) + .with('original', () => `'${value}'`) + .with('upper', () => `'${value.toUpperCase()}'`) + .with('lower', () => `'${value.toLowerCase()}'`) + .with('capitalize', () => `'${value.charAt(0).toUpperCase() + value.slice(1)}'`) + .with('uncapitalize', () => `'${value.charAt(0).toLowerCase() + value.slice(1)}'`) + .exhaustive(); + } + + @func('currentModel') + private _currentModel(invocation: InvocationExpr, args: Expression[]) { + let casing: Casing = 'original'; + if (args[0]) { + casing = getLiteral(args[0]) as Casing; + } + + const containingModel = getContainerOfType(invocation, isDataModel); + if (!containingModel) { + throw new TypeScriptExpressionTransformerError('currentModel() must be called inside a model'); + } + return this.toStringWithCaseChange(containingModel.name, casing); + } + + @func('currentOperation') + private _currentOperation(_invocation: InvocationExpr, args: Expression[]) { + let casing: Casing = 'original'; + if (args[0]) { + casing = getLiteral(args[0]) as Casing; + } + + if (!this.options.operationContext) { + throw new TypeScriptExpressionTransformerError( + 'currentOperation() must be called inside an access policy rule' + ); + } + let contextOperation = this.options.operationContext; + if (contextOperation === 'postUpdate') { + contextOperation = 'update'; + } + return this.toStringWithCaseChange(contextOperation, casing); + } + private ensureBoolean(expr: string) { if (this.options.context === ExpressionContext.ValidationRule) { // all fields are optional in a validation context, so we treat undefined diff --git a/tests/integration/tests/enhancements/with-policy/currentModel.test.ts b/tests/integration/tests/enhancements/with-policy/currentModel.test.ts new file mode 100644 index 000000000..0b98314a4 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/currentModel.test.ts @@ -0,0 +1,185 @@ +import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; + +describe('currentModel tests', () => { + it('works in models', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with upper case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('upper') == 'USER') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('upper') == 'Post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with lower case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('lower') == 'user') + } + + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('lower') == 'Post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with capitalization', async () => { + const { enhance } = await loadSchema( + ` + model user { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('capitalize') == 'User') + } + + model post { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('capitalize') == 'post') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with uncapitalization', async () => { + const { enhance } = await loadSchema( + ` + model USER { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('uncapitalize') == 'uSER') + } + + model POST { + id Int @id + @@allow('read', true) + @@allow('create', currentModel('uncapitalize') == 'POST') + } + ` + ); + + const db = enhance(); + await expect(db.USER.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.POST.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works when inherited from abstract base', async () => { + const { enhance } = await loadSchema( + ` + abstract model Base { + id Int @id + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model User extends Base { + } + + model Post extends Base { + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works when inherited from delegate base', async () => { + const { enhance } = await loadSchema( + ` + model Base { + id Int @id + type String + @@delegate(type) + + @@allow('read', true) + @@allow('create', currentModel() == 'User') + } + + model User extends Base { + } + + model Post extends Base { + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('complains when used outside policies', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @default(currentModel()) + } + ` + ) + ).resolves.toContain('function "currentModel" is not allowed in the current context: DefaultValue'); + }); + + it('complains when casing argument is invalid', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id + @@allow('create', currentModel('foo') == 'User') + } + ` + ) + ).resolves.toContain('argument must be one of: "original", "upper", "lower", "capitalize", "uncapitalize"'); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts b/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts new file mode 100644 index 000000000..a2c2f2792 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/currentOperation.test.ts @@ -0,0 +1,154 @@ +import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; + +describe('currentOperation tests', () => { + it('works with specific rules', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with all rule', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('all', currentOperation() == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation() == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with upper case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('upper') == 'CREATE') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('upper') == 'READ') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with lower case', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('lower') == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('lower') == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with capitalization', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'Create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'create') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('works with uncapitalization', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'create') + } + model Post { + id Int @id + @@allow('read', true) + @@allow('create', currentOperation('capitalize') == 'read') + } + ` + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: 1 } })).toResolveTruthy(); + await expect(db.post.create({ data: { id: 1 } })).toBeRejectedByPolicy(); + }); + + it('complains when used outside policies', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @default(currentOperation()) + } + ` + ) + ).resolves.toContain('function "currentOperation" is not allowed in the current context: DefaultValue'); + }); + + it('complains when casing argument is invalid', async () => { + await expect( + loadModelWithError( + ` + model User { + id String @id + @@allow('create', currentOperation('foo') == 'User') + } + ` + ) + ).resolves.toContain('argument must be one of: "original", "upper", "lower", "capitalize", "uncapitalize"'); + }); +});