From 96d0ce502154a0216f28d444ffc45b00c7f9f741 Mon Sep 17 00:00:00 2001 From: Yiming Date: Tue, 7 Jan 2025 16:47:42 +0800 Subject: [PATCH] fix(encryption): fixes for `createMany` and `createManyAndReturn` operations (#1944) --- packages/runtime/src/constants.ts | 12 ++ .../runtime/src/cross/nested-write-visitor.ts | 48 +++---- .../src/enhancements/node/default-auth.ts | 11 +- .../src/enhancements/node/encryption.ts | 4 +- .../runtime/src/enhancements/node/password.ts | 5 +- .../with-encrypted/with-encrypted.test.ts | 118 ++++++++++++++++++ .../with-password/with-password.test.ts | 23 +++- 7 files changed, 183 insertions(+), 38 deletions(-) diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index 36acf8c83..495e1853d 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -67,3 +67,15 @@ export const PRISMA_MINIMUM_VERSION = '5.0.0'; * Prefix for auxiliary relation field generated for delegated models */ export const DELEGATE_AUX_RELATION_PREFIX = 'delegate_aux'; + +/** + * Prisma actions that can have a write payload + */ +export const ACTIONS_WITH_WRITE_PAYLOAD = [ + 'create', + 'createMany', + 'createManyAndReturn', + 'update', + 'updateMany', + 'upsert', +]; diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index c69f9d203..ba4b232a6 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -4,7 +4,7 @@ import type { FieldInfo, ModelMeta } from './model-meta'; import { resolveField } from './model-meta'; import { MaybePromise, PrismaWriteActionType, PrismaWriteActions } from './types'; -import { getModelFields } from './utils'; +import { enumerate, getModelFields } from './utils'; type NestingPathItem = { field?: FieldInfo; model: string; where: any; unique: boolean }; @@ -310,31 +310,33 @@ export class NestedWriteVisitor { payload: any, nestingPath: NestingPathItem[] ) { - for (const field of getModelFields(payload)) { - const fieldInfo = resolveField(this.modelMeta, model, field); - if (!fieldInfo) { - continue; - } + for (const item of enumerate(payload)) { + for (const field of getModelFields(item)) { + const fieldInfo = resolveField(this.modelMeta, model, field); + if (!fieldInfo) { + continue; + } - if (fieldInfo.isDataModel) { - if (payload[field]) { - // recurse into nested payloads - for (const [subAction, subData] of Object.entries(payload[field])) { - if (this.isPrismaWriteAction(subAction) && subData) { - await this.doVisit(fieldInfo.type, subAction, subData, payload[field], fieldInfo, [ - ...nestingPath, - ]); + if (fieldInfo.isDataModel) { + if (item[field]) { + // recurse into nested payloads + for (const [subAction, subData] of Object.entries(item[field])) { + if (this.isPrismaWriteAction(subAction) && subData) { + await this.doVisit(fieldInfo.type, subAction, subData, item[field], fieldInfo, [ + ...nestingPath, + ]); + } } } - } - } else { - // visit plain field - if (this.callback.field) { - await this.callback.field(fieldInfo, action, payload[field], { - parent: payload, - nestingPath, - field: fieldInfo, - }); + } else { + // visit plain field + if (this.callback.field) { + await this.callback.field(fieldInfo, action, item[field], { + parent: item, + nestingPath, + field: fieldInfo, + }); + } } } } diff --git a/packages/runtime/src/enhancements/node/default-auth.ts b/packages/runtime/src/enhancements/node/default-auth.ts index 03ce3750c..e6162a2d2 100644 --- a/packages/runtime/src/enhancements/node/default-auth.ts +++ b/packages/runtime/src/enhancements/node/default-auth.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ /* eslint-disable @typescript-eslint/no-explicit-any */ +import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; import { FieldInfo, NestedWriteVisitor, @@ -50,15 +51,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = [ - 'create', - 'createMany', - 'createManyAndReturn', - 'update', - 'updateMany', - 'upsert', - ]; - if (actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); return newArgs; } diff --git a/packages/runtime/src/enhancements/node/encryption.ts b/packages/runtime/src/enhancements/node/encryption.ts index 65666d8cd..42001fc16 100644 --- a/packages/runtime/src/enhancements/node/encryption.ts +++ b/packages/runtime/src/enhancements/node/encryption.ts @@ -2,6 +2,7 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ import { z } from 'zod'; +import { ACTIONS_WITH_WRITE_PAYLOAD } from '../../constants'; import { FieldInfo, NestedWriteVisitor, @@ -211,8 +212,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; - if (args && args.data && actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); } return args; diff --git a/packages/runtime/src/enhancements/node/password.ts b/packages/runtime/src/enhancements/node/password.ts index 8c1aeb959..a2fdae42c 100644 --- a/packages/runtime/src/enhancements/node/password.ts +++ b/packages/runtime/src/enhancements/node/password.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-unused-vars */ -import { DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants'; +import { ACTIONS_WITH_WRITE_PAYLOAD, DEFAULT_PASSWORD_SALT_LENGTH } from '../../constants'; import { NestedWriteVisitor, type PrismaWriteActionType } from '../../cross'; import { DbClientContract } from '../../types'; import { InternalEnhancementOptions } from './create-enhancement'; @@ -39,8 +39,7 @@ class PasswordHandler extends DefaultPrismaProxyHandler { // base override protected async preprocessArgs(action: PrismaProxyActions, args: any) { - const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; - if (args && args.data && actionsOfInterest.includes(action)) { + if (args && ACTIONS_WITH_WRITE_PAYLOAD.includes(action)) { await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); } return args; diff --git a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts index 71ccd0323..71d32769f 100644 --- a/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts +++ b/tests/integration/tests/enhancements/with-encrypted/with-encrypted.test.ts @@ -58,6 +58,124 @@ describe('Encrypted test', () => { expect(read.encrypted_value).toBe('abc123'); expect(sudoRead.encrypted_value).not.toBe('abc123'); expect(rawRead.encrypted_value).not.toBe('abc123'); + + // update + const updated = await db.user.update({ + where: { id: '1' }, + data: { encrypted_value: 'abc234' }, + }); + expect(updated.encrypted_value).toBe('abc234'); + await expect(db.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({ + encrypted_value: 'abc234', + }); + await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc234', + }); + + // upsert with create + const upsertCreate = await db.user.upsert({ + where: { id: '2' }, + create: { + id: '2', + encrypted_value: 'abc345', + }, + update: { + encrypted_value: 'abc456', + }, + }); + expect(upsertCreate.encrypted_value).toBe('abc345'); + await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: 'abc345', + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc345', + }); + + // upsert with update + const upsertUpdate = await db.user.upsert({ + where: { id: '2' }, + create: { + id: '2', + encrypted_value: 'abc345', + }, + update: { + encrypted_value: 'abc456', + }, + }); + expect(upsertUpdate.encrypted_value).toBe('abc456'); + await expect(db.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: 'abc456', + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc456', + }); + + // createMany + await db.user.createMany({ + data: [ + { id: '3', encrypted_value: 'abc567' }, + { id: '4', encrypted_value: 'abc678' }, + ], + }); + await expect(db.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ + encrypted_value: 'abc567', + }); + await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc567', + }); + + // createManyAndReturn + await expect( + db.user.createManyAndReturn({ + data: [ + { id: '5', encrypted_value: 'abc789' }, + { id: '6', encrypted_value: 'abc890' }, + ], + }) + ).resolves.toEqual( + expect.arrayContaining([ + { id: '5', encrypted_value: 'abc789' }, + { id: '6', encrypted_value: 'abc890' }, + ]) + ); + await expect(db.user.findUnique({ where: { id: '5' } })).resolves.toMatchObject({ + encrypted_value: 'abc789', + }); + await expect(prisma.user.findUnique({ where: { id: '5' } })).resolves.not.toMatchObject({ + encrypted_value: 'abc789', + }); + }); + + it('Works with nullish values', async () => { + const { enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + encrypted_value String? @encrypted() + }`, + { + enhancements: ['encryption'], + enhanceOptions: { + encryption: { encryptionKey }, + }, + } + ); + + const db = enhance(); + await expect(db.user.create({ data: { id: '1', encrypted_value: '' } })).resolves.toMatchObject({ + encrypted_value: '', + }); + await expect(prisma.user.findUnique({ where: { id: '1' } })).resolves.toMatchObject({ encrypted_value: '' }); + + await expect(db.user.create({ data: { id: '2' } })).resolves.toMatchObject({ + encrypted_value: null, + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.toMatchObject({ encrypted_value: null }); + + await expect(db.user.create({ data: { id: '3', encrypted_value: null } })).resolves.toMatchObject({ + encrypted_value: null, + }); + await expect(prisma.user.findUnique({ where: { id: '3' } })).resolves.toMatchObject({ encrypted_value: null }); }); it('Decrypts nested fields', async () => { diff --git a/tests/integration/tests/enhancements/with-password/with-password.test.ts b/tests/integration/tests/enhancements/with-password/with-password.test.ts index b2fd89a65..a54d0c42d 100644 --- a/tests/integration/tests/enhancements/with-password/with-password.test.ts +++ b/tests/integration/tests/enhancements/with-password/with-password.test.ts @@ -14,7 +14,7 @@ describe('Password test', () => { }); it('password tests', async () => { - const { enhance } = await loadSchema(` + const { enhance, prisma } = await loadSchema(` model User { id String @id @default(cuid()) password String @password(saltLength: 16) @@ -38,6 +38,27 @@ describe('Password test', () => { }, }); expect(compareSync('abc456', r1.password)).toBeTruthy(); + + await db.user.createMany({ + data: [ + { id: '2', password: 'user2' }, + { id: '3', password: 'user3' }, + ], + }); + await expect(prisma.user.findUnique({ where: { id: '2' } })).resolves.not.toMatchObject({ password: 'user2' }); + const r2 = await db.user.findUnique({ where: { id: '2' } }); + expect(compareSync('user2', r2.password)).toBeTruthy(); + + const [u4] = await db.user.createManyAndReturn({ + data: [ + { id: '4', password: 'user4' }, + { id: '5', password: 'user5' }, + ], + }); + expect(compareSync('user4', u4.password)).toBeTruthy(); + await expect(prisma.user.findUnique({ where: { id: '4' } })).resolves.not.toMatchObject({ password: 'user4' }); + const r4 = await db.user.findUnique({ where: { id: '4' } }); + expect(compareSync('user4', r4.password)).toBeTruthy(); }); it('length tests', async () => {