From 909281f8090734322c0cab09d0187b6b5e813c9a Mon Sep 17 00:00:00 2001 From: Yiming Date: Fri, 8 Dec 2023 14:39:57 +0800 Subject: [PATCH] fix: clean up zod generation (#883) --- package.json | 2 +- packages/language/package.json | 2 +- packages/plugins/openapi/package.json | 2 +- packages/plugins/swr/package.json | 2 +- packages/plugins/tanstack-query/package.json | 2 +- packages/plugins/trpc/package.json | 2 +- packages/runtime/package.json | 2 +- .../src/enhancements/policy/policy-utils.ts | 2 +- packages/schema/package.json | 2 +- packages/schema/src/plugins/zod/generator.ts | 148 +++++++----------- .../src/plugins/zod/utils/schema-gen.ts | 17 +- packages/sdk/package.json | 2 +- packages/server/package.json | 2 +- packages/testtools/package.json | 2 +- .../with-policy/field-validation.test.ts | 19 +++ tests/integration/tests/plugins/zod.test.ts | 17 +- tests/integration/tests/schema/todo.zmodel | 6 + 17 files changed, 114 insertions(+), 117 deletions(-) diff --git a/package.json b/package.json index 41c0c67c1..74b232e35 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "1.4.0", + "version": "1.4.1", "description": "", "scripts": { "build": "pnpm -r build", diff --git a/packages/language/package.json b/packages/language/package.json index ff9c6a466..a07c23f75 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "1.4.0", + "version": "1.4.1", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index 148b50bcd..4c9ceaefa 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "1.4.0", + "version": "1.4.1", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index bdab48881..7eb32245c 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "1.4.0", + "version": "1.4.1", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index 80b4e7e38..166d08937 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "1.4.0", + "version": "1.4.1", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 54f6e96cb..e6840a064 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "1.4.0", + "version": "1.4.1", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index fd37e28f4..78d018a83 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "1.4.0", + "version": "1.4.1", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 63acf7a57..ea00e1f31 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -1051,7 +1051,7 @@ export class PolicyUtil { if (!this.hasFieldValidation(model)) { return undefined; } - const schemaKey = `${upperCaseFirst(model)}${kind ? upperCaseFirst(kind) : ''}Schema`; + const schemaKey = `${upperCaseFirst(model)}${kind ? 'Prisma' + upperCaseFirst(kind) : ''}Schema`; return this.zodSchemas?.models?.[schemaKey]; } diff --git a/packages/schema/package.json b/packages/schema/package.json index 7892fd9fe..3527cf1a2 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "Build scalable web apps with minimum code by defining authorization and validation rules inside the data schema that closer to the database", - "version": "1.4.0", + "version": "1.4.1", "author": { "name": "ZenStack Team" }, diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 04c903577..ad857fbda 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -4,8 +4,6 @@ import { PluginOptions, createProject, emitProject, - getAttribute, - getAttributeArg, getDataModels, getLiteral, getPrismaClientImportSpec, @@ -17,16 +15,7 @@ import { resolvePath, saveProject, } from '@zenstackhq/sdk'; -import { - DataModel, - DataModelField, - DataSource, - EnumField, - Model, - isDataModel, - isDataSource, - isEnum, -} from '@zenstackhq/sdk/ast'; +import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast'; import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers'; import { promises as fs } from 'fs'; import { streamAllContents } from 'langium'; @@ -271,7 +260,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s overwrite: true, }); sf.replaceWithText((writer) => { - const fields = model.fields.filter( + const scalarFields = model.fields.filter( (field) => // regular fields only !isDataModel(field.type.reference?.ref) && !isForeignKeyField(field) @@ -279,10 +268,6 @@ async function generateModelSchema(model: DataModel, project: Project, output: s const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref)); const fkFields = model.fields.filter((field) => isForeignKeyField(field)); - // unsafe version of relations: including foreign keys and relation fields without fk - const unsafeRelations = model.fields.filter( - (field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field)) - ); writer.writeLine('/* eslint-disable */'); writer.writeLine(`import { z } from 'zod';`); @@ -304,7 +289,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s // import enum schemas const importedEnumSchemas = new Set(); - for (const field of fields) { + for (const field of scalarFields) { if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) { const name = upperCaseFirst(field.type.reference?.ref.name); if (!importedEnumSchemas.has(name)) { @@ -315,7 +300,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s } // import Decimal - if (fields.some((field) => field.type.type === 'Decimal')) { + if (scalarFields.some((field) => field.type.type === 'Decimal')) { writer.writeLine(`import { DecimalSchema } from '../common';`); writer.writeLine(`import { Decimal } from 'decimal.js';`); } @@ -323,7 +308,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s // base schema writer.write(`const baseSchema = z.object(`); writer.inlineBlock(() => { - fields.forEach((field) => { + scalarFields.forEach((field) => { writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); }); }); @@ -331,13 +316,12 @@ async function generateModelSchema(model: DataModel, project: Project, output: s // relation fields - let allRelationSchema: string | undefined; - let safeRelationSchema: string | undefined; - let unsafeRelationSchema: string | undefined; + let relationSchema: string | undefined; + let fkSchema: string | undefined; if (relations.length > 0 || fkFields.length > 0) { - allRelationSchema = 'allRelationSchema'; - writer.write(`const ${allRelationSchema} = z.object(`); + relationSchema = 'relationSchema'; + writer.write(`const ${relationSchema} = z.object(`); writer.inlineBlock(() => { [...relations, ...fkFields].forEach((field) => { writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); @@ -346,23 +330,12 @@ async function generateModelSchema(model: DataModel, project: Project, output: s writer.writeLine(');'); } - if (relations.length > 0) { - safeRelationSchema = 'safeRelationSchema'; - writer.write(`const ${safeRelationSchema} = z.object(`); + if (fkFields.length > 0) { + fkSchema = 'fkSchema'; + writer.write(`const ${fkSchema} = z.object(`); writer.inlineBlock(() => { - relations.forEach((field) => { - writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); - }); - }); - writer.writeLine(');'); - } - - if (unsafeRelations.length > 0) { - unsafeRelationSchema = 'unsafeRelationSchema'; - writer.write(`const ${unsafeRelationSchema} = z.object(`); - writer.inlineBlock(() => { - unsafeRelations.forEach((field) => { - writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); + fkFields.forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); }); }); writer.writeLine(');'); @@ -383,10 +356,10 @@ async function generateModelSchema(model: DataModel, project: Project, output: s //////////////////////////////////////////////// // 1. Model schema //////////////////////////////////////////////// - let modelSchema = 'baseSchema'; + let modelSchema = makePartial('baseSchema'); // omit fields - const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit')); + const fieldsToOmit = scalarFields.filter((field) => hasAttribute(field, '@omit')); if (fieldsToOmit.length > 0) { modelSchema = makeOmit( modelSchema, @@ -394,14 +367,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s ); } - if (allRelationSchema) { + if (relationSchema) { // export schema with only scalar fields const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`; writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`); modelSchema = modelScalarSchema; // merge relations - modelSchema = makeMerge(modelSchema, allRelationSchema); + modelSchema = makeMerge(modelSchema, makePartial(relationSchema)); } // refine @@ -413,10 +386,40 @@ async function generateModelSchema(model: DataModel, project: Project, output: s writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`); //////////////////////////////////////////////// - // 2. Create schema + // 2. Prisma create & update + //////////////////////////////////////////////// + + // schema for validating prisma create input (all fields optional) + let prismaCreateSchema = makePartial('baseSchema'); + if (refineFuncName) { + prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`; + } + writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSchema};`); + + // schema for validating prisma update input (all fields optional) + // note numeric fields can be simple update or atomic operations + let prismaUpdateSchema = `z.object({ + ${scalarFields + .map((field) => { + let fieldSchema = makeFieldSchema(field); + if (field.type.type === 'Int' || field.type.type === 'Float') { + fieldSchema = `z.union([${fieldSchema}, z.record(z.unknown())])`; + } + return `\t${field.name}: ${fieldSchema}`; + }) + .join(',\n')} +})`; + prismaUpdateSchema = makePartial(prismaUpdateSchema); + if (refineFuncName) { + prismaUpdateSchema = `${refineFuncName}(${prismaUpdateSchema})`; + } + writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSchema};`); + + //////////////////////////////////////////////// + // 3. Create schema //////////////////////////////////////////////// let createSchema = 'baseSchema'; - const fieldsWithDefault = fields.filter( + const fieldsWithDefault = scalarFields.filter( (field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array ); if (fieldsWithDefault.length > 0) { @@ -426,30 +429,13 @@ async function generateModelSchema(model: DataModel, project: Project, output: s ); } - if (safeRelationSchema || unsafeRelationSchema) { + if (fkSchema) { // export schema with only scalar fields const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`; writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`); - createSchema = createScalarSchema; - - if (safeRelationSchema && unsafeRelationSchema) { - // build a union of with relation object fields and with fk fields (mutually exclusive) - - // TODO: we make all relation fields partial for now because in case of - // nested create, not all relation/fk fields are inside payload, need a - // better solution - createSchema = makeUnion( - makeMerge(createSchema, makePartial(safeRelationSchema)), - makeMerge(createSchema, makePartial(unsafeRelationSchema)) - ); - } else if (safeRelationSchema) { - // just relation - - // TODO: we make all relation fields partial for now because in case of - // nested create, not all relation/fk fields are inside payload, need a - // better solution - createSchema = makeMerge(createSchema, makePartial(safeRelationSchema)); - } + + // merge fk fields + createSchema = makeMerge(createScalarSchema, fkSchema); } if (refineFuncName) { @@ -465,22 +451,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s //////////////////////////////////////////////// let updateSchema = makePartial('baseSchema'); - if (safeRelationSchema || unsafeRelationSchema) { + if (fkSchema) { // export schema with only scalar fields const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`; writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`); updateSchema = updateScalarSchema; - if (safeRelationSchema && unsafeRelationSchema) { - // build a union of with relation object fields and with fk fields (mutually exclusive) - updateSchema = makeUnion( - makeMerge(updateSchema, makePartial(safeRelationSchema)), - makeMerge(updateSchema, makePartial(unsafeRelationSchema)) - ); - } else if (safeRelationSchema) { - // just relation - updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema)); - } + // merge fk fields + updateSchema = makeMerge(updateSchema, makePartial(fkSchema)); } if (refineFuncName) { @@ -514,15 +492,3 @@ function makeOmit(schema: string, fields: string[]) { function makeMerge(schema1: string, schema2: string): string { return `${schema1}.merge(${schema2})`; } - -function makeUnion(...schemas: string[]): string { - return `z.union([${schemas.join(', ')}])`; -} - -function hasForeignKey(field: DataModelField) { - const relAttr = getAttribute(field, '@relation'); - if (!relAttr) { - return false; - } - return !!getAttributeArg(relAttr, 'fields'); -} diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index a73b34924..f84b25a2d 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -7,20 +7,13 @@ import { TypeScriptExpressionTransformerError, } from '../../../utils/typescript-expression-transformer'; -export function makeFieldSchema(field: DataModelField, forMutation = false) { +export function makeFieldSchema(field: DataModelField) { if (isDataModel(field.type.reference?.ref)) { - if (!forMutation) { - // read schema, always optional - if (field.type.array) { - return `z.array(z.unknown()).optional()`; - } else { - return `z.record(z.unknown()).optional()`; - } + if (field.type.array) { + // array field is always optional + return `z.array(z.unknown()).optional()`; } else { - // write schema - return `${ - field.type.optional || field.type.array ? 'z.record(z.unknown()).optional()' : 'z.record(z.unknown())' - }`; + return field.type.optional ? `z.record(z.unknown()).optional()` : `z.record(z.unknown())`; } } diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 16dcb04f1..8896e8d0d 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "1.4.0", + "version": "1.4.1", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { diff --git a/packages/server/package.json b/packages/server/package.json index d33a4cc2e..b91f99b89 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "1.4.0", + "version": "1.4.1", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 7dafca343..2e202ae0e 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "1.4.0", + "version": "1.4.1", "description": "ZenStack Test Tools", "main": "index.js", "private": true, diff --git a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts index c6cebed39..8727f1561 100644 --- a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts @@ -13,6 +13,7 @@ describe('With Policy: field validation', () => { email String? @email @endsWith("@myorg.com") profileImage String? @url handle String? @regex("^[0-9a-zA-Z]{4,16}$") + age Int @default(18) @gt(0) @lt(100) userData UserData? tasks Task[] @@ -127,6 +128,24 @@ describe('With Policy: field validation', () => { 'String must contain at least 8 character(s) at "password"', 'must end with "@myorg.com" at "email"', ]); + + await expect( + db.user.update({ + where: { id: '1' }, + data: { + age: { increment: 100 }, + }, + }) + ).toBeRejectedByPolicy(['Number must be less than 100 at "age"']); + + await expect( + db.user.update({ + where: { id: '1' }, + data: { + age: { increment: 10 }, + }, + }) + ).toResolveTruthy(); }); it('direct write more', async () => { diff --git a/tests/integration/tests/plugins/zod.test.ts b/tests/integration/tests/plugins/zod.test.ts index 3081a1524..fd12d8b60 100644 --- a/tests/integration/tests/plugins/zod.test.ts +++ b/tests/integration/tests/plugins/zod.test.ts @@ -64,6 +64,8 @@ describe('Zod plugin tests', () => { expect(schemas.UserSchema).toBeTruthy(); expect(schemas.UserCreateSchema).toBeTruthy(); expect(schemas.UserUpdateSchema).toBeTruthy(); + expect(schemas.UserPrismaCreateSchema).toBeTruthy(); + expect(schemas.UserPrismaUpdateSchema).toBeTruthy(); // create expect(schemas.UserCreateSchema.safeParse({ email: 'abc' }).success).toBeFalsy(); @@ -77,7 +79,6 @@ describe('Zod plugin tests', () => { schemas.UserCreateSchema.safeParse({ email: 'abc@zenstack.dev', role: 'ADMIN', password: 'abc123' }).success ).toBeTruthy(); - // create unchecked // create unchecked expect( zodSchemas.input.UserInputSchema.create.safeParse({ @@ -97,7 +98,7 @@ describe('Zod plugin tests', () => { ).toBeTruthy(); // model schema - expect(schemas.UserSchema.safeParse({ email: 'abc@zenstack.dev', role: 'ADMIN' }).success).toBeFalsy(); + expect(schemas.UserSchema.safeParse({ email: 'abc@zenstack.dev', role: 'ADMIN' }).success).toBeTruthy(); // without omitted field expect( schemas.UserSchema.safeParse({ @@ -126,6 +127,18 @@ describe('Zod plugin tests', () => { expect(schemas.PostCreateSchema.safeParse({ title: 'abc' }).success).toBeFalsy(); expect(schemas.PostCreateSchema.safeParse({ title: 'abcabcabcabc' }).success).toBeFalsy(); expect(schemas.PostCreateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy(); + schemas.PostCreateSchema.parse({ title: 'abcde', authorId: 1 }); + expect(schemas.PostCreateSchema.safeParse({ title: 'abcde', authorId: 1 }).data.authorId).toBe(1); + expect(schemas.PostUpdateSchema.safeParse({ authorId: 1 }).data.authorId).toBe(1); + + expect(schemas.PostPrismaCreateSchema.safeParse({ title: 'a' }).success).toBeFalsy(); + expect(schemas.PostPrismaCreateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy(); + expect(schemas.PostPrismaCreateSchema.safeParse({ viewCount: 1 }).success).toBeTruthy(); + + expect(schemas.PostPrismaUpdateSchema.safeParse({ title: 'a' }).success).toBeFalsy(); + expect(schemas.PostPrismaUpdateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy(); + expect(schemas.PostPrismaUpdateSchema.safeParse({ viewCount: 1 }).success).toBeTruthy(); + expect(schemas.PostPrismaUpdateSchema.safeParse({ viewCount: { increment: 1 } }).success).toBeTruthy(); }); it('mixed casing', async () => { diff --git a/tests/integration/tests/schema/todo.zmodel b/tests/integration/tests/schema/todo.zmodel index bb68c185f..733391bd1 100644 --- a/tests/integration/tests/schema/todo.zmodel +++ b/tests/integration/tests/schema/todo.zmodel @@ -12,6 +12,11 @@ generator js { previewFeatures = ['clientExtensions'] } +plugin zod { + provider = '@core/zod' + preserveTsFiles = true +} + /* * Model for a space in which users can collaborate on Lists and Todos */ @@ -104,6 +109,7 @@ model List { title String @length(1, 100) private Boolean @default(false) todos Todo[] + revision Int @default(0) // require login @@deny('all', auth() == null)