Skip to content

Commit

Permalink
fix(zod): zod create/update schemas should exclude discriminator fiel…
Browse files Browse the repository at this point in the history
…ds (#1609)
  • Loading branch information
ymc9 authored Jul 22, 2024
1 parent 91abbb8 commit 17fe8c3
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 40 deletions.
54 changes: 20 additions & 34 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
getDataModels,
getLiteral,
isDelegateModel,
isDiscriminatorField,
type PluginOptions,
} from '@zenstackhq/sdk';
import {
Expand Down Expand Up @@ -495,33 +496,34 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
return source;
}

private readonly ModelCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/;

private removeDiscriminatorFromConcreteInput(
typeAlias: TypeAliasDeclaration,
delegateInfo: DelegateInfo,
_delegateInfo: DelegateInfo,
source: string
) {
// remove discriminator field from the create/update input of concrete models because
// discriminator cannot be set directly
// remove discriminator field from the create/update input because discriminator cannot be set directly
const typeName = typeAlias.getName();
const concreteModelNames = delegateInfo.map(([, concretes]) => concretes.map((c) => c.name)).flatMap((c) => c);
const concreteCreateUpdateInputRegex = new RegExp(
`(${concreteModelNames.join('|')})(Unchecked)?(Create|Update).*Input`
);

const match = typeName.match(concreteCreateUpdateInputRegex);
const match = typeName.match(this.ModelCreateUpdateInputRegex);
if (match) {
const modelName = match[1];
const record = delegateInfo.find(([, concretes]) => concretes.some((c) => c.name === modelName));
if (record) {
// remove all discriminator fields recursively
const delegateOfConcrete = record[0];
const discriminators = this.getDiscriminatorFieldsRecursively(delegateOfConcrete);
discriminators.forEach((discriminatorDecl) => {
const discriminatorNode = this.findNamedProperty(typeAlias, discriminatorDecl.name);
if (discriminatorNode) {
source = source.replace(discriminatorNode.getText(), '');
const dataModel = this.model.declarations.find(
(d): d is DataModel => isDataModel(d) && d.name === modelName
);

if (!dataModel) {
return source;
}

for (const field of dataModel.fields) {
if (isDiscriminatorField(field)) {
const fieldDef = this.findNamedProperty(typeAlias, field.name);
if (fieldDef) {
source = source.replace(fieldDef.getText(), '');
}
});
}
}
}
return source;
Expand Down Expand Up @@ -618,22 +620,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}

private getDiscriminatorFieldsRecursively(delegate: DataModel, result: DataModelField[] = []) {
if (isDelegateModel(delegate)) {
const discriminator = this.getDiscriminatorField(delegate);
if (discriminator) {
result.push(discriminator);
}

for (const superType of delegate.superTypes) {
if (superType.ref) {
result.push(...this.getDiscriminatorFieldsRecursively(superType.ref, result));
}
}
}
return result;
}

private async saveSourceFile(sf: SourceFile) {
if (this.options.preserveTsFiles) {
await sf.save();
Expand Down
15 changes: 12 additions & 3 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import {
ensureEmptyDir,
getDataModels,
hasAttribute,
isDiscriminatorField,
isEnumFieldReference,
isForeignKeyField,
isFromStdlib,
Expand Down Expand Up @@ -368,6 +369,13 @@ export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T
);
}

// delegate discriminator fields are to be excluded from mutation schemas
const delegateFields = model.fields.filter((field) => isDiscriminatorField(field));
const omitDiscriminators =
delegateFields.length > 0
? `.omit({ ${delegateFields.map((f) => `${f.name}: true`).join(', ')} })`
: '';

////////////////////////////////////////////////
// 1. Model schema
////////////////////////////////////////////////
Expand Down Expand Up @@ -429,7 +437,7 @@ export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};
////////////////////////////////////////////////

// schema for validating prisma create input (all fields optional)
let prismaCreateSchema = this.makePassthrough(this.makePartial('baseSchema'));
let prismaCreateSchema = this.makePassthrough(this.makePartial(`baseSchema${omitDiscriminators}`));
if (refineFuncName) {
prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`;
}
Expand All @@ -445,6 +453,7 @@ export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSch
// note numeric fields can be simple update or atomic operations
let prismaUpdateSchema = `z.object({
${scalarFields
.filter((f) => !isDiscriminatorField(f))
.map((field) => {
let fieldSchema = makeFieldSchema(field);
if (field.type.type === 'Int' || field.type.type === 'Float') {
Expand Down Expand Up @@ -472,7 +481,7 @@ export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSch
// 3. Create schema
////////////////////////////////////////////////

let createSchema = 'baseSchema';
let createSchema = `baseSchema${omitDiscriminators}`;
const fieldsWithDefault = scalarFields.filter(
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
);
Expand Down Expand Up @@ -524,7 +533,7 @@ export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};
////////////////////////////////////////////////

// for update all fields are optional
let updateSchema = this.makePartial('baseSchema');
let updateSchema = this.makePartial(`baseSchema${omitDiscriminators}`);

// export schema with only scalar fields: `[Model]UpdateScalarSchema`
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;
Expand Down
29 changes: 26 additions & 3 deletions packages/schema/src/plugins/zod/transformer.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { indentString, type PluginOptions } from '@zenstackhq/sdk';
import type { Model } from '@zenstackhq/sdk/ast';
import { indentString, isDiscriminatorField, type PluginOptions } from '@zenstackhq/sdk';
import { DataModel, isDataModel, type Model } from '@zenstackhq/sdk/ast';
import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers';
import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
import path from 'path';
Expand Down Expand Up @@ -90,8 +90,31 @@ export default class Transformer {
return `${this.name}.schema`;
}

private delegateCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/;

generateObjectSchemaFields(generateUnchecked: boolean) {
const zodObjectSchemaFields = this.fields
let fields = this.fields;

// exclude discriminator fields from create/update input schemas
const createUpdateMatch = this.delegateCreateUpdateInputRegex.exec(this.name);
if (createUpdateMatch) {
const modelName = createUpdateMatch[1];
const dataModel = this.zmodel.declarations.find(
(d): d is DataModel => isDataModel(d) && d.name === modelName
);
if (dataModel) {
const discriminatorFields = dataModel.fields.filter(isDiscriminatorField);
if (discriminatorFields.length > 0) {
fields = fields.filter((field) => {
return !discriminatorFields.some(
(discriminatorField) => discriminatorField.name === field.name
);
});
}
}
}

const zodObjectSchemaFields = fields
.map((field) => this.generateObjectSchemaField(field, generateUnchecked))
.flatMap((item) => item)
.map((item) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,94 @@ describe('Polymorphic Plugin Interaction Test', () => {
extraDependencies: ['@trpc/client', '@trpc/server', '@trpc/react-query'],
});
});

it('zod', async () => {
const { zodSchemas } = await loadSchema(POLYMORPHIC_SCHEMA, { fullZod: true });

// model schema
expect(
zodSchemas.models.AssetSchema.parse({
id: 1,
assetType: 'video',
createdAt: new Date(),
viewCount: 100,
})
).toBeTruthy();

expect(
zodSchemas.models.AssetSchema.parse({
id: 1,
assetType: 'video',
createdAt: new Date(),
viewCount: 100,
videoType: 'ratedVideo', // should be stripped
}).videoType
).toBeUndefined();

expect(
zodSchemas.models.VideoSchema.parse({
id: 1,
assetType: 'video',
videoType: 'ratedVideo',
duration: 100,
url: 'http://example.com',
createdAt: new Date(),
viewCount: 100,
})
).toBeTruthy();

expect(() =>
zodSchemas.models.VideoSchema.parse({
id: 1,
assetType: 'video',
videoType: 'ratedVideo',
url: 'http://example.com',
createdAt: new Date(),
viewCount: 100,
})
).toThrow('duration');

// create schema
expect(
zodSchemas.models.VideoCreateSchema.parse({
duration: 100,
url: 'http://example.com',
}).assetType // discriminator should not be set
).toBeUndefined();

// update schema
expect(
zodSchemas.models.VideoUpdateSchema.parse({
duration: 100,
url: 'http://example.com',
}).assetType // discriminator should not be set
).toBeUndefined();

// prisma create schema
expect(
zodSchemas.models.VideoPrismaCreateSchema.strip().parse({
assetType: 'video',
}).assetType // discriminator should not be set
).toBeUndefined();

// input object schema
expect(
zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({
duration: 100,
viewCount: 200,
url: 'http://www.example.com',
rating: 5,
})
).toBeTruthy();

expect(() =>
zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({
duration: 100,
viewCount: 200,
url: 'http://www.example.com',
rating: 5,
videoType: 'ratedVideo',
})
).toThrow('videoType');
});
});

0 comments on commit 17fe8c3

Please sign in to comment.