diff --git a/packages/runtime/src/enhancements/edge/encrypted.ts b/packages/runtime/src/enhancements/edge/encrypted.ts index 4fcb64dd0..3d758edb6 100644 --- a/packages/runtime/src/enhancements/edge/encrypted.ts +++ b/packages/runtime/src/enhancements/edge/encrypted.ts @@ -1,8 +1,15 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ /* eslint-disable @typescript-eslint/no-unused-vars */ -import { NestedWriteVisitor, enumerate, getModelFields, resolveField, type PrismaWriteActionType } from '../../cross'; -import { DbClientContract } from '../../types'; +import { + FieldInfo, + NestedWriteVisitor, + enumerate, + getModelFields, + resolveField, + type PrismaWriteActionType, +} from '../../cross'; +import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types'; import { InternalEnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; import { QueryUtils } from './query-utils'; @@ -33,52 +40,65 @@ const getKey = async (secret: string): Promise => { 'decrypt', ]); }; -const encryptFunc = async (data: string, secret: string): Promise => { - const key = await getKey(secret); - const iv = crypto.getRandomValues(new Uint8Array(12)); - - const encrypted = await crypto.subtle.encrypt( - { - name: 'AES-GCM', - iv, - }, - key, - encoder.encode(data) - ); - // Combine IV and encrypted data into a single array of bytes - const bytes = [...iv, ...new Uint8Array(encrypted)]; +class EncryptedHandler extends DefaultPrismaProxyHandler { + private queryUtils: QueryUtils; - // Convert bytes to base64 string - return btoa(String.fromCharCode(...bytes)); -}; + constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { + super(prisma, model, options); -const decryptFunc = async (encryptedData: string, secret: string): Promise => { - const key = await getKey(secret); + this.queryUtils = new QueryUtils(prisma, options); + } - // Convert base64 back to bytes - const bytes = Uint8Array.from(atob(encryptedData), (c) => c.charCodeAt(0)); + private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption { + return 'encrypt' in encryption && 'decrypt' in encryption; + } - // First 12 bytes are IV, rest is encrypted data - const decrypted = await crypto.subtle.decrypt( - { - name: 'AES-GCM', - iv: bytes.slice(0, 12), - }, - key, - bytes.slice(12) - ); + private async encrypt(field: FieldInfo, data: string): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.encrypt(this.model, field, data); + } - return decoder.decode(decrypted); -}; + const key = await getKey(this.options.encryption!.encryptionKey); + const iv = crypto.getRandomValues(new Uint8Array(12)); -class EncryptedHandler extends DefaultPrismaProxyHandler { - private queryUtils: QueryUtils; + const encrypted = await crypto.subtle.encrypt( + { + name: 'AES-GCM', + iv, + }, + key, + encoder.encode(data) + ); - constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) { - super(prisma, model, options); + // Combine IV and encrypted data into a single array of bytes + const bytes = [...iv, ...new Uint8Array(encrypted)]; - this.queryUtils = new QueryUtils(prisma, options); + // Convert bytes to base64 string + return btoa(String.fromCharCode(...bytes)); + } + + private async decrypt(field: FieldInfo, data: string): Promise { + if (this.isCustomEncryption(this.options.encryption!)) { + return this.options.encryption.decrypt(this.model, field, data); + } + + const key = await getKey(this.options.encryption!.encryptionKey); + + // Convert base64 back to bytes + const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0)); + + // First 12 bytes are IV, rest is encrypted data + const decrypted = await crypto.subtle.decrypt( + { + name: 'AES-GCM', + iv: bytes.slice(0, 12), + }, + key, + bytes.slice(12) + ); + + return decoder.decode(decrypted); } // base override @@ -115,9 +135,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted'); if (shouldDecrypt) { - const descryptSecret = shouldDecrypt.args.find((arg) => arg.name === 'secret')?.value as string; - - entityData[field] = await decryptFunc(entityData[field], descryptSecret); + entityData[field] = await this.decrypt(fieldInfo, entityData[field]); } } } @@ -131,7 +149,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler { const secret: string = encAttr.args.find((arg) => arg.name === 'secret')?.value as string; - context.parent[field.name] = await encryptFunc(data, secret); + context.parent[field.name] = await this.encrypt(field, data); } }, });