Skip to content

Commit

Permalink
refactor(runtime): unify the handling of fluent api calls (#1187)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Mar 30, 2024
1 parent be2aec9 commit 1b2f48e
Show file tree
Hide file tree
Showing 7 changed files with 679 additions and 473 deletions.
654 changes: 303 additions & 351 deletions packages/runtime/src/enhancements/policy/handler.ts

Large diffs are not rendered by default.

19 changes: 12 additions & 7 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -897,16 +897,21 @@ export class PolicyUtil extends QueryUtils {
* @returns
*/
injectReadCheckSelect(model: string, args: any) {
if (!this.hasFieldLevelPolicy(model)) {
return;
if (this.hasFieldLevelPolicy(model)) {
// recursively inject selection for fields needed for field-level read checks
const readFieldSelect = this.getReadFieldSelect(model);
if (readFieldSelect) {
this.doInjectReadCheckSelect(model, args, { select: readFieldSelect });
}
}

const readFieldSelect = this.getReadFieldSelect(model);
if (!readFieldSelect) {
return;
// recurse into relation fields
for (const [k, v] of Object.entries<any>(args.select ?? args.include ?? {})) {
const field = resolveField(this.modelMeta, model, k);
if (field?.isDataModel && v && typeof v === 'object') {
this.injectReadCheckSelect(field.type, v);
}
}

this.doInjectReadCheckSelect(model, args, { select: readFieldSelect });
}

private doInjectReadCheckSelect(model: string, args: any, input: any) {
Expand Down
38 changes: 0 additions & 38 deletions packages/runtime/src/enhancements/policy/promise.ts

This file was deleted.

99 changes: 99 additions & 0 deletions packages/runtime/src/enhancements/promise.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { getModelInfo, type ModelMeta } from '../cross';

/**
* Creates a promise that only executes when it's awaited or .then() is called.
* @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts
*/
export function createDeferredPromise<T>(callback: () => Promise<T>): Promise<T> {
let promise: Promise<T> | undefined;
const cb = () => {
try {
return (promise ??= valueToPromise(callback()));
} catch (err) {
// deal with synchronous errors
return Promise.reject<T>(err);
}
};

return {
then(onFulfilled, onRejected) {
return cb().then(onFulfilled, onRejected);
},
catch(onRejected) {
return cb().catch(onRejected);
},
finally(onFinally) {
return cb().finally(onFinally);
},
[Symbol.toStringTag]: 'ZenStackPromise',
};
}

function valueToPromise(thing: any): Promise<any> {
if (typeof thing === 'object' && typeof thing?.then === 'function') {
return thing;
} else {
return Promise.resolve(thing);
}
}

/**
* Create a deferred promise with fluent API call stub installed.
*
* @param callback The callback to execute when the promise is awaited.
* @param parentArgs The parent promise's query args.
* @param modelMeta The model metadata.
* @param model The model name.
*/
export function createFluentPromise(
callback: () => Promise<any>,
parentArgs: any,
modelMeta: ModelMeta,
model: string
): Promise<any> {
const promise: any = createDeferredPromise(callback);

const modelInfo = getModelInfo(modelMeta, model);
if (!modelInfo) {
return promise;
}

// install fluent call stub for model fields
Object.values(modelInfo.fields)
.filter((field) => field.isDataModel)
.forEach((field) => {
// e.g., `posts` in `db.user.findUnique(...).posts()`
promise[field.name] = (fluentArgs: any) => {
if (field.isArray) {
// an array relation terminates fluent call chain
return createDeferredPromise(async () => {
setFluentSelect(parentArgs, field.name, fluentArgs ?? true);
const parentResult: any = await promise;
return parentResult?.[field.name] ?? null;
});
} else {
fluentArgs = { ...fluentArgs };
// create a chained subsequent fluent call promise
return createFluentPromise(
async () => {
setFluentSelect(parentArgs, field.name, fluentArgs);
const parentResult: any = await promise;
return parentResult?.[field.name] ?? null;
},
fluentArgs,
modelMeta,
field.type
);
}
};
});

return promise;
}

function setFluentSelect(args: any, fluentFieldName: any, fluentArgs: any) {
delete args.include;
args.select = { [fluentFieldName]: fluentArgs };
}
125 changes: 62 additions & 63 deletions packages/runtime/src/enhancements/proxy.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import deepcopy from 'deepcopy';
import { PRISMA_PROXY_ENHANCER } from '../constants';
import type { ModelMeta } from '../cross';
import type { DbClientContract } from '../types';
import { InternalEnhancementOptions } from './create-enhancement';
import { createDeferredPromise } from './policy/promise';
import type { InternalEnhancementOptions } from './create-enhancement';
import { createDeferredPromise, createFluentPromise } from './promise';

/**
* Prisma batch write operation result
Expand Down Expand Up @@ -70,93 +71,91 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
protected readonly options: InternalEnhancementOptions
) {}

async findUnique(args: any): Promise<unknown> {
args = await this.preprocessArgs('findUnique', args);
const r = await this.prisma[this.model].findUnique(args);
return this.processResultEntity(r);
protected withFluentCall(method: keyof PrismaProxyHandler, args: any, postProcess = true): Promise<unknown> {
args = args ? deepcopy(args) : {};
const promise = createFluentPromise(
async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
return postProcess ? this.processResultEntity(r) : r;
},
args,
this.options.modelMeta,
this.model
);
return promise;
}

async findUniqueOrThrow(args: any): Promise<unknown> {
args = await this.preprocessArgs('findUniqueOrThrow', args);
const r = await this.prisma[this.model].findUniqueOrThrow(args);
return this.processResultEntity(r);
protected deferred<TResult = unknown>(method: keyof PrismaProxyHandler, args: any, postProcess = true) {
return createDeferredPromise<TResult>(async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
return postProcess ? this.processResultEntity(r) : r;
});
}

async findFirst(args: any): Promise<unknown> {
args = await this.preprocessArgs('findFirst', args);
const r = await this.prisma[this.model].findFirst(args);
return this.processResultEntity(r);
findUnique(args: any) {
return this.withFluentCall('findUnique', args);
}

async findFirstOrThrow(args: any): Promise<unknown> {
args = await this.preprocessArgs('findFirstOrThrow', args);
const r = await this.prisma[this.model].findFirstOrThrow(args);
return this.processResultEntity(r);
findUniqueOrThrow(args: any) {
return this.withFluentCall('findUniqueOrThrow', args);
}

async findMany(args: any): Promise<unknown[]> {
args = await this.preprocessArgs('findMany', args);
const r = await this.prisma[this.model].findMany(args);
return this.processResultEntity(r);
findFirst(args: any) {
return this.withFluentCall('findFirst', args);
}

async create(args: any): Promise<unknown> {
args = await this.preprocessArgs('create', args);
const r = await this.prisma[this.model].create(args);
return this.processResultEntity(r);
findFirstOrThrow(args: any) {
return this.withFluentCall('findFirstOrThrow', args);
}

async createMany(args: { data: any; skipDuplicates?: boolean }): Promise<{ count: number }> {
args = await this.preprocessArgs('createMany', args);
return this.prisma[this.model].createMany(args);
findMany(args: any) {
return this.deferred<unknown[]>('findMany', args);
}

async update(args: any): Promise<unknown> {
args = await this.preprocessArgs('update', args);
const r = await this.prisma[this.model].update(args);
return this.processResultEntity(r);
create(args: any): Promise<unknown> {
return this.deferred('create', args);
}

async updateMany(args: any): Promise<{ count: number }> {
args = await this.preprocessArgs('updateMany', args);
return this.prisma[this.model].updateMany(args);
createMany(args: { data: any; skipDuplicates?: boolean }) {
return this.deferred<{ count: number }>('createMany', args, false);
}

async upsert(args: any): Promise<unknown> {
args = await this.preprocessArgs('upsert', args);
const r = await this.prisma[this.model].upsert(args);
return this.processResultEntity(r);
update(args: any) {
return this.deferred('update', args);
}

async delete(args: any): Promise<unknown> {
args = await this.preprocessArgs('delete', args);
const r = await this.prisma[this.model].delete(args);
return this.processResultEntity(r);
updateMany(args: any) {
return this.deferred<{ count: number }>('updateMany', args, false);
}

async deleteMany(args: any): Promise<{ count: number }> {
args = await this.preprocessArgs('deleteMany', args);
return this.prisma[this.model].deleteMany(args);
upsert(args: any) {
return this.deferred('upsert', args);
}

async aggregate(args: any): Promise<unknown> {
args = await this.preprocessArgs('aggregate', args);
return this.prisma[this.model].aggregate(args);
delete(args: any) {
return this.deferred('delete', args);
}

async groupBy(args: any): Promise<unknown> {
args = await this.preprocessArgs('groupBy', args);
return this.prisma[this.model].groupBy(args);
deleteMany(args: any) {
return this.deferred<{ count: number }>('deleteMany', args, false);
}

async count(args: any): Promise<unknown> {
args = await this.preprocessArgs('count', args);
return this.prisma[this.model].count(args);
aggregate(args: any) {
return this.deferred('aggregate', args, false);
}

async subscribe(args: any): Promise<unknown> {
args = await this.preprocessArgs('subscribe', args);
return this.prisma[this.model].subscribe(args);
groupBy(args: any) {
return this.deferred('groupBy', args, false);
}

count(args: any): Promise<unknown> {
return this.deferred('count', args, false);
}

subscribe(args: any) {
return this.deferred('subscribe', args, false);
}

/**
Expand All @@ -177,6 +176,8 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
// a marker for filtering error stack trace
const ERROR_MARKER = '__error_marker__';

const customInspect = Symbol.for('nodejs.util.inspect.custom');

/**
* Makes a Prisma client proxy.
*/
Expand All @@ -196,10 +197,6 @@ export function makeProxy<T extends PrismaProxyHandler>(
return name;
}

if (prop === 'toString') {
return () => `$zenstack_prisma_${prisma._clientVersion}`;
}

if (prop === '$transaction') {
// for interactive transactions, we need to proxy the transaction function so that
// when it runs the callback, it provides a proxy to the Prisma client wrapped with
Expand Down Expand Up @@ -245,6 +242,8 @@ export function makeProxy<T extends PrismaProxyHandler>(
},
});

proxy[customInspect] = `$zenstack_prisma_${prisma._clientVersion}`;

return proxy;
}

Expand Down
Loading

0 comments on commit 1b2f48e

Please sign in to comment.