Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: query injection error when create (in array form) is nested inside an update #865

Merged
merged 1 commit into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/runtime/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"pluralize": "^8.0.0",
"semver": "^7.3.8",
"superjson": "^1.11.0",
"tiny-invariant": "^1.3.1",
"tslib": "^2.4.1",
"upper-case-first": "^2.0.2",
"uuid": "^9.0.0",
Expand Down
13 changes: 12 additions & 1 deletion packages/runtime/src/cross/model-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,21 @@ export type ModelMeta = {
/**
* Resolves a model field to its metadata. Returns undefined if not found.
*/
export function resolveField(modelMeta: ModelMeta, model: string, field: string): FieldInfo | undefined {
export function resolveField(modelMeta: ModelMeta, model: string, field: string) {
return modelMeta.fields[lowerCaseFirst(model)]?.[field];
}

/**
* Resolves a model field to its metadata. Throws an error if not found.
*/
export function requireField(modelMeta: ModelMeta, model: string, field: string) {
const f = resolveField(modelMeta, model, field);
if (!f) {
throw new Error(`Field ${model}.${field} cannot be resolved`);
}
return f;
}

/**
* Gets all fields of a model.
*/
Expand Down
79 changes: 31 additions & 48 deletions packages/runtime/src/cross/nested-write-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,60 +145,61 @@ export class NestedWriteVisitor {
return;
}

const context = { parent, field, nestingPath: [...nestingPath] };
const toplevel = field == undefined;

const context = { parent, field, nestingPath: [...nestingPath] };
const pushNewContext = (field: FieldInfo | undefined, model: string, where: any, unique = false) => {
return { ...context, nestingPath: [...context.nestingPath, { field, model, where, unique }] };
};

// visit payload
switch (action) {
case 'create':
context.nestingPath.push({ field, model, where: {}, unique: false });
for (const item of enumerate(data)) {
const newContext = pushNewContext(field, model, {});
let callbackResult: any;
if (this.callback.create) {
callbackResult = await this.callback.create(model, item, context);
callbackResult = await this.callback.create(model, item, newContext);
}
if (callbackResult !== false) {
const subPayload = typeof callbackResult === 'object' ? callbackResult : item;
await this.visitSubPayload(model, action, subPayload, context.nestingPath);
await this.visitSubPayload(model, action, subPayload, newContext.nestingPath);
}
}
break;

case 'createMany':
if (data) {
context.nestingPath.push({ field, model, where: {}, unique: false });
const newContext = pushNewContext(field, model, {});
let callbackResult: any;
if (this.callback.createMany) {
callbackResult = await this.callback.createMany(model, data, context);
callbackResult = await this.callback.createMany(model, data, newContext);
}
if (callbackResult !== false) {
const subPayload = typeof callbackResult === 'object' ? callbackResult : data.data;
await this.visitSubPayload(model, action, subPayload, context.nestingPath);
await this.visitSubPayload(model, action, subPayload, newContext.nestingPath);
}
}
break;

case 'connectOrCreate':
context.nestingPath.push({ field, model, where: data.where, unique: false });
for (const item of enumerate(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.connectOrCreate) {
callbackResult = await this.callback.connectOrCreate(model, item, context);
callbackResult = await this.callback.connectOrCreate(model, item, newContext);
}
if (callbackResult !== false) {
const subPayload = typeof callbackResult === 'object' ? callbackResult : item.create;
await this.visitSubPayload(model, action, subPayload, context.nestingPath);
await this.visitSubPayload(model, action, subPayload, newContext.nestingPath);
}
}
break;

case 'connect':
if (this.callback.connect) {
for (const item of enumerate(data)) {
const newContext = {
...context,
nestingPath: [...context.nestingPath, { field, model, where: item, unique: true }],
};
const newContext = pushNewContext(field, model, item, true);
await this.callback.connect(model, item, newContext);
}
}
Expand All @@ -210,31 +211,25 @@ export class NestedWriteVisitor {
// if relation is to-one, the payload can only be boolean `true`
if (this.callback.disconnect) {
for (const item of enumerate(data)) {
const newContext = {
...context,
nestingPath: [
...context.nestingPath,
{ field, model, where: item, unique: typeof item === 'object' },
],
};
const newContext = pushNewContext(field, model, item, typeof item === 'object');
await this.callback.disconnect(model, item, newContext);
}
}
break;

case 'set':
if (this.callback.set) {
context.nestingPath.push({ field, model, where: {}, unique: false });
await this.callback.set(model, data, context);
const newContext = pushNewContext(field, model, {});
await this.callback.set(model, data, newContext);
}
break;

case 'update':
context.nestingPath.push({ field, model, where: data.where, unique: false });
for (const item of enumerate(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.update) {
callbackResult = await this.callback.update(model, item, context);
callbackResult = await this.callback.update(model, item, newContext);
}
if (callbackResult !== false) {
const subPayload =
Expand All @@ -243,38 +238,38 @@ export class NestedWriteVisitor {
: typeof item.data === 'object'
? item.data
: item;
await this.visitSubPayload(model, action, subPayload, context.nestingPath);
await this.visitSubPayload(model, action, subPayload, newContext.nestingPath);
}
}
break;

case 'updateMany':
context.nestingPath.push({ field, model, where: data.where, unique: false });
for (const item of enumerate(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.updateMany) {
callbackResult = await this.callback.updateMany(model, item, context);
callbackResult = await this.callback.updateMany(model, item, newContext);
}
if (callbackResult !== false) {
const subPayload = typeof callbackResult === 'object' ? callbackResult : item;
await this.visitSubPayload(model, action, subPayload, context.nestingPath);
await this.visitSubPayload(model, action, subPayload, newContext.nestingPath);
}
}
break;

case 'upsert': {
context.nestingPath.push({ field, model, where: data.where, unique: false });
for (const item of enumerate(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.upsert) {
callbackResult = await this.callback.upsert(model, item, context);
callbackResult = await this.callback.upsert(model, item, newContext);
}
if (callbackResult !== false) {
if (typeof callbackResult === 'object') {
await this.visitSubPayload(model, action, callbackResult, context.nestingPath);
await this.visitSubPayload(model, action, callbackResult, newContext.nestingPath);
} else {
await this.visitSubPayload(model, action, item.create, context.nestingPath);
await this.visitSubPayload(model, action, item.update, context.nestingPath);
await this.visitSubPayload(model, action, item.create, newContext.nestingPath);
await this.visitSubPayload(model, action, item.update, newContext.nestingPath);
}
}
}
Expand All @@ -284,13 +279,7 @@ export class NestedWriteVisitor {
case 'delete': {
if (this.callback.delete) {
for (const item of enumerate(data)) {
const newContext = {
...context,
nestingPath: [
...context.nestingPath,
{ field, model, where: toplevel ? item.where : item, unique: false },
],
};
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
await this.callback.delete(model, item, newContext);
}
}
Expand All @@ -300,13 +289,7 @@ export class NestedWriteVisitor {
case 'deleteMany':
if (this.callback.deleteMany) {
for (const item of enumerate(data)) {
const newContext = {
...context,
nestingPath: [
...context.nestingPath,
{ field, model, where: toplevel ? item.where : item, unique: false },
],
};
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
await this.callback.deleteMany(model, item, newContext);
}
}
Expand Down
67 changes: 54 additions & 13 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { lowerCaseFirst } from 'lower-case-first';
import invariant from 'tiny-invariant';
import { upperCaseFirst } from 'upper-case-first';
import { fromZodError } from 'zod-validation-error';
import { CrudFailureReason, PRISMA_TX_FLAG } from '../../constants';
Expand All @@ -10,6 +11,7 @@ import {
NestedWriteVisitorContext,
enumerate,
getIdFields,
requireField,
resolveField,
type FieldInfo,
type ModelMeta,
Expand Down Expand Up @@ -641,20 +643,62 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

// handles the connection to upstream entity
const reversedQuery = this.utils.buildReversedQuery(context, true, unsafe);
if (reversedQuery[context.field.backLink]) {
// the built reverse query contains a condition for the backlink field, build a "connect" with it
if ((!unsafe || context.field.isRelationOwner) && reversedQuery[context.field.backLink]) {
// if mutation is safe, or current field owns the relation (so the other side has no fk),
// and the reverse query contains the back link, then we can build a "connect" with it
createData = {
...createData,
[context.field.backLink]: {
connect: reversedQuery[context.field.backLink],
},
};
} else {
// otherwise, the reverse query is translated to foreign key setting, merge it to the create data
createData = {
...createData,
...reversedQuery,
};
// otherwise, the reverse query should be translated to foreign key setting
// and merged to the create data

const backLinkField = this.requireBackLink(context.field);
invariant(backLinkField.foreignKeyMapping);

// try to extract foreign key values from the reverse query
let fkValues = Object.values(backLinkField.foreignKeyMapping).reduce<any>((obj, fk) => {
obj[fk] = reversedQuery[fk];
return obj;
}, {});

if (Object.values(fkValues).every((v) => v !== undefined)) {
// all foreign key values are available, merge them to the create data
createData = {
...createData,
...fkValues,
};
} else {
// some foreign key values are missing, need to look up the upstream entity,
// this can happen when the upstream entity doesn't have a unique where clause,
// for example when it's nested inside a one-to-one update
const upstreamQuery = {
where: reversedQuery[backLinkField.name],
select: this.utils.makeIdSelection(backLinkField.type),
};

// fetch the upstream entity
if (this.logger.enabled('info')) {
this.logger.info(
`[policy] \`findUniqueOrThrow\` ${model}: looking up upstream entity of ${
backLinkField.type
}, ${formatObject(upstreamQuery)}`
);
}
const upstreamEntity = await this.prisma[backLinkField.type].findUniqueOrThrow(upstreamQuery);

// map ids to foreign keys
fkValues = Object.entries(backLinkField.foreignKeyMapping).reduce<any>((obj, [id, fk]) => {
obj[fk] = upstreamEntity[id];
return obj;
}, {});

// merge them to the create data
createData = { ...createData, ...fkValues };
}
}
}

Expand Down Expand Up @@ -1192,7 +1236,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
// already in transaction, don't nest
return action(this.prisma);
} else {
return this.prisma.$transaction((tx) => action(tx));
return this.prisma.$transaction((tx) => action(tx), { maxWait: 100000, timeout: 100000 });
}
}

Expand All @@ -1217,11 +1261,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

private requireBackLink(fieldInfo: FieldInfo) {
const backLinkField = fieldInfo.backLink && resolveField(this.modelMeta, fieldInfo.type, fieldInfo.backLink);
if (!backLinkField) {
throw new Error('Missing back link for field: ' + fieldInfo.name);
}
return backLinkField;
invariant(fieldInfo.backLink, `back link not found for field ${fieldInfo.name}`);
return requireField(this.modelMeta, fieldInfo.type, fieldInfo.backLink);
}

//#endregion
Expand Down
3 changes: 3 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading