Skip to content

Commit

Permalink
fix(delegate): delegate model's guards are not properly including con…
Browse files Browse the repository at this point in the history
…crete models

fixes #1930
  • Loading branch information
ymc9 committed Dec 31, 2024
1 parent f609c86 commit ff86ce0
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 59 deletions.
17 changes: 3 additions & 14 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import {
isArrayExpr,
isDataModel,
isGeneratorDecl,
isReferenceExpr,
isTypeDef,
type Model,
} from '@zenstackhq/sdk/ast';
Expand All @@ -45,6 +44,7 @@ import {
} from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '..';
import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils';
import { execPackage } from '../../../utils/exec-utils';
import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils';
import { trackPrismaSchemaError } from '../../prisma';
Expand Down Expand Up @@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
this.model.declarations
.filter((d): d is DataModel => isDelegateModel(d))
.forEach((dm) => {
const concreteModels = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
);
const concreteModels = getConcreteModels(dm);
if (concreteModels.length > 0) {
delegateInfo.push([dm, concreteModels]);
}
Expand Down Expand Up @@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const typeName = typeAlias.getName();
const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName);
if (payloadRecord) {
const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]);
const discriminatorDecl = getDiscriminatorField(payloadRecord[0]);
if (discriminatorDecl) {
source = `${payloadRecord[1]
.map(
Expand Down Expand Up @@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
.filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX));
}

private getDiscriminatorField(delegate: DataModel) {
const delegateAttr = getAttribute(delegate, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}

private saveSourceFile(sf: SourceFile) {
if (this.options.preserveTsFiles) {
saveSourceFile(sf);
Expand Down
24 changes: 13 additions & 11 deletions packages/schema/src/plugins/enhancer/policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -839,16 +839,18 @@ export class ExpressionWriter {
operation = this.options.operationContext;
}

this.block(() => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
}
});
this.block(() =>
this.writeFieldCondition(fieldRef, () => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(FALSE);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${targetGuardFunc}(context, db)`);
}
})
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
hasAttribute,
hasValidationAttributes,
isAuthInvocation,
isDelegateModel,
isForeignKeyField,
saveSourceFile,
} from '@zenstackhq/sdk';
Expand Down Expand Up @@ -454,36 +455,44 @@ export class PolicyGenerator {
writer: CodeBlockWriter,
sourceFile: SourceFile
) {
if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func: FunctionDeclaration;
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
} else {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
const isDelegate = isDelegateModel(model);

if (!isDelegate) {
// handle cases where a constant function can be used
// note that this doesn't apply to delegate models because
// all concrete models inheriting it need to be considered

if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
let func: FunctionDeclaration;
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
} else {
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
}
writer.write(`guard: ${func.getName()!},`);
return;
}
writer.write(`guard: ${func.getName()!},`);
return;
}

if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
}
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
// no 'postUpdate' rule, always allow
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
writer.write(`guard: ${func.getName()},`);
return;
}

if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
// constant policy
const func = generateConstantQueryGuardFunction(
sourceFile,
model,
kind,
policies[kind as keyof typeof policies] as boolean
);
writer.write(`guard: ${func.getName()!},`);
return;
if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
// constant policy
const func = generateConstantQueryGuardFunction(
sourceFile,
model,
kind,
policies[kind as keyof typeof policies] as boolean
);
writer.write(`guard: ${func.getName()!},`);
return;
}
}

