Skip to content

Commit

Permalink
fix: cross-model field comparison validation issue (#1509)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Jun 14, 2024
1 parent 665f9b3 commit 9c7527f
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
DataModelAttribute,
Expression,
ExpressionType,
isArrayExpr,
isDataModel,
isDataModelAttribute,
isDataModelField,
Expand Down Expand Up @@ -82,6 +83,8 @@ export default class ExpressionValidator implements AstValidator<Expression> {
node: expr.right,
});
}

this.validateCrossModelFieldComparison(expr, accept);
break;
}

Expand Down Expand Up @@ -137,6 +140,7 @@ export default class ExpressionValidator implements AstValidator<Expression> {
accept('error', 'incompatible operand types', { node: expr });
}

this.validateCrossModelFieldComparison(expr, accept);
break;
}

Expand All @@ -158,43 +162,8 @@ export default class ExpressionValidator implements AstValidator<Expression> {
break;
}

// not supported:
// - foo.a == bar
// - foo.user.id == userId
// except:
// - future().userId == userId
if (
(isMemberAccessExpr(expr.left) &&
isDataModelField(expr.left.member.ref) &&
expr.left.member.ref.$container != getContainingDataModel(expr)) ||
(isMemberAccessExpr(expr.right) &&
isDataModelField(expr.right.member.ref) &&
expr.right.member.ref.$container != getContainingDataModel(expr))
) {
// foo.user.id == auth().id
// foo.user.id == "123"
// foo.user.id == null
// foo.user.id == EnumValue
if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) {
const containingPolicyAttr = findUpAst(
expr,
(node) => isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText)
) as DataModelAttribute | undefined;

if (containingPolicyAttr) {
const operation = getAttributeArgLiteral<string>(containingPolicyAttr, 'operation');
if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) {
accept(
'error',
'comparison between fields of different models is not supported in model-level "read" rules',
{
node: expr,
}
);
break;
}
}
}
if (!this.validateCrossModelFieldComparison(expr, accept)) {
break;
}

if (
Expand Down Expand Up @@ -262,6 +231,49 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}
}

private validateCrossModelFieldComparison(expr: BinaryExpr, accept: ValidationAcceptor) {
// not supported in "read" rules:
// - foo.a == bar
// - foo.user.id == userId
// except:
// - future().userId == userId
if (
(isMemberAccessExpr(expr.left) &&
isDataModelField(expr.left.member.ref) &&
expr.left.member.ref.$container != getContainingDataModel(expr)) ||
(isMemberAccessExpr(expr.right) &&
isDataModelField(expr.right.member.ref) &&
expr.right.member.ref.$container != getContainingDataModel(expr))
) {
// foo.user.id == auth().id
// foo.user.id == "123"
// foo.user.id == null
// foo.user.id == EnumValue
if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) {
const containingPolicyAttr = findUpAst(
expr,
(node) => isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText)
) as DataModelAttribute | undefined;

if (containingPolicyAttr) {
const operation = getAttributeArgLiteral<string>(containingPolicyAttr, 'operation');
if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) {
accept(
'error',
'comparison between fields of different models is not supported in model-level "read" rules',
{
node: expr,
}
);
return false;
}
}
}
}

return true;
}

private validateCollectionPredicate(expr: BinaryExpr, accept: ValidationAcceptor) {
if (!expr.$resolvedType) {
accept('error', 'collection predicate can only be used on an array of model type', { node: expr });
Expand All @@ -273,9 +285,18 @@ export default class ExpressionValidator implements AstValidator<Expression> {
return findUpAst(node, (n) => isDataModelAttribute(n) && n.decl.$refText === '@@validate');
}

private isNotModelFieldExpr(expr: Expression) {
private isNotModelFieldExpr(expr: Expression): boolean {
return (
isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr)
// literal
isLiteralExpr(expr) ||
// enum field
isEnumFieldReference(expr) ||
// null
isNullExpr(expr) ||
// `auth()` access
this.isAuthOrAuthMemberAccess(expr) ||
// array
(isArrayExpr(expr) && expr.items.every((item) => this.isNotModelFieldExpr(item)))
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,37 @@ describe('Attribute tests', () => {
`)
).toContain('comparison between fields of different models is not supported in model-level "read" rules');

expect(
await loadModelWithError(`
${prelude}
model User {
id Int @id
lists List[]
todos Todo[]
value Int
}
model List {
id Int @id
user User @relation(fields: [userId], references: [id])
userId Int
todos Todo[]
}
model Todo {
id Int @id
user User @relation(fields: [userId], references: [id])
userId Int
list List @relation(fields: [listId], references: [id])
listId Int
value Int
@@allow('all', list.user.value > value)
}
`)
).toContain('comparison between fields of different models is not supported in model-level "read" rules');

expect(
await loadModel(`
${prelude}
Expand Down
34 changes: 8 additions & 26 deletions tests/regression/tests/issue-1506.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import { loadSchema } from '@zenstackhq/testtools';
import { loadModelWithError } from '@zenstackhq/testtools';
describe('issue 1506', () => {
it('regression', async () => {
const { prisma, enhance } = await loadSchema(
`
await expect(
loadModelWithError(
`
model A {
id Int @id @default(autoincrement())
value Int
Expand All @@ -29,29 +30,10 @@ describe('issue 1506', () => {
@@allow('read', true)
}
`,
{ preserveTsFiles: true, logPrismaQuery: true }
`
)
).resolves.toContain(
'comparison between fields of different models is not supported in model-level "read" rules'
);

await prisma.a.create({
data: {
value: 3,
b: {
create: {
value: 2,
c: {
create: {
value: 1,
},
},
},
},
},
});

const db = enhance();
const read = await db.a.findMany({ include: { b: true } });
expect(read).toHaveLength(1);
expect(read[0].b).toBeTruthy();
});
});

0 comments on commit 9c7527f

Please sign in to comment.