Skip to content

Commit

Permalink
fix(delegate): enforcing concrete model policies when read from a del…
Browse files Browse the repository at this point in the history
…egate base (#1726)
  • Loading branch information
ymc9 authored Sep 22, 2024
1 parent cb68815 commit 738bba6
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 24 deletions.
34 changes: 31 additions & 3 deletions packages/runtime/src/enhancements/node/create-enhancement.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import semver from 'semver';
import { PRISMA_MINIMUM_VERSION } from '../../constants';
import { isDelegateModel, type ModelMeta } from '../../cross';
import type { EnhancementContext, EnhancementKind, EnhancementOptions, ZodSchemas } from '../../types';
import type {
DbClientContract,
EnhancementContext,
EnhancementKind,
EnhancementOptions,
ZodSchemas,
} from '../../types';
import { withDefaultAuth } from './default-auth';
import { withDelegate } from './delegate';
import { Logger } from './logger';
import { withOmit } from './omit';
import { withPassword } from './password';
import { withPolicy } from './policy';
import { policyProcessIncludeRelationPayload, withPolicy } from './policy';
import type { PolicyDef } from './types';

/**
Expand Down Expand Up @@ -41,6 +47,18 @@ export type InternalEnhancementOptions = EnhancementOptions & {
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
prismaModule: any;

/**
* A callback shared among enhancements to process the payload for including a relation
* field. e.g.: `{ author: true }`.
*/
processIncludeRelationPayload?: (
prisma: DbClientContract,
model: string,
payload: unknown,
options: InternalEnhancementOptions,
context: EnhancementContext | undefined
) => Promise<void>;
};

