Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(encryption): support providing multiple decryption keys for key rotation #1942

Merged
merged 3 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 129 additions & 32 deletions packages/runtime/src/enhancements/node/encryption.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import { z } from 'zod';
import {
FieldInfo,
NestedWriteVisitor,
Expand Down Expand Up @@ -37,79 +38,175 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
private encoder = new TextEncoder();
private decoder = new TextDecoder();
private logger: Logger;
private encryptionKey: CryptoKey | undefined;
private encryptionKeyDigest: string | undefined;
private decryptionKeys: Array<{ key: CryptoKey; digest: string }> = [];
private encryptionMetaSchema = z.object({
// version
v: z.number(),
// algorithm
a: z.string(),
// key digest
k: z.string(),
});

// constants
private readonly ENCRYPTION_KEY_BYTES = 32;
private readonly IV_BYTES = 12;
private readonly ALGORITHM = 'AES-GCM';
private readonly ENCRYPTER_VERSION = 1;
private readonly KEY_DIGEST_BYTES = 8;

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);
this.logger = new Logger(prisma);

if (!options.encryption) throw this.queryUtils.unknownError('Encryption options must be provided');
if (!options.encryption) {
throw this.queryUtils.unknownError('Encryption options must be provided');
}

if (this.isCustomEncryption(options.encryption!)) {
if (!options.encryption.encrypt || !options.encryption.decrypt)
if (!options.encryption.encrypt || !options.encryption.decrypt) {
throw this.queryUtils.unknownError('Custom encryption must provide encrypt and decrypt functions');
}
} else {
if (!options.encryption.encryptionKey)
if (!options.encryption.encryptionKey) {
throw this.queryUtils.unknownError('Encryption key must be provided');
if (options.encryption.encryptionKey.length !== 32)
throw this.queryUtils.unknownError('Encryption key must be 32 bytes');
}
if (options.encryption.encryptionKey.length !== this.ENCRYPTION_KEY_BYTES) {
throw this.queryUtils.unknownError(`Encryption key must be ${this.ENCRYPTION_KEY_BYTES} bytes`);
}
}
}

private async getKey(secret: Uint8Array): Promise<CryptoKey> {
return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']);
}

private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption {
return 'encrypt' in encryption && 'decrypt' in encryption;
}

private async loadKey(key: Uint8Array, keyUsages: KeyUsage[]): Promise<CryptoKey> {
return crypto.subtle.importKey('raw', key, this.ALGORITHM, false, keyUsages);
}

private async computeKeyDigest(key: Uint8Array) {
const rawDigest = await crypto.subtle.digest('SHA-256', key);
return new Uint8Array(rawDigest.slice(0, this.KEY_DIGEST_BYTES)).reduce(
(acc, byte) => acc + byte.toString(16).padStart(2, '0'),
''
);
}
ymc9 marked this conversation as resolved.
Show resolved Hide resolved

private async getEncryptionKey(): Promise<CryptoKey> {
if (this.isCustomEncryption(this.options.encryption!)) {
throw new Error('Unexpected custom encryption settings');
}
if (!this.encryptionKey) {
this.encryptionKey = await this.loadKey(this.options.encryption!.encryptionKey, ['encrypt', 'decrypt']);
}
return this.encryptionKey;
}

private async getEncryptionKeyDigest() {
if (this.isCustomEncryption(this.options.encryption!)) {
throw new Error('Unexpected custom encryption settings');
}
if (!this.encryptionKeyDigest) {
this.encryptionKeyDigest = await this.computeKeyDigest(this.options.encryption!.encryptionKey);
}
return this.encryptionKeyDigest;
}

private async findDecryptionKeys(keyDigest: string): Promise<CryptoKey[]> {
if (this.isCustomEncryption(this.options.encryption!)) {
throw new Error('Unexpected custom encryption settings');
}

if (this.decryptionKeys.length === 0) {
const keys = [this.options.encryption!.encryptionKey, ...(this.options.encryption!.decryptionKeys || [])];
this.decryptionKeys = await Promise.all(
keys.map(async (key) => ({
key: await this.loadKey(key, ['decrypt']),
digest: await this.computeKeyDigest(key),
}))
);
}

return this.decryptionKeys.filter((entry) => entry.digest === keyDigest).map((entry) => entry.key);
}

private async encrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.encrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);
const iv = crypto.getRandomValues(new Uint8Array(12));

const key = await this.getEncryptionKey();
const iv = crypto.getRandomValues(new Uint8Array(this.IV_BYTES));
const encrypted = await crypto.subtle.encrypt(
{
name: 'AES-GCM',
name: this.ALGORITHM,
iv,
},
key,
this.encoder.encode(data)
);

// Combine IV and encrypted data into a single array of bytes
const bytes = [...iv, ...new Uint8Array(encrypted)];
// combine IV and encrypted data into a single array of bytes
const cipherBytes = [...iv, ...new Uint8Array(encrypted)];

// encryption metadata
const meta = { v: this.ENCRYPTER_VERSION, a: this.ALGORITHM, k: await this.getEncryptionKeyDigest() };

// Convert bytes to base64 string
return btoa(String.fromCharCode(...bytes));
// convert concatenated result to base64 string
return `${btoa(JSON.stringify(meta))}.${btoa(String.fromCharCode(...cipherBytes))}`;
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
}

