From e0f0c80a99773125f9d74bedb793ac0e6fe9efac Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 8 Mar 2024 15:24:16 -0800 Subject: [PATCH 1/4] fix: `@@validate` should ignore fields that are not present --- .../attribute-application-validator.ts | 15 +- .../validator/expression-validator.ts | 51 +++-- packages/schema/src/utils/ast-utils.ts | 17 +- .../typescript-expression-transformer.ts | 100 +++++++--- .../tests/generator/expression-writer.test.ts | 18 +- .../validation/attribute-validation.test.ts | 41 ++-- .../with-policy/field-validation.test.ts | 179 +++++++++++++++++- 7 files changed, 340 insertions(+), 81 deletions(-) diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index f81f5c166..92c086005 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -15,7 +15,7 @@ import { isEnum, isReferenceExpr, } from '@zenstackhq/language/ast'; -import { isFutureExpr, isRelationshipField, resolved } from '@zenstackhq/sdk'; +import { isDataModelFieldReference, isFutureExpr, isRelationshipField, resolved } from '@zenstackhq/sdk'; import { ValidationAcceptor, streamAst } from 'langium'; import pluralize from 'pluralize'; import { AstValidator } from '../types'; @@ -151,6 +151,19 @@ export default class AttributeApplicationValidator implements AstValidator isDataModelFieldReference(node) && isDataModel(node.$resolvedType?.decl) + ) + ) { + accept('error', `\`@@validate\` condition cannot use relation fields`, { node: condition }); + } + } + private validatePolicyKinds( kind: string, candidates: string[], diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 7644521b8..7d8c4dd95 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -3,16 +3,17 @@ import { Expression, ExpressionType, isDataModel, + isDataModelAttribute, + isDataModelField, isEnum, + isLiteralExpr, isMemberAccessExpr, isNullExpr, isThisExpr, - isDataModelField, - isLiteralExpr, } from '@zenstackhq/language/ast'; import { isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; -import { ValidationAcceptor } from 'langium'; -import { getContainingDataModel, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; +import { AstNode, ValidationAcceptor } from 'langium'; +import { findUpAst, getContainingDataModel, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; @@ -123,6 +124,17 @@ export default class ExpressionValidator implements AstValidator { case '==': case '!=': { + if (this.isInValidationContext(expr)) { + // in validation context, all fields are optional, so we should allow + // comparing any field against null + if ( + (isDataModelFieldReference(expr.left) && isNullExpr(expr.right)) || + (isDataModelFieldReference(expr.right) && isNullExpr(expr.left)) + ) { + return; + } + } + if (!!expr.left.$resolvedType?.array !== !!expr.right.$resolvedType?.array) { accept('error', 'incompatible operand types', { node: expr }); break; @@ -132,18 +144,24 @@ export default class ExpressionValidator implements AstValidator { // - foo.user.id == userId // except: // - future().userId == userId - if(isMemberAccessExpr(expr.left) && isDataModelField(expr.left.member.ref) && expr.left.member.ref.$container != getContainingDataModel(expr) - || isMemberAccessExpr(expr.right) && isDataModelField(expr.right.member.ref) && expr.right.member.ref.$container != getContainingDataModel(expr)) - { + if ( + (isMemberAccessExpr(expr.left) && + isDataModelField(expr.left.member.ref) && + expr.left.member.ref.$container != getContainingDataModel(expr)) || + (isMemberAccessExpr(expr.right) && + isDataModelField(expr.right.member.ref) && + expr.right.member.ref.$container != getContainingDataModel(expr)) + ) { // foo.user.id == auth().id // foo.user.id == "123" // foo.user.id == null // foo.user.id == EnumValue - if(!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) - { - accept('error', 'comparison between fields of different models are not supported', { node: expr }); - break; - } + if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { + accept('error', 'comparison between fields of different models are not supported', { + node: expr, + }); + break; + } } if ( @@ -205,14 +223,17 @@ export default class ExpressionValidator implements AstValidator { } } + private isInValidationContext(node: AstNode) { + return findUpAst(node, (n) => isDataModelAttribute(n) && n.decl.$refText === '@@validate'); + } private isNotModelFieldExpr(expr: Expression) { - return isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + return ( + isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + ); } private isAuthOrAuthMemberAccess(expr: Expression) { return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand)); } - } - diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 661f14b26..348752fae 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -157,7 +157,6 @@ export function isCollectionPredicate(node: AstNode): node is BinaryExpr { return isBinaryExpr(node) && ['?', '!', '^'].includes(node.operator); } - export function getContainingDataModel(node: Expression): DataModel | undefined { let curr: AstNode | undefined = node.$container; while (curr) { @@ -167,4 +166,18 @@ export function getContainingDataModel(node: Expression): DataModel | undefined curr = curr.$container; } return undefined; -} \ No newline at end of file +} + +/** + * Walk upward from the current AST node to find the first node that satisfies the predicate. + */ +export function findUpAst(node: AstNode, predicate: (node: AstNode) => boolean): AstNode | undefined { + let curr: AstNode | undefined = node; + while (curr) { + if (predicate(curr)) { + return curr; + } + curr = curr.$container; + } + return undefined; +} diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index ee63b718a..16b8637d6 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -7,6 +7,7 @@ import { InvocationExpr, isDataModel, isEnumField, + isNullExpr, isThisExpr, LiteralExpr, MemberAccessExpr, @@ -168,13 +169,13 @@ export class TypeScriptExpressionTransformer { const max = getLiteral(args[2]); let result: string; if (min === undefined) { - result = `(${field}?.length > 0)`; + result = this.ensureBooleanTernary(field, `${field}?.length > 0`); } else if (max === undefined) { - result = `(${field}?.length >= ${min})`; + result = this.ensureBooleanTernary(field, `${field}?.length >= ${min}`); } else { - result = `(${field}?.length >= ${min} && ${field}?.length <= ${max})`; + result = this.ensureBooleanTernary(field, `${field}?.length >= ${min} && ${field}?.length <= ${max}`); } - return this.ensureBoolean(result); + return result; } @func('contains') @@ -208,25 +209,25 @@ export class TypeScriptExpressionTransformer { private _regex(args: Expression[]) { const field = this.transform(args[0], false); const pattern = getLiteral(args[1]); - return `new RegExp(${JSON.stringify(pattern)}).test(${field})`; + return this.ensureBoolean(`${field}?.match(new RegExp(${JSON.stringify(pattern)}))`); } @func('email') private _email(args: Expression[]) { const field = this.transform(args[0], false); - return `z.string().email().safeParse(${field}).success`; + return this.ensureBooleanTernary(field, `z.string().email().safeParse(${field}).success`); } @func('datetime') private _datetime(args: Expression[]) { const field = this.transform(args[0], false); - return `z.string().datetime({ offset: true }).safeParse(${field}).success`; + return this.ensureBooleanTernary(field, `z.string().datetime({ offset: true }).safeParse(${field}).success`); } @func('url') private _url(args: Expression[]) { const field = this.transform(args[0], false); - return `z.string().url().safeParse(${field}).success`; + return this.ensureBooleanTernary(field, `z.string().url().safeParse(${field}).success`); } @func('has') @@ -239,22 +240,25 @@ export class TypeScriptExpressionTransformer { @func('hasEvery') private _hasEvery(args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); - const result = `${this.transform(args[1], normalizeUndefined)}?.every((item) => ${field}?.includes(item))`; - return this.ensureBoolean(result); + return this.ensureBooleanTernary( + field, + `${this.transform(args[1], normalizeUndefined)}?.every((item) => ${field}?.includes(item))` + ); } @func('hasSome') private _hasSome(args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); - const result = `${this.transform(args[1], normalizeUndefined)}?.some((item) => ${field}?.includes(item))`; - return this.ensureBoolean(result); + return this.ensureBooleanTernary( + field, + `${this.transform(args[1], normalizeUndefined)}?.some((item) => ${field}?.includes(item))` + ); } @func('isEmpty') private _isEmpty(args: Expression[]) { const field = this.transform(args[0], false); - const result = `(!${field} || ${field}?.length === 0)`; - return this.ensureBoolean(result); + return `(!${field} || ${field}?.length === 0)`; } private ensureBoolean(expr: string) { @@ -263,7 +267,17 @@ export class TypeScriptExpressionTransformer { // as boolean true return `(${expr} ?? true)`; } else { - return `(${expr} ?? false)`; + return `((${expr}) ?? false)`; + } + } + + private ensureBooleanTernary(predicate: string, value: string) { + if (this.options.context === ExpressionContext.ValidationRule) { + // all fields are optional in a validation context, so we treat undefined + // as boolean true + return `((${predicate}) !== undefined ? (${value}): true)`; + } else { + return `((${predicate}) !== undefined ? (${value}): false)`; } } @@ -315,7 +329,7 @@ export class TypeScriptExpressionTransformer { isDataModelFieldReference(expr.operand) ) { // in a validation context, we treat unary involving undefined as boolean true - result = `(${operand} !== undefined ? (${result}): true)`; + result = this.ensureBooleanTernary(operand, result); } return result; } @@ -336,21 +350,39 @@ export class TypeScriptExpressionTransformer { let _default = `(${left} ${expr.operator} ${right})`; if (this.options.context === ExpressionContext.ValidationRule) { - // in a validation context, we treat binary involving undefined as boolean true - if (isDataModelFieldReference(expr.left)) { - _default = `(${left} !== undefined ? (${_default}): true)`; - } - if (isDataModelFieldReference(expr.right)) { - _default = `(${right} !== undefined ? (${_default}): true)`; + const nullComparison = this.extractNullComparison(expr); + if (nullComparison) { + // null comparison covers both null and undefined + const { fieldRef } = nullComparison; + const field = this.transform(fieldRef, normalizeUndefined); + if (expr.operator === '==') { + _default = `(${field} === null || ${field} === undefined)`; + } else if (expr.operator === '!=') { + _default = `(${field} !== null && ${field} !== undefined)`; + } + } else { + // for other comparisons, in a validation context, + // we treat binary involving undefined as boolean true + if (isDataModelFieldReference(expr.left)) { + _default = this.ensureBooleanTernary(left, _default); + } + if (isDataModelFieldReference(expr.right)) { + _default = this.ensureBooleanTernary(right, _default); + } } } return match(expr.operator) - .with('in', () => - this.ensureBoolean( - `${this.transform(expr.right, false)}?.includes(${this.transform(expr.left, normalizeUndefined)})` - ) - ) + .with('in', () => { + const left = `${this.transform(expr.left, normalizeUndefined)}`; + let result = this.ensureBoolean(`${this.transform(expr.right, false)}?.includes(${left})`); + + if (this.options.context === ExpressionContext.ValidationRule) { + // in a validation context, we treat binary involving undefined as boolean true + result = this.ensureBooleanTernary(left, result); + } + return result; + }) .with(P.union('==', '!='), () => { if (isThisExpr(expr.left) || isThisExpr(expr.right)) { // map equality comparison with `this` to id comparison @@ -376,6 +408,20 @@ export class TypeScriptExpressionTransformer { .otherwise(() => _default); } + private extractNullComparison(expr: BinaryExpr) { + if (expr.operator !== '==' && expr.operator !== '!=') { + return undefined; + } + + if (isDataModelFieldReference(expr.left) && isNullExpr(expr.right)) { + return { fieldRef: expr.left, nullExpr: expr.right }; + } else if (isDataModelFieldReference(expr.right) && isNullExpr(expr.left)) { + return { fieldRef: expr.right, nullExpr: expr.left }; + } else { + return undefined; + } + } + private collectionPredicate(expr: BinaryExpr, operator: '?' | '!' | '^', normalizeUndefined: boolean) { const operand = this.transform(expr.left, normalizeUndefined); const innerTransformer = new TypeScriptExpressionTransformer({ diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index f9baa0de9..7121c9589 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -1178,7 +1178,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user?.roles?.includes(Role.ADMIN)??false)?{AND:[]}:{OR:[]}`, + `((user?.roles?.includes(Role.ADMIN))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1205,7 +1205,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user?.email?.includes('test')??false)?{AND:[]}:{OR:[]}`, + `((user?.email?.includes('test'))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1218,7 +1218,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user?.email?.toLowerCase().includes('test'?.toLowerCase())??false)?{AND:[]}:{OR:[]}`, + `((user?.email?.toLowerCase().includes('test'?.toLowerCase()))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1231,7 +1231,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user?.email?.startsWith('test')??false)?{AND:[]}:{OR:[]}`, + `((user?.email?.startsWith('test'))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1244,7 +1244,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user?.email?.endsWith('test')??false)?{AND:[]}:{OR:[]}`, + `((user?.email?.endsWith('test'))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1257,7 +1257,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user?.roles?.includes(Role.ADMIN)??false)?{AND:[]}:{OR:[]}`, + `((user?.roles?.includes(Role.ADMIN))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1270,7 +1270,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `([Role.ADMIN,Role.USER]?.every((item)=>user?.roles?.includes(item))??false)?{AND:[]}:{OR:[]}`, + `((user?.roles)!==undefined?([Role.ADMIN,Role.USER]?.every((item)=>user?.roles?.includes(item))):false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1283,7 +1283,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `([Role.USER,Role.ADMIN]?.some((item)=>user?.roles?.includes(item))??false)?{AND:[]}:{OR:[]}`, + `((user?.roles)!==undefined?([Role.USER,Role.ADMIN]?.some((item)=>user?.roles?.includes(item))):false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1296,7 +1296,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `((!user?.roles||user?.roles?.length===0)??false)?{AND:[]}:{OR:[]}`, + `(!user?.roles||user?.roles?.length===0)?{AND:[]}:{OR:[]}`, userInit ); }); diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 8eb674b2f..eb8a6065b 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -227,7 +227,7 @@ describe('Attribute tests', () => { `); await loadModel(` - ${ prelude } + ${prelude} model A { id String @id x String @@ -927,17 +927,6 @@ describe('Attribute tests', () => { @@validate(hasSome(es, [E1])) @@validate(hasEvery(es, [E1])) @@validate(isEmpty(es)) - - @@validate(n.e in [E1, E2]) - @@validate(n.i in [1, 2]) - @@validate(contains(n.s, 'a')) - @@validate(contains(n.s, 'a', true)) - @@validate(startsWith(n.s, 'a')) - @@validate(endsWith(n.s, 'a')) - @@validate(has(n.es, E1)) - @@validate(hasSome(n.es, [E1])) - @@validate(hasEvery(n.es, [E1])) - @@validate(isEmpty(n.es)) } `); @@ -1000,26 +989,21 @@ describe('Attribute tests', () => { expect( await loadModelWithError(` ${prelude} - model N { - id String @id - m M @relation(fields: [mId], references: [id]) - mId String - } model M { id String @id - n N? - @@validate(n in [1]) + x Int + @@validate(has(x, 1)) } `) - ).toContain('left operand of "in" must be of scalar type'); + ).toContain('argument is not assignable to parameter'); expect( await loadModelWithError(` ${prelude} model M { id String @id - x Int - @@validate(has(x, 1)) + x Int[] + @@validate(hasSome(x, 1)) } `) ).toContain('argument is not assignable to parameter'); @@ -1029,11 +1013,18 @@ describe('Attribute tests', () => { ${prelude} model M { id String @id - x Int[] - @@validate(hasSome(x, 1)) + n N? + @@validate(n.value > 0) + } + + model N { + id String @id + value Int + m M @relation(fields: [mId], references: [id]) + mId String @unique } `) - ).toContain('argument is not assignable to parameter'); + ).toContain('`@@validate` condition cannot use relation fields'); }); it('auth function check', async () => { diff --git a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts index d34c7183b..d54913bb3 100644 --- a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts @@ -1,5 +1,5 @@ import { CrudFailureReason, isPrismaClientKnownRequestError } from '@zenstackhq/runtime'; -import { FullDbClientContract, loadSchema, run } from '@zenstackhq/testtools'; +import { FullDbClientContract, createPostgresDb, dropPostgresDb, loadSchema, run } from '@zenstackhq/testtools'; describe('With Policy: field validation', () => { let db: FullDbClientContract; @@ -685,7 +685,7 @@ describe('With Policy: model-level validation', () => { await expect(db.model.create({ data: {} })).toResolveTruthy(); }); - it('optionality with comparison', async () => { + it('optionality with binary', async () => { const { enhance } = await loadSchema(` model Model { id Int @id @default(autoincrement()) @@ -705,6 +705,56 @@ describe('With Policy: model-level validation', () => { await expect(db.model.create({ data: {} })).toResolveTruthy(); }); + it('optionality with in operator lhs', async () => { + const { enhance } = await loadSchema(` + model Model { + id Int @id @default(autoincrement()) + x String? + + @@validate(x in ['foo', 'bar']) + @@allow('all', true) + } + `); + + const db = enhance(); + + await expect(db.model.create({ data: { x: 'hello' } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { x: 'foo' } })).toResolveTruthy(); + await expect(db.model.create({ data: {} })).toResolveTruthy(); + }); + + it('optionality with in operator rhs', async () => { + let prisma; + try { + const dbUrl = await createPostgresDb('field-validation-in-operator'); + const r = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + x String[] + + @@validate('foo' in x) + @@allow('all', true) + } + `, + { + provider: 'postgresql', + dbUrl, + } + ); + + const db = r.enhance(); + prisma = r.prisma; + + await expect(db.model.create({ data: { x: ['hello'] } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { x: ['foo', 'bar'] } })).toResolveTruthy(); + await expect(db.model.create({ data: {} })).toResolveTruthy(); + } finally { + await prisma.$disconnect(); + await dropPostgresDb('field-validation-in-operator'); + } + }); + it('optionality with complex expression', async () => { const { enhance } = await loadSchema(` model Model { @@ -762,4 +812,129 @@ describe('With Policy: model-level validation', () => { await expect(db.model.update({ where: { id: 1 }, data: { y: 1 } })).toResolveTruthy(); await expect(db.model.update({ where: { id: 1 }, data: {} })).toResolveTruthy(); }); + + it('optionality with scalar functions', async () => { + const { enhance } = await loadSchema(` + model Model { + id Int @id @default(autoincrement()) + s String + e String + u String + d String + + @@validate( + length(s, 1, 5) && + contains(s, 'b') && + startsWith(s, 'a') && + endsWith(s, 'c') && + regex(s, '^[0-9a-zA-Z]*$'), + 'invalid s') + @@validate(email(e), 'invalid e') + @@validate(url(u), 'invalid u') + @@validate(datetime(d), 'invalid d') + + @@allow('all', true) + } + `); + + const db = enhance(); + + await expect( + db.model.create({ + data: { + id: 1, + s: 'a1b2c', + e: 'a@bcd.com', + u: 'https://www.zenstack.dev', + d: '2024-01-01T00:00:00.000Z', + }, + }) + ).toResolveTruthy(); + + await expect(db.model.update({ where: { id: 1 }, data: {} })).toResolveTruthy(); + + await expect(db.model.update({ where: { id: 1 }, data: { s: 'a2b3c' } })).toResolveTruthy(); + await expect(db.model.update({ where: { id: 1 }, data: { s: 'c2b3c' } })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { s: 'a1b2c3' } })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { s: 'aaccc' } })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { s: 'a1b2d' } })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { s: 'a1-3c' } })).toBeRejectedByPolicy(); + + await expect(db.model.update({ where: { id: 1 }, data: { e: 'b@def.com' } })).toResolveTruthy(); + await expect(db.model.update({ where: { id: 1 }, data: { e: 'xyz' } })).toBeRejectedByPolicy(); + + await expect(db.model.update({ where: { id: 1 }, data: { u: 'https://zenstack.dev/docs' } })).toResolveTruthy(); + await expect(db.model.update({ where: { id: 1 }, data: { u: 'xyz' } })).toBeRejectedByPolicy(); + + await expect(db.model.update({ where: { id: 1 }, data: { d: '2025-01-01T00:00:00.000Z' } })).toResolveTruthy(); + await expect(db.model.update({ where: { id: 1 }, data: { d: 'xyz' } })).toBeRejectedByPolicy(); + }); + + it('optionality with array functions', async () => { + let prisma; + try { + const dbUrl = await createPostgresDb('field-validation-array-funcs'); + const r = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + x String[] + y Int[] + + @@validate( + has(x, 'a') && + hasEvery(x, ['a', 'b']) && + hasSome(x, ['x', 'y']) && + (y == null || !isEmpty(y)) + ) + + @@allow('all', true) + } + `, + { + provider: 'postgresql', + dbUrl, + } + ); + + const db = r.enhance(); + prisma = r.prisma; + + await expect(db.model.create({ data: { id: 1, x: ['a', 'b', 'x'], y: [1] } })).toResolveTruthy(); + await expect(db.model.update({ where: { id: 1 }, data: {} })).toResolveTruthy(); + await expect(db.model.update({ where: { id: 1 }, data: { x: ['b', 'x'] } })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { x: ['a', 'b'] } })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { y: [] } })).toBeRejectedByPolicy(); + } finally { + await prisma.$disconnect(); + await dropPostgresDb('field-validation-array-funcs'); + } + }); + + it('null comparison', async () => { + const { enhance } = await loadSchema(` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int + + @@validate(x == null || !(x <= 0)) + @@validate(y != null && !(y > 1)) + + @@allow('all', true) + } + `); + + const db = enhance(); + + await expect(db.model.create({ data: { id: 1, x: 1 } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { id: 1, x: 1, y: 2 } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { id: 1, x: 0, y: 0 } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { id: 1, x: 1, y: 0 } })).toResolveTruthy(); + + await expect(db.model.update({ where: { id: 1 }, data: {} })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { y: 2 } })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: 1 }, data: { y: 1 } })).toResolveTruthy(); + await expect(db.model.update({ where: { id: 1 }, data: { x: 2, y: 1 } })).toResolveTruthy(); + }); }); From b54405265d3200b2a099653ced3756b08a51b7b9 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:01:21 -0800 Subject: [PATCH 2/4] fixes --- .../schema/src/utils/typescript-expression-transformer.ts | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index 16b8637d6..ec6c114cc 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -209,7 +209,7 @@ export class TypeScriptExpressionTransformer { private _regex(args: Expression[]) { const field = this.transform(args[0], false); const pattern = getLiteral(args[1]); - return this.ensureBoolean(`${field}?.match(new RegExp(${JSON.stringify(pattern)}))`); + return this.ensureBooleanTernary(field, `new RegExp(${JSON.stringify(pattern)}).test(${field})`); } @func('email') @@ -375,11 +375,12 @@ export class TypeScriptExpressionTransformer { return match(expr.operator) .with('in', () => { const left = `${this.transform(expr.left, normalizeUndefined)}`; - let result = this.ensureBoolean(`${this.transform(expr.right, false)}?.includes(${left})`); - + let result = `${this.transform(expr.right, false)}?.includes(${left})`; if (this.options.context === ExpressionContext.ValidationRule) { // in a validation context, we treat binary involving undefined as boolean true result = this.ensureBooleanTernary(left, result); + } else { + result = this.ensureBoolean(result); } return result; }) From 7abede1d1ed29257d92371f23b7a2a42f2da4e87 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:20:24 -0800 Subject: [PATCH 3/4] fixes --- .../schema/src/utils/typescript-expression-transformer.ts | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index ec6c114cc..07dd78517 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -375,10 +375,11 @@ export class TypeScriptExpressionTransformer { return match(expr.operator) .with('in', () => { const left = `${this.transform(expr.left, normalizeUndefined)}`; - let result = `${this.transform(expr.right, false)}?.includes(${left})`; + const right = `${this.transform(expr.right, false)}`; + let result = `${right}?.includes(${left})`; if (this.options.context === ExpressionContext.ValidationRule) { // in a validation context, we treat binary involving undefined as boolean true - result = this.ensureBooleanTernary(left, result); + result = this.ensureBooleanTernary(left, this.ensureBooleanTernary(right, result)); } else { result = this.ensureBoolean(result); } From 8294e49c0f742daed28a5c632af8c125062b1137 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:52:54 -0800 Subject: [PATCH 4/4] fixes --- .../typescript-expression-transformer.ts | 59 +++++++++++++------ 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index 07dd78517..ec4f89fcb 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -5,10 +5,6 @@ import { DataModel, Expression, InvocationExpr, - isDataModel, - isEnumField, - isNullExpr, - isThisExpr, LiteralExpr, MemberAccessExpr, NullExpr, @@ -17,9 +13,15 @@ import { StringLiteral, ThisExpr, UnaryExpr, + isArrayExpr, + isDataModel, + isEnumField, + isLiteralExpr, + isNullExpr, + isThisExpr, } from '@zenstackhq/language/ast'; import { ExpressionContext, getLiteral, isDataModelFieldReference, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk'; -import { match, P } from 'ts-pattern'; +import { P, match } from 'ts-pattern'; import { getIdFields } from './ast-utils'; export class TypeScriptExpressionTransformerError extends Error { @@ -169,11 +171,15 @@ export class TypeScriptExpressionTransformer { const max = getLiteral(args[2]); let result: string; if (min === undefined) { - result = this.ensureBooleanTernary(field, `${field}?.length > 0`); + result = this.ensureBooleanTernary(args[0], field, `${field}?.length > 0`); } else if (max === undefined) { - result = this.ensureBooleanTernary(field, `${field}?.length >= ${min}`); + result = this.ensureBooleanTernary(args[0], field, `${field}?.length >= ${min}`); } else { - result = this.ensureBooleanTernary(field, `${field}?.length >= ${min} && ${field}?.length <= ${max}`); + result = this.ensureBooleanTernary( + args[0], + field, + `${field}?.length >= ${min} && ${field}?.length <= ${max}` + ); } return result; } @@ -209,25 +215,29 @@ export class TypeScriptExpressionTransformer { private _regex(args: Expression[]) { const field = this.transform(args[0], false); const pattern = getLiteral(args[1]); - return this.ensureBooleanTernary(field, `new RegExp(${JSON.stringify(pattern)}).test(${field})`); + return this.ensureBooleanTernary(args[0], field, `new RegExp(${JSON.stringify(pattern)}).test(${field})`); } @func('email') private _email(args: Expression[]) { const field = this.transform(args[0], false); - return this.ensureBooleanTernary(field, `z.string().email().safeParse(${field}).success`); + return this.ensureBooleanTernary(args[0], field, `z.string().email().safeParse(${field}).success`); } @func('datetime') private _datetime(args: Expression[]) { const field = this.transform(args[0], false); - return this.ensureBooleanTernary(field, `z.string().datetime({ offset: true }).safeParse(${field}).success`); + return this.ensureBooleanTernary( + args[0], + field, + `z.string().datetime({ offset: true }).safeParse(${field}).success` + ); } @func('url') private _url(args: Expression[]) { const field = this.transform(args[0], false); - return this.ensureBooleanTernary(field, `z.string().url().safeParse(${field}).success`); + return this.ensureBooleanTernary(args[0], field, `z.string().url().safeParse(${field}).success`); } @func('has') @@ -241,6 +251,7 @@ export class TypeScriptExpressionTransformer { private _hasEvery(args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( + args[0], field, `${this.transform(args[1], normalizeUndefined)}?.every((item) => ${field}?.includes(item))` ); @@ -250,6 +261,7 @@ export class TypeScriptExpressionTransformer { private _hasSome(args: Expression[], normalizeUndefined: boolean) { const field = this.transform(args[0], false); return this.ensureBooleanTernary( + args[0], field, `${this.transform(args[1], normalizeUndefined)}?.some((item) => ${field}?.includes(item))` ); @@ -271,13 +283,18 @@ export class TypeScriptExpressionTransformer { } } - private ensureBooleanTernary(predicate: string, value: string) { + private ensureBooleanTernary(predicate: Expression, transformedPredicate: string, value: string) { + if (isLiteralExpr(predicate) || isArrayExpr(predicate)) { + // these are never undefined + return value; + } + if (this.options.context === ExpressionContext.ValidationRule) { // all fields are optional in a validation context, so we treat undefined // as boolean true - return `((${predicate}) !== undefined ? (${value}): true)`; + return `((${transformedPredicate}) !== undefined ? (${value}): true)`; } else { - return `((${predicate}) !== undefined ? (${value}): false)`; + return `((${transformedPredicate}) !== undefined ? (${value}): false)`; } } @@ -329,7 +346,7 @@ export class TypeScriptExpressionTransformer { isDataModelFieldReference(expr.operand) ) { // in a validation context, we treat unary involving undefined as boolean true - result = this.ensureBooleanTernary(operand, result); + result = this.ensureBooleanTernary(expr.operand, operand, result); } return result; } @@ -364,10 +381,10 @@ export class TypeScriptExpressionTransformer { // for other comparisons, in a validation context, // we treat binary involving undefined as boolean true if (isDataModelFieldReference(expr.left)) { - _default = this.ensureBooleanTernary(left, _default); + _default = this.ensureBooleanTernary(expr.left, left, _default); } if (isDataModelFieldReference(expr.right)) { - _default = this.ensureBooleanTernary(right, _default); + _default = this.ensureBooleanTernary(expr.right, right, _default); } } } @@ -379,7 +396,11 @@ export class TypeScriptExpressionTransformer { let result = `${right}?.includes(${left})`; if (this.options.context === ExpressionContext.ValidationRule) { // in a validation context, we treat binary involving undefined as boolean true - result = this.ensureBooleanTernary(left, this.ensureBooleanTernary(right, result)); + result = this.ensureBooleanTernary( + expr.left, + left, + this.ensureBooleanTernary(expr.right, right, result) + ); } else { result = this.ensureBoolean(result); }