diff --git a/src/index.ts b/src/index.ts index cb5e2b2..7b89df9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -41,6 +41,8 @@ export function useCaslAbilities( return Prisma.defineExtension((client) => { + const transactionsToBatch = new Set() + const allOperations = (getAbilities: () => AbilityBuilder>) => ({ async $allOperations({ args, query, model, operation, ...rest }: { args: any, query: any, model: any, operation: any }) { @@ -202,14 +204,12 @@ export function useCaslAbilities( return transactionQuery((client as any)._createItxClient(transaction)) } else { return client.$transaction(async (tx) => { - - (tx as any)[ - Symbol.for("prisma.client.transaction.id") - ] = 'casl-extension-' + (tx as any)[ - Symbol.for("prisma.client.transaction.id") - ] + const transactionId = tx[Symbol.for("prisma.client.transaction.id")].toString() + transactionsToBatch.add(transactionId) //@ts-ignore - return transactionQuery(tx) + return transactionQuery(tx).finally(() => { + transactionsToBatch.delete(transactionId) + }); }, { //https://github.com/prisma/prisma/issues/20015 maxWait: 10000 // default prisma pool timeout. would be better to get it from client @@ -229,12 +229,11 @@ export function useCaslAbilities( ; (client as any)._requestHandler.dataloader.options.batchBy = ( request: any, ) => { - - if (request.transaction?.id && !request.transaction?.id?.toString().startsWith('casl-extension-')) { + const batchId = getBatchId(request.protocolQuery); + if (request.transaction?.id && (!transactionsToBatch.has(request.transaction.id.toString()) && batchId)) { return `transaction-${request.transaction.id}`; } - - return getBatchId(request.protocolQuery); + return batchId };