Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow comparing fields from different models in mutation policies #1476

Merged
merged 4 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 154 additions & 51 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import deepmerge from 'deepmerge';
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
import { lowerCaseFirst } from 'lower-case-first';
import invariant from 'tiny-invariant';
import { P, match } from 'ts-pattern';
Expand All @@ -23,7 +24,7 @@ import { Logger } from '../logger';
import { createDeferredPromise, createFluentPromise } from '../promise';
import { PrismaProxyHandler } from '../proxy';
import { QueryUtils } from '../query-utils';
import type { CheckerConstraint } from '../types';
import type { AdditionalCheckerFunc, CheckerConstraint } from '../types';
import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
import { ConstraintSolver } from './constraint-solver';
import { PolicyUtil } from './policy-utils';
Expand Down Expand Up @@ -152,8 +153,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

const result = await this.modelClient[actionName](_args);
this.policyUtils.postProcessForRead(result, this.model, origArgs);
return result;
return this.policyUtils.postProcessForRead(result, this.model, origArgs);
}

//#endregion
Expand Down Expand Up @@ -779,10 +779,27 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
};

const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => {
const _connectDisconnect = async (
model: string,
args: any,
context: NestedWriteVisitorContext,
operation: 'connect' | 'disconnect'
) => {
if (context.field?.backLink) {
const backLinkField = this.policyUtils.getModelField(model, context.field.backLink);
if (backLinkField?.isRelationOwner) {
let uniqueFilter = args;
if (operation === 'disconnect') {
// disconnect filter is not unique, need to build a reversed query to
// locate the entity and use its id fields as unique filter
const reversedQuery = this.policyUtils.buildReversedQuery(context);
const found = await db[model].findUnique({
where: reversedQuery,
select: this.policyUtils.makeIdSelection(model),
});
uniqueFilter = found && this.policyUtils.getIdFieldValues(model, found);
}

// update happens on the related model, require updatable,
// translate args to foreign keys so field-level policies can be checked
const checkArgs: any = {};
Expand All @@ -794,10 +811,15 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
}
}
await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs);

// register post-update check
await _registerPostUpdateCheck(model, args, args);
// `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist
if (uniqueFilter) {
// check for update
await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, checkArgs);

// register post-update check
await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter);
}
}
}
};
Expand Down Expand Up @@ -970,14 +992,14 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
},

connect: async (model, args, context) => _connectDisconnect(model, args, context),
connect: async (model, args, context) => _connectDisconnect(model, args, context, 'connect'),

connectOrCreate: async (model, args, context) => {
// the where condition is already unique, so we can use it to check if the target exists
const existing = await this.policyUtils.checkExistence(db, model, args.where);
if (existing) {
// connect
await _connectDisconnect(model, args.where, context);
await _connectDisconnect(model, args.where, context, 'connect');
return true;
} else {
// create
Expand All @@ -997,7 +1019,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
},

disconnect: async (model, args, context) => _connectDisconnect(model, args, context),
disconnect: async (model, args, context) => _connectDisconnect(model, args, context, 'disconnect'),

set: async (model, args, context) => {
// find the set of items to be replaced
Expand All @@ -1012,10 +1034,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
const currentSet = await db[model].findMany(findCurrSetArgs);

// register current set for update (foreign key)
await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context)));
await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, 'disconnect')));

// proceed with connecting the new set
await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context)));
await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, 'connect')));
},

delete: async (model, args, context) => {
Expand Down Expand Up @@ -1160,48 +1182,78 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

args.data = this.validateUpdateInputSchema(this.model, args.data);

if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) {
// use a transaction to do post-update checks
const postWriteChecks: PostWriteCheckRecord[] = [];
return this.queryUtils.transaction(this.prisma, async (tx) => {
// collect pre-update values
let select = this.policyUtils.makeIdSelection(this.model);
const preValueSelect = this.policyUtils.getPreValueSelect(this.model);
if (preValueSelect) {
select = { ...select, ...preValueSelect };
}
const currentSetQuery = { select, where: args.where };
this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read');
const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'update');

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
}
const currentSet = await tx[this.model].findMany(currentSetQuery);
const canProceedWithoutTransaction =
// no post-update rules
!this.policyUtils.hasAuthGuard(this.model, 'postUpdate') &&
// no Zod schema
!this.policyUtils.getZodSchema(this.model) &&
// no additional checker
!additionalChecker;

postWriteChecks.push(
...currentSet.map((preValue) => ({
model: this.model,
operation: 'postUpdate' as PolicyOperationKind,
uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue),
preValue: preValueSelect ? preValue : undefined,
}))
);

