Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add @encrypted enhancer #1922

Merged
merged 20 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
23a06cc
feat: add encrypted kind
genu Dec 20, 2024
7784099
chore: add encrypt function
genu Dec 20, 2024
f8ee204
test: add integration tests for encrypted model functionality
genu Dec 20, 2024
6bff7f4
test: Add test
genu Dec 20, 2024
b86e814
fix: require encryption options for @encrypted enhancement
genu Dec 23, 2024
e0789b7
feat: enhance encryption handling in EncryptedHandler and update sche…
genu Dec 23, 2024
688d92d
fix: remove hardcoded encryption key from schema loading command
genu Dec 23, 2024
aedbd93
feat: implement custom encryption handling in EncryptedHandler
genu Dec 23, 2024
8752f06
fix: update encryption methods to return promises in EncryptedHandler
genu Dec 23, 2024
83c242c
test: add integration tests for custom encryption handling in Encrypt…
genu Dec 23, 2024
d9b95ef
chore: Add symlink
genu Dec 24, 2024
2ea8bd2
refactor: streamline encryption handling by moving key retrieval and …
genu Dec 24, 2024
78046b3
refactor: don't enable `encrypted` enhancement by default
genu Dec 24, 2024
9d16be0
refactor: change encryptionKey type from string to Uint8Array in Simp…
genu Dec 24, 2024
29b7d15
refactor: enhance encryption validation and update key handling in En…
genu Dec 24, 2024
a7169ef
refactor: prevent encryption of null, undefined, or empty string valu…
genu Dec 24, 2024
acb2ee2
refactor: prevent decryption and encryption of null, undefined, or em…
genu Dec 24, 2024
f4dda18
refactor: continue instead of return
genu Dec 26, 2024
4e5a2be
refactor: add 'encrypted' enhancement kind to ALL_ENHANCEMENTS
genu Dec 27, 2024
fa5c065
refactor: improve error handling for encryption and decryption in Enc…
genu Dec 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/runtime/src/enhancements/edge/encrypted.ts
15 changes: 13 additions & 2 deletions packages/runtime/src/enhancements/node/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ import { withJsonProcessor } from './json-processor';
import { Logger } from './logger';
import { withOmit } from './omit';
import { withPassword } from './password';
import { withEncrypted } from './encrypted';
import { policyProcessIncludeRelationPayload, withPolicy } from './policy';
import type { PolicyDef } from './types';

/**
* All enhancement kinds
*/
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate'];
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted'];
genu marked this conversation as resolved.
Show resolved Hide resolved

/**
* Options for {@link createEnhancement}
Expand Down Expand Up @@ -100,6 +101,7 @@ export function createEnhancement<DbClient extends object>(
}

const hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password'));
const hasEncrypted = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@encrypted'));
const hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit'));
const hasDefaultAuth = allFields.some((field) => field.defaultValueProvider);
const hasTypeDefField = allFields.some((field) => field.isTypeDef);
Expand All @@ -120,13 +122,22 @@ export function createEnhancement<DbClient extends object>(
}
}

// password enhancement must be applied prior to policy because it changes then length of the field
// password and encrypted enhancement must be applied prior to policy because it changes then length of the field
// and can break validation rules like `@length`
if (hasPassword && kinds.includes('password')) {
// @password proxy
result = withPassword(result, options);
}

if (hasEncrypted && kinds.includes('encrypted')) {
if (!options.encryption) {
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
throw new Error('Encryption options are required for @encrypted enhancement');
}

// @encrypted proxy
result = withEncrypted(result, options);
}

// 'policy' and 'validation' enhancements are both enabled by `withPolicy`
if (kinds.includes('policy') || kinds.includes('validation')) {
result = withPolicy(result, options, context);
Expand Down
175 changes: 175 additions & 0 deletions packages/runtime/src/enhancements/node/encrypted.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import {
FieldInfo,
NestedWriteVisitor,
enumerate,
getModelFields,
resolveField,
type PrismaWriteActionType,
} from '../../cross';
import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types';
import { InternalEnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';

/**
* Gets an enhanced Prisma client that supports `@encrypted` attribute.
*
* @private
*/
export function withEncrypted<DbClient extends object = any>(
prisma: DbClient,
options: InternalEnhancementOptions
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options),
'encrypted'
);
}

class EncryptedHandler extends DefaultPrismaProxyHandler {
private queryUtils: QueryUtils;
private encoder = new TextEncoder();
private decoder = new TextDecoder();

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);

if (!options.encryption) throw new Error('Encryption options must be provided');

if (this.isCustomEncryption(options.encryption!)) {
if (!options.encryption.encrypt || !options.encryption.decrypt)
throw new Error('Custom encryption must provide encrypt and decrypt functions');
} else {
if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided');
if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes');
}
}

private async getKey(secret: Uint8Array): Promise<CryptoKey> {
return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']);
}

private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption {
return 'encrypt' in encryption && 'decrypt' in encryption;
}

