From 17fe8c31d0fdc5b620cd190deedcadfae6567b08 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 22 Jul 2024 16:08:16 -0700 Subject: [PATCH] fix(zod): zod create/update schemas should exclude discriminator fields (#1609) --- .../src/plugins/enhancer/enhance/index.ts | 54 +++++------ packages/schema/src/plugins/zod/generator.ts | 15 +++- .../schema/src/plugins/zod/transformer.ts | 29 +++++- .../with-delegate/plugin-interaction.test.ts | 90 +++++++++++++++++++ 4 files changed, 148 insertions(+), 40 deletions(-) diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 0da689e49..0c8219c43 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -7,6 +7,7 @@ import { getDataModels, getLiteral, isDelegateModel, + isDiscriminatorField, type PluginOptions, } from '@zenstackhq/sdk'; import { @@ -495,33 +496,34 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara return source; } + private readonly ModelCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/; + private removeDiscriminatorFromConcreteInput( typeAlias: TypeAliasDeclaration, - delegateInfo: DelegateInfo, + _delegateInfo: DelegateInfo, source: string ) { - // remove discriminator field from the create/update input of concrete models because - // discriminator cannot be set directly + // remove discriminator field from the create/update input because discriminator cannot be set directly const typeName = typeAlias.getName(); - const concreteModelNames = delegateInfo.map(([, concretes]) => concretes.map((c) => c.name)).flatMap((c) => c); - const concreteCreateUpdateInputRegex = new RegExp( - `(${concreteModelNames.join('|')})(Unchecked)?(Create|Update).*Input` - ); - const match = typeName.match(concreteCreateUpdateInputRegex); + const match = typeName.match(this.ModelCreateUpdateInputRegex); if (match) { const modelName = match[1]; - const record = delegateInfo.find(([, concretes]) => concretes.some((c) => c.name === modelName)); - if (record) { - // remove all discriminator fields recursively - const delegateOfConcrete = record[0]; - const discriminators = this.getDiscriminatorFieldsRecursively(delegateOfConcrete); - discriminators.forEach((discriminatorDecl) => { - const discriminatorNode = this.findNamedProperty(typeAlias, discriminatorDecl.name); - if (discriminatorNode) { - source = source.replace(discriminatorNode.getText(), ''); + const dataModel = this.model.declarations.find( + (d): d is DataModel => isDataModel(d) && d.name === modelName + ); + + if (!dataModel) { + return source; + } + + for (const field of dataModel.fields) { + if (isDiscriminatorField(field)) { + const fieldDef = this.findNamedProperty(typeAlias, field.name); + if (fieldDef) { + source = source.replace(fieldDef.getText(), ''); } - }); + } } } return source; @@ -618,22 +620,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; } - private getDiscriminatorFieldsRecursively(delegate: DataModel, result: DataModelField[] = []) { - if (isDelegateModel(delegate)) { - const discriminator = this.getDiscriminatorField(delegate); - if (discriminator) { - result.push(discriminator); - } - - for (const superType of delegate.superTypes) { - if (superType.ref) { - result.push(...this.getDiscriminatorFieldsRecursively(superType.ref, result)); - } - } - } - return result; - } - private async saveSourceFile(sf: SourceFile) { if (this.options.preserveTsFiles) { await sf.save(); diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 91f152af8..3f656668d 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -5,6 +5,7 @@ import { ensureEmptyDir, getDataModels, hasAttribute, + isDiscriminatorField, isEnumFieldReference, isForeignKeyField, isFromStdlib, @@ -368,6 +369,13 @@ export function ${refineFuncName}(schema: z.ZodType isDiscriminatorField(field)); + const omitDiscriminators = + delegateFields.length > 0 + ? `.omit({ ${delegateFields.map((f) => `${f.name}: true`).join(', ')} })` + : ''; + //////////////////////////////////////////////// // 1. Model schema //////////////////////////////////////////////// @@ -429,7 +437,7 @@ export const ${upperCaseFirst(model.name)}Schema = ${modelSchema}; //////////////////////////////////////////////// // schema for validating prisma create input (all fields optional) - let prismaCreateSchema = this.makePassthrough(this.makePartial('baseSchema')); + let prismaCreateSchema = this.makePassthrough(this.makePartial(`baseSchema${omitDiscriminators}`)); if (refineFuncName) { prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`; } @@ -445,6 +453,7 @@ export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSch // note numeric fields can be simple update or atomic operations let prismaUpdateSchema = `z.object({ ${scalarFields + .filter((f) => !isDiscriminatorField(f)) .map((field) => { let fieldSchema = makeFieldSchema(field); if (field.type.type === 'Int' || field.type.type === 'Float') { @@ -472,7 +481,7 @@ export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSch // 3. Create schema //////////////////////////////////////////////// - let createSchema = 'baseSchema'; + let createSchema = `baseSchema${omitDiscriminators}`; const fieldsWithDefault = scalarFields.filter( (field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array ); @@ -524,7 +533,7 @@ export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema}; //////////////////////////////////////////////// // for update all fields are optional - let updateSchema = this.makePartial('baseSchema'); + let updateSchema = this.makePartial(`baseSchema${omitDiscriminators}`); // export schema with only scalar fields: `[Model]UpdateScalarSchema` const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`; diff --git a/packages/schema/src/plugins/zod/transformer.ts b/packages/schema/src/plugins/zod/transformer.ts index 09804f42b..ca714f1ad 100644 --- a/packages/schema/src/plugins/zod/transformer.ts +++ b/packages/schema/src/plugins/zod/transformer.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ -import { indentString, type PluginOptions } from '@zenstackhq/sdk'; -import type { Model } from '@zenstackhq/sdk/ast'; +import { indentString, isDiscriminatorField, type PluginOptions } from '@zenstackhq/sdk'; +import { DataModel, isDataModel, type Model } from '@zenstackhq/sdk/ast'; import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers'; import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma'; import path from 'path'; @@ -90,8 +90,31 @@ export default class Transformer { return `${this.name}.schema`; } + private delegateCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/; + generateObjectSchemaFields(generateUnchecked: boolean) { - const zodObjectSchemaFields = this.fields + let fields = this.fields; + + // exclude discriminator fields from create/update input schemas + const createUpdateMatch = this.delegateCreateUpdateInputRegex.exec(this.name); + if (createUpdateMatch) { + const modelName = createUpdateMatch[1]; + const dataModel = this.zmodel.declarations.find( + (d): d is DataModel => isDataModel(d) && d.name === modelName + ); + if (dataModel) { + const discriminatorFields = dataModel.fields.filter(isDiscriminatorField); + if (discriminatorFields.length > 0) { + fields = fields.filter((field) => { + return !discriminatorFields.some( + (discriminatorField) => discriminatorField.name === field.name + ); + }); + } + } + } + + const zodObjectSchemaFields = fields .map((field) => this.generateObjectSchemaField(field, generateUnchecked)) .flatMap((item) => item) .map((item) => { diff --git a/tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts index a44e69aac..b0fb0d343 100644 --- a/tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts @@ -55,4 +55,94 @@ describe('Polymorphic Plugin Interaction Test', () => { extraDependencies: ['@trpc/client', '@trpc/server', '@trpc/react-query'], }); }); + + it('zod', async () => { + const { zodSchemas } = await loadSchema(POLYMORPHIC_SCHEMA, { fullZod: true }); + + // model schema + expect( + zodSchemas.models.AssetSchema.parse({ + id: 1, + assetType: 'video', + createdAt: new Date(), + viewCount: 100, + }) + ).toBeTruthy(); + + expect( + zodSchemas.models.AssetSchema.parse({ + id: 1, + assetType: 'video', + createdAt: new Date(), + viewCount: 100, + videoType: 'ratedVideo', // should be stripped + }).videoType + ).toBeUndefined(); + + expect( + zodSchemas.models.VideoSchema.parse({ + id: 1, + assetType: 'video', + videoType: 'ratedVideo', + duration: 100, + url: 'http://example.com', + createdAt: new Date(), + viewCount: 100, + }) + ).toBeTruthy(); + + expect(() => + zodSchemas.models.VideoSchema.parse({ + id: 1, + assetType: 'video', + videoType: 'ratedVideo', + url: 'http://example.com', + createdAt: new Date(), + viewCount: 100, + }) + ).toThrow('duration'); + + // create schema + expect( + zodSchemas.models.VideoCreateSchema.parse({ + duration: 100, + url: 'http://example.com', + }).assetType // discriminator should not be set + ).toBeUndefined(); + + // update schema + expect( + zodSchemas.models.VideoUpdateSchema.parse({ + duration: 100, + url: 'http://example.com', + }).assetType // discriminator should not be set + ).toBeUndefined(); + + // prisma create schema + expect( + zodSchemas.models.VideoPrismaCreateSchema.strip().parse({ + assetType: 'video', + }).assetType // discriminator should not be set + ).toBeUndefined(); + + // input object schema + expect( + zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({ + duration: 100, + viewCount: 200, + url: 'http://www.example.com', + rating: 5, + }) + ).toBeTruthy(); + + expect(() => + zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({ + duration: 100, + viewCount: 200, + url: 'http://www.example.com', + rating: 5, + videoType: 'ratedVideo', + }) + ).toThrow('videoType'); + }); });