// proceed with the update
const result = await tx[this.model].updateMany(args);

// run post-write checks
await this.runPostWriteChecks(postWriteChecks, tx);

return result;
});
} else {
if (canProceedWithoutTransaction) {
// proceed without a transaction
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`);
}
return this.modelClient.updateMany(args);
}

// collect post-update checks
const postWriteChecks: PostWriteCheckRecord[] = [];

return this.queryUtils.transaction(this.prisma, async (tx) => {
// collect pre-update values
let select = this.policyUtils.makeIdSelection(this.model);
const preValueSelect = this.policyUtils.getPreValueSelect(this.model);
if (preValueSelect) {
select = { ...select, ...preValueSelect };
}

// merge selection required for running additional checker
const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector(this.model, 'update');
if (additionalCheckerSelector) {
select = deepmerge(select, additionalCheckerSelector);
}

const currentSetQuery = { select, where: args.where };
this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'update');

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
}
let candidates = await tx[this.model].findMany(currentSetQuery);

if (additionalChecker) {
// filter candidates with additional checker and build an id filter
const r = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker);
candidates = r.filteredCandidates;

// merge id filter into update's where clause
args.where = args.where ? { AND: [args.where, r.idFilter] } : r.idFilter;
}

postWriteChecks.push(
...candidates.map((preValue) => ({
model: this.model,
operation: 'postUpdate' as PolicyOperationKind,
uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue),
preValue: preValueSelect ? preValue : undefined,
}))
);

// proceed with the update
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`updateMany\` in tx for ${this.model}: ${formatObject(args)}`);
}
const result = await tx[this.model].updateMany(args);

// run post-write checks
await this.runPostWriteChecks(postWriteChecks, tx);

return result;
});
});
}

Expand Down Expand Up @@ -1328,14 +1380,53 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
this.policyUtils.tryReject(this.prisma, this.model, 'delete');

// inject policy conditions
args = args ?? {};
args = clone(args);
this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete');

// conduct the deletion
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`);
const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'delete');
if (additionalChecker) {
// additional checker exists, need to run deletion inside a transaction
return this.queryUtils.transaction(this.prisma, async (tx) => {
// find the delete candidates, selecting id fields and fields needed for
// running the additional checker
let candidateSelect = this.policyUtils.makeIdSelection(this.model);
const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector(
this.model,
'delete'
);
if (additionalCheckerSelector) {
candidateSelect = deepmerge(candidateSelect, additionalCheckerSelector);
}

if (this.shouldLogQuery) {
this.logger.info(
`[policy] \`findMany\` ${this.model}: ${formatObject({
where: args.where,
select: candidateSelect,
})}`
);
}
const candidates = await tx[this.model].findMany({ where: args.where, select: candidateSelect });

// build a ID filter based on id values filtered by the additional checker
const { idFilter } = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker);

// merge the ID filter into the where clause
args.where = args.where ? { AND: [args.where, idFilter] } : idFilter;

// finally, conduct the deletion with the combined where clause
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`deleteMany\` in tx for ${this.model}:\n${formatObject(args)}`);
}
return tx[this.model].deleteMany(args);
});
} else {
// conduct the deletion directly
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`);
}
return this.modelClient.deleteMany(args);
}
return this.modelClient.deleteMany(args);
});
}

Expand Down Expand Up @@ -1599,5 +1690,17 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
}

private buildIdFilterWithAdditionalChecker(candidates: any[], additionalChecker: AdditionalCheckerFunc) {
const filteredCandidates = candidates.filter((value) => additionalChecker({ user: this.context?.user }, value));
const idFields = this.policyUtils.getIdFields(this.model);
let idFilter: any;
if (idFields.length === 1) {
idFilter = { [idFields[0].name]: { in: filteredCandidates.map((x) => x[idFields[0].name]) } };
} else {
idFilter = { AND: filteredCandidates.map((x) => this.policyUtils.getIdFieldValues(this.model, x)) };
}
return { filteredCandidates, idFilter };
}

//#endregion
}
Loading
Loading