Skip to content

Commit

Permalink
fix(runtime): intercepts $extends to reproxy its result to make sur…
Browse files Browse the repository at this point in the history
…e enhancements persist (#1847)
  • Loading branch information
ymc9 authored Nov 16, 2024
1 parent f377441 commit 68a0eb3
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 38 deletions.
19 changes: 19 additions & 0 deletions packages/runtime/src/enhancements/node/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,25 @@ export function makeProxy<T extends PrismaProxyHandler>(
}
}

if (prop === '$extends') {
// Prisma's `$extends` API returns a new client instance, we need to recreate
// a proxy around it
const $extends = Reflect.get(target, prop, receiver);
if ($extends && typeof $extends === 'function') {
return (...args: any[]) => {
const result = $extends.bind(target)(...args);
if (!result[PRISMA_PROXY_ENHANCER]) {
return makeProxy(result, modelMeta, makeHandler, name + '$ext', errorTransformer);
} else {
// avoid double wrapping
return result;
}
};
} else {
return $extends;
}
}

if (typeof prop !== 'string' || prop.startsWith('$') || !models.includes(prop.toLowerCase())) {
// skip non-model fields
return Reflect.get(target, prop, receiver);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,9 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.getAll()).resolves.toHaveLength(2);

// FIXME: extending an enhanced client doesn't work for this case
// const db1 = enhance(prisma).$extends(ext);
// await expect(db1.model.getAll()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
});

it('one model new method', async () => {
Expand Down Expand Up @@ -84,9 +80,9 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.getAll()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).model.getAll()).resolves.toHaveLength(3);
await expect(enhanceRaw(prisma.$extends(ext)).model.getAll()).resolves.toHaveLength(2);
await expect(enhanceRaw(prisma).$extends(ext).model.getAll()).resolves.toHaveLength(2);
});

it('add client method', async () => {
Expand Down Expand Up @@ -115,8 +111,11 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
xprisma.$log('abc');
enhanceRaw(prisma).$extends(ext).$log('abc');
expect(logged).toBeTruthy();

logged = false;
enhanceRaw(prisma.$extends(ext)).$log('abc');
expect(logged).toBeTruthy();
});

Expand All @@ -143,7 +142,6 @@ describe('With Policy: client extensions', () => {
query: {
model: {
async findMany({ args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
return query(args);
},
Expand All @@ -152,9 +150,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('query override all models', async () => {
Expand All @@ -180,7 +177,6 @@ describe('With Policy: client extensions', () => {
query: {
$allModels: {
async findMany({ args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
console.log('findMany args:', args);
return query(args);
Expand All @@ -190,9 +186,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('query override all operations', async () => {
Expand All @@ -218,7 +213,6 @@ describe('With Policy: client extensions', () => {
query: {
model: {
async $allOperations({ operation, args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
console.log(`${operation} args:`, args);
return query(args);
Expand All @@ -228,9 +222,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('query override everything', async () => {
Expand All @@ -255,7 +248,6 @@ describe('With Policy: client extensions', () => {
name: 'prisma-extension-queryOverride',
query: {
async $allOperations({ operation, args, query }: any) {
// take incoming `where` and set `age`
args.where = { ...args.where, y: { lt: 300 } };
console.log(`${operation} args:`, args);
return query(args);
Expand All @@ -264,9 +256,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
await expect(db.model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toHaveLength(1);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toHaveLength(1);
});

it('result mutation', async () => {
Expand Down Expand Up @@ -301,11 +292,9 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
const r = await db.model.findMany();
expect(r).toHaveLength(1);
expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ value: 2 })]));
const expected = [expect.objectContaining({ value: 2 })];
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
});

it('result custom fields', async () => {
Expand Down Expand Up @@ -339,10 +328,8 @@ describe('With Policy: client extensions', () => {
});
});

const xprisma = prisma.$extends(ext);
const db = enhanceRaw(xprisma);
const r = await db.model.findMany();
expect(r).toHaveLength(1);
expect(r).toEqual(expect.arrayContaining([expect.objectContaining({ doubleValue: 2 })]));
const expected = [expect.objectContaining({ doubleValue: 2 })];
await expect(enhanceRaw(prisma.$extends(ext)).model.findMany()).resolves.toEqual(expected);
await expect(enhanceRaw(prisma).$extends(ext).model.findMany()).resolves.toEqual(expected);
});
});
90 changes: 90 additions & 0 deletions tests/regression/tests/issue-1859.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue 1859', () => {
it('extend enhanced client', async () => {
const { enhance, prisma } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean
@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const db = enhance();
await expect(db.post.findMany()).resolves.toHaveLength(1);

const extended = db.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await db.post.findMany(args) };
},
},
},
});

await expect(extended.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [{ id: 1, title: 'post1', published: true }],
});
await expect(extended.post.findMany()).resolves.toHaveLength(1);
});

it('enhance extended client', async () => {
const { enhanceRaw, prisma, prismaModule } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean
@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const ext = prismaModule.defineExtension((_prisma: any) => {
return _prisma.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await prisma.post.findMany(args) };
},
},
},
});
});

await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});

const enhanced = enhanceRaw(prisma.$extends(ext));
await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
// findManyListView internally uses the un-enhanced client
await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});
});
});
90 changes: 90 additions & 0 deletions tests/regression/tests/issue-prisma-extension.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue prisma extension', () => {
it('extend enhanced client', async () => {
const { enhance, prisma } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean
@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const db = enhance();
await expect(db.post.findMany()).resolves.toHaveLength(1);

const extended = db.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await db.post.findMany(args) };
},
},
},
});

await expect(extended.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [{ id: 1, title: 'post1', published: true }],
});
await expect(extended.post.findMany()).resolves.toHaveLength(1);
});

it('enhance extended client', async () => {
const { enhanceRaw, prisma, prismaModule } = await loadSchema(
`
model Post {
id Int @id
title String
published Boolean
@@allow('create', true)
@@allow('read', published)
}
`
);

await prisma.post.create({ data: { id: 1, title: 'post1', published: true } });
await prisma.post.create({ data: { id: 2, title: 'post2', published: false } });

const ext = prismaModule.defineExtension((_prisma: any) => {
return _prisma.$extends({
model: {
post: {
findManyListView: async (args: any) => {
return { view: true, data: await prisma.post.findMany(args) };
},
},
},
});
});

await expect(prisma.$extends(ext).post.findMany()).resolves.toHaveLength(2);
await expect(prisma.$extends(ext).post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});

const enhanced = enhanceRaw(prisma.$extends(ext));
await expect(enhanced.post.findMany()).resolves.toHaveLength(1);
// findManyListView internally uses the un-enhanced client
await expect(enhanced.post.findManyListView()).resolves.toMatchObject({
view: true,
data: [
{ id: 1, title: 'post1', published: true },
{ id: 2, title: 'post2', published: false },
],
});
});
});

0 comments on commit 68a0eb3

Please sign in to comment.