Skip to content

Commit

Permalink
fix: clean up zod generation
Browse files Browse the repository at this point in the history
- Add PrismaCreate/PrismaUpdate schemas for internal use
- Make Create/Update schemas only include foreign keys but not relation
  • Loading branch information
ymc9 committed Dec 7, 2023
1 parent aa705a4 commit 29c5e81
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 106 deletions.
2 changes: 1 addition & 1 deletion packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ export class PolicyUtil {
if (!this.hasFieldValidation(model)) {
return undefined;
}
const schemaKey = `${upperCaseFirst(model)}${kind ? upperCaseFirst(kind) : ''}Schema`;
const schemaKey = `${upperCaseFirst(model)}${kind ? 'Prisma' + upperCaseFirst(kind) : ''}Schema`;
return this.zodSchemas?.models?.[schemaKey];
}

Expand Down
148 changes: 57 additions & 91 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import {
PluginOptions,
createProject,
emitProject,
getAttribute,
getAttributeArg,
getDataModels,
getLiteral,
getPrismaClientImportSpec,
Expand All @@ -17,16 +15,7 @@ import {
resolvePath,
saveProject,
} from '@zenstackhq/sdk';
import {
DataModel,
DataModelField,
DataSource,
EnumField,
Model,
isDataModel,
isDataSource,
isEnum,
} from '@zenstackhq/sdk/ast';
import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast';
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
import { promises as fs } from 'fs';
import { streamAllContents } from 'langium';
Expand Down Expand Up @@ -271,18 +260,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
overwrite: true,
});
sf.replaceWithText((writer) => {
const fields = model.fields.filter(
const scalarFields = model.fields.filter(
(field) =>
// regular fields only
!isDataModel(field.type.reference?.ref) && !isForeignKeyField(field)
);

const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref));
const fkFields = model.fields.filter((field) => isForeignKeyField(field));
// unsafe version of relations: including foreign keys and relation fields without fk
const unsafeRelations = model.fields.filter(
(field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field))
);

writer.writeLine('/* eslint-disable */');
writer.writeLine(`import { z } from 'zod';`);
Expand All @@ -304,7 +289,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s

// import enum schemas
const importedEnumSchemas = new Set<string>();
for (const field of fields) {
for (const field of scalarFields) {
if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) {
const name = upperCaseFirst(field.type.reference?.ref.name);
if (!importedEnumSchemas.has(name)) {
Expand All @@ -315,29 +300,28 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
}

// import Decimal
if (fields.some((field) => field.type.type === 'Decimal')) {
if (scalarFields.some((field) => field.type.type === 'Decimal')) {
writer.writeLine(`import { DecimalSchema } from '../common';`);
writer.writeLine(`import { Decimal } from 'decimal.js';`);
}

// base schema
writer.write(`const baseSchema = z.object(`);
writer.inlineBlock(() => {
fields.forEach((field) => {
scalarFields.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
});
});
writer.writeLine(');');

// relation fields

let allRelationSchema: string | undefined;
let safeRelationSchema: string | undefined;
let unsafeRelationSchema: string | undefined;
let relationSchema: string | undefined;
let fkSchema: string | undefined;

if (relations.length > 0 || fkFields.length > 0) {
allRelationSchema = 'allRelationSchema';
writer.write(`const ${allRelationSchema} = z.object(`);
relationSchema = 'relationSchema';
writer.write(`const ${relationSchema} = z.object(`);
writer.inlineBlock(() => {
[...relations, ...fkFields].forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
Expand All @@ -346,23 +330,12 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
writer.writeLine(');');
}

if (relations.length > 0) {
safeRelationSchema = 'safeRelationSchema';
writer.write(`const ${safeRelationSchema} = z.object(`);
if (fkFields.length > 0) {
fkSchema = 'fkSchema';
writer.write(`const ${fkSchema} = z.object(`);
writer.inlineBlock(() => {
relations.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
});
});
writer.writeLine(');');
}

if (unsafeRelations.length > 0) {
unsafeRelationSchema = 'unsafeRelationSchema';
writer.write(`const ${unsafeRelationSchema} = z.object(`);
writer.inlineBlock(() => {
unsafeRelations.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
fkFields.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
});
});
writer.writeLine(');');
Expand All @@ -383,25 +356,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
////////////////////////////////////////////////
// 1. Model schema
////////////////////////////////////////////////
let modelSchema = 'baseSchema';
let modelSchema = makePartial('baseSchema');

// omit fields
const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit'));
const fieldsToOmit = scalarFields.filter((field) => hasAttribute(field, '@omit'));
if (fieldsToOmit.length > 0) {
modelSchema = makeOmit(
modelSchema,
fieldsToOmit.map((f) => f.name)
);
}

if (allRelationSchema) {
if (relationSchema) {
// export schema with only scalar fields
const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`;
writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`);
modelSchema = modelScalarSchema;

// merge relations
modelSchema = makeMerge(modelSchema, allRelationSchema);
modelSchema = makeMerge(modelSchema, makePartial(relationSchema));
}

// refine
Expand All @@ -413,10 +386,40 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`);

////////////////////////////////////////////////
// 2. Create schema
// 2. Prisma create & update
////////////////////////////////////////////////

// schema for validating prisma create input (all fields optional)
let prismaCreateSchema = makePartial('baseSchema');
if (refineFuncName) {
prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`;
}
writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSchema};`);

// schema for validating prisma update input (all fields optional)
// note numeric fields can be simple update or atomic operations
let prismaUpdateSchema = `z.object({
${scalarFields
.map((field) => {
let fieldSchema = makeFieldSchema(field);
if (field.type.type === 'Int' || field.type.type === 'Float') {
fieldSchema = `z.union([${fieldSchema}, z.record(z.unknown())])`;
}
return `\t${field.name}: ${fieldSchema}`;
})
.join(',\n')}
})`;
prismaUpdateSchema = makePartial(prismaUpdateSchema);
if (refineFuncName) {
prismaUpdateSchema = `${refineFuncName}(${prismaUpdateSchema})`;
}
writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSchema};`);

