Skip to content

Commit

Permalink
fix(encryption): fixes for createMany and createManyAndReturn ope…
Browse files Browse the repository at this point in the history
…rations (#1944)
  • Loading branch information
ymc9 authored Jan 7, 2025
1 parent 7ed9841 commit 96d0ce5
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 38 deletions.
12 changes: 12 additions & 0 deletions packages/runtime/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,15 @@ export const PRISMA_MINIMUM_VERSION = '5.0.0';
* Prefix for auxiliary relation field generated for delegated models
*/
export const DELEGATE_AUX_RELATION_PREFIX = 'delegate_aux';

/**
* Prisma actions that can have a write payload
*/
export const ACTIONS_WITH_WRITE_PAYLOAD = [
'create',
'createMany',
'createManyAndReturn',
'update',
'updateMany',
'upsert',
];
48 changes: 25 additions & 23 deletions packages/runtime/src/cross/nested-write-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import type { FieldInfo, ModelMeta } from './model-meta';
import { resolveField } from './model-meta';
import { MaybePromise, PrismaWriteActionType, PrismaWriteActions } from './types';
import { getModelFields } from './utils';
import { enumerate, getModelFields } from './utils';

type NestingPathItem = { field?: FieldInfo; model: string; where: any; unique: boolean };

Expand Down Expand Up @@ -310,31 +310,33 @@ export class NestedWriteVisitor {
payload: any,
nestingPath: NestingPathItem[]
) {
for (const field of getModelFields(payload)) {
const fieldInfo = resolveField(this.modelMeta, model, field);
if (!fieldInfo) {
continue;
}
for (const item of enumerate(payload)) {
for (const field of getModelFields(item)) {
const fieldInfo = resolveField(this.modelMeta, model, field);
if (!fieldInfo) {
continue;
}

if (fieldInfo.isDataModel) {
if (payload[field]) {
// recurse into nested payloads
for (const [subAction, subData] of Object.entries<any>(payload[field])) {
if (this.isPrismaWriteAction(subAction) && subData) {
await this.doVisit(fieldInfo.type, subAction, subData, payload[field], fieldInfo, [
...nestingPath,
]);
if (fieldInfo.isDataModel) {
if (item[field]) {
// recurse into nested payloads
for (const [subAction, subData] of Object.entries<any>(item[field])) {
if (this.isPrismaWriteAction(subAction) && subData) {
await this.doVisit(fieldInfo.type, subAction, subData, item[field], fieldInfo, [
...nestingPath,
]);
}
}
}
}
} else {
// visit plain field
if (this.callback.field) {
await this.callback.field(fieldInfo, action, payload[field], {
parent: payload,
nestingPath,
field: fieldInfo,
});
} else {
// visit plain field
if (this.callback.field) {
await this.callback.field(fieldInfo, action, item[field], {
parent: item,
nestingPath,
field: fieldInfo,
});
}
}
}
}
Expand Down
11 changes: 2 additions & 9 deletions packages/runtime/src/enhancements/node/default-auth.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
/* eslint-disable @typescript-eslint/no-explicit-any */

import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants';
import {
FieldInfo,
NestedWriteVisitor,
Expand Down Expand Up @@ -50,15 +51,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = [
'create',
'createMany',
'createManyAndReturn',
'update',
'updateMany',
'upsert',
];
if (actionsOfInterest.includes(action)) {
if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) {
const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
return newArgs;
}
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/enhancements/node/encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/* eslint-disable @typescript-eslint/no-unused-vars */

import { z } from 'zod';
import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants';
import {
FieldInfo,
NestedWriteVisitor,
Expand Down Expand Up @@ -211,8 +212,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (args && args.data && actionsOfInterest.includes(action)) {
if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) {
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
}
return args;
Expand Down
5 changes: 2 additions & 3 deletions packages/runtime/src/enhancements/node/password.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import { DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants';
import { ACTIONS_WITH_WRITE_PAYLOAD, DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants';
import { NestedWriteVisitor, type PrismaWriteActionType } from '../../cross';
import { DbClientContract } from '../../types';
import { InternalEnhancementOptions } from './create-enhancement';
Expand Down Expand Up @@ -39,8 +39,7 @@ class PasswordHandler extends DefaultPrismaProxyHandler {

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (args && args.data && actionsOfInterest.includes(action)) {
if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) {
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
}
return args;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,124 @@ describe('Encrypted test', () => {
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).not.toBe('abc123');
expect(rawRead.encrypted_value).not.toBe('abc123');

// update
const updated = await db.user.update({
where: { id: '1' },
data: { encrypted_value: 'abc234' },
});
expect(updated.encrypted_value).toBe('abc234');
await expect(db.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({
encrypted_value: 'abc234',
});
await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.not.toMatchObject({
encrypted_value: 'abc234',
});

// upsert with create
const upsertCreate = await db.user.upsert({
where: { id: '2' },
create: {
id: '2',
encrypted_value: 'abc345',
},
update: {
encrypted_value: 'abc456',
},
});
expect(upsertCreate.encrypted_value).toBe('abc345');
await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({
encrypted_value: 'abc345',
});
await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({
encrypted_value: 'abc345',
});

// upsert with update
const upsertUpdate = await db.user.upsert({
where: { id: '2' },
create: {
id: '2',
encrypted_value: 'abc345',
},
update: {
encrypted_value: 'abc456',
},
});
expect(upsertUpdate.encrypted_value).toBe('abc456');
await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({
encrypted_value: 'abc456',
});
await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({
encrypted_value: 'abc456',
});

// createMany
await db.user.createMany({
data: [
{ id: '3', encrypted_value: 'abc567' },
{ id: '4', encrypted_value: 'abc678' },
],
});
await expect(db.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({
encrypted_value: 'abc567',
});
await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.not.toMatchObject({
encrypted_value: 'abc567',
});

// createManyAndReturn
await expect(
db.user.createManyAndReturn({
data: [
{ id: '5', encrypted_value: 'abc789' },
{ id: '6', encrypted_value: 'abc890' },
],
})
).resolves.toEqual(
expect.arrayContaining([
{ id: '5', encrypted_value: 'abc789' },
{ id: '6', encrypted_value: 'abc890' },
])
);
await expect(db.user.findUnique({ where: { id: '5' } })).resolves.toMatchObject({
encrypted_value: 'abc789',
});
await expect(prisma.user.findUnique({ where: { id: '5' } })).resolves.not.toMatchObject({
encrypted_value: 'abc789',
});
});

it('Works with nullish values', async () => {
const { enhance, prisma } = await loadSchema(
`
model User {
id String @id @default(cuid())
encrypted_value String? @encrypted()
}`,
{
enhancements: ['encryption'],
enhanceOptions: {
encryption: { encryptionKey },
},
}
);

const db = enhance();
await expect(db.user.create({ data: { id: '1', encrypted_value: '' } })).resolves.toMatchObject({
encrypted_value: '',
});
await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({ encrypted_value: '' });

await expect(db.user.create({ data: { id: '2' } })).resolves.toMatchObject({
encrypted_value: null,
});
await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ encrypted_value: null });

await expect(db.user.create({ data: { id: '3', encrypted_value: null } })).resolves.toMatchObject({
encrypted_value: null,
});
await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ encrypted_value: null });
});

it('Decrypts nested fields', async () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ describe('Password test', () => {
});

it('password tests', async () => {
const { enhance } = await loadSchema(`
const { enhance, prisma } = await loadSchema(`
model User {
id String @id @default(cuid())
password String @password(saltLength: 16)
Expand All @@ -38,6 +38,27 @@ describe('Password test', () => {
},
});
expect(compareSync('abc456', r1.password)).toBeTruthy();

await db.user.createMany({
data: [
{ id: '2', password: 'user2' },
{ id: '3', password: 'user3' },
],
});
await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ password: 'user2' });
const r2 = await db.user.findUnique({ where: { id: '2' } });
expect(compareSync('user2', r2.password)).toBeTruthy();

const [u4] = await db.user.createManyAndReturn({
data: [
{ id: '4', password: 'user4' },
{ id: '5', password: 'user5' },
],
});
expect(compareSync('user4', u4.password)).toBeTruthy();
await expect(prisma.user.findUnique({ where: { id: '4' } })).resolves.not.toMatchObject({ password: 'user4' });
const r4 = await db.user.findUnique({ where: { id: '4' } });
expect(compareSync('user4', r4.password)).toBeTruthy();
});

it('length tests', async () => {
Expand Down

0 comments on commit 96d0ce5

Please sign in to comment.