Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(delegate): enforcing concrete model policies when read from a delegate base #1726

Merged
merged 3 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions packages/runtime/src/enhancements/node/create-enhancement.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import semver from 'semver';
import { PRISMA_MINIMUM_VERSION } from '../../constants';
import { isDelegateModel, type ModelMeta } from '../../cross';
import type { EnhancementContext, EnhancementKind, EnhancementOptions, ZodSchemas } from '../../types';
import type {
DbClientContract,
EnhancementContext,
EnhancementKind,
EnhancementOptions,
ZodSchemas,
} from '../../types';
import { withDefaultAuth } from './default-auth';
import { withDelegate } from './delegate';
import { Logger } from './logger';
import { withOmit } from './omit';
import { withPassword } from './password';
import { withPolicy } from './policy';
import { policyProcessIncludeRelationPayload, withPolicy } from './policy';
import type { PolicyDef } from './types';

/**
Expand Down Expand Up @@ -41,6 +47,18 @@ export type InternalEnhancementOptions = EnhancementOptions & {
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
prismaModule: any;

/**
* A callback shared among enhancements to process the payload for including a relation
* field. e.g.: `{ author: true }`.
*/
processIncludeRelationPayload?: (
prisma: DbClientContract,
model: string,
payload: unknown,
options: InternalEnhancementOptions,
context: EnhancementContext | undefined
) => Promise<void>;
};

/**
Expand All @@ -53,7 +71,7 @@ export type InternalEnhancementOptions = EnhancementOptions & {
* @param context Context.
* @param options Options.
*/
export function createEnhancement<DbClient extends object>(
export function createEnhancement<DbClient extends DbClientContract>(
prisma: DbClient,
options: InternalEnhancementOptions,
context?: EnhancementContext
Expand Down Expand Up @@ -89,7 +107,7 @@ export function createEnhancement<DbClient extends object>(
'Your ZModel contains delegate models but "delegate" enhancement kind is not enabled. This may result in unexpected behavior.'
);
} else {
result = withDelegate(result, options);
result = withDelegate(result, options, context);
}
}

Expand All @@ -103,6 +121,16 @@ export function createEnhancement<DbClient extends object>(
// 'policy' and 'validation' enhancements are both enabled by `withPolicy`
if (kinds.includes('policy') || kinds.includes('validation')) {
result = withPolicy(result, options, context);

// if any enhancement is to introduce an inclusion of a relation field, the
// inclusion payload must be processed by the policy enhancement for injecting
// access control rules

// TODO: this is currently a global callback shared among all enhancements, which
// is far from ideal

options.processIncludeRelationPayload = policyProcessIncludeRelationPayload;

if (kinds.includes('policy') && hasDefaultAuth) {
// @default(auth()) proxy
result = withDefaultAuth(result, options, context);
Expand Down
62 changes: 43 additions & 19 deletions packages/runtime/src/enhancements/node/delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@ import {
isDelegateModel,
resolveField,
} from '../../cross';
import type { CrudContract, DbClientContract } from '../../types';
import type { CrudContract, DbClientContract, EnhancementContext } from '../../types';
import type { InternalEnhancementOptions } from './create-enhancement';
import { Logger } from './logger';
import { DefaultPrismaProxyHandler, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';
import { formatObject, prismaClientValidationError } from './utils';

export function withDelegate<DbClient extends object>(prisma: DbClient, options: InternalEnhancementOptions): DbClient {
export function withDelegate<DbClient extends object>(
prisma: DbClient,
options: InternalEnhancementOptions,
context: EnhancementContext | undefined
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options),
(_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options, context),
'delegate'
);
}
Expand All @@ -35,7 +39,12 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
private readonly logger: Logger;
private readonly queryUtils: QueryUtils;

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
constructor(
prisma: DbClientContract,
model: string,
options: InternalEnhancementOptions,
private readonly context: EnhancementContext | undefined
) {
super(prisma, model, options);
this.logger = new Logger(prisma);
this.queryUtils = new QueryUtils(prisma, this.options);
Expand Down Expand Up @@ -76,7 +85,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = args ? clone(args) : {};

this.injectWhereHierarchy(model, args?.where);
this.injectSelectIncludeHierarchy(model, args);
await this.injectSelectIncludeHierarchy(model, args);

// discriminator field is needed during post process to determine the
// actual concrete model type
Expand Down Expand Up @@ -166,7 +175,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
});
}

private injectSelectIncludeHierarchy(model: string, args: any) {
private async injectSelectIncludeHierarchy(model: string, args: any) {
if (!args || typeof args !== 'object') {
return;
}
Expand All @@ -186,7 +195,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
// make sure the payload is an object
args[kind][field] = {};
}
this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]);
await this.injectSelectIncludeHierarchy(fieldInfo.type, args[kind][field]);
}
}

