Skip to content

Commit

Permalink
feat(zmodel): add new functions currentModel and currentOperation
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 committed Dec 25, 2024
1 parent b41fd93 commit 594bb7b
Show file tree
Hide file tree
Showing 6 changed files with 441 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ export default class FunctionInvocationValidator implements AstValidator<Express
return;
}

if (
// TODO: express function validation rules declaratively in ZModel

const allCasing = ['original', 'upper', 'lower', 'capitalize', 'uncapitalize'];
if (['currentModel', 'currentOperation'].includes(funcDecl.name)) {
const arg = getLiteral<string>(expr.args[0]?.value);
if (arg && !allCasing.includes(arg)) {
accept('error', `argument must be one of: ${allCasing.map((c) => '"' + c + '"').join(', ')}`, {
node: expr.args[0],
});
}
} else if (
funcAllowedContext.includes(ExpressionContext.AccessPolicy) ||
funcAllowedContext.includes(ExpressionContext.ValidationRule)
) {
Expand Down
23 changes: 23 additions & 0 deletions packages/schema/src/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,29 @@ function hasSome(field: Any[], search: Any[]): Boolean {
function isEmpty(field: Any[]): Boolean {
} @@@expressionContext([AccessPolicy, ValidationRule])

/**
* The name of the model for which the policy rule is defined. If the rule is
* inherited to a sub model, this function returns the name of the sub model.
*
* @param optional parameter to control the casing of the returned value. Valid
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
* to "original".
*/
function currentModel(casing: String?): String {
} @@@expressionContext([AccessPolicy])

/**
* The operation for which the policy rule is defined for. Note that a rule with
* "all" operation is expanded to "create", "read", "update", and "delete" rules,
* and the function returns corresponding value for each expanded version.
*
* @param optional parameter to control the casing of the returned value. Valid
* values are "original", "upper", "lower", "capitalize", "uncapitalize". Defaults
* to "original".
*/
function currentOperation(casing: String?): String {
} @@@expressionContext([AccessPolicy])

/**
* Marks an attribute to be only applicable to certain field types.
*/
Expand Down
5 changes: 5 additions & 0 deletions packages/sdk/src/code-gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ export async function saveProject(project: Project) {
* Emit a TS project to JS files.
*/
export async function emitProject(project: Project) {
// ignore type checking for all source files
for (const sf of project.getSourceFiles()) {
sf.insertStatements(0, '// @ts-nocheck');
}

const errors = project.getPreEmitDiagnostics().filter((d) => d.getCategory() === DiagnosticCategory.Error);
if (errors.length > 0) {
console.error('Error compiling generated code:');
Expand Down
77 changes: 63 additions & 14 deletions packages/sdk/src/typescript-expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
isNullExpr,
isThisExpr,
} from '@zenstackhq/language/ast';
import { getContainerOfType } from 'langium';
import { P, match } from 'ts-pattern';
import { ExpressionContext } from './constants';
import { getEntityCheckerFunctionName } from './names';
Expand All @@ -40,6 +41,8 @@ type Options = {
operationContext?: 'read' | 'create' | 'update' | 'postUpdate' | 'delete';
};

type Casing = 'original' | 'upper' | 'lower' | 'capitalize' | 'uncapitalize';

// a registry of function handlers marked with @func
const functionHandlers = new Map<string, PropertyDescriptor>();

Expand Down Expand Up @@ -150,7 +153,7 @@ export class TypeScriptExpressionTransformer {
}

const args = expr.args.map((arg) => arg.value);
return handler.value.call(this, args, normalizeUndefined);
return handler.value.call(this, expr, args, normalizeUndefined);
}

// #region function invocation handlers
Expand All @@ -168,7 +171,7 @@ export class TypeScriptExpressionTransformer {
}

@func('length')
private _length(args: Expression[]) {
private _length(_invocation: InvocationExpr, args: Expression[]) {
const field = this.transform(args[0], false);
const min = getLiteral<number>(args[1]);
const max = getLiteral<number>(args[2]);
Expand All @@ -188,7 +191,7 @@ export class TypeScriptExpressionTransformer {
}

@func('contains')
private _contains(args: Expression[], normalizeUndefined: boolean) {
private _contains(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) {
const field = this.transform(args[0], false);
const caseInsensitive = getLiteral<boolean>(args[2]) === true;
let result: string;
Expand All @@ -201,34 +204,34 @@ export class TypeScriptExpressionTransformer {
}

@func('startsWith')
private _startsWith(args: Expression[], normalizeUndefined: boolean) {
private _startsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) {
const field = this.transform(args[0], false);
const result = `${field}?.startsWith(${this.transform(args[1], normalizeUndefined)})`;
return this.ensureBoolean(result);
}

@func('endsWith')
private _endsWith(args: Expression[], normalizeUndefined: boolean) {
private _endsWith(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) {
const field = this.transform(args[0], false);
const result = `${field}?.endsWith(${this.transform(args[1], normalizeUndefined)})`;
return this.ensureBoolean(result);
}

@func('regex')
private _regex(args: Expression[]) {
private _regex(_invocation: InvocationExpr, args: Expression[]) {
const field = this.transform(args[0], false);
const pattern = getLiteral<string>(args[1]);
return this.ensureBooleanTernary(args[0], field, `new RegExp(${JSON.stringify(pattern)}).test(${field})`);
}

@func('email')
private _email(args: Expression[]) {
private _email(_invocation: InvocationExpr, args: Expression[]) {
const field = this.transform(args[0], false);
return this.ensureBooleanTernary(args[0], field, `z.string().email().safeParse(${field}).success`);
}

@func('datetime')
private _datetime(args: Expression[]) {
private _datetime(_invocation: InvocationExpr, args: Expression[]) {
const field = this.transform(args[0], false);
return this.ensureBooleanTernary(
args[0],
Expand All @@ -238,20 +241,20 @@ export class TypeScriptExpressionTransformer {
}

@func('url')
private _url(args: Expression[]) {
private _url(_invocation: InvocationExpr, args: Expression[]) {
const field = this.transform(args[0], false);
return this.ensureBooleanTernary(args[0], field, `z.string().url().safeParse(${field}).success`);
}

@func('has')
private _has(args: Expression[], normalizeUndefined: boolean) {
private _has(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) {
const field = this.transform(args[0], false);
const result = `${field}?.includes(${this.transform(args[1], normalizeUndefined)})`;
return this.ensureBoolean(result);
}

@func('hasEvery')
private _hasEvery(args: Expression[], normalizeUndefined: boolean) {
private _hasEvery(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) {
const field = this.transform(args[0], false);
return this.ensureBooleanTernary(
args[0],
Expand All @@ -261,7 +264,7 @@ export class TypeScriptExpressionTransformer {
}

@func('hasSome')
private _hasSome(args: Expression[], normalizeUndefined: boolean) {
private _hasSome(_invocation: InvocationExpr, args: Expression[], normalizeUndefined: boolean) {
const field = this.transform(args[0], false);
return this.ensureBooleanTernary(
args[0],
Expand All @@ -271,13 +274,13 @@ export class TypeScriptExpressionTransformer {
}

@func('isEmpty')
private _isEmpty(args: Expression[]) {
private _isEmpty(_invocation: InvocationExpr, args: Expression[]) {
const field = this.transform(args[0], false);
return `(!${field} || ${field}?.length === 0)`;
}

@func('check')
private _check(args: Expression[]) {
private _check(_invocation: InvocationExpr, args: Expression[]) {
if (!isDataModelFieldReference(args[0])) {
throw new TypeScriptExpressionTransformerError(`First argument of check() must be a field`);
}
Expand Down Expand Up @@ -309,6 +312,52 @@ export class TypeScriptExpressionTransformer {
return `${entityCheckerFunc}(input.${fieldRef.target.$refText}, context)`;
}

private toStringWithCaseChange(value: string, casing: Casing) {
if (!value) {
return "''";
}
return match(casing)
.with('original', () => `'${value}'`)
.with('upper', () => `'${value.toUpperCase()}'`)
.with('lower', () => `'${value.toLowerCase()}'`)
.with('capitalize', () => `'${value.charAt(0).toUpperCase() + value.slice(1)}'`)
.with('uncapitalize', () => `'${value.charAt(0).toLowerCase() + value.slice(1)}'`)
.exhaustive();
}

@func('currentModel')
private _currentModel(invocation: InvocationExpr, args: Expression[]) {
let casing: Casing = 'original';
if (args[0]) {
casing = getLiteral<string>(args[0]) as Casing;
}

const containingModel = getContainerOfType(invocation, isDataModel);
if (!containingModel) {
throw new TypeScriptExpressionTransformerError('currentModel() must be called inside a model');
}
return this.toStringWithCaseChange(containingModel.name, casing);
}

@func('currentOperation')
private _currentOperation(_invocation: InvocationExpr, args: Expression[]) {
let casing: Casing = 'original';
if (args[0]) {
casing = getLiteral<string>(args[0]) as Casing;
}

if (!this.options.operationContext) {
throw new TypeScriptExpressionTransformerError(
'currentOperation() must be called inside an access policy rule'
);
}
let contextOperation = this.options.operationContext;
if (contextOperation === 'postUpdate') {
contextOperation = 'update';
}
return this.toStringWithCaseChange(contextOperation, casing);
}

private ensureBoolean(expr: string) {
if (this.options.context === ExpressionContext.ValidationRule) {
// all fields are optional in a validation context, so we treat undefined
Expand Down
Loading

0 comments on commit 594bb7b

Please sign in to comment.