diff --git a/packages/server/src/nestjs/zenstack.module.ts b/packages/server/src/nestjs/zenstack.module.ts index f2ae601c6..a113fb84d 100644 --- a/packages/server/src/nestjs/zenstack.module.ts +++ b/packages/server/src/nestjs/zenstack.module.ts @@ -12,7 +12,7 @@ export interface ZenStackModuleOptions { /** * A callback for getting an enhanced `PrismaClient`. */ - getEnhancedPrisma: () => unknown; + getEnhancedPrisma: (model?: string | symbol ) => unknown; } /** @@ -79,7 +79,7 @@ export class ZenStackModule { { get(_target, prop) { // eslint-disable-next-line @typescript-eslint/no-explicit-any - const enhancedPrisma: any = getEnhancedPrisma(); + const enhancedPrisma: any = getEnhancedPrisma(prop); if (!enhancedPrisma) { throw new Error('`getEnhancedPrisma` must return a valid Prisma client'); } diff --git a/packages/server/tests/adapter/nestjs.test.ts b/packages/server/tests/adapter/nestjs.test.ts index 6cfa48617..d28a3ecc8 100644 --- a/packages/server/tests/adapter/nestjs.test.ts +++ b/packages/server/tests/adapter/nestjs.test.ts @@ -160,4 +160,53 @@ describe('NestJS adapter tests', () => { const postSvc = app.get('PostService'); await expect(postSvc.findAll()).resolves.toHaveLength(1); }); + + it('pass property', async () => { + const { prisma, enhanceRaw } = await loadSchema(schema); + + await prisma.user.create({ + data: { + posts: { + create: [ + { title: 'post1', published: true }, + { title: 'post2', published: false }, + ], + }, + }, + }); + + const moduleRef = await Test.createTestingModule({ + imports: [ + ZenStackModule.registerAsync({ + useFactory: (prismaService) => ({ + getEnhancedPrisma: (prop) => { + return prop === 'post' ? prismaService : enhanceRaw(prismaService, { user: { id: 2 } }); + }, + }), + inject: ['PrismaService'], + extraProviders: [ + { + provide: 'PrismaService', + useValue: prisma, + }, + ], + }), + ], + providers: [ + { + provide: 'PostService', + useFactory: (enhancedPrismaService) => ({ + findAll: () => enhancedPrismaService.post.findMany(), + }), + inject: [ENHANCED_PRISMA], + }, + ], + }).compile(); + + const app = moduleRef.createNestApplication(); + await app.init(); + + const postSvc = app.get('PostService'); + await expect(postSvc.findAll()).resolves.toHaveLength(2); + }); });