Expand All @@ -208,7 +217,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
// make sure the payload is an object
args[kind][field] = nextValue = {};
}
this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue);
await this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue);
}
}
}
Expand All @@ -220,11 +229,11 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
this.injectBaseIncludeRecursively(model, args);

// include sub models downwards
this.injectConcreteIncludeRecursively(model, args);
await this.injectConcreteIncludeRecursively(model, args);
}
}

private buildSelectIncludeHierarchy(model: string, args: any) {
private async buildSelectIncludeHierarchy(model: string, args: any) {
args = clone(args);
const selectInclude: any = this.extractSelectInclude(args) || {};

Expand All @@ -248,7 +257,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

if (!selectInclude.select) {
this.injectBaseIncludeRecursively(model, selectInclude);
this.injectConcreteIncludeRecursively(model, selectInclude);
await this.injectConcreteIncludeRecursively(model, selectInclude);
}
return selectInclude;
}
Expand Down Expand Up @@ -319,7 +328,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
this.injectBaseIncludeRecursively(base.name, selectInclude.include[baseRelationName]);
}

private injectConcreteIncludeRecursively(model: string, selectInclude: any) {
private async injectConcreteIncludeRecursively(model: string, selectInclude: any) {
const modelInfo = getModelInfo(this.options.modelMeta, model);
if (!modelInfo) {
return;
Expand All @@ -333,13 +342,27 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
for (const subModel of subModels) {
// include sub model relation field
const subRelationName = this.makeAuxRelationName(subModel);
const includePayload: any = {};

if (this.options.processIncludeRelationPayload) {
// use the callback in options to process the include payload, so enhancements
// like 'policy' can do extra work (e.g., inject policy rules)
await this.options.processIncludeRelationPayload(
this.prisma,
subModel.name,
includePayload,
this.options,
this.context
);
}

if (selectInclude.select) {
selectInclude.include = { [subRelationName]: {}, ...selectInclude.select };
selectInclude.include = { [subRelationName]: includePayload, ...selectInclude.select };
delete selectInclude.select;
} else {
selectInclude.include = { [subRelationName]: {}, ...selectInclude.include };
selectInclude.include = { [subRelationName]: includePayload, ...selectInclude.include };
}
this.injectConcreteIncludeRecursively(subModel.name, selectInclude.include[subRelationName]);
await this.injectConcreteIncludeRecursively(subModel.name, selectInclude.include[subRelationName]);
}
}

Expand Down Expand Up @@ -480,7 +503,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

await this.injectCreateHierarchy(model, args);
this.injectSelectIncludeHierarchy(model, args);
await this.injectSelectIncludeHierarchy(model, args);

if (this.options.logPrismaQuery) {
this.logger.info(`[delegate] \`create\` ${this.getModelName(model)}: ${formatObject(args)}`);
Expand Down Expand Up @@ -702,7 +725,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

args = clone(args);
this.injectWhereHierarchy(this.model, (args as any)?.where);
this.injectSelectIncludeHierarchy(this.model, args);
await this.injectSelectIncludeHierarchy(this.model, args);
if (args.create) {
this.doProcessCreatePayload(this.model, args.create);
}
Expand All @@ -721,7 +744,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

await this.injectUpdateHierarchy(db, model, args);
this.injectSelectIncludeHierarchy(model, args);
await this.injectSelectIncludeHierarchy(model, args);

if (this.options.logPrismaQuery) {
this.logger.info(`[delegate] \`update\` ${this.getModelName(model)}: ${formatObject(args)}`);
Expand Down Expand Up @@ -915,7 +938,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}

return this.queryUtils.transaction(this.prisma, async (tx) => {
const selectInclude = this.buildSelectIncludeHierarchy(this.model, args);
const selectInclude = await this.buildSelectIncludeHierarchy(this.model, args);

// make sure id fields are selected
const idFields = this.getIdFields(this.model);
Expand Down Expand Up @@ -967,6 +990,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

private async doDelete(db: CrudContract, model: string, args: any): Promise<unknown> {
this.injectWhereHierarchy(model, args.where);
await this.injectSelectIncludeHierarchy(model, args);

if (this.options.logPrismaQuery) {
this.logger.info(`[delegate] \`delete\` ${this.getModelName(model)}: ${formatObject(args)}`);
Expand Down
20 changes: 19 additions & 1 deletion packages/runtime/src/enhancements/node/policy/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { InternalEnhancementOptions } from '../create-enhancement';
import { Logger } from '../logger';
import { makeProxy } from '../proxy';
import { PolicyProxyHandler } from './handler';
import { PolicyUtil } from './policy-utils';

/**
* Gets an enhanced Prisma client with access policy check.
Expand All @@ -18,7 +19,7 @@ import { PolicyProxyHandler } from './handler';
*
* @private
*/
export function withPolicy<DbClient extends object>(
export function withPolicy<DbClient extends DbClientContract>(
prisma: DbClient,
options: InternalEnhancementOptions,
context?: EnhancementContext
Expand Down Expand Up @@ -60,3 +61,20 @@ export function withPolicy<DbClient extends object>(
options?.errorTransformer
);
}

/**
* Function for processing a payload for including a relation field in a query.
* @param model The relation's model name
* @param payload The payload to process
*/
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
export async function policyProcessIncludeRelationPayload(
prisma: DbClientContract,
model: string,
payload: unknown,
options: InternalEnhancementOptions,
context: EnhancementContext | undefined
) {
const utils = new PolicyUtil(prisma, options, context);
await utils.injectForRead(prisma, model, payload);
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
await utils.injectReadCheckSelect(model, payload);
}
3 changes: 3 additions & 0 deletions packages/runtime/src/enhancements/node/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,9 @@ export class PolicyUtil extends QueryUtils {
}
const result = await db[model].findFirst(readArgs);
if (!result) {
if (this.shouldLogQuery) {
this.logger.info(`[policy] cannot read back ${model}`);
}
return { error, result: undefined };
}

Expand Down
6 changes: 3 additions & 3 deletions packages/runtime/src/enhancements/node/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
protected readonly options: InternalEnhancementOptions
) {}

protected withFluentCall(method: keyof PrismaProxyHandler, args: any, postProcess = true): Promise<unknown> {
protected withFluentCall(method: PrismaProxyActions, args: any, postProcess = true): Promise<unknown> {
args = args ? clone(args) : {};
const promise = createFluentPromise(
async () => {
Expand All @@ -84,7 +84,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
return promise;
}

protected deferred<TResult = unknown>(method: keyof PrismaProxyHandler, args: any, postProcess = true) {
protected deferred<TResult = unknown>(method: PrismaProxyActions, args: any, postProcess = true) {
return createDeferredPromise<TResult>(async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
Expand Down Expand Up @@ -210,7 +210,7 @@ const customInspect = Symbol.for('nodejs.util.inspect.custom');
export function makeProxy<T extends PrismaProxyHandler>(
prisma: any,
modelMeta: ModelMeta,
makeHandler: (prisma: object, model: string) => T,
makeHandler: (prisma: DbClientContract, model: string) => T,
name = 'unnamed_enhancer',
errorTransformer?: ErrorTransformer
) {
Expand Down
Loading
Loading