Skip to content

Commit

Permalink
feat(encryption): support providing multiple decryption keys for key …
Browse files Browse the repository at this point in the history
…rotation (#1942)
  • Loading branch information
ymc9 authored Jan 7, 2025
1 parent 00c1982 commit 7ed9841
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 38 deletions.
161 changes: 129 additions & 32 deletions packages/runtime/src/enhancements/node/encryption.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import { z } from 'zod';
import {
FieldInfo,
NestedWriteVisitor,
Expand Down Expand Up @@ -37,79 +38,175 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
private encoder = new TextEncoder();
private decoder = new TextDecoder();
private logger: Logger;
private encryptionKey: CryptoKey | undefined;
private encryptionKeyDigest: string | undefined;
private decryptionKeys: Array<{ key: CryptoKey; digest: string }> = [];
private encryptionMetaSchema = z.object({
// version
v: z.number(),
// algorithm
a: z.string(),
// key digest
k: z.string(),
});

// constants
private readonly ENCRYPTION_KEY_BYTES = 32;
private readonly IV_BYTES = 12;
private readonly ALGORITHM = 'AES-GCM';
private readonly ENCRYPTER_VERSION = 1;
private readonly KEY_DIGEST_BYTES = 8;

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);
this.logger = new Logger(prisma);

if (!options.encryption) throw this.queryUtils.unknownError('Encryption options must be provided');
if (!options.encryption) {
throw this.queryUtils.unknownError('Encryption options must be provided');
}

if (this.isCustomEncryption(options.encryption!)) {
if (!options.encryption.encrypt || !options.encryption.decrypt)
if (!options.encryption.encrypt || !options.encryption.decrypt) {
throw this.queryUtils.unknownError('Custom encryption must provide encrypt and decrypt functions');
}
} else {
if (!options.encryption.encryptionKey)
if (!options.encryption.encryptionKey) {
throw this.queryUtils.unknownError('Encryption key must be provided');
if (options.encryption.encryptionKey.length !== 32)
throw this.queryUtils.unknownError('Encryption key must be 32 bytes');
}
if (options.encryption.encryptionKey.length !== this.ENCRYPTION_KEY_BYTES) {
throw this.queryUtils.unknownError(`Encryption key must be ${this.ENCRYPTION_KEY_BYTES} bytes`);
}
}
}

private async getKey(secret: Uint8Array): Promise<CryptoKey> {
return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']);
}

private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption {
return 'encrypt' in encryption && 'decrypt' in encryption;
}

private async loadKey(key: Uint8Array, keyUsages: KeyUsage[]): Promise<CryptoKey> {
return crypto.subtle.importKey('raw', key, this.ALGORITHM, false, keyUsages);
}

private async computeKeyDigest(key: Uint8Array) {
const rawDigest = await crypto.subtle.digest('SHA-256', key);
return new Uint8Array(rawDigest.slice(0, this.KEY_DIGEST_BYTES)).reduce(
(acc, byte) => acc + byte.toString(16).padStart(2, '0'),
''
);
}

private async getEncryptionKey(): Promise<CryptoKey> {
if (this.isCustomEncryption(this.options.encryption!)) {
throw new Error('Unexpected custom encryption settings');
}
if (!this.encryptionKey) {
this.encryptionKey = await this.loadKey(this.options.encryption!.encryptionKey, ['encrypt', 'decrypt']);
}
return this.encryptionKey;
}

private async getEncryptionKeyDigest() {
if (this.isCustomEncryption(this.options.encryption!)) {
throw new Error('Unexpected custom encryption settings');
}
if (!this.encryptionKeyDigest) {
this.encryptionKeyDigest = await this.computeKeyDigest(this.options.encryption!.encryptionKey);
}
return this.encryptionKeyDigest;
}

private async findDecryptionKeys(keyDigest: string): Promise<CryptoKey[]> {
if (this.isCustomEncryption(this.options.encryption!)) {
throw new Error('Unexpected custom encryption settings');
}

if (this.decryptionKeys.length === 0) {
const keys = [this.options.encryption!.encryptionKey, ...(this.options.encryption!.decryptionKeys || [])];
this.decryptionKeys = await Promise.all(
keys.map(async (key) => ({
key: await this.loadKey(key, ['decrypt']),
digest: await this.computeKeyDigest(key),
}))
);
}

return this.decryptionKeys.filter((entry) => entry.digest === keyDigest).map((entry) => entry.key);
}

private async encrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.encrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);
const iv = crypto.getRandomValues(new Uint8Array(12));

const key = await this.getEncryptionKey();
const iv = crypto.getRandomValues(new Uint8Array(this.IV_BYTES));
const encrypted = await crypto.subtle.encrypt(
{
name: 'AES-GCM',
name: this.ALGORITHM,
iv,
},
key,
this.encoder.encode(data)
);

// Combine IV and encrypted data into a single array of bytes
const bytes = [...iv, ...new Uint8Array(encrypted)];
// combine IV and encrypted data into a single array of bytes
const cipherBytes = [...iv, ...new Uint8Array(encrypted)];

// encryption metadata
const meta = { v: this.ENCRYPTER_VERSION, a: this.ALGORITHM, k: await this.getEncryptionKeyDigest() };

// Convert bytes to base64 string
return btoa(String.fromCharCode(...bytes));
// convert concatenated result to base64 string
return `${btoa(JSON.stringify(meta))}.${btoa(String.fromCharCode(...cipherBytes))}`;
}

