Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 committed Jan 2, 2025
1 parent ff86ce0 commit d6618c9
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 125 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import {
hasAttribute,
hasValidationAttributes,
isAuthInvocation,
isDelegateModel,
isForeignKeyField,
saveSourceFile,
} from '@zenstackhq/sdk';
Expand Down Expand Up @@ -455,44 +454,38 @@ export class PolicyGenerator {
writer: CodeBlockWriter,
sourceFile: SourceFile
) {
const isDelegate = isDelegateModel(model);

if (!isDelegate) {
// handle cases where a constant function can be used
// note that this doesn't apply to delegate models because
// all concrete models inheriting it need to be considered

if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func: FunctionDeclaration;
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
} else {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
}
writer.write(`guard: ${func.getName()!},`);
return;
}
// first handle several cases where a constant function can be used

if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func: FunctionDeclaration;
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
} else {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
}
writer.write(`guard: ${func.getName()!},`);
return;
}

if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
// constant policy
const func = generateConstantQueryGuardFunction(
sourceFile,
model,
kind,
policies[kind as keyof typeof policies] as boolean
);
writer.write(`guard: ${func.getName()!},`);
return;
}
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
}

if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
// constant policy
const func = generateConstantQueryGuardFunction(
sourceFile,
model,
kind,
policies[kind as keyof typeof policies] as boolean
);
writer.write(`guard: ${func.getName()!},`);
return;
}

// generate a policy function that evaluates a partial prisma query
Expand Down
71 changes: 2 additions & 69 deletions packages/schema/src/plugins/enhancer/policy/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { DELEGATE_AUX_RELATION_PREFIX, type PolicyKind, type PolicyOperationKind } from '@zenstackhq/runtime';
import { type PolicyKind, type PolicyOperationKind } from '@zenstackhq/runtime';
import {
ExpressionContext,
PluginError,
Expand All @@ -15,7 +15,6 @@ import {
getQueryGuardFunctionName,
isAuthInvocation,
isDataModelFieldReference,
isDelegateModel,
isEnumFieldReference,
isFromStdlib,
isFutureExpr,
Expand All @@ -40,16 +39,9 @@ import {
} from '@zenstackhq/sdk/ast';
import deepmerge from 'deepmerge';
import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium';
import { lowerCaseFirst } from 'lower-case-first';
import { SourceFile, WriterFunction } from 'ts-morph';
import { name } from '..';
import {
getConcreteModels,
getDiscriminatorField,
isCheckInvocation,
isCollectionPredicate,
isFutureInvocation,
} from '../../../utils/ast-utils';
import { isCheckInvocation, isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils';
import { ExpressionWriter, FALSE, TRUE } from './expression-writer';

/**
Expand Down Expand Up @@ -311,10 +303,6 @@ export function generateQueryGuardFunction(
forField?: DataModelField,
fieldOverride = false
) {
if (isDelegateModel(model) && !forField) {
return generateDelegateQueryGuardFunction(sourceFile, model, kind);
}

const statements: (string | WriterFunction)[] = [];
const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule));
const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule));
Expand Down Expand Up @@ -449,61 +437,6 @@ export function generateQueryGuardFunction(
return func;
}

function generateDelegateQueryGuardFunction(sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind) {
const concreteModels = getConcreteModels(model);

const discriminator = getDiscriminatorField(model);
if (!discriminator) {
throw new PluginError(name, `Model '${model.name}' does not have a discriminator field`);
}

const func = sourceFile.addFunction({
name: getQueryGuardFunctionName(model, undefined, false, kind),
returnType: 'any',
parameters: [
{
name: 'context',
type: 'QueryContext',
},
{
// for generating field references used by field comparison in the same model
name: 'db',
type: 'CrudContract',
},
],
statements: (writer) => {
writer.write('return ');
if (concreteModels.length === 0) {
writer.write(TRUE);
} else {
writer.block(() => {
// union all concrete model's guards
writer.writeLine('OR: [');
concreteModels.forEach((concrete) => {
writer.block(() => {
writer.write('AND: [');
// discriminator condition
writer.write(`{ ${discriminator.name}: '${concrete.name}' },`);
// concrete model guard
writer.write(
`{ ${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(
concrete.name
)}: ${getQueryGuardFunctionName(concrete, undefined, false, kind)}(context, db) }`
);
writer.writeLine(']');
});
writer.write(',');
});
writer.writeLine(']');
});
}
writer.write(';');
},
});

return func;
}

export function generateEntityCheckerFunction(
sourceFile: SourceFile,
model: DataModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,16 @@ describe('Polymorphic Policy Test', () => {
`;

for (const schema of [booleanCondition, booleanExpression]) {
const { enhanceRaw: enhance, prisma } = await loadSchema(schema);
const { enhanceRaw: enhance, prisma } = await loadSchema(schema, { logPrismaQuery: true });

const fullDb = enhance(prisma, undefined, { kinds: ['delegate'] });

const user = await fullDb.user.create({ data: { id: 1 } });
const userDb = enhance(prisma, { user: { id: user.id } }, { kinds: ['delegate', 'policy'] });
const userDb = enhance(
prisma,
{ user: { id: user.id } },
{ kinds: ['delegate', 'policy'], logPrismaQuery: true }
);

// violating Asset create
await expect(
Expand Down Expand Up @@ -588,13 +592,14 @@ describe('Polymorphic Policy Test', () => {
type String
@@delegate(type)
@@allow('all', true)
}
model Post extends Asset {
title String
private Boolean
@@allow('create', true)
@@allow('read', !private)
@@deny('read', private)
}
`
);
Expand All @@ -607,9 +612,9 @@ describe('Polymorphic Policy Test', () => {
});

const db = enhance();
await expect(db.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
asset: null,
});
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
expect(read.asset).toBeTruthy();
expect(read.asset.title).toBeUndefined();
});

it('respects concrete policies when read as base required relation', async () => {
Expand All @@ -636,8 +641,7 @@ describe('Polymorphic Policy Test', () => {
private Boolean
@@deny('read', private)
}
`,
{ logPrismaQuery: true }
`
);

const fullDb = enhance(undefined, { kinds: ['delegate'] });
Expand All @@ -647,6 +651,8 @@ describe('Polymorphic Policy Test', () => {
});

const db = enhance();
await expect(db.user.findUnique({ where: { id: 1 }, include: { asset: true } })).toResolveNull();
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
expect(read).toBeTruthy();
expect(read.asset.title).toBeUndefined();
});
});
11 changes: 0 additions & 11 deletions tests/regression/tests/issue-1930.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ model EntityContent {
}
model Article extends Entity {
private Boolean @default(false)
@@deny('all', private)
}
model ArticleContent extends EntityContent {
Expand Down Expand Up @@ -78,14 +76,5 @@ model OtherContent extends EntityContent {
data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } },
});
await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull();

// private article's contents are not readable
const privateArticle = await fullDb.article.create({
data: { org: { connect: { id: org.id } }, private: true },
});
const content2 = await fullDb.articleContent.create({
data: { body: 'cde', entity: { connect: { id: privateArticle.id } } },
});
await expect(db.articleContent.findUnique({ where: { id: content2.id } })).toResolveNull();
});
});

0 comments on commit d6618c9

Please sign in to comment.