From 2d822182717f7ae1aa313d90500387abd6507a8d Mon Sep 17 00:00:00 2001 From: Yiming Date: Sun, 17 Mar 2024 11:32:17 -0700 Subject: [PATCH] chore: improve `enhance` API code generation (#1150) --- .../src/plugins/enhancer/enhance/index.ts | 86 +++++++++++++------ 1 file changed, 62 insertions(+), 24 deletions(-) diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index fbb8a442a..7b042f05a 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -56,10 +56,9 @@ export class EnhancerGenerator { let logicalPrismaClientDir: string | undefined; let dmmf: DMMF.Document | undefined; - const withLogicalClient = this.needsLogicalClient(); const prismaImport = getPrismaClientImportSpec(this.outDir, this.options); - if (withLogicalClient) { + if (this.needsLogicalClient()) { // schema contains delegate models, need to generate a logical prisma schema const result = await this.generateLogicalPrisma(); @@ -90,22 +89,67 @@ export class EnhancerGenerator { const enhanceTs = this.project.createSourceFile( path.join(this.outDir, 'enhance.ts'), `import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas, type AuthUser } from '@zenstackhq/runtime'; - import modelMeta from './model-meta'; - import policy from './policy'; - import { Prisma as _Prisma, PrismaClient as _PrismaClient } from '${prismaImport}'; - import type { InternalArgs, TypeMapDef, TypeMapCbDef, DynamicClientExtensionThis } from '${prismaImport}/runtime/library'; - ${ - withLogicalClient - ? `import type * as _P from '${logicalPrismaClientDir}/index-fixed'; - import type { Prisma, PrismaClient } from '${logicalPrismaClientDir}/index-fixed'; - ` - : `import type * as _P from '${prismaImport}'; - import type { Prisma, PrismaClient } from '${prismaImport}'; - ` +import modelMeta from './model-meta'; +import policy from './policy'; +${this.options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} + +${ + logicalPrismaClientDir + ? this.createLogicalPrismaImports(prismaImport, logicalPrismaClientDir) + : this.createSimplePrismaImports(prismaImport) +} + +${authTypes} + +${ + logicalPrismaClientDir + ? this.createLogicalPrismaEnhanceFunction(authTypeParam) + : this.createSimplePrismaEnhanceFunction(authTypeParam) +} + `, + { overwrite: true } + ); + + await this.saveSourceFile(enhanceTs); + + return { dmmf }; } - ${this.options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} - - ${authTypes} + + private createSimplePrismaImports(prismaImport: string) { + return `import { Prisma } from '${prismaImport}'; +import type * as _P from '${prismaImport}'; + `; + } + + private createSimplePrismaEnhanceFunction(authTypeParam: string) { + return ` +export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions) { + return createEnhancement(prisma, { + modelMeta, + policy, + zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), + prismaModule: Prisma, + ...options + }, context); +} + `; + } + + private createLogicalPrismaImports(prismaImport: string, logicalPrismaClientDir: string) { + return `import { Prisma as _Prisma, PrismaClient as _PrismaClient } from '${prismaImport}'; +import type { + InternalArgs, + TypeMapDef, + TypeMapCbDef, + DynamicClientExtensionThis, +} from '${prismaImport}/runtime/library'; +import type * as _P from '${logicalPrismaClientDir}/index-fixed'; +import type { Prisma, PrismaClient } from '${logicalPrismaClientDir}/index-fixed'; +`; + } + + private createLogicalPrismaEnhanceFunction(authTypeParam: string) { + return ` // overload for plain PrismaClient export function enhance & InternalArgs>( prisma: _PrismaClient, @@ -125,13 +169,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara ...options }, context); } - `, - { overwrite: true } - ); - - await this.saveSourceFile(enhanceTs); - - return { dmmf }; +`; } private needsLogicalClient() {