private async decrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.decrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);
const [metaText, cipherText] = data.split('.');
if (!metaText || !cipherText) {
throw new Error('Malformed encrypted data');
}

// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));
let metaObj: unknown;
try {
metaObj = JSON.parse(atob(metaText));
} catch (error) {
throw new Error('Malformed metadata');
}

// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);
// parse meta
const { a: algorithm, k: keyDigest } = this.encryptionMetaSchema.parse(metaObj);

// find a matching decryption key
const keys = await this.findDecryptionKeys(keyDigest);
if (keys.length === 0) {
throw new Error('No matching decryption key found');
}

// convert base64 back to bytes
const bytes = Uint8Array.from(atob(cipherText), (c) => c.charCodeAt(0));

// extract IV from the head
const iv = bytes.slice(0, this.IV_BYTES);
const cipher = bytes.slice(this.IV_BYTES);
let lastError: unknown;

for (const key of keys) {
let decrypted: ArrayBuffer;
try {
decrypted = await crypto.subtle.decrypt({ name: algorithm, iv }, key, cipher);
} catch (err) {
lastError = err;
continue;
}
return this.decoder.decode(decrypted);
}

return this.decoder.decode(decrypted);
throw lastError;
}

// base override
Expand Down Expand Up @@ -138,7 +235,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

for (const field of getModelFields(entityData)) {
// Don't decrypt null, undefined or empty string values
// don't decrypt null, undefined or empty string values
if (!entityData[field]) continue;

const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);
Expand Down Expand Up @@ -169,7 +266,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
field: async (field, _action, data, context) => {
// Don't encrypt null, undefined or empty string values
// don't encrypt null, undefined or empty string values
if (!data) return;

const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted');
Expand Down
33 changes: 31 additions & 2 deletions packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,38 @@ export type ZodSchemas = {
input?: Record<string, Record<string, z.ZodSchema>>;
};

/**
* Simple encryption settings for processing fields marked with `@encrypted`.
*/
export type SimpleEncryption = {
/**
* The encryption key.
*/
encryptionKey: Uint8Array;

/**
* Optional list of all decryption keys that were previously used to encrypt the data
* , for supporting key rotation. The `encryptionKey` field value is automatically
* included for decryption.
*
* When the encrypted data is persisted, a metadata object containing the digest of the
* encryption key is stored alongside the data. This digest is used to quickly determine
* the correct decryption key to use when reading the data.
*/
decryptionKeys?: Uint8Array[];
};

/**
* Custom encryption settings for processing fields marked with `@encrypted`.
*/
export type CustomEncryption = {
/**
* Encryption function.
*/
encrypt: (model: string, field: FieldInfo, plain: string) => Promise<string>;

/**
* Decryption function
*/
decrypt: (model: string, field: FieldInfo, cipher: string) => Promise<string>;
};

export type SimpleEncryption = { encryptionKey: Uint8Array };
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,12 @@ describe('Encrypted test', () => {
});

it('Custom encryption test', async () => {
const { enhance } = await loadSchema(`
const { enhance, prisma } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const db = enhance(undefined, {
kinds: ['encryption'],
encryption: {
Expand Down Expand Up @@ -181,15 +180,83 @@ describe('Encrypted test', () => {
},
});

const sudoRead = await sudoDb.user.findUnique({
const rawRead = await prisma.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).toBe('abc123_enc');
expect(rawRead.encrypted_value).toBe('abc123_enc');
});

it('Works with multiple decryption keys', async () => {
const { enhanceRaw: enhance, prisma } = await loadSchema(
`
model User {
id String @id @default(cuid())
secret String @encrypted()
}`
);

const key1 = crypto.getRandomValues(new Uint8Array(32));
const key2 = crypto.getRandomValues(new Uint8Array(32));

const db1 = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1 },
});
const user1 = await db1.user.create({ data: { secret: 'user1' } });

const db2 = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key2 },
});
const user2 = await db2.user.create({ data: { secret: 'user2' } });

const dbAll = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)), decryptionKeys: [key1, key2] },
});
const allUsers = await dbAll.user.findMany();
expect(allUsers).toEqual(expect.arrayContaining([user1, user2]));

const dbWithEncryptionKeyExplicitlyProvided = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1, decryptionKeys: [key1, key2] },
});
await expect(dbWithEncryptionKeyExplicitlyProvided.user.findMany()).resolves.toEqual(
expect.arrayContaining([user1, user2])
);

const dbWithDuplicatedKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1, decryptionKeys: [key1, key1, key2, key2] },
});
await expect(dbWithDuplicatedKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2]));

const dbWithInvalidKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1, decryptionKeys: [key2, crypto.getRandomValues(new Uint8Array(32))] },
});
await expect(dbWithInvalidKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2]));

const dbWithMissingKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key2 },
});
const found = await dbWithMissingKeys.user.findMany();
expect(found).not.toContainEqual(user1);
expect(found).toContainEqual(user2);

const dbWithAllWrongKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) },
});
const found1 = await dbWithAllWrongKeys.user.findMany();
expect(found1).not.toContainEqual(user1);
expect(found1).not.toContainEqual(user2);
});

it('Only supports string fields', async () => {
Expand Down

0 comments on commit 7ed9841

Please sign in to comment.