From 13bc04187c5131f3a8285a6d0e50ea42dbcc6ac1 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:01:59 -0800 Subject: [PATCH] more fixes --- .../access-policy/expression-writer.ts | 4 +- .../access-policy/policy-guard-generator.ts | 3 + .../with-policy/post-update.test.ts | 60 +++++++++++++++++-- 3 files changed, 60 insertions(+), 7 deletions(-) diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index 4aa414729..a88c1dc03 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -234,7 +234,9 @@ export class ExpressionWriter { this.writeFieldCondition( expr.left, () => { - this.write(expr.right); + // inner scope of collection expression is always compiled as non-post-guard + const innerWriter = new ExpressionWriter(this.writer, false); + innerWriter.write(expr.right); }, operator === '?' ? 'some' : operator === '!' ? 'every' : 'none' ); diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index f46750b87..597a132d0 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -550,6 +550,9 @@ export default class PolicyGenerator { } else { return []; } + } else if (isInvocationExpr(expr)) { + // recurse into function arguments + return expr.args.flatMap((arg) => collectReferencePaths(arg.value)); } else { // recurse const children = streamContents(expr) diff --git a/tests/integration/tests/enhancements/with-policy/post-update.test.ts b/tests/integration/tests/enhancements/with-policy/post-update.test.ts index cc05a37bd..c40d338a3 100644 --- a/tests/integration/tests/enhancements/with-policy/post-update.test.ts +++ b/tests/integration/tests/enhancements/with-policy/post-update.test.ts @@ -75,6 +75,54 @@ describe('With Policy: post update', () => { await expect(db.model.update({ where: { id: '2' }, data: { value: 4 } })).toResolveTruthy(); }); + it('functions pre-update', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model Model { + id String @id @default(uuid()) + value String + x Int + + @@allow('create,read', true) + @@allow('update', startsWith(value, 'hello') && future().x > 0) + } + ` + ); + + const db = withPolicy(); + + await prisma.model.create({ data: { id: '1', value: 'good', x: 1 } }); + await expect(db.model.update({ where: { id: '1' }, data: { value: 'hello' } })).toBeRejectedByPolicy(); + + await prisma.model.update({ where: { id: '1' }, data: { value: 'hello world' } }); + const r = await db.model.update({ where: { id: '1' }, data: { value: 'hello new world' } }); + expect(r.value).toBe('hello new world'); + }); + + it('functions post-update', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model Model { + id String @id @default(uuid()) + value String + x Int + + @@allow('create,read', true) + @@allow('update', x > 0 && startsWith(future().value, 'hello')) + } + `, + { logPrismaQuery: true } + ); + + const db = withPolicy(); + + await prisma.model.create({ data: { id: '1', value: 'good', x: 1 } }); + await expect(db.model.update({ where: { id: '1' }, data: { value: 'nice' } })).toBeRejectedByPolicy(); + + const r = await db.model.update({ where: { id: '1' }, data: { x: 0, value: 'hello world' } }); + expect(r.value).toBe('hello world'); + }); + it('collection predicate pre-update', async () => { const { prisma, withPolicy } = await loadSchema( ` @@ -109,7 +157,7 @@ describe('With Policy: post update', () => { }, }); - expect( + await expect( db.m1.update({ where: { id: '1' }, data: { value: 1 }, @@ -124,7 +172,7 @@ describe('With Policy: post update', () => { }, }); - expect( + await expect( db.m1.update({ where: { id: '1' }, data: { value: 1 }, @@ -159,14 +207,14 @@ describe('With Policy: post update', () => { await prisma.m1.create({ data: { id: '1', - value: 0, + value: 1, m2: { - create: [{ id: '1', value: 1 }], + create: [{ id: '1', value: 0 }], }, }, }); - expect( + await expect( db.m1.update({ where: { id: '1' }, data: { value: 2 }, @@ -181,7 +229,7 @@ describe('With Policy: post update', () => { }, }); - expect( + await expect( db.m1.update({ where: { id: '1' }, data: { value: 2 },