From 738bba6ef1edcd36c576df66a268b63d00741f2b Mon Sep 17 00:00:00 2001 From: Yiming Date: Sun, 22 Sep 2024 14:41:21 -0700 Subject: [PATCH] fix(delegate): enforcing concrete model policies when read from a delegate base (#1726) --- .../enhancements/node/create-enhancement.ts | 34 ++- .../runtime/src/enhancements/node/delegate.ts | 62 +++-- .../src/enhancements/node/policy/index.ts | 18 ++ .../enhancements/node/policy/policy-utils.ts | 3 + .../runtime/src/enhancements/node/proxy.ts | 4 +- .../with-delegate/policy-interaction.test.ts | 215 ++++++++++++++++++ 6 files changed, 312 insertions(+), 24 deletions(-) diff --git a/packages/runtime/src/enhancements/node/create-enhancement.ts b/packages/runtime/src/enhancements/node/create-enhancement.ts index 127574e26..263e12192 100644 --- a/packages/runtime/src/enhancements/node/create-enhancement.ts +++ b/packages/runtime/src/enhancements/node/create-enhancement.ts @@ -1,13 +1,19 @@ import semver from 'semver'; import { PRISMA_MINIMUM_VERSION } from '../../constants'; import { isDelegateModel, type ModelMeta } from '../../cross'; -import type { EnhancementContext, EnhancementKind, EnhancementOptions, ZodSchemas } from '../../types'; +import type { + DbClientContract, + EnhancementContext, + EnhancementKind, + EnhancementOptions, + ZodSchemas, +} from '../../types'; import { withDefaultAuth } from './default-auth'; import { withDelegate } from './delegate'; import { Logger } from './logger'; import { withOmit } from './omit'; import { withPassword } from './password'; -import { withPolicy } from './policy'; +import { policyProcessIncludeRelationPayload, withPolicy } from './policy'; import type { PolicyDef } from './types'; /** @@ -41,6 +47,18 @@ export type InternalEnhancementOptions = EnhancementOptions & { */ // eslint-disable-next-line @typescript-eslint/no-explicit-any prismaModule: any; + + /** + * A callback shared among enhancements to process the payload for including a relation + * field. e.g.: `{ author: true }`. + */ + processIncludeRelationPayload?: ( + prisma: DbClientContract, + model: string, + payload: unknown, + options: InternalEnhancementOptions, + context: EnhancementContext | undefined + ) => Promise; }; /** @@ -89,7 +107,7 @@ export function createEnhancement( 'Your ZModel contains delegate models but "delegate" enhancement kind is not enabled. This may result in unexpected behavior.' ); } else { - result = withDelegate(result, options); + result = withDelegate(result, options, context); } } @@ -103,6 +121,16 @@ export function createEnhancement( // 'policy' and 'validation' enhancements are both enabled by `withPolicy` if (kinds.includes('policy') || kinds.includes('validation')) { result = withPolicy(result, options, context); + + // if any enhancement is to introduce an inclusion of a relation field, the + // inclusion payload must be processed by the policy enhancement for injecting + // access control rules + + // TODO: this is currently a global callback shared among all enhancements, which + // is far from ideal + + options.processIncludeRelationPayload = policyProcessIncludeRelationPayload; + if (kinds.includes('policy') && hasDefaultAuth) { // @default(auth()) proxy result = withDefaultAuth(result, options, context); diff --git a/packages/runtime/src/enhancements/node/delegate.ts b/packages/runtime/src/enhancements/node/delegate.ts index 9cfc720c5..45e74ee36 100644 --- a/packages/runtime/src/enhancements/node/delegate.ts +++ b/packages/runtime/src/enhancements/node/delegate.ts @@ -15,18 +15,22 @@ import { isDelegateModel, resolveField, } from '../../cross'; -import type { CrudContract, DbClientContract } from '../../types'; +import type { CrudContract, DbClientContract, EnhancementContext } from '../../types'; import type { InternalEnhancementOptions } from './create-enhancement'; import { Logger } from './logger'; import { DefaultPrismaProxyHandler, makeProxy } from './proxy'; import { QueryUtils } from './query-utils'; import { formatObject, prismaClientValidationError } from './utils'; -export function withDelegate(prisma: DbClient, options: InternalEnhancementOptions): DbClient { +export function withDelegate( + prisma: DbClient, + options: InternalEnhancementOptions, + context: EnhancementContext | undefined +): DbClient { return makeProxy( prisma, options.modelMeta, - (_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options), + (_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options, context), 'delegate' ); } @@ -35,7 +39,12 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { private readonly logger: Logger; private readonly queryUtils: QueryUtils; - constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + constructor( + prisma: DbClientContract, + model: string, + options: InternalEnhancementOptions, + private readonly context: EnhancementContext | undefined + ) { super(prisma, model, options); this.logger = new Logger(prisma); this.queryUtils = new QueryUtils(prisma, this.options); @@ -76,7 +85,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { args = args ? clone(args) : {}; this.injectWhereHierarchy(model, args?.where); - this.injectSelectIncludeHierarchy(model, args); + await this.injectSelectIncludeHierarchy(model, args); // discriminator field is needed during post process to determine the // actual concrete model type @@ -166,7 +175,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { }); } - private injectSelectIncludeHierarchy(model: string, args: any) { + private async injectSelectIncludeHierarchy(model: string, args: any) { if (!args || typeof args !== 'object') { return; } @@ -186,7 +195,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { // make sure the payload is an object args[kind][field] = {}; } - this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]); + await this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]); } } @@ -208,7 +217,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { // make sure the payload is an object args[kind][field] = nextValue = {}; } - this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); + await this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); } } } @@ -220,11 +229,11 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { this.injectBaseIncludeRecursively(model, args); // include sub models downwards - this.injectConcreteIncludeRecursively(model, args); + await this.injectConcreteIncludeRecursively(model, args); } } - private buildSelectIncludeHierarchy(model: string, args: any) { + private async buildSelectIncludeHierarchy(model: string, args: any) { args = clone(args); const selectInclude: any = this.extractSelectInclude(args) || {}; @@ -248,7 +257,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { if (!selectInclude.select) { this.injectBaseIncludeRecursively(model, selectInclude); - this.injectConcreteIncludeRecursively(model, selectInclude); + await this.injectConcreteIncludeRecursively(model, selectInclude); } return selectInclude; } @@ -319,7 +328,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { this.injectBaseIncludeRecursively(base.name, selectInclude.include[baseRelationName]); } - private injectConcreteIncludeRecursively(model: string, selectInclude: any) { + private async injectConcreteIncludeRecursively(model: string, selectInclude: any) { const modelInfo = getModelInfo(this.options.modelMeta, model); if (!modelInfo) { return; @@ -333,13 +342,27 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { for (const subModel of subModels) { // include sub model relation field const subRelationName = this.makeAuxRelationName(subModel); + const includePayload: any = {}; + + if (this.options.processIncludeRelationPayload) { + // use the callback in options to process the include payload, so enhancements + // like 'policy' can do extra work (e.g., inject policy rules) + await this.options.processIncludeRelationPayload( + this.prisma, + subModel.name, + includePayload, + this.options, + this.context + ); + } + if (selectInclude.select) { - selectInclude.include = { [subRelationName]: {}, ...selectInclude.select }; + selectInclude.include = { [subRelationName]: includePayload, ...selectInclude.select }; delete selectInclude.select; } else { - selectInclude.include = { [subRelationName]: {}, ...selectInclude.include }; + selectInclude.include = { [subRelationName]: includePayload, ...selectInclude.include }; } - this.injectConcreteIncludeRecursively(subModel.name, selectInclude.include[subRelationName]); + await this.injectConcreteIncludeRecursively(subModel.name, selectInclude.include[subRelationName]); } } @@ -480,7 +503,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { args = clone(args); await this.injectCreateHierarchy(model, args); - this.injectSelectIncludeHierarchy(model, args); + await this.injectSelectIncludeHierarchy(model, args); if (this.options.logPrismaQuery) { this.logger.info(`[delegate] \`create\` ${this.getModelName(model)}: ${formatObject(args)}`); @@ -702,7 +725,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { args = clone(args); this.injectWhereHierarchy(this.model, (args as any)?.where); - this.injectSelectIncludeHierarchy(this.model, args); + await this.injectSelectIncludeHierarchy(this.model, args); if (args.create) { this.doProcessCreatePayload(this.model, args.create); } @@ -721,7 +744,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { args = clone(args); await this.injectUpdateHierarchy(db, model, args); - this.injectSelectIncludeHierarchy(model, args); + await this.injectSelectIncludeHierarchy(model, args); if (this.options.logPrismaQuery) { this.logger.info(`[delegate] \`update\` ${this.getModelName(model)}: ${formatObject(args)}`); @@ -915,7 +938,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { } return this.queryUtils.transaction(this.prisma, async (tx) => { - const selectInclude = this.buildSelectIncludeHierarchy(this.model, args); + const selectInclude = await this.buildSelectIncludeHierarchy(this.model, args); // make sure id fields are selected const idFields = this.getIdFields(this.model); @@ -967,6 +990,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { private async doDelete(db: CrudContract, model: string, args: any): Promise { this.injectWhereHierarchy(model, args.where); + await this.injectSelectIncludeHierarchy(model, args); if (this.options.logPrismaQuery) { this.logger.info(`[delegate] \`delete\` ${this.getModelName(model)}: ${formatObject(args)}`); diff --git a/packages/runtime/src/enhancements/node/policy/index.ts b/packages/runtime/src/enhancements/node/policy/index.ts index 66834a802..d5523e31b 100644 --- a/packages/runtime/src/enhancements/node/policy/index.ts +++ b/packages/runtime/src/enhancements/node/policy/index.ts @@ -7,6 +7,7 @@ import type { InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { makeProxy } from '../proxy'; import { PolicyProxyHandler } from './handler'; +import { PolicyUtil } from './policy-utils'; /** * Gets an enhanced Prisma client with access policy check. @@ -60,3 +61,20 @@ export function withPolicy( options?.errorTransformer ); } + +/** + * Function for processing a payload for including a relation field in a query. + * @param model The relation's model name + * @param payload The payload to process + */ +export async function policyProcessIncludeRelationPayload( + prisma: DbClientContract, + model: string, + payload: unknown, + options: InternalEnhancementOptions, + context: EnhancementContext | undefined +) { + const utils = new PolicyUtil(prisma, options, context); + await utils.injectForRead(prisma, model, payload); + await utils.injectReadCheckSelect(model, payload); +} diff --git a/packages/runtime/src/enhancements/node/policy/policy-utils.ts b/packages/runtime/src/enhancements/node/policy/policy-utils.ts index ed0ff5196..71985b30f 100644 --- a/packages/runtime/src/enhancements/node/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/node/policy/policy-utils.ts @@ -1098,6 +1098,9 @@ export class PolicyUtil extends QueryUtils { } const result = await db[model].findFirst(readArgs); if (!result) { + if (this.shouldLogQuery) { + this.logger.info(`[policy] cannot read back ${model}`); + } return { error, result: undefined }; } diff --git a/packages/runtime/src/enhancements/node/proxy.ts b/packages/runtime/src/enhancements/node/proxy.ts index 3802e2390..ae4105301 100644 --- a/packages/runtime/src/enhancements/node/proxy.ts +++ b/packages/runtime/src/enhancements/node/proxy.ts @@ -69,7 +69,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { protected readonly options: InternalEnhancementOptions ) {} - protected withFluentCall(method: keyof PrismaProxyHandler, args: any, postProcess = true): Promise { + protected withFluentCall(method: PrismaProxyActions, args: any, postProcess = true): Promise { args = args ? clone(args) : {}; const promise = createFluentPromise( async () => { @@ -84,7 +84,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return promise; } - protected deferred(method: keyof PrismaProxyHandler, args: any, postProcess = true) { + protected deferred(method: PrismaProxyActions, args: any, postProcess = true) { return createDeferredPromise(async () => { args = await this.preprocessArgs(method, args); const r = await this.prisma[this.model][method](args); diff --git a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts index 275e73853..ff791beb1 100644 --- a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts @@ -269,6 +269,221 @@ describe('Polymorphic Policy Test', () => { ).toResolveTruthy(); }); + it('respects base model policies when queried from a sub', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + assets Asset[] + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + deleted Boolean @default(false) + user User @relation(fields: [userId], references: [id]) + userId Int + type String + @@delegate(type) + @@allow('all', true) + @@deny('read', deleted) + } + + model Post extends Asset { + title String + } + ` + ); + + const db = enhance(); + const user = await db.user.create({ data: { id: 1 } }); + const post = await db.post.create({ data: { id: 1, title: 'Post1', userId: user.id } }); + + await expect(db.post.findUnique({ where: { id: post.id } })).toResolveTruthy(); + await expect(db.asset.findUnique({ where: { id: post.id } })).toResolveTruthy(); + let withAssets = await db.user.findUnique({ where: { id: user.id }, include: { assets: true } }); + expect(withAssets.assets).toHaveLength(1); + + await prisma.asset.update({ where: { id: post.id }, data: { deleted: true } }); + await expect(db.post.findUnique({ where: { id: post.id } })).toResolveFalsy(); + await expect(db.asset.findUnique({ where: { id: post.id } })).toResolveFalsy(); + withAssets = await db.user.findUnique({ where: { id: user.id }, include: { assets: true } }); + expect(withAssets.assets).toHaveLength(0); + + // unable to read back + await expect( + db.post.create({ data: { title: 'Post2', deleted: true, userId: user.id } }) + ).toBeRejectedByPolicy(); + // actually created + await expect(prisma.post.count()).resolves.toBe(2); + + // unable to read back + await expect(db.post.update({ where: { id: 2 }, data: { title: 'Post2-1' } })).toBeRejectedByPolicy(); + // actually updated + await expect(prisma.post.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ title: 'Post2-1' }); + + // unable to read back + await expect(db.post.delete({ where: { id: 2 } })).toBeRejectedByPolicy(); + // actually deleted + await expect(prisma.post.findUnique({ where: { id: 2 } })).toResolveFalsy(); + }); + + it('respects sub model policies when queried from a base: case 1', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + assets Asset[] + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + value Int @default(0) + deleted Boolean @default(false) + type String + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + @@deny('read', deleted) + } + ` + ); + + const db = enhance(); + const user = await db.user.create({ data: { id: 1 } }); + + // create read back + const post = await db.post.create({ data: { id: 1, title: 'Post1', userId: user.id } }); + expect(post.type).toBe('Post'); + expect(post.title).toBe('Post1'); + expect(post.value).toBe(0); + + // update read back + const updatedPost = await db.post.update({ where: { id: post.id }, data: { value: 1 } }); + expect(updatedPost.type).toBe('Post'); + expect(updatedPost.title).toBe('Post1'); + expect(updatedPost.value).toBe(1); + + // both asset and post fields are readable + const readPost = await db.post.findUnique({ where: { id: post.id } }); + expect(readPost.title).toBe('Post1'); + + const readAsset = await db.asset.findUnique({ where: { id: post.id } }); + expect(readAsset.type).toBe('Post'); + const userWithAssets = await db.user.findUnique({ where: { id: user.id }, include: { assets: true } }); + expect(userWithAssets.assets[0].title).toBe('Post1'); + + await prisma.asset.update({ where: { id: post.id }, data: { deleted: true } }); + + // asset fields are readable, but not post fields + const readAsset1 = await db.asset.findUnique({ where: { id: post.id } }); + expect(readAsset1.type).toBe('Post'); + expect(readAsset1.title).toBeUndefined(); + + const userWithAssets1 = await db.user.findUnique({ where: { id: user.id }, include: { assets: true } }); + expect(userWithAssets1.assets[0].type).toBe('Post'); + expect(userWithAssets1.assets[0].title).toBeUndefined(); + + // update read back + const updateRead = await db.asset.update({ where: { id: post.id }, data: { value: 2 } }); + expect(updateRead.value).toBe(2); + // cannot read back sub model + expect(updateRead.title).toBeUndefined(); + + // delete read back + const deleteRead = await db.asset.delete({ where: { id: post.id } }); + expect(deleteRead.value).toBe(2); + // cannot read back sub model + expect(deleteRead.title).toBeUndefined(); + // actually deleted + await expect(prisma.asset.findUnique({ where: { id: post.id } })).toResolveFalsy(); + await expect(prisma.post.findUnique({ where: { id: post.id } })).toResolveFalsy(); + }); + + it('respects sub model policies when queried from a base: case 2', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + assets Asset[] + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + value Int @default(0) + type String + @@delegate(type) + @@allow('all', true) + } + + model Post extends Asset { + title String + deleted Boolean @default(false) + @@deny('read', deleted) + } + ` + ); + + const db = enhance(); + const user = await db.user.create({ data: { id: 1 } }); + + // create read back + const post = await db.post.create({ data: { id: 1, title: 'Post1', userId: user.id } }); + expect(post.type).toBe('Post'); + expect(post.title).toBe('Post1'); + expect(post.value).toBe(0); + + // update read back + const updatedPost = await db.post.update({ where: { id: post.id }, data: { value: 1 } }); + expect(updatedPost.type).toBe('Post'); + expect(updatedPost.title).toBe('Post1'); + expect(updatedPost.value).toBe(1); + + // both asset and post fields are readable + const readPost = await db.post.findUnique({ where: { id: post.id } }); + expect(readPost.title).toBe('Post1'); + + const readAsset = await db.asset.findUnique({ where: { id: post.id } }); + expect(readAsset.type).toBe('Post'); + const userWithAssets = await db.user.findUnique({ where: { id: user.id }, include: { assets: true } }); + expect(userWithAssets.assets[0].title).toBe('Post1'); + + await prisma.post.update({ where: { id: post.id }, data: { deleted: true } }); + + // asset fields are readable, but not post fields + const readAsset1 = await db.asset.findUnique({ where: { id: post.id } }); + expect(readAsset1.type).toBe('Post'); + expect(readAsset1.title).toBeUndefined(); + + const userWithAssets1 = await db.user.findUnique({ where: { id: user.id }, include: { assets: true } }); + expect(userWithAssets1.assets[0].type).toBe('Post'); + expect(userWithAssets1.assets[0].title).toBeUndefined(); + + // update read back + const updateRead = await db.asset.update({ where: { id: post.id }, data: { value: 2 } }); + expect(updateRead.value).toBe(2); + // cannot read back sub model + expect(updateRead.title).toBeUndefined(); + + // delete read back + const deleteRead = await db.asset.delete({ where: { id: post.id } }); + expect(deleteRead.value).toBe(2); + // cannot read back sub model + expect(deleteRead.title).toBeUndefined(); + // actually deleted + await expect(prisma.asset.findUnique({ where: { id: post.id } })).toResolveFalsy(); + await expect(prisma.post.findUnique({ where: { id: post.id } })).toResolveFalsy(); + }); + it('respects field-level policies', async () => { const { enhance } = await loadSchema(` model User {