From 6e7993afa8dde03ae12c44f198bcca04724dbc92 Mon Sep 17 00:00:00 2001 From: Yiming Date: Tue, 5 Mar 2024 22:36:41 -0800 Subject: [PATCH] fix: clean up generation of logical prisma client (#1082) --- .../plugins/openapi/src/generator-base.ts | 4 +- packages/plugins/openapi/src/index.ts | 14 +- .../plugins/openapi/src/rest-generator.ts | 2 +- packages/plugins/openapi/src/rpc-generator.ts | 2 +- packages/plugins/swr/src/generator.ts | 9 +- packages/plugins/swr/src/index.ts | 13 +- .../plugins/tanstack-query/src/generator.ts | 9 +- packages/plugins/tanstack-query/src/index.ts | 13 +- packages/plugins/trpc/src/generator.ts | 18 +- packages/plugins/trpc/src/helpers.ts | 7 +- packages/plugins/trpc/src/index.ts | 13 +- packages/runtime/package.json | 3 + packages/runtime/res/prisma.d.ts | 1 + packages/runtime/src/prisma.d.ts | 2 + packages/schema/src/cli/plugin-runner.ts | 127 +-- .../src/plugins/enhancer/delegate/index.ts | 16 - .../src/plugins/enhancer/enhance/index.ts | 85 +- packages/schema/src/plugins/enhancer/index.ts | 57 +- .../src/plugins/enhancer/model-meta/index.ts | 4 + .../enhancer/policy/policy-guard-generator.ts | 10 +- packages/schema/src/plugins/plugin-utils.ts | 3 + packages/schema/src/plugins/prisma/index.ts | 94 +- .../src/plugins/prisma/schema-generator.ts | 65 +- packages/schema/src/plugins/zod/generator.ts | 815 +++++++++--------- packages/schema/src/plugins/zod/index.ts | 5 +- .../schema/src/plugins/zod/transformer.ts | 50 +- packages/schema/src/plugins/zod/types.ts | 2 - packages/schema/src/telemetry.ts | 6 +- packages/sdk/src/model-meta-generator.ts | 20 + packages/sdk/src/prisma.ts | 61 +- packages/sdk/src/types.ts | 45 +- packages/sdk/src/utils.ts | 2 +- packages/testtools/src/schema.ts | 20 +- ...rphism.test.ts => enhanced-client.test.ts} | 49 +- .../with-delegate/plugin-interaction.test.ts | 25 + ...icy.test.ts => policy-interaction.test.ts} | 0 .../tests/enhancements/with-delegate/utils.ts | 47 + 37 files changed, 961 insertions(+), 757 deletions(-) create mode 100644 packages/runtime/res/prisma.d.ts create mode 100644 packages/runtime/src/prisma.d.ts delete mode 100644 packages/schema/src/plugins/enhancer/delegate/index.ts rename tests/integration/tests/enhancements/with-delegate/{polymorphism.test.ts => enhanced-client.test.ts} (97%) create mode 100644 tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts rename tests/integration/tests/enhancements/with-delegate/{policy.test.ts => policy-interaction.test.ts} (100%) create mode 100644 tests/integration/tests/enhancements/with-delegate/utils.ts diff --git a/packages/plugins/openapi/src/generator-base.ts b/packages/plugins/openapi/src/generator-base.ts index 1a46fa528..e033a5206 100644 --- a/packages/plugins/openapi/src/generator-base.ts +++ b/packages/plugins/openapi/src/generator-base.ts @@ -1,5 +1,5 @@ import type { DMMF } from '@prisma/generator-helper'; -import { PluginError, PluginOptions, getDataModels, hasAttribute } from '@zenstackhq/sdk'; +import { PluginError, PluginOptions, PluginResult, getDataModels, hasAttribute } from '@zenstackhq/sdk'; import { Model } from '@zenstackhq/sdk/ast'; import type { OpenAPIV3_1 as OAPI } from 'openapi-types'; import semver from 'semver'; @@ -12,7 +12,7 @@ export abstract class OpenAPIGeneratorBase { constructor(protected model: Model, protected options: PluginOptions, protected dmmf: DMMF.Document) {} - abstract generate(): string[]; + abstract generate(): PluginResult; protected get includedModels() { return getDataModels(this.model).filter((d) => !hasAttribute(d, '@@openapi.ignore')); diff --git a/packages/plugins/openapi/src/index.ts b/packages/plugins/openapi/src/index.ts index ddc752d8c..264403c0a 100644 --- a/packages/plugins/openapi/src/index.ts +++ b/packages/plugins/openapi/src/index.ts @@ -1,12 +1,14 @@ -import type { DMMF } from '@prisma/generator-helper'; -import { PluginError, PluginOptions } from '@zenstackhq/sdk'; -import { Model } from '@zenstackhq/sdk/ast'; +import { PluginError, PluginFunction } from '@zenstackhq/sdk'; import { RESTfulOpenAPIGenerator } from './rest-generator'; import { RPCOpenAPIGenerator } from './rpc-generator'; export const name = 'OpenAPI'; -export default async function run(model: Model, options: PluginOptions, dmmf: DMMF.Document) { +const run: PluginFunction = async (model, options, dmmf) => { + if (!dmmf) { + throw new Error('DMMF is required'); + } + const flavor = options.flavor ? (options.flavor as string) : 'rpc'; switch (flavor) { @@ -17,4 +19,6 @@ export default async function run(model: Model, options: PluginOptions, dmmf: DM default: throw new PluginError(name, `Unknown flavor: ${flavor}`); } -} +}; + +export default run; diff --git a/packages/plugins/openapi/src/rest-generator.ts b/packages/plugins/openapi/src/rest-generator.ts index 9dceeec3e..0bb76251e 100644 --- a/packages/plugins/openapi/src/rest-generator.ts +++ b/packages/plugins/openapi/src/rest-generator.ts @@ -76,7 +76,7 @@ export class RESTfulOpenAPIGenerator extends OpenAPIGeneratorBase { fs.writeFileSync(output, JSON.stringify(openapi, undefined, 2)); } - return this.warnings; + return { warnings: this.warnings }; } private generatePaths(): OAPI.PathsObject { diff --git a/packages/plugins/openapi/src/rpc-generator.ts b/packages/plugins/openapi/src/rpc-generator.ts index c551a8aef..8aa9189d0 100644 --- a/packages/plugins/openapi/src/rpc-generator.ts +++ b/packages/plugins/openapi/src/rpc-generator.ts @@ -89,7 +89,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase { fs.writeFileSync(output, JSON.stringify(openapi, undefined, 2)); } - return this.warnings; + return { warnings: this.warnings }; } private generatePaths(components: OAPI.ComponentsObject): OAPI.PathsObject { diff --git a/packages/plugins/swr/src/generator.ts b/packages/plugins/swr/src/generator.ts index 3a47a1c87..ca84c101c 100644 --- a/packages/plugins/swr/src/generator.ts +++ b/packages/plugins/swr/src/generator.ts @@ -49,11 +49,11 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. warnings.push(`Unable to find mapping for model ${dataModel.name}`); return; } - generateModelHooks(project, outDir, dataModel, mapping, legacyMutations); + generateModelHooks(project, outDir, dataModel, mapping, legacyMutations, options); }); await saveProject(project); - return warnings; + return { warnings }; } function generateModelHooks( @@ -61,14 +61,15 @@ function generateModelHooks( outDir: string, model: DataModel, mapping: DMMF.ModelMapping, - legacyMutations: boolean + legacyMutations: boolean, + options: PluginOptions ) { const fileName = paramCase(model.name); const sf = project.createSourceFile(path.join(outDir, `${fileName}.ts`), undefined, { overwrite: true }); sf.addStatements('/* eslint-disable */'); - const prismaImport = getPrismaClientImportSpec(model.$container, outDir); + const prismaImport = getPrismaClientImportSpec(outDir, options); sf.addImportDeclaration({ namedImports: ['Prisma'], isTypeOnly: true, diff --git a/packages/plugins/swr/src/index.ts b/packages/plugins/swr/src/index.ts index 43731f984..16ae20c05 100644 --- a/packages/plugins/swr/src/index.ts +++ b/packages/plugins/swr/src/index.ts @@ -1,10 +1,13 @@ -import type { DMMF } from '@prisma/generator-helper'; -import type { PluginOptions } from '@zenstackhq/sdk'; -import type { Model } from '@zenstackhq/sdk/ast'; +import type { PluginFunction } from '@zenstackhq/sdk'; import { generate } from './generator'; export const name = 'SWR'; -export default async function run(model: Model, options: PluginOptions, dmmf: DMMF.Document) { +const run: PluginFunction = async (model, options, dmmf) => { + if (!dmmf) { + throw new Error('DMMF is required'); + } return generate(model, options, dmmf); -} +}; + +export default run; diff --git a/packages/plugins/tanstack-query/src/generator.ts b/packages/plugins/tanstack-query/src/generator.ts index a6cd75a5c..4e3079db8 100644 --- a/packages/plugins/tanstack-query/src/generator.ts +++ b/packages/plugins/tanstack-query/src/generator.ts @@ -55,11 +55,11 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. warnings.push(`Unable to find mapping for model ${dataModel.name}`); return; } - generateModelHooks(target, version, project, outDir, dataModel, mapping); + generateModelHooks(target, version, project, outDir, dataModel, mapping, options); }); await saveProject(project); - return warnings; + return { warnings }; } function generateQueryHook( @@ -286,7 +286,8 @@ function generateModelHooks( project: Project, outDir: string, model: DataModel, - mapping: DMMF.ModelMapping + mapping: DMMF.ModelMapping, + options: PluginOptions ) { const modelNameCap = upperCaseFirst(model.name); const prismaVersion = getPrismaVersion(); @@ -295,7 +296,7 @@ function generateModelHooks( sf.addStatements('/* eslint-disable */'); - const prismaImport = getPrismaClientImportSpec(model.$container, outDir); + const prismaImport = getPrismaClientImportSpec(outDir, options); sf.addImportDeclaration({ namedImports: ['Prisma', model.name], isTypeOnly: true, diff --git a/packages/plugins/tanstack-query/src/index.ts b/packages/plugins/tanstack-query/src/index.ts index 181727a02..eb315e00c 100644 --- a/packages/plugins/tanstack-query/src/index.ts +++ b/packages/plugins/tanstack-query/src/index.ts @@ -1,10 +1,13 @@ -import type { DMMF } from '@prisma/generator-helper'; -import type { PluginOptions } from '@zenstackhq/sdk'; -import type { Model } from '@zenstackhq/sdk/ast'; +import type { PluginFunction } from '@zenstackhq/sdk'; import { generate } from './generator'; export const name = 'Tanstack Query'; -export default async function run(model: Model, options: PluginOptions, dmmf: DMMF.Document) { +const run: PluginFunction = async (model, options, dmmf) => { + if (!dmmf) { + throw new Error('DMMF is required'); + } return generate(model, options, dmmf); -} +}; + +export default run; diff --git a/packages/plugins/trpc/src/generator.ts b/packages/plugins/trpc/src/generator.ts index 0d252cabc..3487386a6 100644 --- a/packages/plugins/trpc/src/generator.ts +++ b/packages/plugins/trpc/src/generator.ts @@ -72,7 +72,8 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. generateModelActions, generateClientHelpers, model, - zodSchemasImport + zodSchemasImport, + options ); createHelper(outDir); @@ -86,7 +87,8 @@ function createAppRouter( generateModelActions: string[] | undefined, generateClientHelpers: string[] | undefined, zmodel: Model, - zodSchemasImport: string + zodSchemasImport: string, + options: PluginOptions ) { const indexFile = path.resolve(outDir, 'routers', `index.ts`); const appRouter = project.createSourceFile(indexFile, undefined, { @@ -95,7 +97,7 @@ function createAppRouter( appRouter.addStatements('/* eslint-disable */'); - const prismaImport = getPrismaClientImportSpec(zmodel, path.dirname(indexFile)); + const prismaImport = getPrismaClientImportSpec(path.dirname(indexFile), options); appRouter.addImportDeclarations([ { namedImports: [ @@ -169,8 +171,8 @@ function createAppRouter( outDir, generateModelActions, generateClientHelpers, - zmodel, - zodSchemasImport + zodSchemasImport, + options ); appRouter.addImportDeclaration({ @@ -239,8 +241,8 @@ function generateModelCreateRouter( outputDir: string, generateModelActions: string[] | undefined, generateClientHelpers: string[] | undefined, - zmodel: Model, - zodSchemasImport: string + zodSchemasImport: string, + options: PluginOptions ) { const modelRouter = project.createSourceFile(path.resolve(outputDir, 'routers', `${model}.router.ts`), undefined, { overwrite: true, @@ -258,7 +260,7 @@ function generateModelCreateRouter( generateRouterSchemaImport(modelRouter, zodSchemasImport); generateHelperImport(modelRouter); if (generateClientHelpers) { - generateRouterTypingImports(modelRouter, zmodel); + generateRouterTypingImports(modelRouter, options); } const createRouterFunc = modelRouter.addFunction({ diff --git a/packages/plugins/trpc/src/helpers.ts b/packages/plugins/trpc/src/helpers.ts index 54aec3ecb..62e2efafd 100644 --- a/packages/plugins/trpc/src/helpers.ts +++ b/packages/plugins/trpc/src/helpers.ts @@ -1,6 +1,5 @@ import type { DMMF } from '@prisma/generator-helper'; -import { PluginError, getPrismaClientImportSpec } from '@zenstackhq/sdk'; -import { Model } from '@zenstackhq/sdk/ast'; +import { PluginError, getPrismaClientImportSpec, type PluginOptions } from '@zenstackhq/sdk'; import { lowerCaseFirst } from 'lower-case-first'; import { CodeBlockWriter, SourceFile } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; @@ -225,9 +224,9 @@ export function generateRouterTyping(writer: CodeBlockWriter, opType: string, mo }); } -export function generateRouterTypingImports(sourceFile: SourceFile, model: Model) { +export function generateRouterTypingImports(sourceFile: SourceFile, options: PluginOptions) { const importingDir = sourceFile.getDirectoryPath(); - const prismaImport = getPrismaClientImportSpec(model, importingDir); + const prismaImport = getPrismaClientImportSpec(importingDir, options); sourceFile.addStatements([ `import type { Prisma } from '${prismaImport}';`, `import type { UseTRPCMutationOptions, UseTRPCMutationResult, UseTRPCQueryOptions, UseTRPCQueryResult, UseTRPCInfiniteQueryOptions, UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared';`, diff --git a/packages/plugins/trpc/src/index.ts b/packages/plugins/trpc/src/index.ts index 85d2a61d8..83125eb74 100644 --- a/packages/plugins/trpc/src/index.ts +++ b/packages/plugins/trpc/src/index.ts @@ -1,11 +1,14 @@ -import type { DMMF } from '@prisma/generator-helper'; -import { PluginOptions } from '@zenstackhq/sdk'; -import { Model } from '@zenstackhq/sdk/ast'; +import type { PluginFunction } from '@zenstackhq/sdk'; import { generate } from './generator'; export const name = 'tRPC'; export const dependencies = ['@core/zod']; -export default async function run(model: Model, options: PluginOptions, dmmf: DMMF.Document) { +const run: PluginFunction = async (model, options, dmmf) => { + if (!dmmf) { + throw new Error('DMMF is required'); + } return generate(model, options, dmmf); -} +}; + +export default run; diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 8292bcb7c..34f6ea0ed 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -46,6 +46,9 @@ "import": "./cross/index.mjs", "require": "./cross/index.js", "default": "./cross/index.js" + }, + "./prisma": { + "types": "./prisma.d.ts" } }, "publishConfig": { diff --git a/packages/runtime/res/prisma.d.ts b/packages/runtime/res/prisma.d.ts new file mode 100644 index 000000000..0068ce7ae --- /dev/null +++ b/packages/runtime/res/prisma.d.ts @@ -0,0 +1 @@ +export type * from '.zenstack/prisma'; diff --git a/packages/runtime/src/prisma.d.ts b/packages/runtime/src/prisma.d.ts new file mode 100644 index 000000000..c01cbe743 --- /dev/null +++ b/packages/runtime/src/prisma.d.ts @@ -0,0 +1,2 @@ +// @ts-expect-error stub for re-exporting PrismaClient +export type * from '.zenstack/prisma'; diff --git a/packages/schema/src/cli/plugin-runner.ts b/packages/schema/src/cli/plugin-runner.ts index 3e73932a1..b07a592f3 100644 --- a/packages/schema/src/cli/plugin-runner.ts +++ b/packages/schema/src/cli/plugin-runner.ts @@ -3,23 +3,25 @@ import type { DMMF } from '@prisma/generator-helper'; import { isPlugin, Model, Plugin } from '@zenstackhq/language/ast'; import { + createProject, + emitProject, getDataModels, - getDMMF, getLiteral, getLiteralArray, hasValidationAttributes, - OptionValue, - PluginDeclaredOptions, PluginError, - PluginFunction, resolvePath, + saveProject, + type OptionValue, + type PluginDeclaredOptions, + type PluginFunction, + type PluginResult, } from '@zenstackhq/sdk'; import colors from 'colors'; -import fs from 'fs'; import ora from 'ora'; import path from 'path'; +import type { Project } from 'ts-morph'; import { CorePlugins, ensureDefaultOutputFolder } from '../plugins/plugin-utils'; -import { getDefaultPrismaOutputFile } from '../plugins/prisma/schema-generator'; import telemetry from '../telemetry'; import { getVersion } from '../utils/version-utils'; @@ -57,8 +59,6 @@ export class PluginRunner { const plugins: PluginInfo[] = []; const pluginDecls = runnerOptions.schema.declarations.filter((d): d is Plugin => isPlugin(d)); - let prismaOutput = getDefaultPrismaOutputFile(runnerOptions.schemaPath); - for (const pluginDecl of pluginDecls) { const pluginProvider = this.getPluginProvider(pluginDecl); if (!pluginProvider) { @@ -103,15 +103,11 @@ export class PluginRunner { run: pluginModule.default as PluginFunction, module: pluginModule, }); - - if (pluginProvider === '@core/prisma' && typeof pluginOptions.output === 'string') { - // record custom prisma output path - prismaOutput = resolvePath(pluginOptions.output, { schemaPath: runnerOptions.schemaPath }); - } } // calculate all plugins (including core plugins implicitly enabled) - const allPlugins = this.calculateAllPlugins(runnerOptions, plugins); + const { corePlugins, userPlugins } = this.calculateAllPlugins(runnerOptions, plugins); + const allPlugins = [...corePlugins, ...userPlugins]; // check dependencies for (const plugin of allPlugins) { @@ -133,22 +129,38 @@ export class PluginRunner { const warnings: string[] = []; + // run core plugins first let dmmf: DMMF.Document | undefined = undefined; - for (const { name, description, provider, run, options: pluginOptions } of allPlugins) { - // const start = Date.now(); - await this.runPlugin(name, description, run, runnerOptions, pluginOptions, dmmf, warnings); - // console.log(`āœ… Plugin ${colors.bold(name)} (${provider}) completed in ${Date.now() - start}ms`); - if (provider === '@core/prisma') { - // load prisma DMMF - dmmf = await getDMMF({ - datamodel: fs.readFileSync(prismaOutput, { encoding: 'utf-8' }), - }); + let prismaClientPath = '@prisma/client'; + const project = createProject(); + for (const { name, description, run, options: pluginOptions } of corePlugins) { + const options = { ...pluginOptions, prismaClientPath }; + const r = await this.runPlugin(name, description, run, runnerOptions, options, dmmf, project); + warnings.push(...(r?.warnings ?? [])); // the null-check is for backward compatibility + + if (r.dmmf) { + // use the DMMF returned by the plugin + dmmf = r.dmmf; + } + + if (r.prismaClientPath) { + // use the prisma client path returned by the plugin + prismaClientPath = r.prismaClientPath; } } - console.log(colors.green(colors.bold('\nšŸ‘» All plugins completed successfully!'))); - warnings.forEach((w) => console.warn(colors.yellow(w))); + // compile code generated by core plugins + await compileProject(project, runnerOptions); + + // run user plugins + for (const { name, description, run, options: pluginOptions } of userPlugins) { + const options = { ...pluginOptions, prismaClientPath }; + const r = await this.runPlugin(name, description, run, runnerOptions, options, dmmf, project); + warnings.push(...(r?.warnings ?? [])); // the null-check is for backward compatibility + } + console.log(colors.green(colors.bold('\nšŸ‘» All plugins completed successfully!'))); + warnings.forEach((w) => console.warn(colors.yellow(w))); console.log(`Don't forget to restart your dev server to let the changes take effect.`); } @@ -168,7 +180,22 @@ export class PluginRunner { const hasValidation = this.hasValidation(options.schema); - // 2. @core/zod + // 2. @core/enhancer + const existingEnhancer = plugins.find((p) => p.provider === CorePlugins.Enhancer); + if (existingEnhancer) { + corePlugins.push(existingEnhancer); + plugins.splice(plugins.indexOf(existingEnhancer), 1); + } else { + if (options.defaultPlugins) { + corePlugins.push( + this.makeCorePlugin(CorePlugins.Enhancer, options.schemaPath, { + withZodSchemas: hasValidation, + }) + ); + } + } + + // 3. @core/zod const existingZod = plugins.find((p) => p.provider === CorePlugins.Zod); if (existingZod && !existingZod.options.output) { // we can reuse the user-provided zod plugin if it didn't specify a custom output path @@ -178,7 +205,7 @@ export class PluginRunner { if ( !corePlugins.some((p) => p.provider === CorePlugins.Zod) && - (options.defaultPlugins || plugins.some((p) => p.provider === CorePlugins.Enhancer)) && + (options.defaultPlugins || corePlugins.some((p) => p.provider === CorePlugins.Enhancer)) && hasValidation ) { // ensure "@core/zod" is enabled if "@core/enhancer" is enabled and there're validation rules @@ -186,21 +213,6 @@ export class PluginRunner { corePlugins.push(this.makeCorePlugin(CorePlugins.Zod, options.schemaPath, { modelOnly: true })); } - // 3. @core/enhancer - const existingEnhancer = plugins.find((p) => p.provider === CorePlugins.Enhancer); - if (existingEnhancer) { - corePlugins.push(existingEnhancer); - plugins.splice(plugins.indexOf(existingEnhancer), 1); - } else { - if (options.defaultPlugins) { - corePlugins.push( - this.makeCorePlugin(CorePlugins.Enhancer, options.schemaPath, { - withZodSchemas: hasValidation, - }) - ); - } - } - // collect core plugins introduced by dependencies plugins.forEach((plugin) => { // TODO: generalize this @@ -245,7 +257,7 @@ export class PluginRunner { } }); - return [...corePlugins, ...plugins]; + return { corePlugins, userPlugins: plugins }; } private makeCorePlugin( @@ -296,12 +308,12 @@ export class PluginRunner { runnerOptions: PluginRunnerOptions, options: PluginDeclaredOptions, dmmf: DMMF.Document | undefined, - warnings: string[] + project: Project ) { const title = description ?? `Running plugin ${colors.cyan(name)}`; const spinner = ora(title).start(); try { - await telemetry.trackSpan( + const r = await telemetry.trackSpan( 'cli:plugin:start', 'cli:plugin:complete', 'cli:plugin:error', @@ -310,19 +322,20 @@ export class PluginRunner { options, }, async () => { - let result = run(runnerOptions.schema, { ...options, schemaPath: runnerOptions.schemaPath }, dmmf, { + return await run(runnerOptions.schema, { ...options, schemaPath: runnerOptions.schemaPath }, dmmf, { output: runnerOptions.output, compile: runnerOptions.compile, + tsProject: project, }); - if (result instanceof Promise) { - result = await result; - } - if (Array.isArray(result)) { - warnings.push(...result); - } } ); spinner.succeed(); + + if (typeof r === 'object') { + return r; + } else { + return { warnings: [] }; + } } catch (err) { spinner.fail(); throw err; @@ -350,3 +363,13 @@ export class PluginRunner { return require(pluginModulePath); } } + +async function compileProject(project: Project, runnerOptions: PluginRunnerOptions) { + if (runnerOptions.compile !== false) { + // emit + await emitProject(project); + } else { + // otherwise save ts files + await saveProject(project); + } +} diff --git a/packages/schema/src/plugins/enhancer/delegate/index.ts b/packages/schema/src/plugins/enhancer/delegate/index.ts deleted file mode 100644 index d3f85576d..000000000 --- a/packages/schema/src/plugins/enhancer/delegate/index.ts +++ /dev/null @@ -1,16 +0,0 @@ -import { type PluginOptions } from '@zenstackhq/sdk'; -import type { Model } from '@zenstackhq/sdk/ast'; -import type { Project } from 'ts-morph'; -import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; -import path from 'path'; - -export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { - const prismaGenerator = new PrismaSchemaGenerator(model); - await prismaGenerator.generate({ - provider: '@internal', - schemaPath: options.schemaPath, - output: path.join(outDir, 'delegate.prisma'), - overrideClientGenerationPath: path.join(outDir, '.delegate'), - mode: 'logical', - }); -} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 06caf6950..9488b24f7 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -1,12 +1,16 @@ +import type { DMMF } from '@prisma/generator-helper'; import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; import { getAttribute, getDataModels, + getDMMF, getPrismaClientImportSpec, isDelegateModel, + PluginError, type PluginOptions, } from '@zenstackhq/sdk'; import { DataModel, DataModelField, isDataModel, isReferenceExpr, type Model } from '@zenstackhq/sdk/ast'; +import fs from 'fs'; import path from 'path'; import { FunctionDeclarationStructure, @@ -19,31 +23,50 @@ import { TypeAliasDeclaration, VariableStatement, } from 'ts-morph'; +import { name } from '..'; +import { execPackage } from '../../../utils/exec-utils'; +import { trackPrismaSchemaError } from '../../prisma'; import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; // information of delegate models and their sub models type DelegateInfo = [DataModel, DataModel[]][]; export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { - const outFile = path.join(outDir, 'enhance.ts'); let logicalPrismaClientDir: string | undefined; + let dmmf: DMMF.Document | undefined; if (hasDelegateModel(model)) { - logicalPrismaClientDir = await generateLogicalPrisma(model, options, outDir); + // schema contains delegate models, need to generate a logical prisma schema + const result = await generateLogicalPrisma(model, options, outDir); + + logicalPrismaClientDir = './.logical-prisma-client'; + dmmf = result.dmmf; + + // create a reexport of the logical prisma client + const prismaDts = project.createSourceFile( + path.join(outDir, 'prisma.d.ts'), + `export type * from '${logicalPrismaClientDir}/index-fixed';`, + { overwrite: true } + ); + await saveSourceFile(prismaDts, options); + } else { + // just reexport the prisma client + const prismaDts = project.createSourceFile( + path.join(outDir, 'prisma.d.ts'), + `export type * from '${getPrismaClientImportSpec(outDir, options)}';`, + { overwrite: true } + ); + await saveSourceFile(prismaDts, options); } - project.createSourceFile( - outFile, + const enhanceTs = project.createSourceFile( + path.join(outDir, 'enhance.ts'), `import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas } from '@zenstackhq/runtime'; import modelMeta from './model-meta'; import policy from './policy'; ${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} -import { Prisma } from '${getPrismaClientImportSpec(model, outDir)}'; -${ - logicalPrismaClientDir - ? `import type { PrismaClient as EnhancedPrismaClient } from '${logicalPrismaClientDir}/index-fixed';` - : '' -} +import { Prisma } from '${getPrismaClientImportSpec(outDir, options)}'; +${logicalPrismaClientDir ? `import { type PrismaClient } from '${logicalPrismaClientDir}/index-fixed';` : ``} export function enhance(prisma: DbClient, context?: EnhancementContext, options?: EnhancementOptions) { return createEnhancement(prisma, { @@ -52,11 +75,15 @@ export function enhance(prisma: DbClient, context?: Enh zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), prismaModule: Prisma, ...options - }, context)${logicalPrismaClientDir ? ' as EnhancedPrismaClient' : ''}; + }, context)${logicalPrismaClientDir ? ' as PrismaClient' : ''}; } `, { overwrite: true } ); + + await saveSourceFile(enhanceTs, options); + + return { dmmf }; } function hasDelegateModel(model: Model) { @@ -68,19 +95,40 @@ function hasDelegateModel(model: Model) { async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) { const prismaGenerator = new PrismaSchemaGenerator(model); - const prismaClientOutDir = './.delegate'; + const prismaClientOutDir = './.logical-prisma-client'; + const logicalPrismaFile = path.join(outDir, 'logical.prisma'); await prismaGenerator.generate({ provider: '@internal', // doesn't matter schemaPath: options.schemaPath, - output: path.join(outDir, 'delegate.prisma'), + output: logicalPrismaFile, overrideClientGenerationPath: prismaClientOutDir, mode: 'logical', }); + // generate the prisma client + const generateCmd = `prisma generate --schema "${logicalPrismaFile}" --no-engine`; + try { + // run 'prisma generate' + await execPackage(generateCmd, { stdio: 'ignore' }); + } catch { + await trackPrismaSchemaError(logicalPrismaFile); + try { + // run 'prisma generate' again with output to the console + await execPackage(generateCmd); + } catch { + // noop + } + throw new PluginError(name, `Failed to run "prisma generate"`); + } + // make a bunch of typing fixes to the generated prisma client await processClientTypes(model, path.join(outDir, prismaClientOutDir)); - return prismaClientOutDir; + return { + prismaSchema: logicalPrismaFile, + // load the dmmf of the logical prisma schema + dmmf: await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) }), + }; } async function processClientTypes(model: Model, prismaClientDir: string) { @@ -106,8 +154,7 @@ async function processClientTypes(model: Model, prismaClientDir: string) { }); transform(sf, sfNew, delegateInfo); sfNew.formatText(); - - await project.save(); + await sfNew.save(); } function transform(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) { @@ -352,3 +399,9 @@ function getDiscriminatorFieldsRecursively(delegate: DataModel, result: DataMode } return result; } + +async function saveSourceFile(sf: SourceFile, options: PluginOptions) { + if (options.preserveTsFiles) { + await sf.save(); + } +} diff --git a/packages/schema/src/plugins/enhancer/index.ts b/packages/schema/src/plugins/enhancer/index.ts index 86e3ecf39..64e2ad1a4 100644 --- a/packages/schema/src/plugins/enhancer/index.ts +++ b/packages/schema/src/plugins/enhancer/index.ts @@ -1,11 +1,5 @@ -import { - PluginError, - createProject, - emitProject, - resolvePath, - saveProject, - type PluginFunction, -} from '@zenstackhq/sdk'; +import { PluginError, createProject, resolvePath, type PluginFunction, RUNTIME_PACKAGE } from '@zenstackhq/sdk'; +import path from 'path'; import { getDefaultOutputFolder } from '../plugin-utils'; import { generate as generateEnhancer } from './enhance'; import { generate as generateModelMeta } from './model-meta'; @@ -15,34 +9,33 @@ export const name = 'Prisma Enhancer'; export const description = 'Generating PrismaClient enhancer'; const run: PluginFunction = async (model, options, _dmmf, globalOptions) => { - let ourDir = options.output ? (options.output as string) : getDefaultOutputFolder(globalOptions); - if (!ourDir) { + let outDir = options.output ? (options.output as string) : getDefaultOutputFolder(globalOptions); + if (!outDir) { throw new PluginError(name, `Unable to determine output path, not running plugin`); } - ourDir = resolvePath(ourDir, options); - - const project = createProject(); - - await generateModelMeta(model, options, project, ourDir); - await generatePolicy(model, options, project, ourDir); - await generateEnhancer(model, options, project, ourDir); - - let shouldCompile = true; - if (typeof options.compile === 'boolean') { - // explicit override - shouldCompile = options.compile; - } else if (globalOptions) { - // from CLI or config file - shouldCompile = globalOptions.compile; + outDir = resolvePath(outDir, options); + + const project = globalOptions?.tsProject ?? createProject(); + + await generateModelMeta(model, options, project, outDir); + await generatePolicy(model, options, project, outDir); + const { dmmf } = await generateEnhancer(model, options, project, outDir); + + let prismaClientPath: string | undefined; + if (dmmf) { + // a logical client is generated + if (typeof options.output === 'string') { + // get the absolute path of the logical prisma client + const prismaClientPathAbs = path.resolve(options.output, 'prisma'); + + // resolve it relative to the schema path + prismaClientPath = path.relative(path.dirname(options.schemaPath), prismaClientPathAbs); + } else { + prismaClientPath = `${RUNTIME_PACKAGE}/prisma`; + } } - if (!shouldCompile || options.preserveTsFiles === true) { - await saveProject(project); - } - - if (shouldCompile) { - await emitProject(project); - } + return { dmmf, warnings: [], prismaClientPath }; }; export default run; diff --git a/packages/schema/src/plugins/enhancer/model-meta/index.ts b/packages/schema/src/plugins/enhancer/model-meta/index.ts index 541106e24..9939ae346 100644 --- a/packages/schema/src/plugins/enhancer/model-meta/index.ts +++ b/packages/schema/src/plugins/enhancer/model-meta/index.ts @@ -6,8 +6,12 @@ import type { Project } from 'ts-morph'; export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { const outFile = path.join(outDir, 'model-meta.ts'); const dataModels = getDataModels(model); + + // save ts files if requested explicitly or the user provided + const preserveTsFiles = options.preserveTsFiles === true || !!options.output; await generateModelMeta(project, dataModels, { output: outFile, generateAttributes: true, + preserveTsFiles, }); } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 2032f2b99..fa1eb831a 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -60,7 +60,7 @@ import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; * Generates source file that contains Prisma query guard objects used for injecting database queries */ export class PolicyGenerator { - async generate(project: Project, model: Model, _options: PluginOptions, output: string) { + async generate(project: Project, model: Model, options: PluginOptions, output: string) { const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true }); sf.addStatements('/* eslint-disable */'); @@ -75,7 +75,7 @@ export class PolicyGenerator { }); // import enums - const prismaImport = getPrismaClientImportSpec(model, output); + const prismaImport = getPrismaClientImportSpec(output, options); for (const e of model.declarations.filter((d) => isEnum(d) && this.isEnumReferenced(model, d))) { sf.addImportDeclaration({ namedImports: [{ name: e.name }], @@ -140,6 +140,12 @@ export class PolicyGenerator { }); sf.addStatements('export default policy'); + + // save ts files if requested explicitly or the user provided + const preserveTsFiles = options.preserveTsFiles === true || !!options.output; + if (preserveTsFiles) { + await sf.save(); + } } // Generates a { select: ... } object to select `auth()` fields used in policy rules diff --git a/packages/schema/src/plugins/plugin-utils.ts b/packages/schema/src/plugins/plugin-utils.ts index 00b806e7e..d6cc12403 100644 --- a/packages/schema/src/plugins/plugin-utils.ts +++ b/packages/schema/src/plugins/plugin-utils.ts @@ -55,6 +55,9 @@ export function ensureDefaultOutputFolder(options: PluginRunnerOptions) { types: './zod/objects/index.d.ts', default: './zod/objects/index.js', }, + './prisma': { + types: './prisma.d.ts', + }, }, }; fs.writeFileSync(path.join(output, 'package.json'), JSON.stringify(pkgJson, undefined, 4)); diff --git a/packages/schema/src/plugins/prisma/index.ts b/packages/schema/src/plugins/prisma/index.ts index 5aa64c145..478b6a54b 100644 --- a/packages/schema/src/plugins/prisma/index.ts +++ b/packages/schema/src/plugins/prisma/index.ts @@ -1,11 +1,101 @@ -import { PluginFunction } from '@zenstackhq/sdk'; +import { PluginError, PluginFunction, getDMMF, getLiteral, resolvePath } from '@zenstackhq/sdk'; +import { GeneratorDecl, isGeneratorDecl } from '@zenstackhq/sdk/ast'; +import fs from 'fs'; +import path from 'path'; +import stripColor from 'strip-color'; +import telemetry from '../../telemetry'; +import { execPackage } from '../../utils/exec-utils'; +import { findUp } from '../../utils/pkg-utils'; import { PrismaSchemaGenerator } from './schema-generator'; export const name = 'Prisma'; export const description = 'Generating Prisma schema'; const run: PluginFunction = async (model, options, _dmmf, _globalOptions) => { - return new PrismaSchemaGenerator(model).generate(options); + // deal with calculation of the default output location + const output = options.output + ? resolvePath(options.output as string, options) + : getDefaultPrismaOutputFile(options.schemaPath); + + const warnings = await new PrismaSchemaGenerator(model).generate({ ...options, output }); + let prismaClientPath = '@prisma/client'; + + if (options.generateClient !== false) { + let generateCmd = `prisma generate --schema "${output}"`; + if (typeof options.generateArgs === 'string') { + generateCmd += ` ${options.generateArgs}`; + } + try { + // run 'prisma generate' + await execPackage(generateCmd, { stdio: 'ignore' }); + } catch { + await trackPrismaSchemaError(output); + try { + // run 'prisma generate' again with output to the console + await execPackage(generateCmd); + } catch { + // noop + } + throw new PluginError(name, `Failed to run "prisma generate"`); + } + + // extract user-provided prisma client output path + const generator = model.declarations.find( + (d): d is GeneratorDecl => + isGeneratorDecl(d) && + d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js') + ); + const clientOutputField = generator?.fields.find((f) => f.name === 'output'); + const clientOutput = getLiteral(clientOutputField?.value); + + if (clientOutput) { + if (path.isAbsolute(clientOutput)) { + prismaClientPath = clientOutput; + } else { + // first get absolute path based on prisma schema location + const absPath = path.resolve(path.dirname(output), clientOutput); + + // then make it relative to the zmodel schema location + prismaClientPath = path.relative(path.dirname(options.schemaPath), absPath); + } + } + } + + // load the result DMMF + const dmmf = await getDMMF({ + datamodel: fs.readFileSync(output, 'utf-8'), + }); + + return { warnings, dmmf, prismaClientPath }; }; +function getDefaultPrismaOutputFile(schemaPath: string) { + // handle override from package.json + const pkgJsonPath = findUp(['package.json'], path.dirname(schemaPath)); + if (pkgJsonPath) { + const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf-8')); + if (typeof pkgJson?.zenstack?.prisma === 'string') { + if (path.isAbsolute(pkgJson.zenstack.prisma)) { + return pkgJson.zenstack.prisma; + } else { + // resolve relative to package.json + return path.resolve(path.dirname(pkgJsonPath), pkgJson.zenstack.prisma); + } + } + } + + return resolvePath('./prisma/schema.prisma', { schemaPath }); +} + +export async function trackPrismaSchemaError(schema: string) { + try { + await getDMMF({ datamodel: fs.readFileSync(schema, 'utf-8') }); + } catch (err) { + if (err instanceof Error) { + // eslint-disable-next-line @typescript-eslint/no-var-requires + telemetry.track('prisma:error', { command: 'generate', message: stripColor(err.message) }); + } + } +} + export default run; diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 4ac78c6e3..2519c3cd3 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -34,7 +34,6 @@ import { getIdFields } from '../../utils/ast-utils'; import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; import { getAttribute, - getDMMF, getLiteral, getPrismaVersion, isAuthInvocation, @@ -43,7 +42,6 @@ import { PluginError, PluginOptions, resolved, - resolvePath, ZModelCodeGenerator, } from '@zenstackhq/sdk'; import fs from 'fs'; @@ -52,13 +50,10 @@ import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import semver from 'semver'; -import stripColor from 'strip-color'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; -import telemetry from '../../telemetry'; import { execPackage } from '../../utils/exec-utils'; -import { findUp } from '../../utils/pkg-utils'; import { AttributeArgValue, ModelFieldType, @@ -100,6 +95,11 @@ export class PrismaSchemaGenerator { constructor(private readonly zmodel: Model) {} async generate(options: PluginOptions) { + if (!options.output) { + throw new PluginError(name, 'Output file is not specified'); + } + + const outFile = options.output as string; const warnings: string[] = []; if (options.mode) { this.mode = options.mode as 'logical' | 'physical'; @@ -134,10 +134,6 @@ export class PrismaSchemaGenerator { } } - const outFile = options.output - ? resolvePath(options.output as string, options) - : getDefaultPrismaOutputFile(options.schemaPath); - if (!fs.existsSync(path.dirname(outFile))) { fs.mkdirSync(path.dirname(outFile), { recursive: true }); } @@ -152,42 +148,9 @@ export class PrismaSchemaGenerator { } } - const generateClient = options.generateClient !== false; - - if (generateClient) { - let generateCmd = `prisma generate --schema "${outFile}"${this.mode === 'logical' ? ' --no-engine' : ''}`; - if (typeof options.generateArgs === 'string') { - generateCmd += ` ${options.generateArgs}`; - } - try { - // run 'prisma generate' - await execPackage(generateCmd, { stdio: 'ignore' }); - } catch { - await this.trackPrismaSchemaError(outFile); - try { - // run 'prisma generate' again with output to the console - await execPackage(generateCmd); - } catch { - // noop - } - throw new PluginError(name, `Failed to run "prisma generate"`); - } - } - return warnings; } - private async trackPrismaSchemaError(schema: string) { - try { - await getDMMF({ datamodel: fs.readFileSync(schema, 'utf-8') }); - } catch (err) { - if (err instanceof Error) { - // eslint-disable-next-line @typescript-eslint/no-var-requires - telemetry.track('prisma:error', { command: 'generate', message: stripColor(err.message) }); - } - } - } - private generateDataSource(prisma: PrismaModel, dataSource: DataSource) { const fields: SimpleField[] = dataSource.fields.map((f) => ({ name: f.name, @@ -693,21 +656,3 @@ export class PrismaSchemaGenerator { function isDescendantOf(model: DataModel, superModel: DataModel): boolean { return model.superTypes.some((s) => s.ref === superModel || isDescendantOf(s.ref!, superModel)); } - -export function getDefaultPrismaOutputFile(schemaPath: string) { - // handle override from package.json - const pkgJsonPath = findUp(['package.json'], path.dirname(schemaPath)); - if (pkgJsonPath) { - const pkgJson = JSON.parse(fs.readFileSync(pkgJsonPath, 'utf-8')); - if (typeof pkgJson?.zenstack?.prisma === 'string') { - if (path.isAbsolute(pkgJson.zenstack.prisma)) { - return pkgJson.zenstack.prisma; - } else { - // resolve relative to package.json - return path.resolve(path.dirname(pkgJsonPath), pkgJson.zenstack.prisma); - } - } - } - - return resolvePath('./prisma/schema.prisma', { schemaPath }); -} diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index a09c4ad73..f09e93951 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -2,8 +2,6 @@ import { ConnectorType, DMMF } from '@prisma/generator-helper'; import { PluginGlobalOptions, PluginOptions, - createProject, - emitProject, getDataModels, getLiteral, getPrismaClientImportSpec, @@ -13,14 +11,13 @@ import { isFromStdlib, parseOptionAsStrings, resolvePath, - saveProject, } from '@zenstackhq/sdk'; import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast'; import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers'; import { promises as fs } from 'fs'; import { streamAllContents } from 'langium'; import path from 'path'; -import { Project } from 'ts-morph'; +import type { SourceFile } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; import { getDefaultOutputFolder } from '../plugin-utils'; @@ -28,480 +25,490 @@ import Transformer from './transformer'; import removeDir from './utils/removeDir'; import { getFieldSchemaDefault, makeFieldSchema, makeValidationRefinements } from './utils/schema-gen'; -export async function generate( - model: Model, - options: PluginOptions, - dmmf: DMMF.Document, - globalOptions?: PluginGlobalOptions -) { - let output = options.output as string; - if (!output) { - const defaultOutputFolder = getDefaultOutputFolder(globalOptions); - if (defaultOutputFolder) { - output = path.join(defaultOutputFolder, 'zod'); - } else { - output = './generated/zod'; +export class ZodSchemaGenerator { + private readonly sourceFiles: SourceFile[] = []; + private readonly globalOptions: PluginGlobalOptions; + + constructor( + private readonly model: Model, + private readonly options: PluginOptions, + private readonly dmmf: DMMF.Document, + globalOptions: PluginGlobalOptions | undefined + ) { + if (!globalOptions) { + throw new Error('Global options are required'); } + this.globalOptions = globalOptions; } - output = resolvePath(output, options); - await handleGeneratorOutputValue(output); - // calculate the models to be excluded - const excludeModels = getExcludedModels(model, options); + async generate() { + let output = this.options.output as string; + if (!output) { + const defaultOutputFolder = getDefaultOutputFolder(this.globalOptions); + if (defaultOutputFolder) { + output = path.join(defaultOutputFolder, 'zod'); + } else { + output = './generated/zod'; + } + } + output = resolvePath(output, this.options); + await this.handleGeneratorOutputValue(output); - const prismaClientDmmf = dmmf; + // calculate the models to be excluded + const excludeModels = this.getExcludedModels(); - const modelOperations = prismaClientDmmf.mappings.modelOperations.filter( - (o) => !excludeModels.find((e) => e === o.model) - ); + const prismaClientDmmf = this.dmmf; - // TODO: better way of filtering than string startsWith? - const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma.filter( - (type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLocaleLowerCase())) - ); - const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma.filter( - (type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLowerCase())) - ); + const modelOperations = prismaClientDmmf.mappings.modelOperations.filter( + (o) => !excludeModels.find((e) => e === o.model) + ); - const models: DMMF.Model[] = prismaClientDmmf.datamodel.models.filter( - (m) => !excludeModels.find((e) => e === m.name) - ); + // TODO: better way of filtering than string startsWith? + const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma.filter( + (type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLocaleLowerCase())) + ); + const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma.filter( + (type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLowerCase())) + ); - // whether Prisma's Unchecked* series of input types should be generated - const generateUnchecked = options.noUncheckedInput !== true; + const models: DMMF.Model[] = prismaClientDmmf.datamodel.models.filter( + (m) => !excludeModels.find((e) => e === m.name) + ); - const project = createProject(); + // common schemas + await this.generateCommonSchemas(output); - // common schemas - await generateCommonSchemas(project, output); + // enums + await this.generateEnumSchemas( + prismaClientDmmf.schema.enumTypes.prisma, + prismaClientDmmf.schema.enumTypes.model ?? [] + ); - // enums - await generateEnumSchemas( - prismaClientDmmf.schema.enumTypes.prisma, - prismaClientDmmf.schema.enumTypes.model ?? [], - project, - model - ); + const dataSource = this.model.declarations.find((d): d is DataSource => isDataSource(d)); - const dataSource = model.declarations.find((d): d is DataSource => isDataSource(d)); + const dataSourceProvider = getLiteral( + dataSource?.fields.find((f) => f.name === 'provider')?.value + ) as ConnectorType; - const dataSourceProvider = getLiteral( - dataSource?.fields.find((f) => f.name === 'provider')?.value - ) as ConnectorType; + await this.generateModelSchemas(output, excludeModels); - await generateModelSchemas(project, model, output, excludeModels); + if (this.options.modelOnly !== true) { + // detailed object schemas referenced from input schemas + Transformer.provider = dataSourceProvider; + addMissingInputObjectTypes(inputObjectTypes, outputObjectTypes, models); + const aggregateOperationSupport = resolveAggregateOperationSupport(inputObjectTypes); + await this.generateObjectSchemas(inputObjectTypes, output); - if (options.modelOnly !== true) { - // detailed object schemas referenced from input schemas - Transformer.provider = dataSourceProvider; - addMissingInputObjectTypes(inputObjectTypes, outputObjectTypes, models); - const aggregateOperationSupport = resolveAggregateOperationSupport(inputObjectTypes); - await generateObjectSchemas(inputObjectTypes, project, output, model, generateUnchecked); + // input schemas + const transformer = new Transformer({ + models, + modelOperations, + aggregateOperationSupport, + project: this.project, + inputObjectTypes, + }); + await transformer.generateInputSchemas(this.options); + this.sourceFiles.push(...transformer.sourceFiles); + } - // input schemas - const transformer = new Transformer({ - models, - modelOperations, - aggregateOperationSupport, - project, - zmodel: model, - inputObjectTypes, - }); - await transformer.generateInputSchemas(generateUnchecked); - } + // create barrel file + const exports = [`export * as models from './models'`, `export * as enums from './enums'`]; + if (this.options.modelOnly !== true) { + exports.push(`export * as input from './input'`, `export * as objects from './objects'`); + } + this.sourceFiles.push( + this.project.createSourceFile(path.join(output, 'index.ts'), exports.join(';\n'), { overwrite: true }) + ); - // create barrel file - const exports = [`export * as models from './models'`, `export * as enums from './enums'`]; - if (options.modelOnly !== true) { - exports.push(`export * as input from './input'`, `export * as objects from './objects'`); - } - project.createSourceFile(path.join(output, 'index.ts'), exports.join(';\n'), { overwrite: true }); - - // emit - let shouldCompile = true; - if (typeof options.compile === 'boolean') { - // explicit override - shouldCompile = options.compile; - } else if (globalOptions) { - // from CLI or config file - shouldCompile = globalOptions.compile; + if (this.options.preserveTsFiles === true || this.options.output) { + // if preserveTsFiles is true or the user provided a custom output directory, + // save the generated files + await Promise.all( + this.sourceFiles.map(async (sf) => { + await sf.formatText(); + await sf.save(); + }) + ); + } } - if (!shouldCompile || options.preserveTsFiles === true) { - // save ts files - await saveProject(project); + private get project() { + return this.globalOptions.tsProject; } - if (shouldCompile) { - await emitProject(project); - } -} -function getExcludedModels(model: Model, options: PluginOptions) { - // resolve "generateModels" option - const generateModels = parseOptionAsStrings(options, 'generateModels', name); - if (generateModels) { - if (options.modelOnly === true) { - // no model reference needs to be considered, directly exclude any model not included - return model.declarations - .filter((d) => isDataModel(d) && !generateModels.includes(d.name)) - .map((m) => m.name); - } else { - // calculate a transitive closure of models to be included - const todo = getDataModels(model).filter((dm) => generateModels.includes(dm.name)); - const included = new Set(); - while (todo.length > 0) { - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const dm = todo.pop()!; - included.add(dm); - - // add referenced models to the todo list - dm.fields - .map((f) => f.type.reference?.ref) - .filter((type): type is DataModel => isDataModel(type)) - .forEach((type) => { - if (!included.has(type)) { - todo.push(type); - } - }); - } + private getExcludedModels() { + // resolve "generateModels" option + const generateModels = parseOptionAsStrings(this.options, 'generateModels', name); + if (generateModels) { + if (this.options.modelOnly === true) { + // no model reference needs to be considered, directly exclude any model not included + return this.model.declarations + .filter((d) => isDataModel(d) && !generateModels.includes(d.name)) + .map((m) => m.name); + } else { + // calculate a transitive closure of models to be included + const todo = getDataModels(this.model).filter((dm) => generateModels.includes(dm.name)); + const included = new Set(); + while (todo.length > 0) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const dm = todo.pop()!; + included.add(dm); + + // add referenced models to the todo list + dm.fields + .map((f) => f.type.reference?.ref) + .filter((type): type is DataModel => isDataModel(type)) + .forEach((type) => { + if (!included.has(type)) { + todo.push(type); + } + }); + } - // finally find the models to be excluded - return getDataModels(model) - .filter((dm) => !included.has(dm)) - .map((m) => m.name); + // finally find the models to be excluded + return getDataModels(this.model) + .filter((dm) => !included.has(dm)) + .map((m) => m.name); + } + } else { + return []; } - } else { - return []; } -} - -async function handleGeneratorOutputValue(output: string) { - // create the output directory and delete contents that might exist from a previous run - await fs.mkdir(output, { recursive: true }); - const isRemoveContentsOnly = true; - await removeDir(output, isRemoveContentsOnly); - Transformer.setOutputPath(output); -} + private async handleGeneratorOutputValue(output: string) { + // create the output directory and delete contents that might exist from a previous run + await fs.mkdir(output, { recursive: true }); + const isRemoveContentsOnly = true; + await removeDir(output, isRemoveContentsOnly); -async function generateCommonSchemas(project: Project, output: string) { - // Decimal - project.createSourceFile( - path.join(output, 'common', 'index.ts'), - ` -import { z } from 'zod'; -export const DecimalSchema = z.union([z.number(), z.string(), z.object({d: z.number().array(), e: z.number(), s: z.number()}).passthrough()]); -`, - { overwrite: true } - ); -} + Transformer.setOutputPath(output); + } -async function generateEnumSchemas( - prismaSchemaEnum: DMMF.SchemaEnum[], - modelSchemaEnum: DMMF.SchemaEnum[], - project: Project, - zmodel: Model -) { - const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum]; - const enumNames = enumTypes.map((enumItem) => upperCaseFirst(enumItem.name)); - Transformer.enumNames = enumNames ?? []; - const transformer = new Transformer({ - enumTypes, - project, - zmodel, - inputObjectTypes: [], - }); - await transformer.generateEnumSchemas(); -} + private async generateCommonSchemas(output: string) { + // Decimal + this.sourceFiles.push( + this.project.createSourceFile( + path.join(output, 'common', 'index.ts'), + ` + import { z } from 'zod'; + export const DecimalSchema = z.union([z.number(), z.string(), z.object({d: z.number().array(), e: z.number(), s: z.number()}).passthrough()]); + `, + { overwrite: true } + ) + ); + } -async function generateObjectSchemas( - inputObjectTypes: DMMF.InputType[], - project: Project, - output: string, - zmodel: Model, - generateUnchecked: boolean -) { - const moduleNames: string[] = []; - for (let i = 0; i < inputObjectTypes.length; i += 1) { - const fields = inputObjectTypes[i]?.fields; - const name = inputObjectTypes[i]?.name; - if (!generateUnchecked && name.includes('Unchecked')) { - continue; - } - const transformer = new Transformer({ name, fields, project, zmodel, inputObjectTypes }); - const moduleName = transformer.generateObjectSchema(generateUnchecked); - moduleNames.push(moduleName); + private async generateEnumSchemas(prismaSchemaEnum: DMMF.SchemaEnum[], modelSchemaEnum: DMMF.SchemaEnum[]) { + const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum]; + const enumNames = enumTypes.map((enumItem) => upperCaseFirst(enumItem.name)); + Transformer.enumNames = enumNames ?? []; + const transformer = new Transformer({ + enumTypes, + project: this.project, + inputObjectTypes: [], + }); + await transformer.generateEnumSchemas(); + this.sourceFiles.push(...transformer.sourceFiles); } - project.createSourceFile( - path.join(output, 'objects/index.ts'), - moduleNames.map((name) => `export * from './${name}';`).join('\n'), - { overwrite: true } - ); -} -async function generateModelSchemas(project: Project, zmodel: Model, output: string, excludedModels: string[]) { - const schemaNames: string[] = []; - for (const dm of getDataModels(zmodel)) { - if (!excludedModels.includes(dm.name)) { - schemaNames.push(await generateModelSchema(dm, project, output)); + private async generateObjectSchemas(inputObjectTypes: DMMF.InputType[], output: string) { + // whether Prisma's Unchecked* series of input types should be generated + const generateUnchecked = this.options.noUncheckedInput !== true; + + const moduleNames: string[] = []; + for (let i = 0; i < inputObjectTypes.length; i += 1) { + const fields = inputObjectTypes[i]?.fields; + const name = inputObjectTypes[i]?.name; + if (!generateUnchecked && name.includes('Unchecked')) { + continue; + } + const transformer = new Transformer({ + name, + fields, + project: this.project, + inputObjectTypes, + }); + const moduleName = transformer.generateObjectSchema(generateUnchecked, this.options); + moduleNames.push(moduleName); + this.sourceFiles.push(...transformer.sourceFiles); } + + this.sourceFiles.push( + this.project.createSourceFile( + path.join(output, 'objects/index.ts'), + moduleNames.map((name) => `export * from './${name}';`).join('\n'), + { overwrite: true } + ) + ); } - project.createSourceFile( - path.join(output, 'models', 'index.ts'), - schemaNames.map((name) => `export * from './${name}';`).join('\n'), - { overwrite: true } - ); -} + private async generateModelSchemas(output: string, excludedModels: string[]) { + const schemaNames: string[] = []; + for (const dm of getDataModels(this.model)) { + if (!excludedModels.includes(dm.name)) { + schemaNames.push(await this.generateModelSchema(dm, output)); + } + } -async function generateModelSchema(model: DataModel, project: Project, output: string) { - const schemaName = `${upperCaseFirst(model.name)}.schema`; - const sf = project.createSourceFile(path.join(output, 'models', `${schemaName}.ts`), undefined, { - overwrite: true, - }); - sf.replaceWithText((writer) => { - const scalarFields = model.fields.filter( - (field) => - // regular fields only - !isDataModel(field.type.reference?.ref) && !isForeignKeyField(field) + this.sourceFiles.push( + this.project.createSourceFile( + path.join(output, 'models', 'index.ts'), + schemaNames.map((name) => `export * from './${name}';`).join('\n'), + { overwrite: true } + ) ); + } - const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref)); - const fkFields = model.fields.filter((field) => isForeignKeyField(field)); + private async generateModelSchema(model: DataModel, output: string) { + const schemaName = `${upperCaseFirst(model.name)}.schema`; + const sf = this.project.createSourceFile(path.join(output, 'models', `${schemaName}.ts`), undefined, { + overwrite: true, + }); + this.sourceFiles.push(sf); + sf.replaceWithText((writer) => { + const scalarFields = model.fields.filter( + (field) => + // regular fields only + !isDataModel(field.type.reference?.ref) && !isForeignKeyField(field) + ); - writer.writeLine('/* eslint-disable */'); - writer.writeLine(`import { z } from 'zod';`); + const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref)); + const fkFields = model.fields.filter((field) => isForeignKeyField(field)); - // import user-defined enums from Prisma as they might be referenced in the expressions - const importEnums = new Set(); - for (const node of streamAllContents(model)) { - if (isEnumFieldReference(node)) { - const field = node.target.ref as EnumField; - if (!isFromStdlib(field.$container)) { - importEnums.add(field.$container.name); + writer.writeLine('/* eslint-disable */'); + writer.writeLine(`import { z } from 'zod';`); + + // import user-defined enums from Prisma as they might be referenced in the expressions + const importEnums = new Set(); + for (const node of streamAllContents(model)) { + if (isEnumFieldReference(node)) { + const field = node.target.ref as EnumField; + if (!isFromStdlib(field.$container)) { + importEnums.add(field.$container.name); + } } } - } - if (importEnums.size > 0) { - const prismaImport = getPrismaClientImportSpec(model.$container, path.join(output, 'models')); - writer.writeLine(`import { ${[...importEnums].join(', ')} } from '${prismaImport}';`); - } + if (importEnums.size > 0) { + const prismaImport = getPrismaClientImportSpec(path.join(output, 'models'), this.options); + writer.writeLine(`import { ${[...importEnums].join(', ')} } from '${prismaImport}';`); + } - // import enum schemas - const importedEnumSchemas = new Set(); - for (const field of scalarFields) { - if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) { - const name = upperCaseFirst(field.type.reference?.ref.name); - if (!importedEnumSchemas.has(name)) { - writer.writeLine(`import { ${name}Schema } from '../enums/${name}.schema';`); - importedEnumSchemas.add(name); + // import enum schemas + const importedEnumSchemas = new Set(); + for (const field of scalarFields) { + if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) { + const name = upperCaseFirst(field.type.reference?.ref.name); + if (!importedEnumSchemas.has(name)) { + writer.writeLine(`import { ${name}Schema } from '../enums/${name}.schema';`); + importedEnumSchemas.add(name); + } } } - } - - // import Decimal - if (scalarFields.some((field) => field.type.type === 'Decimal')) { - writer.writeLine(`import { DecimalSchema } from '../common';`); - writer.writeLine(`import { Decimal } from 'decimal.js';`); - } - // base schema - writer.write(`const baseSchema = z.object(`); - writer.inlineBlock(() => { - scalarFields.forEach((field) => { - writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); - }); - }); - writer.writeLine(');'); - - // relation fields - - let relationSchema: string | undefined; - let fkSchema: string | undefined; + // import Decimal + if (scalarFields.some((field) => field.type.type === 'Decimal')) { + writer.writeLine(`import { DecimalSchema } from '../common';`); + writer.writeLine(`import { Decimal } from 'decimal.js';`); + } - if (relations.length > 0 || fkFields.length > 0) { - relationSchema = 'relationSchema'; - writer.write(`const ${relationSchema} = z.object(`); + // base schema + writer.write(`const baseSchema = z.object(`); writer.inlineBlock(() => { - [...relations, ...fkFields].forEach((field) => { - writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); + scalarFields.forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); }); }); writer.writeLine(');'); - } - if (fkFields.length > 0) { - fkSchema = 'fkSchema'; - writer.write(`const ${fkSchema} = z.object(`); - writer.inlineBlock(() => { - fkFields.forEach((field) => { - writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); + // relation fields + + let relationSchema: string | undefined; + let fkSchema: string | undefined; + + if (relations.length > 0 || fkFields.length > 0) { + relationSchema = 'relationSchema'; + writer.write(`const ${relationSchema} = z.object(`); + writer.inlineBlock(() => { + [...relations, ...fkFields].forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); + }); }); - }); - writer.writeLine(');'); - } + writer.writeLine(');'); + } - // compile "@@validate" to ".refine" - const refinements = makeValidationRefinements(model); - let refineFuncName: string | undefined; - if (refinements.length > 0) { - refineFuncName = `refine${upperCaseFirst(model.name)}`; - writer.writeLine( - `export function ${refineFuncName}(schema: z.ZodType) { return schema${refinements.join( - '\n' - )}; }` - ); - } + if (fkFields.length > 0) { + fkSchema = 'fkSchema'; + writer.write(`const ${fkSchema} = z.object(`); + writer.inlineBlock(() => { + fkFields.forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); + }); + }); + writer.writeLine(');'); + } - //////////////////////////////////////////////// - // 1. Model schema - //////////////////////////////////////////////// - const fieldsWithoutDefault = scalarFields.filter((f) => !getFieldSchemaDefault(f)); - // mark fields without default value as optional - let modelSchema = makePartial( - 'baseSchema', - fieldsWithoutDefault.length < scalarFields.length ? fieldsWithoutDefault.map((f) => f.name) : undefined - ); + // compile "@@validate" to ".refine" + const refinements = makeValidationRefinements(model); + let refineFuncName: string | undefined; + if (refinements.length > 0) { + refineFuncName = `refine${upperCaseFirst(model.name)}`; + writer.writeLine( + `export function ${refineFuncName}(schema: z.ZodType) { return schema${refinements.join( + '\n' + )}; }` + ); + } - // omit fields - const fieldsToOmit = scalarFields.filter((field) => hasAttribute(field, '@omit')); - if (fieldsToOmit.length > 0) { - modelSchema = makeOmit( - modelSchema, - fieldsToOmit.map((f) => f.name) + //////////////////////////////////////////////// + // 1. Model schema + //////////////////////////////////////////////// + const fieldsWithoutDefault = scalarFields.filter((f) => !getFieldSchemaDefault(f)); + // mark fields without default value as optional + let modelSchema = this.makePartial( + 'baseSchema', + fieldsWithoutDefault.length < scalarFields.length ? fieldsWithoutDefault.map((f) => f.name) : undefined ); - } - if (relationSchema) { - // export schema with only scalar fields - const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`; - writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`); - modelSchema = modelScalarSchema; + // omit fields + const fieldsToOmit = scalarFields.filter((field) => hasAttribute(field, '@omit')); + if (fieldsToOmit.length > 0) { + modelSchema = this.makeOmit( + modelSchema, + fieldsToOmit.map((f) => f.name) + ); + } - // merge relations - modelSchema = makeMerge(modelSchema, makePartial(relationSchema)); - } + if (relationSchema) { + // export schema with only scalar fields + const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`; + writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`); + modelSchema = modelScalarSchema; - // refine - if (refineFuncName) { - const noRefineSchema = `${upperCaseFirst(model.name)}WithoutRefineSchema`; - writer.writeLine(`export const ${noRefineSchema} = ${modelSchema};`); - modelSchema = `${refineFuncName}(${noRefineSchema})`; - } - writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`); + // merge relations + modelSchema = this.makeMerge(modelSchema, this.makePartial(relationSchema)); + } - //////////////////////////////////////////////// - // 2. Prisma create & update - //////////////////////////////////////////////// + // refine + if (refineFuncName) { + const noRefineSchema = `${upperCaseFirst(model.name)}WithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${modelSchema};`); + modelSchema = `${refineFuncName}(${noRefineSchema})`; + } + writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`); - // schema for validating prisma create input (all fields optional) - let prismaCreateSchema = makePassthrough(makePartial('baseSchema')); - if (refineFuncName) { - prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`; - } - writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSchema};`); - - // schema for validating prisma update input (all fields optional) - // note numeric fields can be simple update or atomic operations - let prismaUpdateSchema = `z.object({ - ${scalarFields - .map((field) => { - let fieldSchema = makeFieldSchema(field); - if (field.type.type === 'Int' || field.type.type === 'Float') { - fieldSchema = `z.union([${fieldSchema}, z.record(z.unknown())])`; - } - return `\t${field.name}: ${fieldSchema}`; - }) - .join(',\n')} -})`; - prismaUpdateSchema = makePartial(prismaUpdateSchema); - if (refineFuncName) { - prismaUpdateSchema = `${refineFuncName}(${prismaUpdateSchema})`; - } - writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSchema};`); - - //////////////////////////////////////////////// - // 3. Create schema - //////////////////////////////////////////////// - let createSchema = 'baseSchema'; - const fieldsWithDefault = scalarFields.filter( - (field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array - ); - if (fieldsWithDefault.length > 0) { - createSchema = makePartial( - createSchema, - fieldsWithDefault.map((f) => f.name) + //////////////////////////////////////////////// + // 2. Prisma create & update + //////////////////////////////////////////////// + + // schema for validating prisma create input (all fields optional) + let prismaCreateSchema = this.makePassthrough(this.makePartial('baseSchema')); + if (refineFuncName) { + prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`; + } + writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSchema};`); + + // schema for validating prisma update input (all fields optional) + // note numeric fields can be simple update or atomic operations + let prismaUpdateSchema = `z.object({ + ${scalarFields + .map((field) => { + let fieldSchema = makeFieldSchema(field); + if (field.type.type === 'Int' || field.type.type === 'Float') { + fieldSchema = `z.union([${fieldSchema}, z.record(z.unknown())])`; + } + return `\t${field.name}: ${fieldSchema}`; + }) + .join(',\n')} + })`; + prismaUpdateSchema = this.makePartial(prismaUpdateSchema); + if (refineFuncName) { + prismaUpdateSchema = `${refineFuncName}(${prismaUpdateSchema})`; + } + writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSchema};`); + + //////////////////////////////////////////////// + // 3. Create schema + //////////////////////////////////////////////// + let createSchema = 'baseSchema'; + const fieldsWithDefault = scalarFields.filter( + (field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array ); - } + if (fieldsWithDefault.length > 0) { + createSchema = this.makePartial( + createSchema, + fieldsWithDefault.map((f) => f.name) + ); + } - if (fkSchema) { - // export schema with only scalar fields - const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`; - writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`); + if (fkSchema) { + // export schema with only scalar fields + const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`; + writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`); - // merge fk fields - createSchema = makeMerge(createScalarSchema, fkSchema); - } + // merge fk fields + createSchema = this.makeMerge(createScalarSchema, fkSchema); + } - if (refineFuncName) { - // export a schema without refinement for extensibility - const noRefineSchema = `${upperCaseFirst(model.name)}CreateWithoutRefineSchema`; - writer.writeLine(`export const ${noRefineSchema} = ${createSchema};`); - createSchema = `${refineFuncName}(${noRefineSchema})`; - } - writer.writeLine(`export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};`); + if (refineFuncName) { + // export a schema without refinement for extensibility + const noRefineSchema = `${upperCaseFirst(model.name)}CreateWithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${createSchema};`); + createSchema = `${refineFuncName}(${noRefineSchema})`; + } + writer.writeLine(`export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};`); - //////////////////////////////////////////////// - // 3. Update schema - //////////////////////////////////////////////// - let updateSchema = makePartial('baseSchema'); + //////////////////////////////////////////////// + // 3. Update schema + //////////////////////////////////////////////// + let updateSchema = this.makePartial('baseSchema'); - if (fkSchema) { - // export schema with only scalar fields - const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`; - writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`); - updateSchema = updateScalarSchema; + if (fkSchema) { + // export schema with only scalar fields + const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`; + writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`); + updateSchema = updateScalarSchema; - // merge fk fields - updateSchema = makeMerge(updateSchema, makePartial(fkSchema)); - } + // merge fk fields + updateSchema = this.makeMerge(updateSchema, this.makePartial(fkSchema)); + } - if (refineFuncName) { - // export a schema without refinement for extensibility - const noRefineSchema = `${upperCaseFirst(model.name)}UpdateWithoutRefineSchema`; - writer.writeLine(`export const ${noRefineSchema} = ${updateSchema};`); - updateSchema = `${refineFuncName}(${noRefineSchema})`; - } - writer.writeLine(`export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};`); - }); + if (refineFuncName) { + // export a schema without refinement for extensibility + const noRefineSchema = `${upperCaseFirst(model.name)}UpdateWithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${updateSchema};`); + updateSchema = `${refineFuncName}(${noRefineSchema})`; + } + writer.writeLine(`export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};`); + }); - return schemaName; -} + return schemaName; + } -function makePartial(schema: string, fields?: string[]) { - if (fields) { - if (fields.length === 0) { - return schema; + private makePartial(schema: string, fields?: string[]) { + if (fields) { + if (fields.length === 0) { + return schema; + } else { + return `${schema}.partial({ + ${fields.map((f) => `${f}: true`).join(', ')} + })`; + } } else { - return `${schema}.partial({ - ${fields.map((f) => `${f}: true`).join(', ')} - })`; + return `${schema}.partial()`; } - } else { - return `${schema}.partial()`; } -} -function makeOmit(schema: string, fields: string[]) { - return `${schema}.omit({ - ${fields.map((f) => `${f}: true`).join(', ')}, - })`; -} + private makeOmit(schema: string, fields: string[]) { + return `${schema}.omit({ + ${fields.map((f) => `${f}: true`).join(', ')}, + })`; + } -function makeMerge(schema1: string, schema2: string): string { - return `${schema1}.merge(${schema2})`; -} + private makeMerge(schema1: string, schema2: string): string { + return `${schema1}.merge(${schema2})`; + } -function makePassthrough(schema: string) { - return `${schema}.passthrough()`; + private makePassthrough(schema: string) { + return `${schema}.passthrough()`; + } } diff --git a/packages/schema/src/plugins/zod/index.ts b/packages/schema/src/plugins/zod/index.ts index 53a30b4e3..ffe198378 100644 --- a/packages/schema/src/plugins/zod/index.ts +++ b/packages/schema/src/plugins/zod/index.ts @@ -1,13 +1,14 @@ import { PluginFunction } from '@zenstackhq/sdk'; import invariant from 'tiny-invariant'; -import { generate } from './generator'; +import { ZodSchemaGenerator } from './generator'; export const name = 'Zod'; export const description = 'Generating Zod schemas'; const run: PluginFunction = async (model, options, dmmf, globalOptions) => { invariant(dmmf); - return generate(model, options, dmmf, globalOptions); + const generator = new ZodSchemaGenerator(model, options, dmmf, globalOptions); + return generator.generate(); }; export default run; diff --git a/packages/schema/src/plugins/zod/transformer.ts b/packages/schema/src/plugins/zod/transformer.ts index 878eff82b..0471dec3f 100644 --- a/packages/schema/src/plugins/zod/transformer.ts +++ b/packages/schema/src/plugins/zod/transformer.ts @@ -1,12 +1,11 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ import type { DMMF, DMMF as PrismaDMMF } from '@prisma/generator-helper'; -import { Model } from '@zenstackhq/language/ast'; -import { getPrismaClientImportSpec, getPrismaVersion } from '@zenstackhq/sdk'; +import { getPrismaClientImportSpec, getPrismaVersion, type PluginOptions } from '@zenstackhq/sdk'; import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers'; import { indentString } from '@zenstackhq/sdk/utils'; import path from 'path'; import * as semver from 'semver'; -import { Project } from 'ts-morph'; +import type { Project, SourceFile } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; import { AggregateOperationSupport, TransformerParams } from './types'; @@ -27,8 +26,8 @@ export default class Transformer { private hasJson = false; private hasDecimal = false; private project: Project; - private zmodel: Model; private inputObjectTypes: DMMF.InputType[]; + public sourceFiles: SourceFile[] = []; constructor(params: TransformerParams) { this.originalName = params.name ?? ''; @@ -39,7 +38,6 @@ export default class Transformer { this.aggregateOperationSupport = params.aggregateOperationSupport ?? {}; this.enumTypes = params.enumTypes ?? []; this.project = params.project; - this.zmodel = params.zmodel; this.inputObjectTypes = params.inputObjectTypes; } @@ -59,12 +57,17 @@ export default class Transformer { `${name}`, `z.enum(${JSON.stringify(enumType.values)})` )}`; - this.project.createSourceFile(filePath, content, { overwrite: true }); + this.sourceFiles.push(this.project.createSourceFile(filePath, content, { overwrite: true })); } - this.project.createSourceFile( - path.join(Transformer.outputPath, `enums/index.ts`), - this.enumTypes.map((enumType) => `export * from './${upperCaseFirst(enumType.name)}.schema';`).join('\n'), - { overwrite: true } + + this.sourceFiles.push( + this.project.createSourceFile( + path.join(Transformer.outputPath, `enums/index.ts`), + this.enumTypes + .map((enumType) => `export * from './${upperCaseFirst(enumType.name)}.schema';`) + .join('\n'), + { overwrite: true } + ) ); } @@ -76,13 +79,13 @@ export default class Transformer { return `export const ${name}Schema = ${schema}`; } - generateObjectSchema(generateUnchecked: boolean) { + generateObjectSchema(generateUnchecked: boolean, options: PluginOptions) { const zodObjectSchemaFields = this.generateObjectSchemaFields(generateUnchecked); - const objectSchema = this.prepareObjectSchema(zodObjectSchemaFields); + const objectSchema = this.prepareObjectSchema(zodObjectSchemaFields, options); const filePath = path.join(Transformer.outputPath, `objects/${this.name}.schema.ts`); const content = '/* eslint-disable */\n' + objectSchema; - this.project.createSourceFile(filePath, content, { overwrite: true }); + this.sourceFiles.push(this.project.createSourceFile(filePath, content, { overwrite: true })); return `${this.name}.schema`; } @@ -254,12 +257,12 @@ export default class Transformer { return zodStringWithMainType; } - prepareObjectSchema(zodObjectSchemaFields: string[]) { + prepareObjectSchema(zodObjectSchemaFields: string[], options: PluginOptions) { const objectSchema = `${this.generateExportObjectSchemaStatement( this.addFinalWrappers({ zodStringFields: zodObjectSchemaFields }) )}\n`; - const prismaImportStatement = this.generateImportPrismaStatement(); + const prismaImportStatement = this.generateImportPrismaStatement(options); const json = this.generateJsonSchemaImplementation(); @@ -285,10 +288,10 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`; return this.wrapWithZodObject(fields) + '.strict()'; } - generateImportPrismaStatement() { + generateImportPrismaStatement(options: PluginOptions) { const prismaClientImportPath = getPrismaClientImportSpec( - this.zmodel, - path.resolve(Transformer.outputPath, './objects') + path.resolve(Transformer.outputPath, './objects'), + options ); return `import type { Prisma } from '${prismaClientImportPath}';\n\n`; } @@ -384,9 +387,12 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`; return wrapped; } - async generateInputSchemas(generateUnchecked: boolean) { + async generateInputSchemas(options: PluginOptions) { const globalExports: string[] = []; + // whether Prisma's Unchecked* series of input types should be generated + const generateUnchecked = options.noUncheckedInput !== true; + for (const modelOperation of this.modelOperations) { const { model: origModelName, @@ -421,7 +427,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`; let imports = [ `import { z } from 'zod'`, - this.generateImportPrismaStatement(), + this.generateImportPrismaStatement(options), selectImport, includeImport, ]; @@ -666,7 +672,7 @@ ${operations } as ${modelName}InputSchemaType; `; - this.project.createSourceFile(filePath, content, { overwrite: true }); + this.sourceFiles.push(this.project.createSourceFile(filePath, content, { overwrite: true })); } const indexFilePath = path.join(Transformer.outputPath, 'input/index.ts'); @@ -674,7 +680,7 @@ ${operations /* eslint-disable */ ${globalExports.join(';\n')} `; - this.project.createSourceFile(indexFilePath, indexContent, { overwrite: true }); + this.sourceFiles.push(this.project.createSourceFile(indexFilePath, indexContent, { overwrite: true })); } generateImportStatements(imports: (string | undefined)[]) { diff --git a/packages/schema/src/plugins/zod/types.ts b/packages/schema/src/plugins/zod/types.ts index 72564c7ef..e71b3b03a 100644 --- a/packages/schema/src/plugins/zod/types.ts +++ b/packages/schema/src/plugins/zod/types.ts @@ -1,5 +1,4 @@ import { DMMF, DMMF as PrismaDMMF } from '@prisma/generator-helper'; -import { Model } from '@zenstackhq/language/ast'; import { Project } from 'ts-morph'; export type TransformerParams = { @@ -12,7 +11,6 @@ export type TransformerParams = { isDefaultPrismaClientOutput?: boolean; prismaClientOutputPath?: string; project: Project; - zmodel: Model; inputObjectTypes: DMMF.InputType[]; }; diff --git a/packages/schema/src/telemetry.ts b/packages/schema/src/telemetry.ts index 9cd8ba386..45983886d 100644 --- a/packages/schema/src/telemetry.ts +++ b/packages/schema/src/telemetry.ts @@ -111,18 +111,18 @@ export class Telemetry { } } - async trackSpan( + async trackSpan( startEvent: TelemetryEvents, completeEvent: TelemetryEvents, errorEvent: TelemetryEvents, properties: Record, - action: () => Promise | void + action: () => Promise | T ) { this.track(startEvent, properties); const start = Date.now(); let success = true; try { - await Promise.resolve(action()); + return await action(); // eslint-disable-next-line @typescript-eslint/no-explicit-any } catch (err: any) { this.track(errorEvent, { diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index cd516f5ec..8adf42c4c 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -34,9 +34,24 @@ import { TypeScriptExpressionTransformer, } from '.'; +/** + * Options for generating model metadata + */ export type ModelMetaGeneratorOptions = { + /** + * Output directory + */ output: string; + + /** + * Whether to generate all attributes + */ generateAttributes: boolean; + + /** + * Whether to preserve the pre-compilation TypeScript files + */ + preserveTsFiles?: boolean; }; export async function generate(project: Project, models: DataModel[], options: ModelMetaGeneratorOptions) { @@ -49,6 +64,11 @@ export async function generate(project: Project, models: DataModel[], options: M ], }); sf.addStatements('export default metadata;'); + + if (options.preserveTsFiles) { + await sf.save(); + } + return sf; } diff --git a/packages/sdk/src/prisma.ts b/packages/sdk/src/prisma.ts index 77db556b4..19b836cc2 100644 --- a/packages/sdk/src/prisma.ts +++ b/packages/sdk/src/prisma.ts @@ -3,62 +3,37 @@ import type { DMMF } from '@prisma/generator-helper'; import path from 'path'; import * as semver from 'semver'; -import { GeneratorDecl, Model, Plugin, isGeneratorDecl, isPlugin } from './ast'; -import { getLiteral } from './utils'; +import { RUNTIME_PACKAGE } from './constants'; +import type { PluginOptions } from './types'; /** - * Given a ZModel and an import context directory, compute the import spec for the Prisma Client. + * Given an import context directory and plugin options, compute the import spec for the Prisma Client. */ -export function getPrismaClientImportSpec(model: Model, importingFromDir: string) { - const generator = model.declarations.find( - (d) => - isGeneratorDecl(d) && - d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === 'prisma-client-js') - ) as GeneratorDecl; - - const clientOutputField = generator?.fields.find((f) => f.name === 'output'); - const clientOutput = getLiteral(clientOutputField?.value); - - if (!clientOutput) { - // no user-declared Prisma Client output location +export function getPrismaClientImportSpec(importingFromDir: string, options: PluginOptions) { + if (!options.prismaClientPath || options.prismaClientPath === '@prisma/client') { return '@prisma/client'; } - if (path.isAbsolute(clientOutput)) { - // absolute path - return clientOutput; + if (options.prismaClientPath.startsWith(RUNTIME_PACKAGE)) { + return options.prismaClientPath; } - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const zmodelDir = path.dirname(model.$document!.uri.fsPath); - - // compute prisma schema absolute output path - let prismaSchemaOutputDir = path.resolve(zmodelDir, './prisma'); - const prismaPlugin = model.declarations.find( - (d) => isPlugin(d) && d.fields.some((f) => f.name === 'provider' && getLiteral(f.value) === '@core/prisma') - ) as Plugin; - if (prismaPlugin) { - const output = getLiteral(prismaPlugin.fields.find((f) => f.name === 'output')?.value); - if (output) { - if (path.isAbsolute(output)) { - // absolute prisma schema output path - prismaSchemaOutputDir = path.dirname(output); - } else { - prismaSchemaOutputDir = path.dirname(path.resolve(zmodelDir, output)); - } - } + if (path.isAbsolute(options.prismaClientPath)) { + // absolute path + return options.prismaClientPath; } - // resolve the prisma client output path, which is relative to the prisma schema - const resolvedPrismaClientOutput = path.resolve(prismaSchemaOutputDir, clientOutput); + // resolve absolute path based on the zmodel file location + const resolvedPrismaClientOutput = path.resolve(path.dirname(options.schemaPath), options.prismaClientPath); + + // translate to path relative to the importing context directory + let result = path.relative(importingFromDir, resolvedPrismaClientOutput); - // DEBUG: - // console.log('PRISMA SCHEMA PATH:', prismaSchemaOutputDir); - // console.log('PRISMA CLIENT PATH:', resolvedPrismaClientOutput); - // console.log('IMPORTING PATH:', importingFromDir); + // remove leading `node_modules` (which may be provided by the user) + result = result.replace(/^([./\\]*)?node_modules\//, ''); // compute prisma client absolute output dir relative to the importing file - return normalizePath(path.relative(importingFromDir, resolvedPrismaClientOutput)); + return normalizePath(result); } function normalizePath(p: string) { diff --git a/packages/sdk/src/types.ts b/packages/sdk/src/types.ts index 9fbbd5553..a6a4b8629 100644 --- a/packages/sdk/src/types.ts +++ b/packages/sdk/src/types.ts @@ -1,5 +1,6 @@ import type { DMMF } from '@prisma/generator-helper'; import { Model } from '@zenstackhq/language/ast'; +import type { Project } from 'ts-morph'; /** * Plugin configuration option value type @@ -19,7 +20,17 @@ export type PluginDeclaredOptions = { /** * Plugin configuration options for execution */ -export type PluginOptions = { schemaPath: string } & PluginDeclaredOptions; +export type PluginOptions = { + /** + * ZModel schema absolute path + */ + schemaPath: string; + + /** + * PrismaClient import path, either relative to `schemaPath` or absolute + */ + prismaClientPath?: string; +} & PluginDeclaredOptions; /** * Global options that apply to all plugins @@ -34,6 +45,34 @@ export type PluginGlobalOptions = { * Whether to compile the generated code */ compile: boolean; + + /** + * The `ts-morph` project used for code generation. + * @private + */ + tsProject: Project; +}; + +/** + * Plugin run results. + */ +export type PluginResult = { + /** + * Warnings + */ + warnings: string[]; + + /** + * PrismaClient path, either relative to zmodel path or absolute, if the plugin + * generated a PrismaClient + */ + prismaClientPath?: string; + + /** + * An optional Prisma DMMF document that a plugin can generate + * @private + */ + dmmf?: DMMF.Document; }; /** @@ -42,9 +81,9 @@ export type PluginGlobalOptions = { export type PluginFunction = ( model: Model, options: PluginOptions, - dmmf?: DMMF.Document, + dmmf: DMMF.Document | undefined, globalOptions?: PluginGlobalOptions -) => Promise | string[] | Promise | void; +) => Promise | PluginResult | Promise | void; /** * Plugin error diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 01d5d274d..f73d2e12c 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -285,7 +285,7 @@ export function resolvePath(_path: string, options: Pick { - const schema = ` -model User { - id Int @id @default(autoincrement()) - level Int @default(0) - assets Asset[] - ratedVideos RatedVideo[] @relation('direct') - - @@allow('all', true) -} - -model Asset { - id Int @id @default(autoincrement()) - createdAt DateTime @default(now()) - viewCount Int @default(0) - owner User? @relation(fields: [ownerId], references: [id]) - ownerId Int? - assetType String - - @@delegate(assetType) - @@allow('all', true) -} - -model Video extends Asset { - duration Int - url String - videoType String - - @@delegate(videoType) -} - -model RatedVideo extends Video { - rating Int - user User? @relation(name: 'direct', fields: [userId], references: [id]) - userId Int? -} - -model Image extends Asset { - format String - gallery Gallery? @relation(fields: [galleryId], references: [id]) - galleryId Int? -} - -model Gallery { - id Int @id @default(autoincrement()) - images Image[] -} -`; + const schema = POLYMORPHIC_SCHEMA; async function setup() { const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); diff --git a/tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts new file mode 100644 index 000000000..8e6562e20 --- /dev/null +++ b/tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts @@ -0,0 +1,25 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { POLYMORPHIC_SCHEMA } from './utils'; +import path from 'path'; + +describe('Polymorphic Plugin Interaction Test', () => { + it('tanstack-query', async () => { + const tanstackPlugin = path.resolve(__dirname, '../../../../../packages/plugins/tanstack-query/dist'); + const schema = ` + ${POLYMORPHIC_SCHEMA} + + plugin hooks { + provider = '${tanstackPlugin}' + output = '$projectRoot/hooks' + target = 'react' + version = 'v5' + } + `; + + await loadSchema(schema, { + compile: true, + copyDependencies: [tanstackPlugin], + extraDependencies: ['@tanstack/react-query'], + }); + }); +}); diff --git a/tests/integration/tests/enhancements/with-delegate/policy.test.ts b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/policy.test.ts rename to tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts diff --git a/tests/integration/tests/enhancements/with-delegate/utils.ts b/tests/integration/tests/enhancements/with-delegate/utils.ts new file mode 100644 index 000000000..0de8a7e8b --- /dev/null +++ b/tests/integration/tests/enhancements/with-delegate/utils.ts @@ -0,0 +1,47 @@ +export const POLYMORPHIC_SCHEMA = ` +model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + ratedVideos RatedVideo[] @relation('direct') + + @@allow('all', true) +} + +model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner User? @relation(fields: [ownerId], references: [id]) + ownerId Int? + assetType String + + @@delegate(assetType) + @@allow('all', true) +} + +model Video extends Asset { + duration Int + url String + videoType String + + @@delegate(videoType) +} + +model RatedVideo extends Video { + rating Int + user User? @relation(name: 'direct', fields: [userId], references: [id]) + userId Int? +} + +model Image extends Asset { + format String + gallery Gallery? @relation(fields: [galleryId], references: [id]) + galleryId Int? +} + +model Gallery { + id Int @id @default(autoincrement()) + images Image[] +} +`;