// generate a policy function that evaluates a partial prisma query
Expand Down
72 changes: 69 additions & 3 deletions packages/schema/src/plugins/enhancer/policy/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime';
import { DELEGATE_AUX_RELATION_PREFIX, type PolicyKind, type PolicyOperationKind } from '@zenstackhq/runtime';
import {
ExpressionContext,
PluginError,
Expand All @@ -15,6 +15,7 @@ import {
getQueryGuardFunctionName,
isAuthInvocation,
isDataModelFieldReference,
isDelegateModel,
isEnumFieldReference,
isFromStdlib,
isFutureExpr,
Expand All @@ -39,9 +40,16 @@ import {
} from '@zenstackhq/sdk/ast';
import deepmerge from 'deepmerge';
import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium';
import { lowerCaseFirst } from 'lower-case-first';
import { SourceFile, WriterFunction } from 'ts-morph';
import { name } from '..';
import { isCheckInvocation, isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils';
import {
getConcreteModels,
getDiscriminatorField,
isCheckInvocation,
isCollectionPredicate,
isFutureInvocation,
} from '../../../utils/ast-utils';
import { ExpressionWriter, FALSE, TRUE } from './expression-writer';

/**
Expand Down Expand Up @@ -303,8 +311,11 @@ export function generateQueryGuardFunction(
forField?: DataModelField,
fieldOverride = false
) {
const statements: (string | WriterFunction)[] = [];
if (isDelegateModel(model) && !forField) {
return generateDelegateQueryGuardFunction(sourceFile, model, kind);
}

const statements: (string | WriterFunction)[] = [];
const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule));
const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule));

Expand Down Expand Up @@ -438,6 +449,61 @@ export function generateQueryGuardFunction(
return func;
}

function generateDelegateQueryGuardFunction(sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind) {
const concreteModels = getConcreteModels(model);

const discriminator = getDiscriminatorField(model);
if (!discriminator) {
throw new PluginError(name, `Model '${model.name}' does not have a discriminator field`);
}

const func = sourceFile.addFunction({
name: getQueryGuardFunctionName(model, undefined, false, kind),
returnType: 'any',
parameters: [
{
name: 'context',
type: 'QueryContext',
},
{
// for generating field references used by field comparison in the same model
name: 'db',
type: 'CrudContract',
},
],
statements: (writer) => {
writer.write('return ');
if (concreteModels.length === 0) {
writer.write(TRUE);
} else {
writer.block(() => {
// union all concrete model's guards
writer.writeLine('OR: [');
concreteModels.forEach((concrete) => {
writer.block(() => {
writer.write('AND: [');
// discriminator condition
writer.write(`{ ${discriminator.name}: '${concrete.name}' },`);
// concrete model guard
writer.write(
`{ ${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(
concrete.name
)}: ${getQueryGuardFunctionName(concrete, undefined, false, kind)}(context, db) }`
);
writer.writeLine(']');
});
writer.write(',');
});
writer.writeLine(']');
});
}
writer.write(';');
},
});

return func;
}

export function generateEntityCheckerFunction(
sourceFile: SourceFile,
model: DataModel,
Expand Down
5 changes: 2 additions & 3 deletions packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import path from 'path';
import semver from 'semver';
import { name } from '.';
import { getStringLiteral } from '../../language-server/validator/utils';
import { getConcreteModels } from '../../utils/ast-utils';
import { execPackage } from '../../utils/exec-utils';
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
import {
Expand Down Expand Up @@ -320,9 +321,7 @@ export class PrismaSchemaGenerator {
}

// collect concrete models inheriting this model
const concreteModels = decl.$container.declarations.filter(
(d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl)
);
const concreteModels = getConcreteModels(decl);

// generate an optional relation field in delegate base model to each concrete model
concreteModels.forEach((concrete) => {
Expand Down
28 changes: 27 additions & 1 deletion packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ import {
BinaryExpr,
DataModel,
DataModelAttribute,
DataModelField,
Expression,
InheritableNode,
isBinaryExpr,
isDataModel,
isDataModelField,
isInvocationExpr,
isModel,
isReferenceExpr,
isTypeDef,
Model,
ModelImport,
TypeDef,
} from '@zenstackhq/language/ast';
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import {
AstNode,
copyAstNode,
Expand Down Expand Up @@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode
}
return undefined;
}

/**
* Gets all concrete models that inherit from the given delegate model
*/
export function getConcreteModels(dataModel: DataModel): DataModel[] {
if (!isDelegateModel(dataModel)) {
return [];
}
return dataModel.$container.declarations.filter(
(d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel)
);
}

/**
* Gets the discriminator field for the given delegate model
*/
export function getDiscriminatorField(delegate: DataModel) {
const delegateAttr = getAttribute(delegate, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}
Loading

0 comments on commit ff86ce0

Please sign in to comment.