////////////////////////////////////////////////
// 3. Create schema
////////////////////////////////////////////////
let createSchema = 'baseSchema';
const fieldsWithDefault = fields.filter(
const fieldsWithDefault = scalarFields.filter(
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
);
if (fieldsWithDefault.length > 0) {
Expand All @@ -426,30 +429,13 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
);
}

if (safeRelationSchema || unsafeRelationSchema) {
if (fkSchema) {
// export schema with only scalar fields
const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`;
writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`);
createSchema = createScalarSchema;

if (safeRelationSchema && unsafeRelationSchema) {
// build a union of with relation object fields and with fk fields (mutually exclusive)

// TODO: we make all relation fields partial for now because in case of
// nested create, not all relation/fk fields are inside payload, need a
// better solution
createSchema = makeUnion(
makeMerge(createSchema, makePartial(safeRelationSchema)),
makeMerge(createSchema, makePartial(unsafeRelationSchema))
);
} else if (safeRelationSchema) {
// just relation

// TODO: we make all relation fields partial for now because in case of
// nested create, not all relation/fk fields are inside payload, need a
// better solution
createSchema = makeMerge(createSchema, makePartial(safeRelationSchema));
}

// merge fk fields
createSchema = makeMerge(createScalarSchema, fkSchema);
}

