diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 382825aa0..f429066f9 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -19,11 +19,11 @@ import { import { PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; +import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; import { formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; import { PolicyUtil } from './policy-utils'; -import { createDeferredPromise } from './promise'; // a record for post-write policy check type PostWriteCheckRecord = { @@ -80,7 +80,8 @@ export class PolicyProxyHandler implements Pr 'where field is required in query argument' ); } - return this.findWithFluentCallStubs(args, 'findUnique', false, () => null); + + return this.findWithFluent('findUnique', args, () => null); } findUniqueOrThrow(args: any) { @@ -94,17 +95,18 @@ export class PolicyProxyHandler implements Pr 'where field is required in query argument' ); } - return this.findWithFluentCallStubs(args, 'findUniqueOrThrow', true, () => { + + return this.findWithFluent('findUniqueOrThrow', args, () => { throw this.policyUtils.notFound(this.model); }); } findFirst(args?: any) { - return this.findWithFluentCallStubs(args, 'findFirst', false, () => null); + return this.findWithFluent('findFirst', args, () => null); } findFirstOrThrow(args: any) { - return this.findWithFluentCallStubs(args, 'findFirstOrThrow', true, () => { + return this.findWithFluent('findFirstOrThrow', args, () => { throw this.policyUtils.notFound(this.model); }); } @@ -113,20 +115,24 @@ export class PolicyProxyHandler implements Pr return createDeferredPromise(() => this.doFind(args, 'findMany', () => [])); } - // returns a promise for the given find operation, together with function stubs for fluent API calls - private findWithFluentCallStubs( - args: any, - actionName: FindOperations, - resolveRoot: boolean, - handleRejection: () => any - ) { - // create a deferred promise so it's only evaluated when awaited or .then() is called - const result = createDeferredPromise(() => this.doFind(args, actionName, handleRejection)); - this.addFluentFunctions(result, this.model, args?.where, resolveRoot ? result : undefined); - return result; + // make a find query promise with fluent API call stubs installed + private findWithFluent(method: FindOperations, args: any, handleRejection: () => any) { + args = this.policyUtils.clone(args); + return createFluentPromise( + () => this.doFind(args, method, handleRejection), + args, + this.options.modelMeta, + this.model + ); + } + + private addFluentSelect(args: any, field: string, fluentArgs: any) { + // overwrite include/select with the fluent field + delete args.include; + args.select = { [field]: fluentArgs ?? true }; } - private doFind(args: any, actionName: FindOperations, handleRejection: () => any) { + private async doFind(args: any, actionName: FindOperations, handleRejection: () => any) { const origArgs = args; const _args = this.policyUtils.clone(args); if (!this.policyUtils.injectForRead(this.prisma, this.model, _args)) { @@ -142,88 +148,16 @@ export class PolicyProxyHandler implements Pr this.logger.info(`[policy] \`${actionName}\` ${this.model}:\n${formatObject(_args)}`); } - return new Promise((resolve, reject) => { - this.modelClient[actionName](_args).then( - (value: any) => { - this.policyUtils.postProcessForRead(value, this.model, origArgs); - resolve(value); - }, - (err: any) => reject(err) - ); - }); - } - - // returns a fluent API call function - private fluentCall(filter: any, fieldInfo: FieldInfo, rootPromise?: Promise) { - return (args: any) => { - args = this.policyUtils.clone(args); - - // combine the parent filter with the current one - const backLinkField = this.requireBackLink(fieldInfo); - const condition = backLinkField.isArray - ? { [backLinkField.name]: { some: filter } } - : { [backLinkField.name]: { is: filter } }; - args.where = this.policyUtils.and(args.where, condition); - - const promise = createDeferredPromise(() => { - // Promise for fetching - const fetchFluent = (resolve: (value: unknown) => void, reject: (reason?: any) => void) => { - const handler = this.makeHandler(fieldInfo.type); - if (fieldInfo.isArray) { - // fluent call stops here - handler.findMany(args).then( - (value: any) => resolve(value), - (err: any) => reject(err) - ); - } else { - handler.findFirst(args).then( - (value) => resolve(value), - (err) => reject(err) - ); - } - }; - - return new Promise((resolve, reject) => { - if (rootPromise) { - // if a root promise exists, resolve it before fluent API call, - // so that fluent calls start with `findUniqueOrThrow` and `findFirstOrThrow` - // can throw error properly if the root promise is rejected - rootPromise.then( - () => fetchFluent(resolve, reject), - (err) => reject(err) - ); - } else { - fetchFluent(resolve, reject); - } - }); - }); - - if (!fieldInfo.isArray) { - // prepare for a chained fluent API call - this.addFluentFunctions(promise, fieldInfo.type, args.where, rootPromise); - } - - return promise; - }; - } - - // add fluent API functions to the given promise - private addFluentFunctions(promise: any, model: string, filter: any, rootPromise?: Promise) { - const fields = this.policyUtils.getModelFields(model); - if (fields) { - for (const [field, fieldInfo] of Object.entries(fields)) { - if (fieldInfo.isDataModel) { - promise[field] = this.fluentCall(filter, fieldInfo, rootPromise); - } - } - } + const result = await this.modelClient[actionName](_args); + this.policyUtils.postProcessForRead(result, this.model, origArgs); + return result; } //#endregion //#region Create - async create(args: any) { + create(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } @@ -235,63 +169,65 @@ export class PolicyProxyHandler implements Pr ); } - this.policyUtils.tryReject(this.prisma, this.model, 'create'); + return createDeferredPromise(async () => { + this.policyUtils.tryReject(this.prisma, this.model, 'create'); - const origArgs = args; - args = this.policyUtils.clone(args); + const origArgs = args; + args = this.policyUtils.clone(args); - // static input policy check for top-level create data - const inputCheck = this.policyUtils.checkInputGuard(this.model, args.data, 'create'); - if (inputCheck === false) { - throw this.policyUtils.deniedByPolicy( - this.model, - 'create', - undefined, - CrudFailureReason.ACCESS_POLICY_VIOLATION - ); - } + // static input policy check for top-level create data + const inputCheck = this.policyUtils.checkInputGuard(this.model, args.data, 'create'); + if (inputCheck === false) { + throw this.policyUtils.deniedByPolicy( + this.model, + 'create', + undefined, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); + } - const hasNestedCreateOrConnect = await this.hasNestedCreateOrConnect(args); + const hasNestedCreateOrConnect = await this.hasNestedCreateOrConnect(args); - const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { - if ( - // MUST check true here since inputCheck can be undefined (meaning static input check not possible) - inputCheck === true && - // simple create: no nested create/connect - !hasNestedCreateOrConnect - ) { - // there's no nested write and we've passed input check, proceed with the create directly + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { + if ( + // MUST check true here since inputCheck can be undefined (meaning static input check not possible) + inputCheck === true && + // simple create: no nested create/connect + !hasNestedCreateOrConnect + ) { + // there's no nested write and we've passed input check, proceed with the create directly - // validate zod schema if any - args.data = this.validateCreateInputSchema(this.model, args.data); + // validate zod schema if any + args.data = this.validateCreateInputSchema(this.model, args.data); - // make a create args only containing data and ID selection - const createArgs: any = { data: args.data, select: this.policyUtils.makeIdSelection(this.model) }; + // make a create args only containing data and ID selection + const createArgs: any = { data: args.data, select: this.policyUtils.makeIdSelection(this.model) }; - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`create\` ${this.model}: ${formatObject(createArgs)}`); - } - const result = await tx[this.model].create(createArgs); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`create\` ${this.model}: ${formatObject(createArgs)}`); + } + const result = await tx[this.model].create(createArgs); - // filter the read-back data - return this.policyUtils.readBack(tx, this.model, 'create', args, result); - } else { - // proceed with a complex create and collect post-write checks - const { result, postWriteChecks } = await this.doCreate(this.model, args, tx); + // filter the read-back data + return this.policyUtils.readBack(tx, this.model, 'create', args, result); + } else { + // proceed with a complex create and collect post-write checks + const { result, postWriteChecks } = await this.doCreate(this.model, args, tx); - // execute post-write checks - await this.runPostWriteChecks(postWriteChecks, tx); + // execute post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); - // filter the read-back data - return this.policyUtils.readBack(tx, this.model, 'create', origArgs, result); + // filter the read-back data + return this.policyUtils.readBack(tx, this.model, 'create', origArgs, result); + } + }); + + if (error) { + throw error; + } else { + return result; } }); - - if (error) { - throw error; - } else { - return result; - } } // create with nested write @@ -488,7 +424,7 @@ export class PolicyProxyHandler implements Pr } } - async createMany(args: { data: any; skipDuplicates?: boolean }) { + createMany(args: { data: any; skipDuplicates?: boolean }) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } @@ -500,47 +436,49 @@ export class PolicyProxyHandler implements Pr ); } - this.policyUtils.tryReject(this.prisma, this.model, 'create'); + return createDeferredPromise(async () => { + this.policyUtils.tryReject(this.prisma, this.model, 'create'); - args = this.policyUtils.clone(args); + args = this.policyUtils.clone(args); - // go through create items, statically check input to determine if post-create - // check is needed, and also validate zod schema - let needPostCreateCheck = false; - for (const item of enumerate(args.data)) { - const validationResult = this.validateCreateInputSchema(this.model, item); - if (validationResult !== item) { - this.policyUtils.replace(item, validationResult); - } + // go through create items, statically check input to determine if post-create + // check is needed, and also validate zod schema + let needPostCreateCheck = false; + for (const item of enumerate(args.data)) { + const validationResult = this.validateCreateInputSchema(this.model, item); + if (validationResult !== item) { + this.policyUtils.replace(item, validationResult); + } - const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); - if (inputCheck === false) { - // unconditionally deny - throw this.policyUtils.deniedByPolicy( - this.model, - 'create', - undefined, - CrudFailureReason.ACCESS_POLICY_VIOLATION - ); - } else if (inputCheck === true) { - // unconditionally allow - } else if (inputCheck === undefined) { - // static policy check is not possible, need to do post-create check - needPostCreateCheck = true; + const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); + if (inputCheck === false) { + // unconditionally deny + throw this.policyUtils.deniedByPolicy( + this.model, + 'create', + undefined, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); + } else if (inputCheck === true) { + // unconditionally allow + } else if (inputCheck === undefined) { + // static policy check is not possible, need to do post-create check + needPostCreateCheck = true; + } } - } - if (!needPostCreateCheck) { - return this.modelClient.createMany(args); - } else { - // create entities in a transaction with post-create checks - return this.queryUtils.transaction(this.prisma, async (tx) => { - const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx); - // post-create check - await this.runPostWriteChecks(postWriteChecks, tx); - return result; - }); - } + if (!needPostCreateCheck) { + return this.modelClient.createMany(args); + } else { + // create entities in a transaction with post-create checks + return this.queryUtils.transaction(this.prisma, async (tx) => { + const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx); + // post-create check + await this.runPostWriteChecks(postWriteChecks, tx); + return result; + }); + } + }); } private async doCreateMany(model: string, args: { data: any; skipDuplicates?: boolean }, db: CrudContract) { @@ -662,7 +600,7 @@ export class PolicyProxyHandler implements Pr // "updateMany" works against a set of entities, entities not passing policy check are silently // ignored - async update(args: any) { + update(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } @@ -681,24 +619,26 @@ export class PolicyProxyHandler implements Pr ); } - args = this.policyUtils.clone(args); + return createDeferredPromise(async () => { + args = this.policyUtils.clone(args); - const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { - // proceed with nested writes and collect post-write checks - const { result, postWriteChecks } = await this.doUpdate(args, tx); + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { + // proceed with nested writes and collect post-write checks + const { result, postWriteChecks } = await this.doUpdate(args, tx); - // post-write check - await this.runPostWriteChecks(postWriteChecks, tx); + // post-write check + await this.runPostWriteChecks(postWriteChecks, tx); - // filter the read-back data - return this.policyUtils.readBack(tx, this.model, 'update', args, result); - }); + // filter the read-back data + return this.policyUtils.readBack(tx, this.model, 'update', args, result); + }); - if (error) { - throw error; - } else { - return result; - } + if (error) { + throw error; + } else { + return result; + } + }); } private async doUpdate(args: any, db: CrudContract) { @@ -1131,7 +1071,7 @@ export class PolicyProxyHandler implements Pr } } - async updateMany(args: any) { + updateMany(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } @@ -1147,58 +1087,60 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, this.options, 'data field is required in query argument'); } - this.policyUtils.tryReject(this.prisma, this.model, 'update'); + return createDeferredPromise(() => { + this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.policyUtils.clone(args); - this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); - - args.data = this.validateUpdateInputSchema(this.model, args.data); - - if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { - // use a transaction to do post-update checks - const postWriteChecks: PostWriteCheckRecord[] = []; - return this.queryUtils.transaction(this.prisma, async (tx) => { - // collect pre-update values - let select = this.policyUtils.makeIdSelection(this.model); - const preValueSelect = this.policyUtils.getPreValueSelect(this.model); - if (preValueSelect) { - select = { ...select, ...preValueSelect }; - } - const currentSetQuery = { select, where: args.where }; - this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); + args = this.policyUtils.clone(args); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); - } - const currentSet = await tx[this.model].findMany(currentSetQuery); - - postWriteChecks.push( - ...currentSet.map((preValue) => ({ - model: this.model, - operation: 'postUpdate' as PolicyOperationKind, - uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), - preValue: preValueSelect ? preValue : undefined, - })) - ); + args.data = this.validateUpdateInputSchema(this.model, args.data); - // proceed with the update - const result = await tx[this.model].updateMany(args); + if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { + // use a transaction to do post-update checks + const postWriteChecks: PostWriteCheckRecord[] = []; + return this.queryUtils.transaction(this.prisma, async (tx) => { + // collect pre-update values + let select = this.policyUtils.makeIdSelection(this.model); + const preValueSelect = this.policyUtils.getPreValueSelect(this.model); + if (preValueSelect) { + select = { ...select, ...preValueSelect }; + } + const currentSetQuery = { select, where: args.where }; + this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); - // run post-write checks - await this.runPostWriteChecks(postWriteChecks, tx); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); + } + const currentSet = await tx[this.model].findMany(currentSetQuery); - return result; - }); - } else { - // proceed without a transaction - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`); + postWriteChecks.push( + ...currentSet.map((preValue) => ({ + model: this.model, + operation: 'postUpdate' as PolicyOperationKind, + uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), + preValue: preValueSelect ? preValue : undefined, + })) + ); + + // proceed with the update + const result = await tx[this.model].updateMany(args); + + // run post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + + return result; + }); + } else { + // proceed without a transaction + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`); + } + return this.modelClient.updateMany(args); } - return this.modelClient.updateMany(args); - } + }); } - async upsert(args: any) { + upsert(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } @@ -1224,36 +1166,38 @@ export class PolicyProxyHandler implements Pr ); } - this.policyUtils.tryReject(this.prisma, this.model, 'create'); - this.policyUtils.tryReject(this.prisma, this.model, 'update'); + return createDeferredPromise(async () => { + this.policyUtils.tryReject(this.prisma, this.model, 'create'); + this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.policyUtils.clone(args); + args = this.policyUtils.clone(args); - // We can call the native "upsert" because we can't tell if an entity was created or updated - // for doing post-write check accordingly. Instead, decompose it into create or update. + // We can call the native "upsert" because we can't tell if an entity was created or updated + // for doing post-write check accordingly. Instead, decompose it into create or update. - const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { - const { where, create, update, ...rest } = args; - const existing = await this.policyUtils.checkExistence(tx, this.model, args.where); + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { + const { where, create, update, ...rest } = args; + const existing = await this.policyUtils.checkExistence(tx, this.model, args.where); - if (existing) { - // update case - const { result, postWriteChecks } = await this.doUpdate({ where, data: update, ...rest }, tx); - await this.runPostWriteChecks(postWriteChecks, tx); - return this.policyUtils.readBack(tx, this.model, 'update', args, result); + if (existing) { + // update case + const { result, postWriteChecks } = await this.doUpdate({ where, data: update, ...rest }, tx); + await this.runPostWriteChecks(postWriteChecks, tx); + return this.policyUtils.readBack(tx, this.model, 'update', args, result); + } else { + // create case + const { result, postWriteChecks } = await this.doCreate(this.model, { data: create, ...rest }, tx); + await this.runPostWriteChecks(postWriteChecks, tx); + return this.policyUtils.readBack(tx, this.model, 'create', args, result); + } + }); + + if (error) { + throw error; } else { - // create case - const { result, postWriteChecks } = await this.doCreate(this.model, { data: create, ...rest }, tx); - await this.runPostWriteChecks(postWriteChecks, tx); - return this.policyUtils.readBack(tx, this.model, 'create', args, result); + return result; } }); - - if (error) { - throw error; - } else { - return result; - } } //#endregion @@ -1263,7 +1207,7 @@ export class PolicyProxyHandler implements Pr // "delete" works against a single entity, and is rejected if the entity fails policy check. // "deleteMany" works against a set of entities, entities that fail policy check are filtered out. - async delete(args: any) { + delete(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } @@ -1275,144 +1219,156 @@ export class PolicyProxyHandler implements Pr ); } - this.policyUtils.tryReject(this.prisma, this.model, 'delete'); + return createDeferredPromise(async () => { + this.policyUtils.tryReject(this.prisma, this.model, 'delete'); - const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { - // do a read-back before delete - const r = await this.policyUtils.readBack(tx, this.model, 'delete', args, args.where); - const error = r.error; - const read = r.result; + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { + // do a read-back before delete + const r = await this.policyUtils.readBack(tx, this.model, 'delete', args, args.where); + const error = r.error; + const read = r.result; - // check existence - await this.policyUtils.checkExistence(tx, this.model, args.where, true); + // check existence + await this.policyUtils.checkExistence(tx, this.model, args.where, true); - // inject delete guard - await this.policyUtils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); + // inject delete guard + await this.policyUtils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); - // proceed with the deletion - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`delete\` ${this.model}:\n${formatObject(args)}`); - } - await tx[this.model].delete(args); + // proceed with the deletion + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`delete\` ${this.model}:\n${formatObject(args)}`); + } + await tx[this.model].delete(args); - return { result: read, error }; - }); + return { result: read, error }; + }); - if (error) { - throw error; - } else { - return result; - } + if (error) { + throw error; + } else { + return result; + } + }); } - async deleteMany(args: any) { - this.policyUtils.tryReject(this.prisma, this.model, 'delete'); + deleteMany(args: any) { + return createDeferredPromise(() => { + this.policyUtils.tryReject(this.prisma, this.model, 'delete'); - // inject policy conditions - args = args ?? {}; - this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); + // inject policy conditions + args = args ?? {}; + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); - // conduct the deletion - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); - } - return this.modelClient.deleteMany(args); + // conduct the deletion + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.deleteMany(args); + }); } //#endregion //#region Aggregation - async aggregate(args: any) { + aggregate(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } - args = this.policyUtils.clone(args); + return createDeferredPromise(() => { + args = this.policyUtils.clone(args); - // inject policy conditions - this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + // inject policy conditions + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); - } - return this.modelClient.aggregate(args); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.aggregate(args); + }); } - async groupBy(args: any) { + groupBy(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } - args = this.policyUtils.clone(args); + return createDeferredPromise(() => { + args = this.policyUtils.clone(args); - // inject policy conditions - this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + // inject policy conditions + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); - } - return this.modelClient.groupBy(args); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.groupBy(args); + }); } - async count(args: any) { - // inject policy conditions - args = args ? this.policyUtils.clone(args) : {}; - this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + count(args: any) { + return createDeferredPromise(() => { + // inject policy conditions + args = args ? this.policyUtils.clone(args) : {}; + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`); - } - return this.modelClient.count(args); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.count(args); + }); } //#endregion //#region Subscribe (Prisma Pulse) - async subscribe(args: any) { - const readGuard = this.policyUtils.getAuthGuard(this.prisma, this.model, 'read'); - if (this.policyUtils.isTrue(readGuard)) { - // no need to inject - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); + subscribe(args: any) { + return createDeferredPromise(() => { + const readGuard = this.policyUtils.getAuthGuard(this.prisma, this.model, 'read'); + if (this.policyUtils.isTrue(readGuard)) { + // no need to inject + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.subscribe(args); } - return this.modelClient.subscribe(args); - } - if (!args) { - // include all - args = { create: {}, update: {}, delete: {} }; - } else { - if (typeof args !== 'object') { - throw prismaClientValidationError(this.prisma, this.prismaModule, 'argument must be an object'); - } - if (Object.keys(args).length === 0) { + if (!args) { // include all args = { create: {}, update: {}, delete: {} }; } else { - args = this.policyUtils.clone(args); + if (typeof args !== 'object') { + throw prismaClientValidationError(this.prisma, this.prismaModule, 'argument must be an object'); + } + if (Object.keys(args).length === 0) { + // include all + args = { create: {}, update: {}, delete: {} }; + } else { + args = this.policyUtils.clone(args); + } } - } - // inject into subscribe conditions + // inject into subscribe conditions - if (args.create) { - args.create.after = this.policyUtils.and(args.create.after, readGuard); - } + if (args.create) { + args.create.after = this.policyUtils.and(args.create.after, readGuard); + } - if (args.update) { - args.update.after = this.policyUtils.and(args.update.after, readGuard); - } + if (args.update) { + args.update.after = this.policyUtils.and(args.update.after, readGuard); + } - if (args.delete) { - args.delete.before = this.policyUtils.and(args.delete.before, readGuard); - } + if (args.delete) { + args.delete.before = this.policyUtils.and(args.delete.before, readGuard); + } - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); - } - return this.modelClient.subscribe(args); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.subscribe(args); + }); } //#endregion @@ -1431,10 +1387,6 @@ export class PolicyProxyHandler implements Pr ); } - private makeHandler(model: string) { - return new PolicyProxyHandler(this.prisma, model, this.options, this.context); - } - private requireBackLink(fieldInfo: FieldInfo) { invariant(fieldInfo.backLink, `back link not found for field ${fieldInfo.name}`); return requireField(this.modelMeta, fieldInfo.type, fieldInfo.backLink); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 5df9bed70..bc313f7c3 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -897,16 +897,21 @@ export class PolicyUtil extends QueryUtils { * @returns */ injectReadCheckSelect(model: string, args: any) { - if (!this.hasFieldLevelPolicy(model)) { - return; + if (this.hasFieldLevelPolicy(model)) { + // recursively inject selection for fields needed for field-level read checks + const readFieldSelect = this.getReadFieldSelect(model); + if (readFieldSelect) { + this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); + } } - const readFieldSelect = this.getReadFieldSelect(model); - if (!readFieldSelect) { - return; + // recurse into relation fields + for (const [k, v] of Object.entries(args.select ?? args.include ?? {})) { + const field = resolveField(this.modelMeta, model, k); + if (field?.isDataModel && v && typeof v === 'object') { + this.injectReadCheckSelect(field.type, v); + } } - - this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); } private doInjectReadCheckSelect(model: string, args: any, input: any) { diff --git a/packages/runtime/src/enhancements/policy/promise.ts b/packages/runtime/src/enhancements/policy/promise.ts deleted file mode 100644 index b6d7baff9..000000000 --- a/packages/runtime/src/enhancements/policy/promise.ts +++ /dev/null @@ -1,38 +0,0 @@ -/* eslint-disable @typescript-eslint/no-explicit-any */ - -/** - * Creates a promise that only executes when it's awaited or .then() is called. - * @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts - */ -export function createDeferredPromise(callback: () => Promise): Promise { - let promise: Promise | undefined; - const cb = () => { - try { - return (promise ??= valueToPromise(callback())); - } catch (err) { - // deal with synchronous errors - return Promise.reject(err); - } - }; - - return { - then(onFulfilled, onRejected) { - return cb().then(onFulfilled, onRejected); - }, - catch(onRejected) { - return cb().catch(onRejected); - }, - finally(onFinally) { - return cb().finally(onFinally); - }, - [Symbol.toStringTag]: 'ZenStackPromise', - }; -} - -function valueToPromise(thing: any): Promise { - if (typeof thing === 'object' && typeof thing?.then === 'function') { - return thing; - } else { - return Promise.resolve(thing); - } -} diff --git a/packages/runtime/src/enhancements/promise.ts b/packages/runtime/src/enhancements/promise.ts new file mode 100644 index 000000000..28a211146 --- /dev/null +++ b/packages/runtime/src/enhancements/promise.ts @@ -0,0 +1,99 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import { getModelInfo, type ModelMeta } from '../cross'; + +/** + * Creates a promise that only executes when it's awaited or .then() is called. + * @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts + */ +export function createDeferredPromise(callback: () => Promise): Promise { + let promise: Promise | undefined; + const cb = () => { + try { + return (promise ??= valueToPromise(callback())); + } catch (err) { + // deal with synchronous errors + return Promise.reject(err); + } + }; + + return { + then(onFulfilled, onRejected) { + return cb().then(onFulfilled, onRejected); + }, + catch(onRejected) { + return cb().catch(onRejected); + }, + finally(onFinally) { + return cb().finally(onFinally); + }, + [Symbol.toStringTag]: 'ZenStackPromise', + }; +} + +function valueToPromise(thing: any): Promise { + if (typeof thing === 'object' && typeof thing?.then === 'function') { + return thing; + } else { + return Promise.resolve(thing); + } +} + +/** + * Create a deferred promise with fluent API call stub installed. + * + * @param callback The callback to execute when the promise is awaited. + * @param parentArgs The parent promise's query args. + * @param modelMeta The model metadata. + * @param model The model name. + */ +export function createFluentPromise( + callback: () => Promise, + parentArgs: any, + modelMeta: ModelMeta, + model: string +): Promise { + const promise: any = createDeferredPromise(callback); + + const modelInfo = getModelInfo(modelMeta, model); + if (!modelInfo) { + return promise; + } + + // install fluent call stub for model fields + Object.values(modelInfo.fields) + .filter((field) => field.isDataModel) + .forEach((field) => { + // e.g., `posts` in `db.user.findUnique(...).posts()` + promise[field.name] = (fluentArgs: any) => { + if (field.isArray) { + // an array relation terminates fluent call chain + return createDeferredPromise(async () => { + setFluentSelect(parentArgs, field.name, fluentArgs ?? true); + const parentResult: any = await promise; + return parentResult?.[field.name] ?? null; + }); + } else { + fluentArgs = { ...fluentArgs }; + // create a chained subsequent fluent call promise + return createFluentPromise( + async () => { + setFluentSelect(parentArgs, field.name, fluentArgs); + const parentResult: any = await promise; + return parentResult?.[field.name] ?? null; + }, + fluentArgs, + modelMeta, + field.type + ); + } + }; + }); + + return promise; +} + +function setFluentSelect(args: any, fluentFieldName: any, fluentArgs: any) { + delete args.include; + args.select = { [fluentFieldName]: fluentArgs }; +} diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index a3141ad0a..e7f55a88c 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -1,10 +1,11 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +import deepcopy from 'deepcopy'; import { PRISMA_PROXY_ENHANCER } from '../constants'; import type { ModelMeta } from '../cross'; import type { DbClientContract } from '../types'; -import { InternalEnhancementOptions } from './create-enhancement'; -import { createDeferredPromise } from './policy/promise'; +import type { InternalEnhancementOptions } from './create-enhancement'; +import { createDeferredPromise, createFluentPromise } from './promise'; /** * Prisma batch write operation result @@ -70,93 +71,91 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { protected readonly options: InternalEnhancementOptions ) {} - async findUnique(args: any): Promise { - args = await this.preprocessArgs('findUnique', args); - const r = await this.prisma[this.model].findUnique(args); - return this.processResultEntity(r); + protected withFluentCall(method: keyof PrismaProxyHandler, args: any, postProcess = true): Promise { + args = args ? deepcopy(args) : {}; + const promise = createFluentPromise( + async () => { + args = await this.preprocessArgs(method, args); + const r = await this.prisma[this.model][method](args); + return postProcess ? this.processResultEntity(r) : r; + }, + args, + this.options.modelMeta, + this.model + ); + return promise; } - async findUniqueOrThrow(args: any): Promise { - args = await this.preprocessArgs('findUniqueOrThrow', args); - const r = await this.prisma[this.model].findUniqueOrThrow(args); - return this.processResultEntity(r); + protected deferred(method: keyof PrismaProxyHandler, args: any, postProcess = true) { + return createDeferredPromise(async () => { + args = await this.preprocessArgs(method, args); + const r = await this.prisma[this.model][method](args); + return postProcess ? this.processResultEntity(r) : r; + }); } - async findFirst(args: any): Promise { - args = await this.preprocessArgs('findFirst', args); - const r = await this.prisma[this.model].findFirst(args); - return this.processResultEntity(r); + findUnique(args: any) { + return this.withFluentCall('findUnique', args); } - async findFirstOrThrow(args: any): Promise { - args = await this.preprocessArgs('findFirstOrThrow', args); - const r = await this.prisma[this.model].findFirstOrThrow(args); - return this.processResultEntity(r); + findUniqueOrThrow(args: any) { + return this.withFluentCall('findUniqueOrThrow', args); } - async findMany(args: any): Promise { - args = await this.preprocessArgs('findMany', args); - const r = await this.prisma[this.model].findMany(args); - return this.processResultEntity(r); + findFirst(args: any) { + return this.withFluentCall('findFirst', args); } - async create(args: any): Promise { - args = await this.preprocessArgs('create', args); - const r = await this.prisma[this.model].create(args); - return this.processResultEntity(r); + findFirstOrThrow(args: any) { + return this.withFluentCall('findFirstOrThrow', args); } - async createMany(args: { data: any; skipDuplicates?: boolean }): Promise<{ count: number }> { - args = await this.preprocessArgs('createMany', args); - return this.prisma[this.model].createMany(args); + findMany(args: any) { + return this.deferred('findMany', args); } - async update(args: any): Promise { - args = await this.preprocessArgs('update', args); - const r = await this.prisma[this.model].update(args); - return this.processResultEntity(r); + create(args: any): Promise { + return this.deferred('create', args); } - async updateMany(args: any): Promise<{ count: number }> { - args = await this.preprocessArgs('updateMany', args); - return this.prisma[this.model].updateMany(args); + createMany(args: { data: any; skipDuplicates?: boolean }) { + return this.deferred<{ count: number }>('createMany', args, false); } - async upsert(args: any): Promise { - args = await this.preprocessArgs('upsert', args); - const r = await this.prisma[this.model].upsert(args); - return this.processResultEntity(r); + update(args: any) { + return this.deferred('update', args); } - async delete(args: any): Promise { - args = await this.preprocessArgs('delete', args); - const r = await this.prisma[this.model].delete(args); - return this.processResultEntity(r); + updateMany(args: any) { + return this.deferred<{ count: number }>('updateMany', args, false); } - async deleteMany(args: any): Promise<{ count: number }> { - args = await this.preprocessArgs('deleteMany', args); - return this.prisma[this.model].deleteMany(args); + upsert(args: any) { + return this.deferred('upsert', args); } - async aggregate(args: any): Promise { - args = await this.preprocessArgs('aggregate', args); - return this.prisma[this.model].aggregate(args); + delete(args: any) { + return this.deferred('delete', args); } - async groupBy(args: any): Promise { - args = await this.preprocessArgs('groupBy', args); - return this.prisma[this.model].groupBy(args); + deleteMany(args: any) { + return this.deferred<{ count: number }>('deleteMany', args, false); } - async count(args: any): Promise { - args = await this.preprocessArgs('count', args); - return this.prisma[this.model].count(args); + aggregate(args: any) { + return this.deferred('aggregate', args, false); } - async subscribe(args: any): Promise { - args = await this.preprocessArgs('subscribe', args); - return this.prisma[this.model].subscribe(args); + groupBy(args: any) { + return this.deferred('groupBy', args, false); + } + + count(args: any): Promise { + return this.deferred('count', args, false); + } + + subscribe(args: any) { + return this.deferred('subscribe', args, false); } /** @@ -177,6 +176,8 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { // a marker for filtering error stack trace const ERROR_MARKER = '__error_marker__'; +const customInspect = Symbol.for('nodejs.util.inspect.custom'); + /** * Makes a Prisma client proxy. */ @@ -196,10 +197,6 @@ export function makeProxy( return name; } - if (prop === 'toString') { - return () => `$zenstack_prisma_${prisma._clientVersion}`; - } - if (prop === '$transaction') { // for interactive transactions, we need to proxy the transaction function so that // when it runs the callback, it provides a proxy to the Prisma client wrapped with @@ -245,6 +242,8 @@ export function makeProxy( }, }); + proxy[customInspect] = `$zenstack_prisma_${prisma._clientVersion}`; + return proxy; } diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index ebaf2d858..de778e8e8 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -51,6 +51,18 @@ describe('Policy: field-level policy', () => { r = await db.model.findUnique({ where: { id: 1 } }); expect(r.y).toBeUndefined(); + r = await db.user.findUnique({ where: { id: 1 }, select: { models: true } }); + expect(r.models[0].y).toBeUndefined(); + + r = await db.user.findUnique({ where: { id: 1 }, select: { models: { select: { y: true } } } }); + expect(r.models[0].y).toBeUndefined(); + + r = await db.user.findUnique({ where: { id: 1 } }).models(); + expect(r[0].y).toBeUndefined(); + + r = await db.user.findUnique({ where: { id: 1 } }).models({ select: { y: true } }); + expect(r[0].y).toBeUndefined(); + r = await db.model.findUnique({ select: { x: true }, where: { id: 1 } }); expect(r.x).toEqual(0); expect(r.y).toBeUndefined(); @@ -82,6 +94,21 @@ describe('Policy: field-level policy', () => { r = await db.model.findUnique({ where: { id: 2 } }); expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + r = await db.user.findUnique({ where: { id: 1 }, select: { models: { where: { id: 2 } } } }); + expect(r.models[0]).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.user.findUnique({ + where: { id: 1 }, + select: { models: { where: { id: 2 }, select: { y: true } } }, + }); + expect(r.models[0]).toEqual(expect.objectContaining({ y: 0 })); + + r = await db.user.findUnique({ where: { id: 1 } }).models({ where: { id: 2 } }); + expect(r[0]).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.user.findUnique({ where: { id: 1 } }).models({ where: { id: 2 }, select: { y: true } }); + expect(r[0]).toEqual(expect.objectContaining({ y: 0 })); + r = await db.model.findUnique({ select: { x: true }, where: { id: 2 } }); expect(r.x).toEqual(1); expect(r.y).toBeUndefined(); diff --git a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts index 6c27aab1c..9dd247d65 100644 --- a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts +++ b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts @@ -12,33 +12,181 @@ describe('With Policy: fluent API', () => { process.chdir(origDir); }); - it('fluent api', async () => { + it('policy tests', async () => { const { enhance, prisma } = await loadSchema( ` model User { id Int @id email String @unique + profile Profile? posts Post[] @@allow('all', true) } +model Profile { + id Int @id + age Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique + @@allow('all', auth() == user) +} + model Post { id Int @id title String author User? @relation(fields: [authorId], references: [id]) authorId Int? published Boolean @default(false) - secret String @default("secret") @allow('read', published == false) + secret String @default("secret") @allow('read', published == false, true) - @@allow('all', author == auth()) -}` + @@allow('read', published) +}`, + { logPrismaQuery: true } ); await prisma.user.create({ data: { id: 1, email: 'a@test.com', + profile: { + create: { id: 1, age: 18 }, + }, + posts: { + create: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: true }, + { id: 3, title: 'post3', published: false }, + ], + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + email: 'b@test.com', + posts: { + create: [{ id: 4, title: 'post4' }], + }, + }, + }); + + const db1 = enhance({ id: 1 }); + const db2 = enhance({ id: 2 }); + + // check policies + await expect(db1.user.findUnique({ where: { id: 1 } }).posts()).resolves.toHaveLength(2); + await expect(db2.user.findUnique({ where: { id: 2 } }).posts()).resolves.toHaveLength(0); + await expect( + db1.user.findUnique({ where: { id: 1 } }).posts({ where: { published: true } }) + ).resolves.toHaveLength(2); + await expect(db1.user.findUnique({ where: { id: 1 } }).posts({ take: 1 })).resolves.toHaveLength(1); + + // field-level policies + let p = ( + await db1.user + .findUnique({ where: { id: 1 } }) + .posts({ where: { published: true }, select: { secret: true } }) + )[0]; + expect(p.secret).toBeUndefined(); + p = ( + await db1.user + .findUnique({ where: { id: 1 } }) + .posts({ where: { published: false }, select: { secret: true } }) + )[0]; + expect(p.secret).toBeTruthy(); + + // to-one optional + await expect(db1.post.findFirst({ where: { id: 1 } }).author()).resolves.toMatchObject({ + id: 1, + email: 'a@test.com', + }); + await expect(db1.post.findFirst({ where: { id: 1 } }).author({ where: { id: 1 } })).resolves.toMatchObject({ + id: 1, + email: 'a@test.com', + }); + await expect(db1.post.findFirst({ where: { id: 1 } }).author({ where: { id: 2 } })).toResolveNull(); + + // to-one required + await expect(db1.profile.findUnique({ where: { userId: 1 } }).user()).resolves.toMatchObject({ + id: 1, + email: 'a@test.com', + }); + // not found + await expect(db1.profile.findUnique({ where: { userId: 2 } }).user()).toResolveNull(); + // not readable + await expect(db2.profile.findUnique({ where: { userId: 1 } }).user()).toResolveNull(); + + // unresolved promise + db1.user.findUniqueOrThrow({ where: { id: 5 } }); + db1.user.findUniqueOrThrow({ where: { id: 5 } }).posts(); + + // not-found + await expect(db1.user.findUniqueOrThrow({ where: { id: 5 } }).posts()).toBeNotFound(); + await expect(db1.user.findFirstOrThrow({ where: { id: 5 } }).posts()).toBeNotFound(); + await expect(db1.post.findUniqueOrThrow({ where: { id: 5 } }).author()).toBeNotFound(); + await expect(db1.post.findFirstOrThrow({ where: { id: 5 } }).author()).toBeNotFound(); + + // chaining + await expect( + db1.post + .findFirst({ where: { id: 1 } }) + .author() + .posts() + ).resolves.toHaveLength(2); + await expect( + db1.post + .findFirst({ where: { id: 1 } }) + .author() + .posts({ where: { published: true } }) + ).resolves.toHaveLength(2); + + // chaining broken + expect((db1.post.findMany() as any).author).toBeUndefined(); + expect( + db1.post + .findFirst({ where: { id: 1 } }) + .author() + .posts().author + ).toBeUndefined(); + }); + + it('non-policy tests', async () => { + const { enhance, prisma } = await loadSchema( + ` +model User { + id Int @id + email String @unique + password String? @omit + profile Profile? + posts Post[] +} + +model Profile { + id Int @id + age Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique +} + +model Post { + id Int @id + title String + author User? @relation(fields: [authorId], references: [id]) + authorId Int? + published Boolean @default(false) +}`, + { enhancements: ['omit'] } + ); + + await prisma.user.create({ + data: { + id: 1, + email: 'a@test.com', + profile: { + create: { id: 1, age: 18 }, + }, posts: { create: [ { id: 1, title: 'post1', published: true }, @@ -58,7 +206,7 @@ model Post { }, }); - const db = enhance({ id: 1 }); + const db = enhance(); // check policies await expect(db.user.findUnique({ where: { id: 1 } }).posts()).resolves.toHaveLength(2); @@ -67,16 +215,24 @@ model Post { ).resolves.toHaveLength(1); await expect(db.user.findUnique({ where: { id: 1 } }).posts({ take: 1 })).resolves.toHaveLength(1); - // field-level policies - let p = (await db.user.findUnique({ where: { id: 1 } }).posts({ where: { published: true } }))[0]; - expect(p.secret).toBeUndefined(); - p = (await db.user.findUnique({ where: { id: 1 } }).posts({ where: { published: false } }))[0]; - expect(p.secret).toBeTruthy(); + // to-one optional + await expect(db.post.findFirst({ where: { id: 1 } }).author()).resolves.toMatchObject({ + id: 1, + email: 'a@test.com', + }); + await expect(db.post.findFirst({ where: { id: 1 } }).author({ where: { id: 1 } })).resolves.toMatchObject({ + id: 1, + email: 'a@test.com', + }); + await expect(db.post.findFirst({ where: { id: 1 } }).author({ where: { id: 2 } })).toResolveNull(); - // to-one - await expect(db.post.findFirst({ where: { id: 1 } }).author()).resolves.toEqual( - expect.objectContaining({ id: 1, email: 'a@test.com' }) - ); + // to-one required + await expect(db.profile.findUnique({ where: { userId: 1 } }).user()).resolves.toMatchObject({ + id: 1, + email: 'a@test.com', + }); + // not found + await expect(db.profile.findUnique({ where: { userId: 2 } }).user()).toResolveNull(); // not-found await expect(db.user.findUniqueOrThrow({ where: { id: 5 } }).posts()).toBeNotFound(); @@ -91,6 +247,12 @@ model Post { .author() .posts() ).resolves.toHaveLength(2); + await expect( + db.post + .findFirst({ where: { id: 1 } }) + .author() + .posts({ where: { published: true } }) + ).resolves.toHaveLength(1); // chaining broken expect((db.post.findMany() as any).author).toBeUndefined();