From 0a2aaf7a6c41e183c18b1b40e01fbb5f7bca1449 Mon Sep 17 00:00:00 2001 From: Yiming Date: Tue, 9 Apr 2024 22:12:08 +0800 Subject: [PATCH] refactor: make data validation a separate enhancement kind (#1226) --- .../src/enhancements/create-enhancement.ts | 10 +- .../src/enhancements/policy/policy-utils.ts | 35 +++- .../runtime/src/enhancements/query-utils.ts | 2 +- .../with-policy/field-validation.test.ts | 151 ++++++++++++------ 4 files changed, 139 insertions(+), 59 deletions(-) diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index 2616ae4b4..596c3e763 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -14,12 +14,12 @@ import type { PolicyDef, ZodSchemas } from './types'; /** * Kinds of enhancements to `PrismaClient` */ -export type EnhancementKind = 'password' | 'omit' | 'policy' | 'delegate'; +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate'; /** * All enhancement kinds */ -const ALL_ENHANCEMENTS = ['password', 'omit', 'policy', 'delegate']; +const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate']; /** * Transaction isolation levels: https://www.prisma.io/docs/orm/prisma-client/queries/transactions#transaction-isolation-level @@ -148,10 +148,10 @@ export function createEnhancement( } } - // policy proxy - if (kinds.includes('policy')) { + // 'policy' and 'validation' enhancements are both enabled by `withPolicy` + if (kinds.includes('policy') || kinds.includes('validation')) { result = withPolicy(result, options, context); - if (hasDefaultAuth) { + if (kinds.includes('policy') && hasDefaultAuth) { // @default(auth()) proxy result = withDefaultAuth(result, options, context); } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index f54285691..bcb946877 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -230,6 +230,25 @@ export class PolicyUtil extends QueryUtils { //# Auth guard + private readonly FULLY_OPEN_AUTH_GUARD = { + create: true, + read: true, + update: true, + delete: true, + postUpdate: true, + create_input: true, + update_input: true, + }; + + private getModelAuthGuard(model: string): PolicyDef['guard']['string'] { + if (this.options.kinds && !this.options.kinds.includes('policy')) { + // policy enhancement not enabled, return an fully open guard + return this.FULLY_OPEN_AUTH_GUARD; + } else { + return this.policy.guard[lowerCaseFirst(model)]; + } + } + /** * Gets pregenerated authorization guard object for a given model and operation. * @@ -237,7 +256,7 @@ export class PolicyUtil extends QueryUtils { * otherwise returns a guard object */ getAuthGuard(db: CrudContract, model: string, operation: PolicyOperationKind, preValue?: any) { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } @@ -318,7 +337,7 @@ export class PolicyUtil extends QueryUtils { * Checks if the given model has a policy guard for the given operation. */ hasAuthGuard(model: string, operation: PolicyOperationKind) { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { return false; } @@ -347,7 +366,7 @@ export class PolicyUtil extends QueryUtils { * @returns boolean if static analysis is enough to determine the result, undefined if not */ checkInputGuard(model: string, args: any, operation: 'create'): boolean | undefined { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { return undefined; } @@ -1020,7 +1039,7 @@ export class PolicyUtil extends QueryUtils { * Gets field selection for fetching pre-update entity values for the given model. */ getPreValueSelect(model: string): object | undefined { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } @@ -1028,7 +1047,7 @@ export class PolicyUtil extends QueryUtils { } private getReadFieldSelect(model: string): object | undefined { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } @@ -1036,7 +1055,7 @@ export class PolicyUtil extends QueryUtils { } private checkReadField(model: string, field: string, entity: any) { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } @@ -1053,7 +1072,7 @@ export class PolicyUtil extends QueryUtils { } private hasFieldLevelPolicy(model: string) { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } @@ -1228,7 +1247,7 @@ export class PolicyUtil extends QueryUtils { } private requireGuard(model: string) { - const guard = this.policy.guard[lowerCaseFirst(model)]; + const guard = this.getModelAuthGuard(model); if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } diff --git a/packages/runtime/src/enhancements/query-utils.ts b/packages/runtime/src/enhancements/query-utils.ts index 6959b922f..c161d5e2c 100644 --- a/packages/runtime/src/enhancements/query-utils.ts +++ b/packages/runtime/src/enhancements/query-utils.ts @@ -13,7 +13,7 @@ import { InternalEnhancementOptions } from './create-enhancement'; import { prismaClientUnknownRequestError, prismaClientValidationError } from './utils'; export class QueryUtils { - constructor(private readonly prisma: DbClientContract, private readonly options: InternalEnhancementOptions) {} + constructor(private readonly prisma: DbClientContract, protected readonly options: InternalEnhancementOptions) {} getIdFields(model: string) { return getIdFields(this.options.modelMeta, model, true); diff --git a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts index 7508333b6..e4cf21825 100644 --- a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts @@ -1,7 +1,7 @@ import { CrudFailureReason, isPrismaClientKnownRequestError } from '@zenstackhq/runtime'; import { FullDbClientContract, createPostgresDb, dropPostgresDb, loadSchema, run } from '@zenstackhq/testtools'; -describe('With Policy: field validation', () => { +describe('Field validation', () => { let db: FullDbClientContract; beforeAll(async () => { @@ -37,8 +37,6 @@ describe('With Policy: field validation', () => { text5 String? @endsWith('xyz') text6 String? @trim @lower text7 String? @upper - - @@allow('all', true) } model Task { @@ -46,10 +44,9 @@ describe('With Policy: field validation', () => { user User @relation(fields: [userId], references: [id]) userId String slug String @regex("^[0-9a-zA-Z]{4,16}$") @lower - - @@allow('all', true) } -` +`, + { enhancements: ['validation'] } ); db = enhance(); }); @@ -610,9 +607,10 @@ describe('With Policy: field validation', () => { }); }); -describe('With Policy: model-level validation', () => { +describe('Model-level validation', () => { it('create', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Int @@ -620,9 +618,10 @@ describe('With Policy: model-level validation', () => { @@validate(x > 0) @@validate(x >= y) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -632,16 +631,18 @@ describe('With Policy: model-level validation', () => { }); it('update', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Int y Int @@validate(x >= y) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -650,15 +651,17 @@ describe('With Policy: model-level validation', () => { }); it('int optionality', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Int? @@validate(x > 0) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -668,15 +671,17 @@ describe('With Policy: model-level validation', () => { }); it('boolean optionality', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Boolean? @@validate(x) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -686,16 +691,18 @@ describe('With Policy: model-level validation', () => { }); it('optionality with binary', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Int? y Int? @@validate(x > y) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -706,15 +713,17 @@ describe('With Policy: model-level validation', () => { }); it('optionality with in operator lhs', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x String? @@validate(x in ['foo', 'bar']) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -734,12 +743,12 @@ describe('With Policy: model-level validation', () => { x String[] @@validate('foo' in x) - @@allow('all', true) } `, { provider: 'postgresql', dbUrl, + enhancements: ['validation'], } ); @@ -756,16 +765,18 @@ describe('With Policy: model-level validation', () => { }); it('optionality with complex expression', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Int? y Int? @@validate(y > 1 && x > y) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -777,15 +788,17 @@ describe('With Policy: model-level validation', () => { }); it('optionality with negation', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Boolean? @@validate(!x) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -795,16 +808,18 @@ describe('With Policy: model-level validation', () => { }); it('update implied optionality', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Int y Int @@validate(x > y) - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -814,7 +829,8 @@ describe('With Policy: model-level validation', () => { }); it('optionality with scalar functions', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) s String @@ -832,10 +848,10 @@ describe('With Policy: model-level validation', () => { @@validate(email(e), 'invalid e') @@validate(url(u), 'invalid u') @@validate(datetime(d), 'invalid d') - - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -887,13 +903,12 @@ describe('With Policy: model-level validation', () => { hasSome(x, ['x', 'y']) && (y == null || !isEmpty(y)) ) - - @@allow('all', true) } `, { provider: 'postgresql', dbUrl, + enhancements: ['validation'], } ); @@ -912,7 +927,8 @@ describe('With Policy: model-level validation', () => { }); it('null comparison', async () => { - const { enhance } = await loadSchema(` + const { enhance } = await loadSchema( + ` model Model { id Int @id @default(autoincrement()) x Int @@ -920,10 +936,10 @@ describe('With Policy: model-level validation', () => { @@validate(x == null || !(x <= 0)) @@validate(y != null && !(y > 1)) - - @@allow('all', true) } - `); + `, + { enhancements: ['validation'] } + ); const db = enhance(); @@ -938,3 +954,48 @@ describe('With Policy: model-level validation', () => { await expect(db.model.update({ where: { id: 1 }, data: { x: 2, y: 1 } })).toResolveTruthy(); }); }); + +describe('Policy and validation interaction', () => { + it('test', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + email String? @email + age Int + + @@allow('all', age > 0) + } + ` + ); + + const db = enhance(); + + await expect( + db.user.create({ + data: { + email: 'hello', + age: 18, + }, + }) + ).toBeRejectedByPolicy(['Invalid email at "email"']); + + await expect( + db.user.create({ + data: { + email: 'user@abc.com', + age: 0, + }, + }) + ).toBeRejectedByPolicy(); + + await expect( + db.user.create({ + data: { + email: 'user@abc.com', + age: 18, + }, + }) + ).toResolveTruthy(); + }); +});