Skip to content

Commit

Permalink
feat(enhance): Prisma Pulse support (#1658)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Aug 24, 2024
1 parent 40ea9fa commit 32c258c
Show file tree
Hide file tree
Showing 8 changed files with 534 additions and 308 deletions.
35 changes: 33 additions & 2 deletions packages/runtime/src/enhancements/omit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,39 @@ class OmitHandler extends DefaultPrismaProxyHandler {
}

// base override
protected async processResultEntity<T>(data: T): Promise<T> {
if (data) {
protected async processResultEntity<T>(method: string, data: T): Promise<T> {
if (!data || typeof data !== 'object') {
return data;
}

if (method === 'subscribe' || method === 'stream') {
if (!('action' in data)) {
return data;
}

// Prisma Pulse result
switch (data.action) {
case 'create':
if ('created' in data) {
await this.doPostProcess(data.created, this.model);
}
break;
case 'update':
if ('before' in data) {
await this.doPostProcess(data.before, this.model);
}
if ('after' in data) {
await this.doPostProcess(data.after, this.model);
}
break;
case 'delete':
if ('deleted' in data) {
await this.doPostProcess(data.deleted, this.model);
}
break;
}
} else {
// regular prisma client result
for (const value of enumerate(data)) {
await this.doPostProcess(value, this.model);
}
Expand Down
125 changes: 87 additions & 38 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1537,53 +1537,102 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

//#endregion

//#region Subscribe (Prisma Pulse)
//#region Prisma Pulse

subscribe(args: any) {
return createDeferredPromise(() => {
const readGuard = this.policyUtils.getAuthGuard(this.prisma, this.model, 'read');
if (this.policyUtils.isTrue(readGuard)) {
// no need to inject
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`);
}
return this.modelClient.subscribe(args);
}

if (!args) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
if (typeof args !== 'object') {
throw prismaClientValidationError(this.prisma, this.prismaModule, 'argument must be an object');
}
if (Object.keys(args).length === 0) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
args = this.policyUtils.safeClone(args);
}
}
return this.handleSubscribeStream('subscribe', args);
}

// inject into subscribe conditions
stream(args: any) {
return this.handleSubscribeStream('stream', args);
}

if (args.create) {
args.create.after = this.policyUtils.and(args.create.after, readGuard);
private async handleSubscribeStream(action: 'subscribe' | 'stream', args: any) {
if (!args) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
if (typeof args !== 'object') {
throw prismaClientValidationError(this.prisma, this.prismaModule, 'argument must be an object');
}
args = this.policyUtils.safeClone(args);
}

if (args.update) {
args.update.after = this.policyUtils.and(args.update.after, readGuard);
// inject read guard as subscription filter
for (const key of ['create', 'update', 'delete']) {
if (args[key] === undefined) {
continue;
}

if (args.delete) {
args.delete.before = this.policyUtils.and(args.delete.before, readGuard);
// "update" has an extra layer of "after"
const payload = key === 'update' ? args[key].after : args[key];
const toInject = { where: payload };
this.policyUtils.injectForRead(this.prisma, this.model, toInject);
if (key === 'update') {
// "update" has an extra layer of "after"
args[key].after = toInject.where;
} else {
args[key] = toInject.where;
}
}

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`);
}
return this.modelClient.subscribe(args);
});
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`${action}\` ${this.model}:\n${formatObject(args)}`);
}

// Prisma Pulse returns an async iterable, which we need to wrap
// and post-process the iteration results
const iterable = await this.modelClient[action](args);
return {
[Symbol.asyncIterator]: () => {
const iter = iterable[Symbol.asyncIterator].bind(iterable)();
return {
next: async () => {
const { done, value } = await iter.next();
let processedValue = value;
if (value && 'action' in value) {
switch (value.action) {
case 'create':
if ('created' in value) {
processedValue = {
...value,
created: this.policyUtils.postProcessForRead(value.created, this.model, {}),
};
}
break;

case 'update':
if ('before' in value) {
processedValue = {
...value,
before: this.policyUtils.postProcessForRead(value.before, this.model, {}),
};
}
if ('after' in value) {
processedValue = {
...value,
after: this.policyUtils.postProcessForRead(value.after, this.model, {}),
};
}
break;

case 'delete':
if ('deleted' in value) {
processedValue = {
...value,
deleted: this.policyUtils.postProcessForRead(value.deleted, this.model, {}),
};
}
break;
}
}

return { done, value: processedValue };
},
return: () => iter.return?.(),
throw: () => iter.throw?.(),
};
},
};
}

//#endregion
Expand Down
1 change: 1 addition & 0 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ export class PolicyUtil extends QueryUtils {
// make select and include visible to the injection
const injected: any = { select: args.select, include: args.include };
if (!this.injectAuthGuardAsWhere(db, injected, model, 'read')) {
args.where = this.makeFalse();
return false;
}

Expand Down
36 changes: 31 additions & 5 deletions packages/runtime/src/enhancements/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ export interface PrismaProxyHandler {
count(args: any): Promise<unknown | number>;

subscribe(args: any): Promise<unknown>;

stream(args: any): Promise<unknown>;
}

/**
Expand All @@ -79,7 +81,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
return postProcess ? this.processResultEntity(r) : r;
return postProcess ? this.processResultEntity(method, r) : r;
},
args,
this.options.modelMeta,
Expand All @@ -92,7 +94,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
return createDeferredPromise<TResult>(async () => {
args = await this.preprocessArgs(method, args);
const r = await this.prisma[this.model][method](args);
return postProcess ? this.processResultEntity(r) : r;
return postProcess ? this.processResultEntity(method, r) : r;
});
}

Expand Down Expand Up @@ -161,20 +163,44 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
}

subscribe(args: any) {
return this.deferred('subscribe', args, false);
return this.doSubscribeStream('subscribe', args);
}

stream(args: any) {
return this.doSubscribeStream('stream', args);
}

private async doSubscribeStream(method: 'subscribe' | 'stream', args: any) {
// Prisma's `subscribe` and `stream` methods return an async iterable
// which we need to wrap to process the iteration results
const iterable = await this.prisma[this.model][method](args);
return {
[Symbol.asyncIterator]: () => {
const iter = iterable[Symbol.asyncIterator].bind(iterable)();
return {
next: async () => {
const { done, value } = await iter.next();
const processedValue = value ? await this.processResultEntity(method, value) : value;
return { done, value: processedValue };
},
return: () => iter.return?.(),
throw: () => iter.throw?.(),
};
},
};
}

/**
* Processes result entities before they're returned
*/
protected async processResultEntity<T>(data: T): Promise<T> {
protected async processResultEntity<T>(_method: PrismaProxyActions, data: T): Promise<T> {
return data;
}

/**
* Processes query args before they're passed to Prisma.
*/
protected async preprocessArgs(method: PrismaProxyActions, args: any) {
protected async preprocessArgs(_method: PrismaProxyActions, args: any) {
return args;
}
}
Expand Down
1 change: 1 addition & 0 deletions packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export interface DbOperations {
groupBy(args: unknown): Promise<any>;
count(args?: unknown): Promise<any>;
subscribe(args?: unknown): Promise<any>;
stream(args?: unknown): Promise<any>;
check(args: unknown): Promise<boolean>;
fields: Record<string, any>;
}
Expand Down
11 changes: 8 additions & 3 deletions packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
}

if (opt.pushDb) {
run('npx prisma db push --skip-generate');
run('npx prisma db push --skip-generate --accept-data-loss');
}

if (opt.pulseApiKey) {
Expand All @@ -264,10 +264,10 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
// https://github.com/prisma/prisma/issues/18292
prisma[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient';

const prismaModule = require(path.join(projectDir, 'node_modules/@prisma/client')).Prisma;
const prismaModule = loadModule('@prisma/client', projectDir).Prisma;

if (opt.pulseApiKey) {
const withPulse = require(path.join(projectDir, 'node_modules/@prisma/extension-pulse/dist/cjs')).withPulse;
const withPulse = loadModule('@prisma/extension-pulse/node', projectDir).withPulse;
prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey }));
}

Expand Down Expand Up @@ -388,3 +388,8 @@ export async function loadZModelAndDmmf(
const dmmf = await getDMMF({ datamodel: prismaContent });
return { model, dmmf, modelFile };
}

function loadModule(module: string, basePath: string): any {
const modulePath = require.resolve(module, { paths: [basePath] });
return require(modulePath);
}
Loading

0 comments on commit 32c258c

Please sign in to comment.