From 68a0eb38a27ed6aa80ad77dc25741517c7a9b766 Mon Sep 17 00:00:00 2001 From: Yiming Date: Fri, 15 Nov 2024 22:29:10 -0800 Subject: [PATCH] fix(runtime): intercepts `$extends` to reproxy its result to make sure enhancements persist (#1847) --- .../runtime/src/enhancements/node/proxy.ts | 19 ++++ .../with-policy/client-extensions.test.ts | 63 ++++++------- tests/regression/tests/issue-1859.test.ts | 90 +++++++++++++++++++ .../tests/issue-prisma-extension.test.ts | 90 +++++++++++++++++++ 4 files changed, 224 insertions(+), 38 deletions(-) create mode 100644 tests/regression/tests/issue-1859.test.ts create mode 100644 tests/regression/tests/issue-prisma-extension.test.ts diff --git a/packages/runtime/src/enhancements/node/proxy.ts b/packages/runtime/src/enhancements/node/proxy.ts index ae4105301..cfbc0eb7c 100644 --- a/packages/runtime/src/enhancements/node/proxy.ts +++ b/packages/runtime/src/enhancements/node/proxy.ts @@ -254,6 +254,25 @@ export function makeProxy( } } + if (prop === '$extends') { + // Prisma's `$extends` API returns a new client instance, we need to recreate + // a proxy around it + const $extends = Reflect.get(target, prop, receiver); + if ($extends && typeof $extends === 'function') { + return (...args: any[]) => { + const result = $extends.bind(target)(...args); + if (!result[PRISMA_PROXY_ENHANCER]) { + return makeProxy(result, modelMeta, makeHandler, name + '$ext', errorTransformer); + } else { + // avoid double wrapping + return result; + } + }; + } else { + return $extends; + } + } + if (typeof prop !== 'string' || prop.startsWith('$') || !models.includes(prop.toLowerCase())) { // skip non-model fields return Reflect.get(target, prop, receiver); diff --git a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts index 13f05aa51..1d907a4f2 100644 --- a/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts +++ b/tests/integration/tests/enhancements/with-policy/client-extensions.test.ts @@ -44,13 +44,9 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.getAll()).resolves.toHaveLength(2); - - // FIXME: extending an enhanced client doesn't work for this case - // const db1 = enhance(prisma).$extends(ext); - // await expect(db1.model.getAll()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3); + await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2); + await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2); }); it('one model new method', async () => { @@ -84,9 +80,9 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.getAll()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3); + await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2); + await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2); }); it('add client method', async () => { @@ -115,8 +111,11 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - xprisma.$log('abc'); + enhanceRaw(prisma).$extends(ext).$log('abc'); + expect(logged).toBeTruthy(); + + logged = false; + enhanceRaw(prisma.$extends(ext)).$log('abc'); expect(logged).toBeTruthy(); }); @@ -143,7 +142,6 @@ describe('With Policy: client extensions', () => { query: { model: { async findMany({ args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; return query(args); }, @@ -152,9 +150,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('query override all models', async () => { @@ -180,7 +177,6 @@ describe('With Policy: client extensions', () => { query: { $allModels: { async findMany({ args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; console.log('findMany args:', args); return query(args); @@ -190,9 +186,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('query override all operations', async () => { @@ -218,7 +213,6 @@ describe('With Policy: client extensions', () => { query: { model: { async $allOperations({ operation, args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; console.log(`${operation} args:`, args); return query(args); @@ -228,9 +222,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('query override everything', async () => { @@ -255,7 +248,6 @@ describe('With Policy: client extensions', () => { name: 'prisma-extension-queryOverride', query: { async $allOperations({ operation, args, query }: any) { - // take incoming `where` and set `age` args.where = { ...args.where, y: { lt: 300 } }; console.log(`${operation} args:`, args); return query(args); @@ -264,9 +256,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - await expect(db.model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1); }); it('result mutation', async () => { @@ -301,11 +292,9 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - const r = await db.model.findMany(); - expect(r).toHaveLength(1); - expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ value: 2 })])); + const expected = [expect.objectContaining({ value: 2 })]; + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected); }); it('result custom fields', async () => { @@ -339,10 +328,8 @@ describe('With Policy: client extensions', () => { }); }); - const xprisma = prisma.$extends(ext); - const db = enhanceRaw(xprisma); - const r = await db.model.findMany(); - expect(r).toHaveLength(1); - expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ doubleValue: 2 })])); + const expected = [expect.objectContaining({ doubleValue: 2 })]; + await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected); + await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected); }); }); diff --git a/tests/regression/tests/issue-1859.test.ts b/tests/regression/tests/issue-1859.test.ts new file mode 100644 index 000000000..2b9d4538b --- /dev/null +++ b/tests/regression/tests/issue-1859.test.ts @@ -0,0 +1,90 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1859', () => { + it('extend enhanced client', async () => { + const { enhance, prisma } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const db = enhance(); + await expect(db.post.findMany()).resolves.toHaveLength(1); + + const extended = db.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await db.post.findMany(args) }; + }, + }, + }, + }); + + await expect(extended.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [{ id: 1, title: 'post1', published: true }], + }); + await expect(extended.post.findMany()).resolves.toHaveLength(1); + }); + + it('enhance extended client', async () => { + const { enhanceRaw, prisma, prismaModule } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const ext = prismaModule.defineExtension((_prisma: any) => { + return _prisma.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await prisma.post.findMany(args) }; + }, + }, + }, + }); + }); + + await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + + const enhanced = enhanceRaw(prisma.$extends(ext)); + await expect(enhanced.post.findMany()).resolves.toHaveLength(1); + // findManyListView internally uses the un-enhanced client + await expect(enhanced.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + }); +}); diff --git a/tests/regression/tests/issue-prisma-extension.test.ts b/tests/regression/tests/issue-prisma-extension.test.ts new file mode 100644 index 000000000..fa041a18a --- /dev/null +++ b/tests/regression/tests/issue-prisma-extension.test.ts @@ -0,0 +1,90 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue prisma extension', () => { + it('extend enhanced client', async () => { + const { enhance, prisma } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const db = enhance(); + await expect(db.post.findMany()).resolves.toHaveLength(1); + + const extended = db.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await db.post.findMany(args) }; + }, + }, + }, + }); + + await expect(extended.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [{ id: 1, title: 'post1', published: true }], + }); + await expect(extended.post.findMany()).resolves.toHaveLength(1); + }); + + it('enhance extended client', async () => { + const { enhanceRaw, prisma, prismaModule } = await loadSchema( + ` + model Post { + id Int @id + title String + published Boolean + + @@allow('create', true) + @@allow('read', published) + } + ` + ); + + await prisma.post.create({ data: { id: 1, title: 'post1', published: true } }); + await prisma.post.create({ data: { id: 2, title: 'post2', published: false } }); + + const ext = prismaModule.defineExtension((_prisma: any) => { + return _prisma.$extends({ + model: { + post: { + findManyListView: async (args: any) => { + return { view: true, data: await prisma.post.findMany(args) }; + }, + }, + }, + }); + }); + + await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2); + await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + + const enhanced = enhanceRaw(prisma.$extends(ext)); + await expect(enhanced.post.findMany()).resolves.toHaveLength(1); + // findManyListView internally uses the un-enhanced client + await expect(enhanced.post.findManyListView()).resolves.toMatchObject({ + view: true, + data: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }); + }); +});