/**
Expand Down Expand Up @@ -89,7 +107,7 @@ export function createEnhancement<DbClient extends object>(
'Your ZModel contains delegate models but "delegate" enhancement kind is not enabled. This may result in unexpected behavior.'
);
} else {
result = withDelegate(result, options);
result = withDelegate(result, options, context);
}
}

Expand All @@ -103,6 +121,16 @@ export function createEnhancement<DbClient extends object>(
// 'policy' and 'validation' enhancements are both enabled by `withPolicy`
if (kinds.includes('policy') || kinds.includes('validation')) {
result = withPolicy(result, options, context);

// if any enhancement is to introduce an inclusion of a relation field, the
// inclusion payload must be processed by the policy enhancement for injecting
// access control rules

// TODO: this is currently a global callback shared among all enhancements, which
// is far from ideal

options.processIncludeRelationPayload = policyProcessIncludeRelationPayload;

if (kinds.includes('policy') && hasDefaultAuth) {
// @default(auth()) proxy
result = withDefaultAuth(result, options, context);
Expand Down
62 changes: 43 additions & 19 deletions packages/runtime/src/enhancements/node/delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@ import {
isDelegateModel,
resolveField,
} from '../../cross';
import type { CrudContract, DbClientContract } from '../../types';
import type { CrudContract, DbClientContract, EnhancementContext } from '../../types';
import type { InternalEnhancementOptions } from './create-enhancement';
import { Logger } from './logger';
import { DefaultPrismaProxyHandler, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';
import { formatObject, prismaClientValidationError } from './utils';

export function withDelegate<DbClient extends object>(prisma: DbClient, options: InternalEnhancementOptions): DbClient {
export function withDelegate<DbClient extends object>(
prisma: DbClient,
options: InternalEnhancementOptions,
context: EnhancementContext | undefined
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options),
(_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options, context),
'delegate'
);
}
Expand All @@ -35,7 +39,12 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
private readonly logger: Logger;
private readonly queryUtils: QueryUtils;

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
constructor(
prisma: DbClientContract,
model: string,
options: InternalEnhancementOptions,
private readonly context: EnhancementContext | undefined
) {
super(prisma, model, options);
this.logger = new Logger(prisma);
this.queryUtils = new QueryUtils(prisma, this.options);
Expand Down Expand Up @@ -76,7 +85,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = args ? clone(args) : {};

this.injectWhereHierarchy(model, args?.where);
this.injectSelectIncludeHierarchy(model, args);
await this.injectSelectIncludeHierarchy(model, args);

// discriminator field is needed during post process to determine the
// actual concrete model type
Expand Down Expand Up @@ -166,7 +175,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
});
}

private injectSelectIncludeHierarchy(model: string, args: any) {
private async injectSelectIncludeHierarchy(model: string, args: any) {
if (!args || typeof args !== 'object') {
return;
}
Expand All @@ -186,7 +195,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
// make sure the payload is an object
args[kind][field] = {};
}
this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]);
await this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]);
}
}

Expand All @@ -208,7 +217,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
// make sure the payload is an object
args[kind][field] = nextValue = {};
}
this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue);
await this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue);
}
}
}
Expand All @@ -220,11 +229,11 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
this.injectBaseIncludeRecursively(model, args);

// include sub models downwards
this.injectConcreteIncludeRecursively(model, args);
await this.injectConcreteIncludeRecursively(model, args);
}
}

private buildSelectIncludeHierarchy(model: string, args: any) {
private async buildSelectIncludeHierarchy(model: string, args: any) {
args = clone(args);
const selectInclude: any = this.extractSelectInclude(args) || {};

Expand All @@ -248,7 +257,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

if (!selectInclude.select) {
this.injectBaseIncludeRecursively(model, selectInclude);
this.injectConcreteIncludeRecursively(model, selectInclude);
await this.injectConcreteIncludeRecursively(model, selectInclude);
}
return selectInclude;
}
Expand Down Expand Up @@ -319,7 +328,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
this.injectBaseIncludeRecursively(base.name, selectInclude.include[baseRelationName]);
}

private injectConcreteIncludeRecursively(model: string, selectInclude: any) {
private async injectConcreteIncludeRecursively(model: string, selectInclude: any) {
const modelInfo = getModelInfo(this.options.modelMeta, model);
if (!modelInfo) {
return;
Expand All @@ -333,13 +342,27 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
for (const subModel of subModels) {
// include sub model relation field
const subRelationName = this.makeAuxRelationName(subModel);
const includePayload: any = {};

if (this.options.processIncludeRelationPayload) {
// use the callback in options to process the include payload, so enhancements
// like 'policy' can do extra work (e.g., inject policy rules)
await this.options.processIncludeRelationPayload(
this.prisma,
subModel.name,
includePayload,
this.options,
this.context
);
}

if (selectInclude.select) {
selectInclude.include = { [subRelationName]: {}, ...selectInclude.select };
selectInclude.include = { [subRelationName]: includePayload, ...selectInclude.select };
delete selectInclude.select;
} else {
selectInclude.include = { [subRelationName]: {}, ...selectInclude.include };
selectInclude.include = { [subRelationName]: includePayload, ...selectInclude.include };
}
this.injectConcreteIncludeRecursively(subModel.name, selectInclude.include[subRelationName]);
await this.injectConcreteIncludeRecursively(subModel.name, selectInclude.include[subRelationName]);
}
}

Expand Down Expand Up @@ -480,7 +503,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

await this.injectCreateHierarchy(model, args);
this.injectSelectIncludeHierarchy(model, args);
await this.injectSelectIncludeHierarchy(model, args);

if (this.options.logPrismaQuery) {
this.logger.info(`[delegate] \`create\` ${this.getModelName(model)}: ${formatObject(args)}`);
Expand Down Expand Up @@ -702,7 +725,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

args = clone(args);
this.injectWhereHierarchy(this.model, (args as any)?.where);
this.injectSelectIncludeHierarchy(this.model, args);
await this.injectSelectIncludeHierarchy(this.model, args);
if (args.create) {
this.doProcessCreatePayload(this.model, args.create);
}
Expand All @@ -721,7 +744,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

await this.injectUpdateHierarchy(db, model, args);
this.injectSelectIncludeHierarchy(model, args);
await this.injectSelectIncludeHierarchy(model, args);

if (this.options.logPrismaQuery) {
this.logger.info(`[delegate] \`update\` ${this.getModelName(model)}: ${formatObject(args)}`);
Expand Down Expand Up @@ -915,7 +938,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}

return this.queryUtils.transaction(this.prisma, async (tx) => {
const selectInclude = this.buildSelectIncludeHierarchy(this.model, args);
const selectInclude = await this.buildSelectIncludeHierarchy(this.model, args);

// make sure id fields are selected
const idFields = this.getIdFields(this.model);
Expand Down Expand Up @@ -967,6 +990,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

private async doDelete(db: CrudContract, model: string, args: any): Promise<unknown> {
this.injectWhereHierarchy(model, args.where);
await this.injectSelectIncludeHierarchy(model, args);

if (this.options.logPrismaQuery) {
this.logger.info(`[delegate] \`delete\` ${this.getModelName(model)}: ${formatObject(args)}`);
Expand Down
18 changes: 18 additions & 0 deletions packages/runtime/src/enhancements/node/policy/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { InternalEnhancementOptions } from '../create-enhancement';
import { Logger } from '../logger';
import { makeProxy } from '../proxy';
import { PolicyProxyHandler } from './handler';
import { PolicyUtil } from './policy-utils';

/**
* Gets an enhanced Prisma client with access policy check.
Expand Down Expand Up @@ -60,3 +61,20 @@ export function withPolicy<DbClient extends object>(
options?.errorTransformer
);
}

/**
* Function for processing a payload for including a relation field in a query.
* @param model The relation's model name
* @param payload The payload to process
*/
export async function policyProcessIncludeRelationPayload(
prisma: DbClientContract,
model: string,
payload: unknown,
options: InternalEnhancementOptions,
context: EnhancementContext | undefined
) {
const utils = new PolicyUtil(prisma, options, context);
await utils.injectForRead(prisma, model, payload);
await utils.injectReadCheckSelect(model, payload);
}
3 changes: 3 additions & 0 deletions packages/runtime/src/enhancements/node/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,9 @@ export class PolicyUtil extends QueryUtils {
}
const result = await db[model].findFirst(readArgs);
if (!result) {
if (this.shouldLogQuery) {
this.logger.info(`[policy] cannot read back ${model}`);
}
return { error, result: undefined };
}

Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/enhancements/node/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
protected readonly options: InternalEnhancementOptions
) {}

protected withFluentCall(method: keyof PrismaProxyHandler, args: any, postProcess = true): Promise<unknown> {
protected withFluentCall(method: PrismaProxyActions, args: any, postProcess = true): Promise<unknown> {
args = args ? clone(args) : {};
const promise = createFluentPromise(
async () => {
Expand All @@ -84,7 +84,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
return promise;
}

protected deferred<TResult = unknown>(method: keyof PrismaProxyHandler, args: any, postProcess = true) {
protected deferred<TResult = unknown>(method: PrismaProxyActions, args: any, postProcess = true) {
return createDeferredPromise<TResult>(async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
Expand Down
Loading

0 comments on commit 738bba6

Please sign in to comment.