private async encrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.encrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);
const iv = crypto.getRandomValues(new Uint8Array(12));

const encrypted = await crypto.subtle.encrypt(
{
name: 'AES-GCM',
iv,
},
key,
this.encoder.encode(data)
);

// Combine IV and encrypted data into a single array of bytes
const bytes = [...iv, ...new Uint8Array(encrypted)];

// Convert bytes to base64 string
return btoa(String.fromCharCode(...bytes));
}
ymc9 marked this conversation as resolved.
Show resolved Hide resolved

private async decrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.decrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);

// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));

// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);

return this.decoder.decode(decrypted);
}
ymc9 marked this conversation as resolved.
Show resolved Hide resolved

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (args && args.data && actionsOfInterest.includes(action)) {
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
}
return args;
}

// base override
protected async processResultEntity<T>(method: PrismaProxyActions, data: T): Promise<T> {
if (!data || typeof data !== 'object') {
return data;
}

for (const value of enumerate(data)) {
await this.doPostProcess(value, this.model);
}

return data;
}

private async doPostProcess(entityData: any, model: string) {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

for (const field of getModelFields(entityData)) {
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);

if (!fieldInfo) {
continue;
}

const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
if (shouldDecrypt) {
// Don't decrypt null, undefined or empty string values
if (!entityData[field]) continue;

try {
entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
} catch (error) {
console.warn('Decryption failed, keeping original value:', error);
}
}
}
}

private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
field: async (field, _action, data, context) => {
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
// Don't encrypt null, undefined or empty string values
if (!data) return;

const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted');
if (encAttr && field.type === 'String') {
try {
context.parent[field.name] = await this.encrypt(field, data);
} catch (error) {
throw new Error(`Encryption failed for field ${field.name}: ${error}`);
}
}
},
});

await visitor.visit(model, action, args);
}
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
}
15 changes: 14 additions & 1 deletion packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import type { z } from 'zod';
import { FieldInfo } from './cross';

export type PrismaPromise<T> = Promise<T> & Record<string, (args?: any) => PrismaPromise<any>>;

Expand Down Expand Up @@ -133,6 +134,11 @@ export type EnhancementOptions = {
* The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack.
*/
transactionIsolationLevel?: TransactionIsolationLevel;

/**
* The encryption options for using the `encrypted` enhancement.
*/
encryption?: SimpleEncryption | CustomEncryption;
};

/**
Expand All @@ -145,7 +151,7 @@ export type EnhancementContext<User extends AuthUser = AuthUser> = {
/**
* Kinds of enhancements to `PrismaClient`
*/
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate';
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted';

/**
* Function for transforming errors.
Expand All @@ -166,3 +172,10 @@ export type ZodSchemas = {
*/
input?: Record<string, Record<string, z.ZodSchema>>;
};

export type CustomEncryption = {
encrypt: (model: string, field: FieldInfo, plain: string) => Promise<string>;
decrypt: (model: string, field: FieldInfo, cipher: string) => Promise<string>;
};

export type SimpleEncryption = { encryptionKey: Uint8Array };
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions packages/schema/src/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,14 @@ attribute @@auth() @@@supportTypeDef
*/
attribute @password(saltLength: Int?, salt: String?) @@@targetField([StringField])


/**
* Indicates that the field is encrypted when storing in the DB and should be decrypted when read
*
* ZenStack uses the Web Crypto API to encrypt and decrypt the field.
*/
attribute @encrypted() @@@targetField([StringField])

/**
* Indicates that the field should be omitted when read from the generated services.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { FieldInfo } from '@zenstackhq/runtime';
import { loadSchema } from '@zenstackhq/testtools';
import path from 'path';

describe('Encrypted test', () => {
let origDir: string;

beforeAll(async () => {
origDir = path.resolve('.');
});

afterEach(async () => {
process.chdir(origDir);
});

it('Simple encryption test', async () => {
const { enhance } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()

@@allow('all', true)
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64'));

const db = enhance(undefined, {
kinds: ['encrypted'],
encryption: { encryptionKey },
});

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});

const read = await db.user.findUnique({
where: {
id: '1',
},
});

const sudoRead = await sudoDb.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).not.toBe('abc123');
});

it('Custom encryption test', async () => {
const { enhance } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()

@@allow('all', true)
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const db = enhance(undefined, {
kinds: ['encrypted'],
encryption: {
encrypt: async (model: string, field: FieldInfo, data: string) => {
// Add _enc to the end of the input
return data + '_enc';
},
decrypt: async (model: string, field: FieldInfo, cipher: string) => {
// Remove _enc from the end of the input explicitly
if (cipher.endsWith('_enc')) {
return cipher.slice(0, -4); // Remove last 4 characters (_enc)
}

return cipher;
},
},
});

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});

const read = await db.user.findUnique({
where: {
id: '1',
},
});

const sudoRead = await sudoDb.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).toBe('abc123_enc');
});
});
Loading