diff --git a/src/helpers.ts b/src/helpers.ts index 42ded7f..95cba64 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -23,6 +23,10 @@ export type PrismaExtensionCaslOptions = { beforeQuery?: (tx: Prisma.TransactionClient) => Promise, /** uses transaction to allow using client queries after actual query, if fails, whole query will be rolled back */ afterQuery?: (tx: Prisma.TransactionClient) => Promise, + /** max wait for batch transaction - default 30000 */ + txMaxWait?: number + /** timeout for batch transaction - default 30000 */ + txTimeout?: number } export type PrismaCaslOperation = diff --git a/src/index.ts b/src/index.ts index 7b89df9..4d1caa9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,6 +1,6 @@ import { AbilityBuilder, AbilityTuple, PureAbility } from '@casl/ability' import { PrismaQuery } from '@casl/prisma' -import { Prisma, PrismaClient } from '@prisma/client' +import { Prisma, PrismaClient, PrismaPromise } from '@prisma/client' import { applyCaslToQuery } from './applyCaslToQuery' import { filterQueryResults } from './filterQueryResults' import { caslOperationDict, getFluentField, getFluentModel, PrismaCaslOperation, PrismaExtensionCaslOptions, propertyFieldsByModel, relationFieldsByModel } from './helpers' @@ -38,10 +38,22 @@ export { applyCaslToQuery } export function useCaslAbilities( getAbilityFactory: () => AbilityBuilder>, opts?: PrismaExtensionCaslOptions) { - + // Set default options + const txMaxWait = opts?.txMaxWait ?? 30000 + const txTimeout = opts?.txTimeout ?? 30000 return Prisma.defineExtension((client) => { - const transactionsToBatch = new Set() + let tickActive = false; + const batches: Record void; + resolve: (result: unknown) => void; + reject: (error: unknown) => void; + }>> = {}; const allOperations = (getAbilities: () => AbilityBuilder>) => ({ async $allOperations({ args, query, model, operation, ...rest }: { args: any, query: any, model: any, operation: any }) { @@ -49,7 +61,8 @@ export function useCaslAbilities( const fluentModel = getFluentModel(model, rest) const [fluentRelationModel, fluentRelationField] = (fluentModel !== model ? Object.entries(relationFieldsByModel[model]).find(([k, v]) => v.type === fluentModel) : undefined) ?? [undefined, undefined] - const transaction = (rest as any).__internalParams.transaction + const __internalParams = (rest as any).__internalParams + const transaction = __internalParams.transaction const debug = (process.env.NODE_ENV === 'development' || process.env.NODE_ENV === 'test') && args.debugCasl const debugAllErrors = args.debugCasl delete args.debugCasl @@ -112,12 +125,12 @@ export function useCaslAbilities( perf?.mark('prisma-casl-extension-3') + if (fluentRelationModel && caslQuery.mask) { // on fluent models we need to take mask of the relation caslQuery.mask = fluentRelationModel && fluentRelationModel in caslQuery.mask ? caslQuery.mask[fluentRelationModel] : {} } const filteredResult = filterQueryResults(result, caslQuery.mask, caslQuery.creationTree, abilities, fluentModel as Prisma.ModelName, operation, opts) - if (perf) { perf.mark('prisma-casl-extension-4') logger?.log( @@ -153,99 +166,168 @@ export function useCaslAbilities( // query(caslQuery.args).then(cleanupResults).then((result: any) => resolve(result)).catch(((e: any) => reject(e))) // }) } - const transactionQuery = async (txClient: any) => { - - if (opts?.beforeQuery) { - await opts.beforeQuery(txClient) - } - if (operationAbility.action === 'update' || operationAbility.action === 'create' || operation === 'deleteMany') { - /** - * we get all update/deleteMany entries for logging purposes. - */ - const getMany = operation === 'deleteMany' || operation === 'updateMany' - const manyResult = getMany ? await txClient[model].findMany(caslQuery.args.where ? { where: caslQuery.args.where } : undefined).then((res: any[]) => { - /** create update objects for updateMany */ - return operation === 'updateMany' ? res.map((r) => ({ ...caslQuery.args.data, id: r.id })) : res - }) : [] - /** - * we use createManyAndReturn instead of createMany createMany entries for logging purposes and to check permissions on new entries - */ - const op = operation === 'createMany' ? 'createManyAndReturn' : operation - return txClient[model][op](caslQuery.args).then(async (result: any) => { - // we need to get the updated many result - if (opts?.afterQuery) { - await opts.afterQuery(txClient) - } - const filteredResult = cleanupResults(getMany ? manyResult : result) - const results = operation === 'createMany' - ? { count: result.length } - : getMany ? { count: manyResult.length } - : filteredResult - return results - }) - } else { - - return txClient[model][operation](caslQuery.args).then(async (result: any) => { - // we need to get the updated many result - if (opts?.afterQuery) { - await opts.afterQuery(txClient) - } - const fluentField = getFluentField(rest) + const hash = transaction?.id ?? 'batch' - if (fluentField) { - return cleanupResults(result?.[fluentField]) - } - return cleanupResults(result) - }) - } + if (!batches[hash]) { + batches[hash] = [] } - if (transaction && transaction.kind === 'itx') { - return transactionQuery((client as any)._createItxClient(transaction)) + // make sure, that we only tick once at a time + if (!tickActive) { + tickActive = true; + process.nextTick(() => { + dispatchBatches(transaction); + tickActive = false; + }); + } + /** batchQuery collects query within batches that will be dispatched every tick */ + const batchQuery = ( + model: string, + action: string, + args: any, + callback: (result: any) => void + ) => new Promise((resolve, reject) => { + batches[hash].push({ + params: __internalParams, + model, + action, + args, + reject, + resolve, + callback, + }) + }); + + + if (operationAbility.action === 'update' || operationAbility.action === 'create' || operation === 'deleteMany') { + /** + * we get all update/deleteMany entries for logging purposes. + */ + const getMany = operation === 'deleteMany' || operation === 'updateMany' + + // const manyResult: any[] = getMany ? await batchQuery(model, 'findMany', caslQuery.args.where ? { where: caslQuery.args.where } : undefined, (res: any[]) => { + // /** create update objects for updateMany */ + // return operation === 'updateMany' ? res.map((r) => ({ ...caslQuery.args.data, id: r.id })) : res + // }) : [] + /** + * we use createManyAndReturn instead of createMany createMany entries for logging purposes and to check permissions on new entries + */ + const op = operation === 'createMany' ? 'createManyAndReturn' : operation + return batchQuery(model, op, caslQuery.args, async (result: any) => { + + const filteredResult = cleanupResults(result)//getMany ? manyResult : result) + const results = operation === 'createMany' || operation === 'deleteMany' || operation === 'updateMany' + ? { count: result.length } + // : getMany ? { count: manyResult.length } + : filteredResult + return results + }) } else { - return client.$transaction(async (tx) => { - const transactionId = tx[Symbol.for("prisma.client.transaction.id")].toString() - transactionsToBatch.add(transactionId) - //@ts-ignore - return transactionQuery(tx).finally(() => { - transactionsToBatch.delete(transactionId) - }); - }, { - //https://github.com/prisma/prisma/issues/20015 - maxWait: 10000 // default prisma pool timeout. would be better to get it from client + + return batchQuery(model, operation, caslQuery.args, async (result: any) => { + + const fluentField = getFluentField(rest) + if (fluentField) { + return cleanupResults(result?.[fluentField]) + } + return cleanupResults(result) }) } + + } }) - // Derived from yates: + // Derived from yates + // https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227 + // // By default, Prisma will batch requests by the transaction ID if it is present. - // If our transaction id does not include casl-extension- it is a normal interactive transaction - // and we hook into it. otherwise we use normal batching for our transaction + // This behaviour prevents automatic batching from working when using this client extension, since all queries are executed inside an interactive transaction. + // To get around this we monkey patch the batching function to use the batch ID and transaction ID. + // To get the batching to work we also need to ensure that all the requests we might want to batch together are generated inside the same tick. + // This means that all the requests per-tick that have the same role and context values will be batched together, + // allowing the in-built prisma batch optimizations to work for us. + // This is why we use process.nextTick and the tickActive flag to ensure we only tick once at a time. + // See: // - https://github.com/prisma/prisma/blob/5.21.1/packages/client/src/runtime/RequestHandler.ts#L122 // - https://www.prisma.io/docs/orm/prisma-client/queries/query-optimization-performance ; (client as any)._requestHandler.dataloader.options.batchBy = ( request: any, ) => { const batchId = getBatchId(request.protocolQuery); - if (request.transaction?.id && (!transactionsToBatch.has(request.transaction.id.toString()) && batchId)) { - return `transaction-${request.transaction.id}`; + if (request.transaction?.id) { + return `transaction-${request.transaction.id}${batchId ? `-${batchId}` : "" + }`; } + return batchId }; + /** + * Derived from yates + * https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227 + * + * This function is called once per tick, and processes all the batches that have been created during that tick. + * If the batch happened within an existing transaction, we use it to recreate its client, so we keep its interactve transaction logic + **/ + const dispatchBatches = (transaction?: { kind: 'itx' | 'batch' }) => { + for (const [key, batch] of Object.entries(batches)) { + delete batches[key]; + + const runBatchTransaction = async (tx: any) => { + if (opts?.beforeQuery) { + await opts.beforeQuery(tx as any) + } + + const results = await Promise.all( + batch.map((request: any) => { + //@ts-ignore + return tx[request.model][request.action](request.args).then((res) => request.callback(res)) + .catch((e: Error) => { + throw (e) + }) + + }), + ); + // Switch role back to admin user + if (opts?.afterQuery) { + await opts?.afterQuery(tx as any) + } + + return results; + } + + new Promise((resolve, reject) => { + if (transaction && transaction.kind === 'itx') { + runBatchTransaction((client as any)._createItxClient(transaction)).then(resolve).catch(reject) + } else { + client.$transaction(async (tx) => { + return runBatchTransaction(tx); + }, { + maxWait: txMaxWait, + timeout: txTimeout, + }).then(resolve).catch(reject) + } + }).then((results: any) => { + results.forEach((result: any, index: number) => { + batch[index].resolve(result); + }); + }) + .catch((e) => { + for (const request of batch) { + request.reject(e); + } + delete batches[key] + }) + } + }; + return client.$extends({ name: "prisma-extension-casl", client: { - // https://github.com/prisma/prisma/issues/20678 - // $transaction(...props: Parameters<(typeof client)['$transaction']>): ReturnType<(typeof client)['$transaction']> { - // return transactionStore.run({ alreadyInTransaction: true }, () => { - // return client.$transaction(...props); - // }); - // }, $casl(extendFactory: (factory: AbilityBuilder>) => AbilityBuilder>) { // alter the getAblities function shortly return client.$extends({ @@ -267,8 +349,14 @@ export function useCaslAbilities( } -//https://github.com/prisma/prisma/blob/1a9ef0fbd3948ee708add6816a33743e1ff7df9c/packages/client/src/runtime/core/jsonProtocol/getBatchId.ts#L4 +/** + * recreates getBatchId from prisma + * //https://github.com/prisma/prisma/blob/1a9ef0fbd3948ee708add6816a33743e1ff7df9c/packages/client/src/runtime/core/jsonProtocol/getBatchId.ts#L4 + * + * @param query + * @returns + */ export function getBatchId(query: any): string | undefined { if (query.action !== "findUnique" && query.action !== "findUniqueOrThrow") { return undefined; diff --git a/test/extension.test.ts b/test/extension.test.ts index 6b77bf2..c2726b1 100644 --- a/test/extension.test.ts +++ b/test/extension.test.ts @@ -2090,7 +2090,6 @@ describe('prisma extension casl', () => { expect(result).toEqual({ email: '0', id: 0, casl: ['create', 'read', 'update', 'delete'] }) }) }) - }) afterAll(async () => { await seedClient.$disconnect()