Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: generate foreign key field in zod schemas #868

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 170 additions & 16 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import {
PluginOptions,
createProject,
emitProject,
getAttribute,
getAttributeArg,
getDataModels,
getLiteral,
getPrismaClientImportSpec,
Expand All @@ -15,7 +17,16 @@ import {
resolvePath,
saveProject,
} from '@zenstackhq/sdk';
import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast';
import {
DataModel,
DataModelField,
DataSource,
EnumField,
Model,
isDataModel,
isDataSource,
isEnum,
} from '@zenstackhq/sdk/ast';
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
import { promises as fs } from 'fs';
import { streamAllContents } from 'langium';
Expand Down Expand Up @@ -262,10 +273,17 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
sf.replaceWithText((writer) => {
const fields = model.fields.filter(
(field) =>
// scalar fields only
// regular fields only
!isDataModel(field.type.reference?.ref) && !isForeignKeyField(field)
);

const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref));
const fkFields = model.fields.filter((field) => isForeignKeyField(field));
// unsafe version of relations: including foreign keys and relation fields without fk
const unsafeRelations = model.fields.filter(
(field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field))
);

writer.writeLine('/* eslint-disable */');
writer.writeLine(`import { z } from 'zod';`);

Expand Down Expand Up @@ -302,7 +320,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
writer.writeLine(`import { Decimal } from 'decimal.js';`);
}

// create base schema
// base schema
writer.write(`const baseSchema = z.object(`);
writer.inlineBlock(() => {
fields.forEach((field) => {
Expand All @@ -311,31 +329,92 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
});
writer.writeLine(');');

// relation fields

let allRelationSchema: string | undefined;
let safeRelationSchema: string | undefined;
let unsafeRelationSchema: string | undefined;

if (relations.length > 0 || fkFields.length > 0) {
allRelationSchema = 'allRelationSchema';
writer.write(`const ${allRelationSchema} = z.object(`);
writer.inlineBlock(() => {
[...relations, ...fkFields].forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
});
});
writer.writeLine(');');
}

if (relations.length > 0) {
safeRelationSchema = 'safeRelationSchema';
writer.write(`const ${safeRelationSchema} = z.object(`);
writer.inlineBlock(() => {
relations.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
});
});
writer.writeLine(');');
}

if (unsafeRelations.length > 0) {
unsafeRelationSchema = 'unsafeRelationSchema';
writer.write(`const ${unsafeRelationSchema} = z.object(`);
writer.inlineBlock(() => {
unsafeRelations.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
});
});
writer.writeLine(');');
}

// compile "@@validate" to ".refine"
const refinements = makeValidationRefinements(model);
let refineFuncName: string | undefined;
if (refinements.length > 0) {
refineFuncName = `refine${upperCaseFirst(model.name)}`;
writer.writeLine(
`function refine<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
`export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
'\n'
)}; }`
);
}

// model schema
////////////////////////////////////////////////
// 1. Model schema
////////////////////////////////////////////////
let modelSchema = 'baseSchema';

// omit fields
const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit'));
if (fieldsToOmit.length > 0) {
modelSchema = makeOmit(
modelSchema,
fieldsToOmit.map((f) => f.name)
);
}
if (refinements.length > 0) {
modelSchema = `refine(${modelSchema})`;

if (allRelationSchema) {
// export schema with only scalar fields
const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`;
writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`);
modelSchema = modelScalarSchema;

// merge relations
modelSchema = makeMerge(modelSchema, allRelationSchema);
}

// refine
if (refineFuncName) {
const noRefineSchema = `${upperCaseFirst(model.name)}WithoutRefineSchema`;
writer.writeLine(`export const ${noRefineSchema} = ${modelSchema};`);
modelSchema = `${refineFuncName}(${noRefineSchema})`;
}
writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`);

