Skip to content

Commit

Permalink
fix: 🐛 custom batching function to allow before and after queries
Browse files Browse the repository at this point in the history
  • Loading branch information
dennemark committed Dec 5, 2024
1 parent 28a1d44 commit 233abda
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 74 deletions.
4 changes: 4 additions & 0 deletions src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ export type PrismaExtensionCaslOptions = {
beforeQuery?: (tx: Prisma.TransactionClient) => Promise<void>,
/** uses transaction to allow using client queries after actual query, if fails, whole query will be rolled back */
afterQuery?: (tx: Prisma.TransactionClient) => Promise<void>,
/** max wait for batch transaction - default 30000 */
txMaxWait?: number
/** timeout for batch transaction - default 30000 */
txTimeout?: number
}

export type PrismaCaslOperation =
Expand Down
234 changes: 161 additions & 73 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { AbilityBuilder, AbilityTuple, PureAbility } from '@casl/ability'
import { PrismaQuery } from '@casl/prisma'
import { Prisma, PrismaClient } from '@prisma/client'
import { Prisma, PrismaClient, PrismaPromise } from '@prisma/client'
import { applyCaslToQuery } from './applyCaslToQuery'
import { filterQueryResults } from './filterQueryResults'
import { caslOperationDict, getFluentField, getFluentModel, PrismaCaslOperation, PrismaExtensionCaslOptions, propertyFieldsByModel, relationFieldsByModel } from './helpers'
Expand Down Expand Up @@ -38,18 +38,31 @@ export { applyCaslToQuery }
export function useCaslAbilities(
getAbilityFactory: () => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>,
opts?: PrismaExtensionCaslOptions) {

// Set default options
const txMaxWait = opts?.txMaxWait ?? 30000
const txTimeout = opts?.txTimeout ?? 30000

return Prisma.defineExtension((client) => {
const transactionsToBatch = new Set()
let tickActive = false;
const batches: Record<string, Array<{
params: object;
model: string;
action: string;
args: unknown;
/** called before resolve */
callback: (result: unknown) => void;
resolve: (result: unknown) => void;
reject: (error: unknown) => void;
}>> = {};

const allOperations = (getAbilities: () => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) => ({
async $allOperations<T>({ args, query, model, operation, ...rest }: { args: any, query: any, model: any, operation: any }) {

const fluentModel = getFluentModel(model, rest)

const [fluentRelationModel, fluentRelationField] = (fluentModel !== model ? Object.entries(relationFieldsByModel[model]).find(([k, v]) => v.type === fluentModel) : undefined) ?? [undefined, undefined]
const transaction = (rest as any).__internalParams.transaction
const __internalParams = (rest as any).__internalParams
const transaction = __internalParams.transaction
const debug = (process.env.NODE_ENV === 'development' || process.env.NODE_ENV === 'test') && args.debugCasl
const debugAllErrors = args.debugCasl
delete args.debugCasl
Expand Down Expand Up @@ -112,12 +125,12 @@ export function useCaslAbilities(

perf?.mark('prisma-casl-extension-3')


if (fluentRelationModel && caslQuery.mask) {
// on fluent models we need to take mask of the relation
caslQuery.mask = fluentRelationModel && fluentRelationModel in caslQuery.mask ? caslQuery.mask[fluentRelationModel] : {}
}
const filteredResult = filterQueryResults(result, caslQuery.mask, caslQuery.creationTree, abilities, fluentModel as Prisma.ModelName, operation, opts)

if (perf) {
perf.mark('prisma-casl-extension-4')
logger?.log(
Expand Down Expand Up @@ -153,99 +166,168 @@ export function useCaslAbilities(
// query(caslQuery.args).then(cleanupResults).then((result: any) => resolve(result)).catch(((e: any) => reject(e)))
// })
}
const transactionQuery = async (txClient: any) => {

if (opts?.beforeQuery) {
await opts.beforeQuery(txClient)
}
if (operationAbility.action === 'update' || operationAbility.action === 'create' || operation === 'deleteMany') {
/**
* we get all update/deleteMany entries for logging purposes.
*/
const getMany = operation === 'deleteMany' || operation === 'updateMany'
const manyResult = getMany ? await txClient[model].findMany(caslQuery.args.where ? { where: caslQuery.args.where } : undefined).then((res: any[]) => {
/** create update objects for updateMany */
return operation === 'updateMany' ? res.map((r) => ({ ...caslQuery.args.data, id: r.id })) : res
}) : []
/**
* we use createManyAndReturn instead of createMany createMany entries for logging purposes and to check permissions on new entries
*/
const op = operation === 'createMany' ? 'createManyAndReturn' : operation
return txClient[model][op](caslQuery.args).then(async (result: any) => {
// we need to get the updated many result
if (opts?.afterQuery) {
await opts.afterQuery(txClient)
}
const filteredResult = cleanupResults(getMany ? manyResult : result)
const results = operation === 'createMany'
? { count: result.length }
: getMany ? { count: manyResult.length }
: filteredResult
return results
})
} else {

return txClient[model][operation](caslQuery.args).then(async (result: any) => {
// we need to get the updated many result
if (opts?.afterQuery) {
await opts.afterQuery(txClient)
}
const fluentField = getFluentField(rest)
const hash = transaction?.id ?? 'batch'

if (fluentField) {
return cleanupResults(result?.[fluentField])
}
return cleanupResults(result)
})
}
if (!batches[hash]) {
batches[hash] = []
}
if (transaction && transaction.kind === 'itx') {

return transactionQuery((client as any)._createItxClient(transaction))
// make sure, that we only tick once at a time
if (!tickActive) {
tickActive = true;
process.nextTick(() => {
dispatchBatches(transaction);
tickActive = false;
});
}
/** batchQuery collects query within batches that will be dispatched every tick */
const batchQuery = (
model: string,
action: string,
args: any,
callback: (result: any) => void
) => new Promise((resolve, reject) => {
batches[hash].push({
params: __internalParams,
model,
action,
args,
reject,
resolve,
callback,
})
});


if (operationAbility.action === 'update' || operationAbility.action === 'create' || operation === 'deleteMany') {
/**
* we get all update/deleteMany entries for logging purposes.
*/
const getMany = operation === 'deleteMany' || operation === 'updateMany'

// const manyResult: any[] = getMany ? await batchQuery(model, 'findMany', caslQuery.args.where ? { where: caslQuery.args.where } : undefined, (res: any[]) => {
// /** create update objects for updateMany */
// return operation === 'updateMany' ? res.map((r) => ({ ...caslQuery.args.data, id: r.id })) : res
// }) : []
/**
* we use createManyAndReturn instead of createMany createMany entries for logging purposes and to check permissions on new entries
*/
const op = operation === 'createMany' ? 'createManyAndReturn' : operation
return batchQuery(model, op, caslQuery.args, async (result: any) => {

const filteredResult = cleanupResults(result)//getMany ? manyResult : result)
const results = operation === 'createMany' || operation === 'deleteMany' || operation === 'updateMany'
? { count: result.length }
// : getMany ? { count: manyResult.length }
: filteredResult
return results
})
} else {
return client.$transaction(async (tx) => {
const transactionId = tx[Symbol.for("prisma.client.transaction.id")].toString()
transactionsToBatch.add(transactionId)
//@ts-ignore
return transactionQuery(tx).finally(() => {
transactionsToBatch.delete(transactionId)
});
}, {
//https://github.com/prisma/prisma/issues/20015
maxWait: 10000 // default prisma pool timeout. would be better to get it from client

return batchQuery(model, operation, caslQuery.args, async (result: any) => {

const fluentField = getFluentField(rest)
if (fluentField) {
return cleanupResults(result?.[fluentField])
}
return cleanupResults(result)
})
}



}
})


// Derived from yates:
// Derived from yates
// https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227
//
// By default, Prisma will batch requests by the transaction ID if it is present.
// If our transaction id does not include casl-extension- it is a normal interactive transaction
// and we hook into it. otherwise we use normal batching for our transaction
// This behaviour prevents automatic batching from working when using this client extension, since all queries are executed inside an interactive transaction.
// To get around this we monkey patch the batching function to use the batch ID and transaction ID.
// To get the batching to work we also need to ensure that all the requests we might want to batch together are generated inside the same tick.
// This means that all the requests per-tick that have the same role and context values will be batched together,
// allowing the in-built prisma batch optimizations to work for us.
// This is why we use process.nextTick and the tickActive flag to ensure we only tick once at a time.
// See:
// - https://github.com/prisma/prisma/blob/5.21.1/packages/client/src/runtime/RequestHandler.ts#L122
// - https://www.prisma.io/docs/orm/prisma-client/queries/query-optimization-performance
; (client as any)._requestHandler.dataloader.options.batchBy = (
request: any,
) => {
const batchId = getBatchId(request.protocolQuery);
if (request.transaction?.id && (!transactionsToBatch.has(request.transaction.id.toString()) && batchId)) {
return `transaction-${request.transaction.id}`;
if (request.transaction?.id) {
return `transaction-${request.transaction.id}${batchId ? `-${batchId}` : ""
}`;
}

return batchId
};

/**
* Derived from yates
* https://github.com/cerebruminc/yates/blob/master/src/index.ts#L227
*
* This function is called once per tick, and processes all the batches that have been created during that tick.
* If the batch happened within an existing transaction, we use it to recreate its client, so we keep its interactve transaction logic
**/
const dispatchBatches = (transaction?: { kind: 'itx' | 'batch' }) => {
for (const [key, batch] of Object.entries(batches)) {
delete batches[key];

const runBatchTransaction = async (tx: any) => {
if (opts?.beforeQuery) {
await opts.beforeQuery(tx as any)
}

const results = await Promise.all(
batch.map((request: any) => {
//@ts-ignore
return tx[request.model][request.action](request.args).then((res) => request.callback(res))
.catch((e: Error) => {
throw (e)
})

}),
);
// Switch role back to admin user
if (opts?.afterQuery) {
await opts?.afterQuery(tx as any)
}

return results;
}

new Promise((resolve, reject) => {
if (transaction && transaction.kind === 'itx') {
runBatchTransaction((client as any)._createItxClient(transaction)).then(resolve).catch(reject)
} else {
client.$transaction(async (tx) => {
return runBatchTransaction(tx);
}, {
maxWait: txMaxWait,
timeout: txTimeout,
}).then(resolve).catch(reject)
}
}).then((results: any) => {
results.forEach((result: any, index: number) => {
batch[index].resolve(result);
});
})
.catch((e) => {
for (const request of batch) {
request.reject(e);
}
delete batches[key]
})
}
};


return client.$extends({
name: "prisma-extension-casl",
client: {
// https://github.com/prisma/prisma/issues/20678
// $transaction(...props: Parameters<(typeof client)['$transaction']>): ReturnType<(typeof client)['$transaction']> {
// return transactionStore.run({ alreadyInTransaction: true }, () => {
// return client.$transaction(...props);
// });
// },
$casl(extendFactory: (factory: AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) => AbilityBuilder<PureAbility<AbilityTuple, PrismaQuery>>) {
// alter the getAblities function shortly
return client.$extends({
Expand All @@ -267,8 +349,14 @@ export function useCaslAbilities(
}


//https://github.com/prisma/prisma/blob/1a9ef0fbd3948ee708add6816a33743e1ff7df9c/packages/client/src/runtime/core/jsonProtocol/getBatchId.ts#L4

/**
* recreates getBatchId from prisma
* //https://github.com/prisma/prisma/blob/1a9ef0fbd3948ee708add6816a33743e1ff7df9c/packages/client/src/runtime/core/jsonProtocol/getBatchId.ts#L4
*
* @param query
* @returns
*/
export function getBatchId(query: any): string | undefined {
if (query.action !== "findUnique" && query.action !== "findUniqueOrThrow") {
return undefined;
Expand Down
1 change: 0 additions & 1 deletion test/extension.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2090,7 +2090,6 @@ describe('prisma extension casl', () => {
expect(result).toEqual({ email: '0', id: 0, casl: ['create', 'read', 'update', 'delete'] })
})
})

})
afterAll(async () => {
await seedClient.$disconnect()
Expand Down

0 comments on commit 233abda

Please sign in to comment.