private async decrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.decrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);
const [metaText, cipherText] = data.split('.');
if (!metaText || !cipherText) {
throw new Error('Malformed encrypted data');
}

// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));
let metaObj: unknown;
try {
metaObj = JSON.parse(atob(metaText));
} catch (error) {
throw new Error('Malformed metadata');
}

// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);
// parse meta
const { a: algorithm, k: keyDigest } = this.encryptionMetaSchema.parse(metaObj);

// find a matching decryption key
const keys = await this.findDecryptionKeys(keyDigest);
if (keys.length === 0) {
throw new Error('No matching decryption key found');
}

// convert base64 back to bytes
const bytes = Uint8Array.from(atob(cipherText), (c) => c.charCodeAt(0));

// extract IV from the head
const iv = bytes.slice(0, this.IV_BYTES);
const cipher = bytes.slice(this.IV_BYTES);
let lastError: unknown;

for (const key of keys) {
let decrypted: ArrayBuffer;
try {
decrypted = await crypto.subtle.decrypt({ name: algorithm, iv }, key, cipher);
} catch (err) {
lastError = err;
continue;
}
return this.decoder.decode(decrypted);
}

return this.decoder.decode(decrypted);
throw lastError;
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
}

// base override
Expand Down Expand Up @@ -138,7 +235,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

for (const field of getModelFields(entityData)) {
// Don't decrypt null, undefined or empty string values
// don't decrypt null, undefined or empty string values
if (!entityData[field]) continue;

const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);
Expand Down Expand Up @@ -169,7 +266,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
field: async (field, _action, data, context) => {
// Don't encrypt null, undefined or empty string values
// don't encrypt null, undefined or empty string values
if (!data) return;

const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted');
Expand Down
33 changes: 31 additions & 2 deletions packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,38 @@ export type ZodSchemas = {
input?: Record<string, Record<string, z.ZodSchema>>;
};

/**
* Simple encryption settings for processing fields marked with `@encrypted`.
*/
export type SimpleEncryption = {
/**
* The encryption key.
*/
encryptionKey: Uint8Array;

/**
* Optional list of all decryption keys that were previously used to encrypt the data
* , for supporting key rotation. The `encryptionKey` field value is automatically
* included for decryption.
*
* When the encrypted data is persisted, a metadata object containing the digest of the
* encryption key is stored alongside the data. This digest is used to quickly determine
* the correct decryption key to use when reading the data.
*/
decryptionKeys?: Uint8Array[];
};

/**
* Custom encryption settings for processing fields marked with `@encrypted`.
*/
export type CustomEncryption = {
/**
* Encryption function.
*/
encrypt: (model: string, field: FieldInfo, plain: string) => Promise<string>;

/**
* Decryption function
*/
decrypt: (model: string, field: FieldInfo, cipher: string) => Promise<string>;
};

export type SimpleEncryption = { encryptionKey: Uint8Array };
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,12 @@ describe('Encrypted test', () => {
});

it('Custom encryption test', async () => {
const { enhance } = await loadSchema(`
const { enhance, prisma } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const db = enhance(undefined, {
kinds: ['encryption'],
encryption: {
Expand Down Expand Up @@ -181,15 +180,83 @@ describe('Encrypted test', () => {
},
});

const sudoRead = await sudoDb.user.findUnique({
const rawRead = await prisma.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).toBe('abc123_enc');
expect(rawRead.encrypted_value).toBe('abc123_enc');
});

it('Works with multiple decryption keys', async () => {
const { enhanceRaw: enhance, prisma } = await loadSchema(
`
model User {
id String @id @default(cuid())
secret String @encrypted()
}`
);

const key1 = crypto.getRandomValues(new Uint8Array(32));
const key2 = crypto.getRandomValues(new Uint8Array(32));

const db1 = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1 },
});
const user1 = await db1.user.create({ data: { secret: 'user1' } });

const db2 = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key2 },
});
const user2 = await db2.user.create({ data: { secret: 'user2' } });

const dbAll = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)), decryptionKeys: [key1, key2] },
});
const allUsers = await dbAll.user.findMany();
expect(allUsers).toEqual(expect.arrayContaining([user1, user2]));

const dbWithEncryptionKeyExplicitlyProvided = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1, decryptionKeys: [key1, key2] },
});
await expect(dbWithEncryptionKeyExplicitlyProvided.user.findMany()).resolves.toEqual(
expect.arrayContaining([user1, user2])
);

const dbWithDuplicatedKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1, decryptionKeys: [key1, key1, key2, key2] },
});
await expect(dbWithDuplicatedKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2]));

const dbWithInvalidKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key1, decryptionKeys: [key2, crypto.getRandomValues(new Uint8Array(32))] },
});
await expect(dbWithInvalidKeys.user.findMany()).resolves.toEqual(expect.arrayContaining([user1, user2]));

const dbWithMissingKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: key2 },
});
const found = await dbWithMissingKeys.user.findMany();
expect(found).not.toContainEqual(user1);
expect(found).toContainEqual(user2);

const dbWithAllWrongKeys = enhance(prisma, undefined, {
kinds: ['encryption'],
encryption: { encryptionKey: crypto.getRandomValues(new Uint8Array(32)) },
});
const found1 = await dbWithAllWrongKeys.user.findMany();
expect(found1).not.toContainEqual(user1);
expect(found1).not.toContainEqual(user2);
});

it('Only supports string fields', async () => {
Expand Down
Loading