Skip to content

Commit

Permalink
test: Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
genu committed Dec 20, 2024
1 parent f8ee204 commit 6bff7f4
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 7 deletions.
141 changes: 141 additions & 0 deletions packages/runtime/src/enhancements/edge/encrypted.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import { NestedWriteVisitor, enumerate, getModelFields, resolveField, type PrismaWriteActionType } from '../../cross';
import { DbClientContract } from '../../types';
import { InternalEnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';

/**
* Gets an enhanced Prisma client that supports `@encrypted` attribute.
*
* @private
*/
export function withEncrypted<DbClient extends object = any>(
prisma: DbClient,
options: InternalEnhancementOptions
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options),
'encrypted'
);
}

const encoder = new TextEncoder();
const decoder = new TextDecoder();

const getKey = async (secret: string): Promise<CryptoKey> => {
return crypto.subtle.importKey('raw', encoder.encode(secret).slice(0, 32), 'AES-GCM', false, [
'encrypt',
'decrypt',
]);
};
const encryptFunc = async (data: string, secret: string): Promise<string> => {
const key = await getKey(secret);
const iv = crypto.getRandomValues(new Uint8Array(12));

const encrypted = await crypto.subtle.encrypt(
{
name: 'AES-GCM',
iv,
},
key,
encoder.encode(data)
);

// Combine IV and encrypted data into a single array of bytes
const bytes = [...iv, ...new Uint8Array(encrypted)];

// Convert bytes to base64 string
return btoa(String.fromCharCode(...bytes));
};

const decryptFunc = async (encryptedData: string, secret: string): Promise<string> => {
const key = await getKey(secret);

// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(encryptedData), (c) => c.charCodeAt(0));

// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);

return decoder.decode(decrypted);
};

class EncryptedHandler extends DefaultPrismaProxyHandler {
private queryUtils: QueryUtils;

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);
}

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (args && args.data && actionsOfInterest.includes(action)) {
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
}
return args;
}

// base override
protected async processResultEntity<T>(method: PrismaProxyActions, data: T): Promise<T> {
if (!data || typeof data !== 'object') {
return data;
}

for (const value of enumerate(data)) {
await this.doPostProcess(value, this.model);
}

return data;
}

private async doPostProcess(entityData: any, model: string) {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

for (const field of getModelFields(entityData)) {
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);

if (!fieldInfo) {
continue;
}

const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
if (shouldDecrypt) {
const descryptSecret = shouldDecrypt.args.find((arg) => arg.name === 'secret')?.value as string;

entityData[field] = await decryptFunc(entityData[field], descryptSecret);
}
}
}

private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
field: async (field, _action, data, context) => {
const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted');
if (encAttr && field.type === 'String') {
// encrypt value

const secret: string = encAttr.args.find((arg) => arg.name === 'secret')?.value as string;

context.parent[field.name] = await encryptFunc(data, secret);
}
},
});

await visitor.visit(model, action, args);
}
}
40 changes: 36 additions & 4 deletions packages/runtime/src/enhancements/node/encrypted.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import { NestedWriteVisitor, type PrismaWriteActionType } from '../../cross';
import { NestedWriteVisitor, enumerate, getModelFields, resolveField, type PrismaWriteActionType } from '../../cross';
import { DbClientContract } from '../../types';
import { InternalEnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';

/**
* Gets an enhanced Prisma client that supports `@encrypted` attribute.
Expand Down Expand Up @@ -72,8 +73,12 @@ const decryptFunc = async (encryptedData: string, secret: string): Promise<strin
};

class EncryptedHandler extends DefaultPrismaProxyHandler {
private queryUtils: QueryUtils;

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);
}

// base override
Expand All @@ -86,8 +91,35 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
}

// base override
protected async processResultEntity(action: PrismaProxyActions, args: any) {
return args;
protected async processResultEntity<T>(method: PrismaProxyActions, data: T): Promise<T> {
if (!data || typeof data !== 'object') {
return data;
}

for (const value of enumerate(data)) {
await this.doPostProcess(value, this.model);
}

return data;
}

private async doPostProcess(entityData: any, model: string) {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

for (const field of getModelFields(entityData)) {
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);

if (!fieldInfo) {
continue;
}

const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
if (shouldDecrypt) {
const descryptSecret = shouldDecrypt.args.find((arg) => arg.name === 'secret')?.value as string;

entityData[field] = await decryptFunc(entityData[field], descryptSecret);
}
}
}

private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
Expand All @@ -97,7 +129,7 @@ class EncryptedHandler extends DefaultPrismaProxyHandler {
if (encAttr && field.type === 'String') {
// encrypt value

let secret: string = encAttr.args.find((arg) => arg.name === 'secret')?.value as string;
const secret: string = encAttr.args.find((arg) => arg.name === 'secret')?.value as string;

context.parent[field.name] = await encryptFunc(data, secret);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,32 @@ describe('Encrypted test', () => {
});

it('encrypted tests', async () => {
const ENCRYPTION_KEY = 'c558Gq0YQK2QcqtkMF9BGXHCQn4dMF8w';

const { enhance } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted(saltLength: 16)
encrypted_value String @encrypted(secret: "${ENCRYPTION_KEY}")
@@allow('all', true)
}`);

const db = enhance();
const r = await db.user.create({

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});

expect(r.encrypted_value).toBe('abc123');
const read = await db.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
});
});

0 comments on commit 6bff7f4

Please sign in to comment.