Skip to content

Commit

Permalink
feat: support @@validate in type declarations (#1868)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Nov 17, 2024
1 parent c7f333d commit 6df80b2
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
} from '@zenstackhq/language/ast';
import {
ExpressionContext,
getDataModelFieldReference,
getFieldReference,
getFunctionExpressionContext,
getLiteral,
isDataModelFieldReference,
Expand Down Expand Up @@ -96,7 +96,7 @@ export default class FunctionInvocationValidator implements AstValidator<Express
// first argument must refer to a model field
const firstArg = expr.args?.[0]?.value;
if (firstArg) {
if (!getDataModelFieldReference(firstArg)) {
if (!getFieldReference(firstArg)) {
accept('error', 'first argument must be a field reference', { node: firstArg });
}
}
Expand Down
5 changes: 3 additions & 2 deletions packages/schema/src/language-server/zmodel-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,14 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
const node = context.container as MemberAccessExpr;

// typedef's fields are only added to the scope if the access starts with `auth().`
const allowTypeDefScope = isAuthOrAuthMemberAccess(node.operand);
// or the member access resides inside a typedef
const allowTypeDefScope = isAuthOrAuthMemberAccess(node.operand) || !!getContainerOfType(node, isTypeDef);

return match(node.operand)
.when(isReferenceExpr, (operand) => {
// operand is a reference, it can only be a model/type-def field
const ref = operand.target.ref;
if (isDataModelField(ref)) {
if (isDataModelField(ref) || isTypeDefField(ref)) {
return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope);
}
return EMPTY_SCOPE;
Expand Down
2 changes: 1 addition & 1 deletion packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ import {
ReferenceExpr,
StringLiteral,
} from '@zenstackhq/language/ast';
import { getIdFields } from '@zenstackhq/sdk';
import { getPrismaVersion } from '@zenstackhq/sdk/prisma';
import { match } from 'ts-pattern';
import { getIdFields } from '../../utils/ast-utils';

import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime';
import {
Expand Down
117 changes: 96 additions & 21 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import {
ExpressionContext,
PluginError,
PluginGlobalOptions,
PluginOptions,
RUNTIME_PACKAGE,
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
ensureEmptyDir,
getAttributeArg,
getAttributeArgLiteral,
getDataModels,
getLiteralArray,
hasAttribute,
isDataModelFieldReference,
isDiscriminatorField,
isEnumFieldReference,
isForeignKeyField,
Expand All @@ -15,7 +22,7 @@ import {
resolvePath,
saveSourceFile,
} from '@zenstackhq/sdk';
import { DataModel, EnumField, Model, TypeDef, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
import { DataModel, EnumField, Model, TypeDef, isArrayExpr, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import { streamAllContents } from 'langium';
Expand All @@ -26,7 +33,7 @@ import { name } from '.';
import { getDefaultOutputFolder } from '../plugin-utils';
import Transformer from './transformer';
import { ObjectMode } from './types';
import { makeFieldSchema, makeValidationRefinements } from './utils/schema-gen';
import { makeFieldSchema } from './utils/schema-gen';

export class ZodSchemaGenerator {
private readonly sourceFiles: SourceFile[] = [];
Expand Down Expand Up @@ -294,7 +301,7 @@ export class ZodSchemaGenerator {
sf.replaceWithText((writer) => {
this.addPreludeAndImports(typeDef, writer, output);

writer.write(`export const ${typeDef.name}Schema = z.object(`);
writer.write(`const baseSchema = z.object(`);
writer.inlineBlock(() => {
typeDef.fields.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
Expand All @@ -313,9 +320,24 @@ export class ZodSchemaGenerator {
writer.writeLine(').strict();');
break;
}
});

// TODO: "@@validate" refinements
// compile "@@validate" to a function calling zod's `.refine()`
const refineFuncName = this.createRefineFunction(typeDef, writer);

if (refineFuncName) {
// export a schema without refinement for extensibility: `[Model]WithoutRefineSchema`
const noRefineSchema = `${upperCaseFirst(typeDef.name)}WithoutRefineSchema`;
writer.writeLine(`
/**
* \`${typeDef.name}\` schema prior to calling \`.refine()\` for extensibility.
*/
export const ${noRefineSchema} = baseSchema;
export const ${typeDef.name}Schema = ${refineFuncName}(${noRefineSchema});
`);
} else {
writer.writeLine(`export const ${typeDef.name}Schema = baseSchema;`);
}
});

return schemaName;
}
Expand Down Expand Up @@ -436,22 +458,7 @@ export class ZodSchemaGenerator {
}

// compile "@@validate" to ".refine"
const refinements = makeValidationRefinements(model);
let refineFuncName: string | undefined;
if (refinements.length > 0) {
refineFuncName = `refine${upperCaseFirst(model.name)}`;
writer.writeLine(
`
/**
* Schema refinement function for applying \`@@validate\` rules.
*/
export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
'\n'
)};
}
`
);
}
const refineFuncName = this.createRefineFunction(model, writer);

// delegate discriminator fields are to be excluded from mutation schemas
const delegateDiscriminatorFields = model.fields.filter((field) => isDiscriminatorField(field));
Expand Down Expand Up @@ -658,6 +665,74 @@ export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};
return schemaName;
}

private createRefineFunction(decl: DataModel | TypeDef, writer: CodeBlockWriter) {
const refinements = this.makeValidationRefinements(decl);
let refineFuncName: string | undefined;
if (refinements.length > 0) {
refineFuncName = `refine${upperCaseFirst(decl.name)}`;
writer.writeLine(
`
/**
* Schema refinement function for applying \`@@validate\` rules.
*/
export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
'\n'
)};
}
`
);
return refineFuncName;
} else {
return undefined;
}
}

private makeValidationRefinements(decl: DataModel | TypeDef) {
const attrs = decl.attributes.filter((attr) => attr.decl.ref?.name === '@@validate');
const refinements = attrs
.map((attr) => {
const valueArg = getAttributeArg(attr, 'value');
if (!valueArg) {
return undefined;
}

const messageArg = getAttributeArgLiteral<string>(attr, 'message');
const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : '';

const pathArg = getAttributeArg(attr, 'path');
const path =
pathArg && isArrayExpr(pathArg)
? `path: ['${getLiteralArray<string>(pathArg)?.join(`', '`)}'],`
: '';

const options = `, { ${message} ${path} }`;

try {
let expr = new TypeScriptExpressionTransformer({
context: ExpressionContext.ValidationRule,
fieldReferenceContext: 'value',
}).transform(valueArg);

if (isDataModelFieldReference(valueArg)) {
// if the expression is a simple field reference, treat undefined
// as true since the all fields are optional in validation context
expr = `${expr} ?? true`;
}

return `.refine((value: any) => ${expr}${options})`;
} catch (err) {
if (err instanceof TypeScriptExpressionTransformerError) {
throw new PluginError(name, err.message);
} else {
throw err;
}
}
})
.filter((r) => !!r);

return refinements;
}

private makePartial(schema: string, fields?: string[]) {
if (fields) {
if (fields.length === 0) {
Expand Down
60 changes: 1 addition & 59 deletions packages/schema/src/plugins/zod/utils/schema-gen.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
import { getLiteral, isFromStdlib } from '@zenstackhq/sdk';
import {
ExpressionContext,
getAttributeArg,
getAttributeArgLiteral,
getLiteral,
getLiteralArray,
isDataModelFieldReference,
isFromStdlib,
PluginError,
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
} from '@zenstackhq/sdk';
import {
DataModel,
DataModelField,
DataModelFieldAttribute,
isArrayExpr,
isBooleanLiteral,
isDataModel,
isEnum,
Expand All @@ -25,7 +12,6 @@ import {
TypeDefField,
} from '@zenstackhq/sdk/ast';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '..';
import { isDefaultWithAuth } from '../../enhancer/enhancer-utils';

export function makeFieldSchema(field: DataModelField | TypeDefField) {
Expand Down Expand Up @@ -222,50 +208,6 @@ function makeZodSchema(field: DataModelField | TypeDefField) {
return schema;
}

export function makeValidationRefinements(model: DataModel) {
const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === '@@validate');
const refinements = attrs
.map((attr) => {
const valueArg = getAttributeArg(attr, 'value');
if (!valueArg) {
return undefined;
}

const messageArg = getAttributeArgLiteral<string>(attr, 'message');
const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : '';

const pathArg = getAttributeArg(attr, 'path');
const path =
pathArg && isArrayExpr(pathArg) ? `path: ['${getLiteralArray<string>(pathArg)?.join(`', '`)}'],` : '';

const options = `, { ${message} ${path} }`;

try {
let expr = new TypeScriptExpressionTransformer({
context: ExpressionContext.ValidationRule,
fieldReferenceContext: 'value',
}).transform(valueArg);

if (isDataModelFieldReference(valueArg)) {
// if the expression is a simple field reference, treat undefined
// as true since the all fields are optional in validation context
expr = `${expr} ?? true`;
}

return `.refine((value: any) => ${expr}${options})`;
} catch (err) {
if (err instanceof TypeScriptExpressionTransformerError) {
throw new PluginError(name, err.message);
} else {
throw err;
}
}
})
.filter((r) => !!r);

return refinements;
}

function getAttrLiteralArg<T extends string | number>(attr: DataModelFieldAttribute, paramName: string) {
const arg = attr.args.find((arg) => arg.$resolvedParam?.name === paramName);
return arg && getLiteral<T>(arg.value);
Expand Down
46 changes: 1 addition & 45 deletions packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,19 @@ import {
BinaryExpr,
DataModel,
DataModelAttribute,
DataModelField,
Expression,
InheritableNode,
isArrayExpr,
isBinaryExpr,
isDataModel,
isDataModelField,
isInvocationExpr,
isMemberAccessExpr,
isModel,
isReferenceExpr,
isTypeDef,
Model,
ModelImport,
ReferenceExpr,
TypeDef,
} from '@zenstackhq/language/ast';
import {
getInheritanceChain,
getModelFieldsWithBases,
getRecursiveBases,
isDelegateModel,
isFromStdlib,
} from '@zenstackhq/sdk';
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import {
AstNode,
copyAstNode,
Expand Down Expand Up @@ -151,29 +140,6 @@ function cloneAst<T extends InheritableNode>(
return clone;
}

export function getIdFields(dataModel: DataModel) {
const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) =>
f.attributes.some((attr) => attr.decl.$refText === '@id')
);
if (fieldLevelId) {
return [fieldLevelId];
} else {
// get model level @@id attribute
const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id');
if (modelIdAttr) {
// get fields referenced in the attribute: @@id([field1, field2]])
if (!isArrayExpr(modelIdAttr.args[0]?.value)) {
return [];
}
const argValue = modelIdAttr.args[0].value;
return argValue.items
.filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr))
.map((expr) => expr.target.ref as DataModelField);
}
}
return [];
}

export function isAuthInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref);
}
Expand All @@ -186,16 +152,6 @@ export function isCheckInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref);
}

export function getDataModelFieldReference(expr: Expression): DataModelField | undefined {
if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) {
return expr.target.ref;
} else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) {
return expr.member.ref;
} else {
return undefined;
}
}

export function resolveImportUri(imp: ModelImport): URI | undefined {
if (!imp.path) return undefined; // This will return true if imp.path is undefined, null, or an empty string ("").

Expand Down
Loading

0 comments on commit 6df80b2

Please sign in to comment.