if (refineFuncName) {
Expand All @@ -465,22 +451,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
////////////////////////////////////////////////
let updateSchema = makePartial('baseSchema');

if (safeRelationSchema || unsafeRelationSchema) {
if (fkSchema) {
// export schema with only scalar fields
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;
writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`);
updateSchema = updateScalarSchema;

if (safeRelationSchema && unsafeRelationSchema) {
// build a union of with relation object fields and with fk fields (mutually exclusive)
updateSchema = makeUnion(
makeMerge(updateSchema, makePartial(safeRelationSchema)),
makeMerge(updateSchema, makePartial(unsafeRelationSchema))
);
} else if (safeRelationSchema) {
// just relation
updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema));
}
// merge fk fields
updateSchema = makeMerge(updateSchema, makePartial(fkSchema));
}

if (refineFuncName) {
Expand Down Expand Up @@ -514,15 +492,3 @@ function makeOmit(schema: string, fields: string[]) {
function makeMerge(schema1: string, schema2: string): string {
return `${schema1}.merge(${schema2})`;
}

function makeUnion(...schemas: string[]): string {
return `z.union([${schemas.join(', ')}])`;
}

function hasForeignKey(field: DataModelField) {
const relAttr = getAttribute(field, '@relation');
if (!relAttr) {
return false;
}
return !!getAttributeArg(relAttr, 'fields');
}
17 changes: 5 additions & 12 deletions packages/schema/src/plugins/zod/utils/schema-gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,13 @@ import {
TypeScriptExpressionTransformerError,
} from '../../../utils/typescript-expression-transformer';

export function makeFieldSchema(field: DataModelField, forMutation = false) {
export function makeFieldSchema(field: DataModelField) {
if (isDataModel(field.type.reference?.ref)) {
if (!forMutation) {
// read schema, always optional
if (field.type.array) {
return `z.array(z.unknown()).optional()`;
} else {
return `z.record(z.unknown()).optional()`;
}
if (field.type.array) {
// array field is always optional
return `z.array(z.unknown()).optional()`;
} else {
// write schema
return `${
field.type.optional || field.type.array ? 'z.record(z.unknown()).optional()' : 'z.record(z.unknown())'
}`;
return field.type.optional ? `z.record(z.unknown()).optional()` : `z.record(z.unknown())`;
}
}

Expand Down
17 changes: 15 additions & 2 deletions tests/integration/tests/plugins/zod.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ describe('Zod plugin tests', () => {
expect(schemas.UserSchema).toBeTruthy();
expect(schemas.UserCreateSchema).toBeTruthy();
expect(schemas.UserUpdateSchema).toBeTruthy();
expect(schemas.UserPrismaCreateSchema).toBeTruthy();
expect(schemas.UserPrismaUpdateSchema).toBeTruthy();

// create
expect(schemas.UserCreateSchema.safeParse({ email: 'abc' }).success).toBeFalsy();
Expand All @@ -77,7 +79,6 @@ describe('Zod plugin tests', () => {
schemas.UserCreateSchema.safeParse({ email: '[email protected]', role: 'ADMIN', password: 'abc123' }).success
).toBeTruthy();

// create unchecked
// create unchecked
expect(
zodSchemas.input.UserInputSchema.create.safeParse({
Expand All @@ -97,7 +98,7 @@ describe('Zod plugin tests', () => {
).toBeTruthy();

// model schema
expect(schemas.UserSchema.safeParse({ email: '[email protected]', role: 'ADMIN' }).success).toBeFalsy();
expect(schemas.UserSchema.safeParse({ email: '[email protected]', role: 'ADMIN' }).success).toBeTruthy();
// without omitted field
expect(
schemas.UserSchema.safeParse({
Expand Down Expand Up @@ -126,6 +127,18 @@ describe('Zod plugin tests', () => {
expect(schemas.PostCreateSchema.safeParse({ title: 'abc' }).success).toBeFalsy();
expect(schemas.PostCreateSchema.safeParse({ title: 'abcabcabcabc' }).success).toBeFalsy();
expect(schemas.PostCreateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy();
schemas.PostCreateSchema.parse({ title: 'abcde', authorId: 1 });
expect(schemas.PostCreateSchema.safeParse({ title: 'abcde', authorId: 1 }).data.authorId).toBe(1);
expect(schemas.PostUpdateSchema.safeParse({ authorId: 1 }).data.authorId).toBe(1);

expect(schemas.PostPrismaCreateSchema.safeParse({ title: 'a' }).success).toBeFalsy();
expect(schemas.PostPrismaCreateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy();
expect(schemas.PostPrismaCreateSchema.safeParse({ viewCount: 1 }).success).toBeTruthy();

expect(schemas.PostPrismaUpdateSchema.safeParse({ title: 'a' }).success).toBeFalsy();
expect(schemas.PostPrismaUpdateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy();
expect(schemas.PostPrismaUpdateSchema.safeParse({ viewCount: 1 }).success).toBeTruthy();
expect(schemas.PostPrismaUpdateSchema.safeParse({ viewCount: { increment: 1 } }).success).toBeTruthy();
});

it('mixed casing', async () => {
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/tests/schema/todo.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ generator js {
previewFeatures = ['clientExtensions']
}

plugin zod {
provider = '@core/zod'
preserveTsFiles = true
}

/*
* Model for a space in which users can collaborate on Lists and Todos
*/
Expand Down Expand Up @@ -104,6 +109,7 @@ model List {
title String @length(1, 100)
private Boolean @default(false)
todos Todo[]
revision Int @default(0)

// require login
@@deny('all', auth() == null)
Expand Down

0 comments on commit 29c5e81

Please sign in to comment.