// create schema
////////////////////////////////////////////////
// 2. Create schema
////////////////////////////////////////////////
let createSchema = 'baseSchema';
const fieldsWithDefault = fields.filter(
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
Expand All @@ -346,29 +425,104 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
fieldsWithDefault.map((f) => f.name)
);
}
if (refinements.length > 0) {
createSchema = `refine(${createSchema})`;

if (safeRelationSchema || unsafeRelationSchema) {
// export schema with only scalar fields
const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`;
writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`);
createSchema = createScalarSchema;

if (safeRelationSchema && unsafeRelationSchema) {
// build a union of with relation object fields and with fk fields (mutually exclusive)

// TODO: we make all relation fields partial for now because in case of
// nested create, not all relation/fk fields are inside payload, need a
// better solution
createSchema = makeUnion(
makeMerge(createSchema, makePartial(safeRelationSchema)),
makeMerge(createSchema, makePartial(unsafeRelationSchema))
);
} else if (safeRelationSchema) {
// just relation

// TODO: we make all relation fields partial for now because in case of
// nested create, not all relation/fk fields are inside payload, need a
// better solution
createSchema = makeMerge(createSchema, makePartial(safeRelationSchema));
}
}

if (refineFuncName) {
// export a schema without refinement for extensibility
const noRefineSchema = `${upperCaseFirst(model.name)}CreateWithoutRefineSchema`;
writer.writeLine(`export const ${noRefineSchema} = ${createSchema};`);
createSchema = `${refineFuncName}(${noRefineSchema})`;
}
writer.writeLine(`export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};`);

// update schema
let updateSchema = 'baseSchema.partial()';
if (refinements.length > 0) {
updateSchema = `refine(${updateSchema})`;
////////////////////////////////////////////////
// 3. Update schema
////////////////////////////////////////////////
let updateSchema = makePartial('baseSchema');

if (safeRelationSchema || unsafeRelationSchema) {
// export schema with only scalar fields
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;
writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`);
updateSchema = updateScalarSchema;

if (safeRelationSchema && unsafeRelationSchema) {
// build a union of with relation object fields and with fk fields (mutually exclusive)
updateSchema = makeUnion(
makeMerge(updateSchema, makePartial(safeRelationSchema)),
makeMerge(updateSchema, makePartial(unsafeRelationSchema))
);
} else if (safeRelationSchema) {
// just relation
updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema));
}
}

if (refineFuncName) {
// export a schema without refinement for extensibility
const noRefineSchema = `${upperCaseFirst(model.name)}UpdateWithoutRefineSchema`;
writer.writeLine(`export const ${noRefineSchema} = ${updateSchema};`);
updateSchema = `${refineFuncName}(${noRefineSchema})`;
}
writer.writeLine(`export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};`);
});

return schemaName;
}

function makePartial(schema: string, fields: string[]) {
return `${schema}.partial({
function makePartial(schema: string, fields?: string[]) {
if (fields) {
return `${schema}.partial({
${fields.map((f) => `${f}: true`).join(', ')},
})`;
} else {
return `${schema}.partial()`;
}
}

function makeOmit(schema: string, fields: string[]) {
return `${schema}.omit({
${fields.map((f) => `${f}: true`).join(', ')},
})`;
}

function makeMerge(schema1: string, schema2: string): string {
return `${schema1}.merge(${schema2})`;
}

function makeUnion(...schemas: string[]): string {
return `z.union([${schemas.join(', ')}])`;
}

function hasForeignKey(field: DataModelField) {
const relAttr = getAttribute(field, '@relation');
if (!relAttr) {
return false;
}
return !!getAttributeArg(relAttr, 'fields');
}
20 changes: 18 additions & 2 deletions packages/schema/src/plugins/zod/utils/schema-gen.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
import { ExpressionContext, PluginError, getAttributeArg, getAttributeArgLiteral, getLiteral } from '@zenstackhq/sdk';
import { DataModel, DataModelField, DataModelFieldAttribute, isEnum } from '@zenstackhq/sdk/ast';
import { DataModel, DataModelField, DataModelFieldAttribute, isDataModel, isEnum } from '@zenstackhq/sdk/ast';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '..';
import {
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
} from '../../../utils/typescript-expression-transformer';

export function makeFieldSchema(field: DataModelField) {
export function makeFieldSchema(field: DataModelField, forMutation = false) {
if (isDataModel(field.type.reference?.ref)) {
if (!forMutation) {
// read schema, always optional
if (field.type.array) {
return `z.array(z.unknown()).optional()`;
} else {
return `z.record(z.unknown()).optional()`;
}
} else {
// write schema
return `${
field.type.optional || field.type.array ? 'z.record(z.unknown()).optional()' : 'z.record(z.unknown())'
}`;
}
}

let schema = makeZodSchema(field);
const isDecimal = field.type.type === 'Decimal';

Expand Down
Loading