From ff004e0f28019a77d1a215f1d1316c219e0955c0 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 9 Nov 2024 15:35:31 -0800 Subject: [PATCH 1/3] feat: allow using a type as auth model --- packages/language/src/generated/ast.ts | 38 ++++-- packages/language/src/generated/grammar.ts | 72 +++++++++-- packages/language/src/zmodel.langium | 15 ++- .../nuxt-trpc-v10/prisma/schema.prisma | 9 -- .../nuxt-trpc-v11/prisma/schema.prisma | 9 -- .../projects/t3-trpc-v11/prisma/schema.prisma | 2 - packages/schema/src/cli/cli-util.ts | 10 +- packages/schema/src/language-server/utils.ts | 10 ++ .../attribute-application-validator.ts | 4 + .../validator/schema-validator.ts | 5 +- .../validator/typedef-validator.ts | 5 + .../src/language-server/zmodel-linker.ts | 20 +-- .../src/language-server/zmodel-scope.ts | 64 +++++---- .../enhancer/enhance/auth-type-generator.ts | 9 +- .../src/plugins/enhancer/enhance/index.ts | 9 +- .../src/plugins/enhancer/policy/utils.ts | 6 +- packages/schema/src/res/stdlib.zmodel | 6 +- packages/schema/src/utils/ast-utils.ts | 17 ++- .../validation/attribute-validation.test.ts | 2 +- packages/sdk/src/model-meta-generator.ts | 8 +- packages/sdk/src/utils.ts | 42 ++++-- .../enhancements/json/validation.test.ts | 122 +++++++++++++++++- .../enhancements/with-policy/auth.test.ts | 56 ++++++++ 23 files changed, 410 insertions(+), 130 deletions(-) diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index 1705f019b..32fa2515d 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -70,7 +70,15 @@ export function isLiteralExpr(item: unknown): item is LiteralExpr { return reflection.isInstance(item, LiteralExpr); } -export type ReferenceTarget = DataModelField | EnumField | FunctionParam; +export type MemberAccessTarget = DataModelField | TypeDefField; + +export const MemberAccessTarget = 'MemberAccessTarget'; + +export function isMemberAccessTarget(item: unknown): item is MemberAccessTarget { + return reflection.isInstance(item, MemberAccessTarget); +} + +export type ReferenceTarget = DataModelField | EnumField | FunctionParam | TypeDefField; export const ReferenceTarget = 'ReferenceTarget'; @@ -285,7 +293,7 @@ export function isDataModel(item: unknown): item is DataModel { } export interface DataModelAttribute extends AstNode { - readonly $container: DataModel | Enum; + readonly $container: DataModel | Enum | TypeDef; readonly $type: 'DataModelAttribute'; args: Array decl: Reference @@ -298,7 +306,7 @@ export function isDataModelAttribute(item: unknown): item is DataModelAttribute } export interface DataModelField extends AstNode { - readonly $container: DataModel | Enum | FunctionDecl; + readonly $container: DataModel | Enum | FunctionDecl | TypeDef; readonly $type: 'DataModelField'; attributes: Array comments: Array @@ -370,7 +378,7 @@ export function isEnum(item: unknown): item is Enum { } export interface EnumField extends AstNode { - readonly $container: DataModel | Enum | FunctionDecl; + readonly $container: DataModel | Enum | FunctionDecl | TypeDef; readonly $type: 'EnumField'; attributes: Array comments: Array @@ -413,7 +421,7 @@ export function isFunctionDecl(item: unknown): item is FunctionDecl { } export interface FunctionParam extends AstNode { - readonly $container: DataModel | Enum | FunctionDecl; + readonly $container: DataModel | Enum | FunctionDecl | TypeDef; readonly $type: 'FunctionParam'; name: RegularID optional: boolean @@ -482,7 +490,7 @@ export function isInvocationExpr(item: unknown): item is InvocationExpr { export interface MemberAccessExpr extends AstNode { readonly $container: Argument | ArrayExpr | AttributeArg | BinaryExpr | ConfigArrayExpr | ConfigField | ConfigInvocationArg | FieldInitializer | FunctionDecl | MemberAccessExpr | PluginField | ReferenceArg | UnaryExpr | UnsupportedFieldType; readonly $type: 'MemberAccessExpr'; - member: Reference + member: Reference operand: Expression } @@ -631,6 +639,7 @@ export function isThisExpr(item: unknown): item is ThisExpr { export interface TypeDef extends AstNode { readonly $container: Model; readonly $type: 'TypeDef'; + attributes: Array comments: Array fields: Array name: RegularID @@ -643,7 +652,7 @@ export function isTypeDef(item: unknown): item is TypeDef { } export interface TypeDefField extends AstNode { - readonly $container: TypeDef; + readonly $container: DataModel | Enum | FunctionDecl | TypeDef; readonly $type: 'TypeDefField'; attributes: Array comments: Array @@ -730,6 +739,7 @@ export type ZModelAstType = { InvocationExpr: InvocationExpr LiteralExpr: LiteralExpr MemberAccessExpr: MemberAccessExpr + MemberAccessTarget: MemberAccessTarget Model: Model ModelImport: ModelImport NullExpr: NullExpr @@ -754,7 +764,7 @@ export type ZModelAstType = { export class ZModelAstReflection extends AbstractAstReflection { getAllTypes(): string[] { - return ['AbstractDeclaration', 'Argument', 'ArrayExpr', 'Attribute', 'AttributeArg', 'AttributeParam', 'AttributeParamType', 'BinaryExpr', 'BooleanLiteral', 'ConfigArrayExpr', 'ConfigExpr', 'ConfigField', 'ConfigInvocationArg', 'ConfigInvocationExpr', 'DataModel', 'DataModelAttribute', 'DataModelField', 'DataModelFieldAttribute', 'DataModelFieldType', 'DataSource', 'Enum', 'EnumField', 'Expression', 'FieldInitializer', 'FunctionDecl', 'FunctionParam', 'FunctionParamType', 'GeneratorDecl', 'InternalAttribute', 'InvocationExpr', 'LiteralExpr', 'MemberAccessExpr', 'Model', 'ModelImport', 'NullExpr', 'NumberLiteral', 'ObjectExpr', 'Plugin', 'PluginField', 'ReferenceArg', 'ReferenceExpr', 'ReferenceTarget', 'StringLiteral', 'ThisExpr', 'TypeDeclaration', 'TypeDef', 'TypeDefField', 'TypeDefFieldType', 'TypeDefFieldTypes', 'UnaryExpr', 'UnsupportedFieldType']; + return ['AbstractDeclaration', 'Argument', 'ArrayExpr', 'Attribute', 'AttributeArg', 'AttributeParam', 'AttributeParamType', 'BinaryExpr', 'BooleanLiteral', 'ConfigArrayExpr', 'ConfigExpr', 'ConfigField', 'ConfigInvocationArg', 'ConfigInvocationExpr', 'DataModel', 'DataModelAttribute', 'DataModelField', 'DataModelFieldAttribute', 'DataModelFieldType', 'DataSource', 'Enum', 'EnumField', 'Expression', 'FieldInitializer', 'FunctionDecl', 'FunctionParam', 'FunctionParamType', 'GeneratorDecl', 'InternalAttribute', 'InvocationExpr', 'LiteralExpr', 'MemberAccessExpr', 'MemberAccessTarget', 'Model', 'ModelImport', 'NullExpr', 'NumberLiteral', 'ObjectExpr', 'Plugin', 'PluginField', 'ReferenceArg', 'ReferenceExpr', 'ReferenceTarget', 'StringLiteral', 'ThisExpr', 'TypeDeclaration', 'TypeDef', 'TypeDefField', 'TypeDefFieldType', 'TypeDefFieldTypes', 'UnaryExpr', 'UnsupportedFieldType']; } protected override computeIsSubtype(subtype: string, supertype: string): boolean { @@ -788,14 +798,17 @@ export class ZModelAstReflection extends AbstractAstReflection { return this.isSubtype(AbstractDeclaration, supertype) || this.isSubtype(TypeDeclaration, supertype); } case DataModelField: - case EnumField: - case FunctionParam: { - return this.isSubtype(ReferenceTarget, supertype); + case TypeDefField: { + return this.isSubtype(MemberAccessTarget, supertype) || this.isSubtype(ReferenceTarget, supertype); } case Enum: case TypeDef: { return this.isSubtype(AbstractDeclaration, supertype) || this.isSubtype(TypeDeclaration, supertype) || this.isSubtype(TypeDefFieldTypes, supertype); } + case EnumField: + case FunctionParam: { + return this.isSubtype(ReferenceTarget, supertype); + } case InvocationExpr: case LiteralExpr: { return this.isSubtype(ConfigExpr, supertype) || this.isSubtype(Expression, supertype); @@ -826,7 +839,7 @@ export class ZModelAstReflection extends AbstractAstReflection { return FunctionDecl; } case 'MemberAccessExpr:member': { - return DataModelField; + return MemberAccessTarget; } case 'ReferenceExpr:target': { return ReferenceTarget; @@ -1055,6 +1068,7 @@ export class ZModelAstReflection extends AbstractAstReflection { return { name: 'TypeDef', mandatory: [ + { name: 'attributes', type: 'array' }, { name: 'comments', type: 'array' }, { name: 'fields', type: 'array' } ] diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 43beb12a2..c6a0113b4 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -1301,7 +1301,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@38" + "$ref": "#/types@1" }, "deprecatedSyntax": false } @@ -2165,7 +2165,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/types@2" + "$ref": "#/types@3" }, "terminal": { "$type": "RuleCall", @@ -2257,16 +2257,33 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "value": "{" }, { - "$type": "Assignment", - "feature": "fields", - "operator": "+=", - "terminal": { - "$type": "RuleCall", - "rule": { - "$ref": "#/rules@41" + "$type": "Alternatives", + "elements": [ + { + "$type": "Assignment", + "feature": "fields", + "operator": "+=", + "terminal": { + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@41" + }, + "arguments": [] + } }, - "arguments": [] - }, + { + "$type": "Assignment", + "feature": "attributes", + "operator": "+=", + "terminal": { + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@55" + }, + "arguments": [] + } + } + ], "cardinality": "*" }, { @@ -2375,7 +2392,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/types@1" + "$ref": "#/types@2" }, "terminal": { "$type": "RuleCall", @@ -2827,7 +2844,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/types@2" + "$ref": "#/types@3" }, "terminal": { "$type": "RuleCall", @@ -3255,7 +3272,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/types@2" + "$ref": "#/types@3" }, "terminal": { "$type": "RuleCall", @@ -3829,6 +3846,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "$ref": "#/rules@38" } }, + { + "$type": "SimpleType", + "typeRef": { + "$ref": "#/rules@41" + } + }, { "$type": "SimpleType", "typeRef": { @@ -3838,6 +3861,27 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel ] } }, + { + "$type": "Type", + "name": "MemberAccessTarget", + "type": { + "$type": "UnionType", + "types": [ + { + "$type": "SimpleType", + "typeRef": { + "$ref": "#/rules@38" + } + }, + { + "$type": "SimpleType", + "typeRef": { + "$ref": "#/rules@41" + } + } + ] + } + }, { "$type": "Type", "name": "TypeDefFieldTypes", diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index ef5a1f883..86d28276b 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -66,7 +66,7 @@ ConfigArrayExpr: ConfigExpr: LiteralExpr | InvocationExpr | ConfigArrayExpr; -type ReferenceTarget = FunctionParam | DataModelField | EnumField; +type ReferenceTarget = FunctionParam | DataModelField | TypeDefField | EnumField; ThisExpr: value='this'; @@ -94,18 +94,20 @@ FieldInitializer: InvocationExpr: function=[FunctionDecl] '(' ArgumentList? ')'; -// binary operator precedence follow Javascript's rules: -// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Operator_Precedence#table +type MemberAccessTarget = DataModelField | TypeDefField; MemberAccessExpr infers Expression: PrimaryExpr ( {infer MemberAccessExpr.operand=current} - ('.' member=[DataModelField]) + ('.' member=[MemberAccessTarget]) )*; UnaryExpr: operator=('!') operand=MemberAccessExpr; +// binary operator precedence follow Javascript's rules: +// https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Operator_Precedence#table + CollectionPredicateExpr infers Expression: MemberAccessExpr ( {infer BinaryExpr.left=current} @@ -179,10 +181,12 @@ DataModelField: DataModelFieldType: (type=BuiltinType | unsupported=UnsupportedFieldType | reference=[TypeDeclaration:RegularID]) (array?='[' ']')? (optional?='?')?; +// TODO: unify TypeDef and abstract DataModel TypeDef: (comments+=TRIPLE_SLASH_COMMENT)* 'type' name=RegularID '{' ( - fields+=TypeDefField + fields+=TypeDefField | + attributes+=DataModelAttribute )* '}'; @@ -245,6 +249,7 @@ type TypeDeclaration = DataModel | TypeDef | Enum; DataModelFieldAttribute: decl=[Attribute:FIELD_ATTRIBUTE_NAME] ('(' AttributeArgList? ')')?; +// TODO: need rename since it's for both DataModel and TypeDef DataModelAttribute: TRIPLE_SLASH_COMMENT* decl=[Attribute:MODEL_ATTRIBUTE_NAME] ('(' AttributeArgList? ')')?; diff --git a/packages/plugins/trpc/tests/projects/nuxt-trpc-v10/prisma/schema.prisma b/packages/plugins/trpc/tests/projects/nuxt-trpc-v10/prisma/schema.prisma index e83468bb0..71cd1ce9b 100644 --- a/packages/plugins/trpc/tests/projects/nuxt-trpc-v10/prisma/schema.prisma +++ b/packages/plugins/trpc/tests/projects/nuxt-trpc-v10/prisma/schema.prisma @@ -12,26 +12,17 @@ generator client { provider = "prisma-client-js" } -/// @@allow('create', true) -/// @@allow('all', auth() == this) model User { id String @id() @default(cuid()) - /// @email - /// @length(6, 32) email String @unique() - /// @password - /// @omit password String posts Post[] } -/// @@allow('read', auth() != null && published) -/// @@allow('all', author == auth()) model Post { id String @id() @default(cuid()) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt() - /// @length(1, 256) title String content String published Boolean @default(false) diff --git a/packages/plugins/trpc/tests/projects/nuxt-trpc-v11/prisma/schema.prisma b/packages/plugins/trpc/tests/projects/nuxt-trpc-v11/prisma/schema.prisma index e83468bb0..71cd1ce9b 100644 --- a/packages/plugins/trpc/tests/projects/nuxt-trpc-v11/prisma/schema.prisma +++ b/packages/plugins/trpc/tests/projects/nuxt-trpc-v11/prisma/schema.prisma @@ -12,26 +12,17 @@ generator client { provider = "prisma-client-js" } -/// @@allow('create', true) -/// @@allow('all', auth() == this) model User { id String @id() @default(cuid()) - /// @email - /// @length(6, 32) email String @unique() - /// @password - /// @omit password String posts Post[] } -/// @@allow('read', auth() != null && published) -/// @@allow('all', author == auth()) model Post { id String @id() @default(cuid()) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt() - /// @length(1, 256) title String content String published Boolean @default(false) diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v11/prisma/schema.prisma b/packages/plugins/trpc/tests/projects/t3-trpc-v11/prisma/schema.prisma index 5199cfaa7..a28fea9fb 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v11/prisma/schema.prisma +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v11/prisma/schema.prisma @@ -12,14 +12,12 @@ generator client { provider = "prisma-client-js" } -/// @@allow('all', true) model User { id Int @id() @default(autoincrement()) email String @unique() posts Post[] } -/// @@allow('all', true) model Post { id Int @id() @default(autoincrement()) name String diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index 43216656e..a5bb81627 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -1,5 +1,5 @@ import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast'; -import { getDataModels, getLiteral, hasAttribute } from '@zenstackhq/sdk'; +import { getDataModelAndTypeDefs, getLiteral, hasAttribute } from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'fs'; import { getDocument, LangiumDocument, LangiumDocuments, linkContentToContainer } from 'langium'; @@ -133,10 +133,10 @@ function validationAfterImportMerge(model: Model) { } // at most one `@@auth` model - const dataModels = getDataModels(model, true); - const authModels = dataModels.filter((d) => hasAttribute(d, '@@auth')); - if (authModels.length > 1) { - console.error(colors.red('Validation error: Multiple `@@auth` models are not allowed')); + const decls = getDataModelAndTypeDefs(model, true); + const authDecls = decls.filter((d) => hasAttribute(d, '@@auth')); + if (authDecls.length > 1) { + console.error(colors.red('Validation error: Multiple `@@auth` declarations are not allowed')); throw new CliError('schema validation errors'); } } diff --git a/packages/schema/src/language-server/utils.ts b/packages/schema/src/language-server/utils.ts index 019004a62..2d31975a6 100644 --- a/packages/schema/src/language-server/utils.ts +++ b/packages/schema/src/language-server/utils.ts @@ -1,6 +1,9 @@ import { isArrayExpr, + isDataModel, isReferenceExpr, + isTypeDef, + TypeDef, type DataModel, type DataModelField, type ReferenceExpr, @@ -25,3 +28,10 @@ export function getUniqueFields(model: DataModel) { .map((item) => resolved(item.target) as DataModelField); }); } + +/** + * Checks if the given node can contain resolvable members. + */ +export function isMemberContainer(node: unknown): node is DataModel | TypeDef { + return isDataModel(node) || isTypeDef(node); +} diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index 2a7f43c29..a7c0fef9a 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -69,6 +69,10 @@ export default class AttributeApplicationValidator implements AstValidator(); for (const arg of attr.args) { diff --git a/packages/schema/src/language-server/validator/schema-validator.ts b/packages/schema/src/language-server/validator/schema-validator.ts index d071324c1..0757c6993 100644 --- a/packages/schema/src/language-server/validator/schema-validator.ts +++ b/packages/schema/src/language-server/validator/schema-validator.ts @@ -1,5 +1,5 @@ import { Model, isDataModel, isDataSource } from '@zenstackhq/language/ast'; -import { hasAttribute } from '@zenstackhq/sdk'; +import { getDataModelAndTypeDefs, hasAttribute } from '@zenstackhq/sdk'; import { LangiumDocuments, ValidationAcceptor } from 'langium'; import { getAllDeclarationsIncludingImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../constants'; @@ -36,7 +36,8 @@ export default class SchemaValidator implements AstValidator { } // at most one `@@auth` model - const authModels = model.declarations.filter((d) => isDataModel(d) && hasAttribute(d, '@@auth')); + const decls = getDataModelAndTypeDefs(model, true); + const authModels = decls.filter((d) => isDataModel(d) && hasAttribute(d, '@@auth')); if (authModels.length > 1) { accept('error', 'Multiple `@@auth` models are not allowed', { node: authModels[1] }); } diff --git a/packages/schema/src/language-server/validator/typedef-validator.ts b/packages/schema/src/language-server/validator/typedef-validator.ts index 55c127d7d..70b6ec860 100644 --- a/packages/schema/src/language-server/validator/typedef-validator.ts +++ b/packages/schema/src/language-server/validator/typedef-validator.ts @@ -10,9 +10,14 @@ import { validateDuplicatedDeclarations } from './utils'; export default class TypeDefValidator implements AstValidator { validate(typeDef: TypeDef, accept: ValidationAcceptor): void { validateDuplicatedDeclarations(typeDef, typeDef.fields, accept); + this.validateAttributes(typeDef, accept); this.validateFields(typeDef, accept); } + private validateAttributes(typeDef: TypeDef, accept: ValidationAcceptor) { + typeDef.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); + } + private validateFields(typeDef: TypeDef, accept: ValidationAcceptor) { typeDef.fields.forEach((field) => this.validateField(field, accept)); } diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index c2751b921..5e6baf848 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -24,6 +24,7 @@ import { ResolvedShape, StringLiteral, ThisExpr, + TypeDefFieldType, UnaryExpr, isArrayExpr, isBooleanLiteral, @@ -35,7 +36,7 @@ import { isReferenceExpr, isStringLiteral, } from '@zenstackhq/language/ast'; -import { getAuthModel, getModelFieldsWithBases, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk'; +import { getAuthDecl, getModelFieldsWithBases, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -53,7 +54,8 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { getAllLoadedAndReachableDataModels, getContainingDataModel } from '../utils/ast-utils'; +import { getAllLoadedAndReachableDataModelsAndTypeDefs, getContainingDataModel } from '../utils/ast-utils'; +import { isMemberContainer } from './utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; interface DefaultReference extends Reference { @@ -281,14 +283,14 @@ export class ZModelLinker extends DefaultLinker { // auth() function is resolved against all loaded and reachable documents // get all data models from loaded and reachable documents - const allDataModels = getAllLoadedAndReachableDataModels( + const allDecls = getAllLoadedAndReachableDataModelsAndTypeDefs( this.langiumDocuments(), getContainerOfType(node, isDataModel) ); - const authModel = getAuthModel(allDataModels); - if (authModel) { - node.$resolvedType = { decl: authModel, nullable: true }; + const authDecl = getAuthDecl(allDecls); + if (authDecl) { + node.$resolvedType = { decl: authDecl, nullable: true }; } } else if (isFutureExpr(node)) { // future() function is resolved to current model @@ -319,7 +321,7 @@ export class ZModelLinker extends DefaultLinker { this.resolveDefault(node, document, extraScopes); const operandResolved = node.operand.$resolvedType; - if (operandResolved && !operandResolved.array && isDataModel(operandResolved.decl)) { + if (operandResolved && !operandResolved.array && isMemberContainer(operandResolved.decl)) { // member access is resolved only in the context of the operand type if (node.member.ref) { this.resolveToDeclaredType(node, node.member.ref.type); @@ -337,7 +339,7 @@ export class ZModelLinker extends DefaultLinker { this.resolveDefault(node, document, extraScopes); const resolvedType = node.left.$resolvedType; - if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) { + if (resolvedType && isMemberContainer(resolvedType.decl) && resolvedType.array) { this.resolveToBuiltinTypeOrDecl(node, 'Boolean'); } else { // error is reported in validation pass @@ -513,7 +515,7 @@ export class ZModelLinker extends DefaultLinker { //#region Utils - private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataModelFieldType) { + private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataModelFieldType | TypeDefFieldType) { let nullable = false; if (isDataModelFieldType(type)) { nullable = type.optional; diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index cde2d4b5a..089bb906b 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -9,8 +9,10 @@ import { isModel, isReferenceExpr, isThisExpr, + isTypeDef, + isTypeDefField, } from '@zenstackhq/language/ast'; -import { getAuthModel, getModelFieldsWithBases, getRecursiveBases, isAuthInvocation } from '@zenstackhq/sdk'; +import { getAuthDecl, getModelFieldsWithBases, getRecursiveBases, isAuthInvocation } from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -32,12 +34,13 @@ import { import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; import { - getAllLoadedAndReachableDataModels, + getAllLoadedAndReachableDataModelsAndTypeDefs, isCollectionPredicate, isFutureInvocation, resolveImportUri, } from '../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; +import { isAuthOrAuthMemberAccess } from './validator/utils'; /** * Custom Langium ScopeComputation implementation which adds enum fields into global scope @@ -133,22 +136,26 @@ export class ZModelScopeProvider extends DefaultScopeProvider { const globalScope = this.getGlobalScope(referenceType, context); const node = context.container as MemberAccessExpr; + // typedef's fields are only added to the scope if the access starts with `auth().` + const allowTypeDefScope = isAuthOrAuthMemberAccess(node.operand); + return match(node.operand) .when(isReferenceExpr, (operand) => { - // operand is a reference, it can only be a model field + // operand is a reference, it can only be a model/type-def field const ref = operand.target.ref; if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - return this.createScopeForModel(targetModel, globalScope); + return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); } return EMPTY_SCOPE; }) .when(isMemberAccessExpr, (operand) => { - // operand is a member access, it must be resolved to a non-array data model type + // operand is a member access, it must be resolved to a non-array model/typedef type const ref = operand.member.ref; if (isDataModelField(ref) && !ref.type.array) { - const targetModel = ref.type.reference?.ref; - return this.createScopeForModel(targetModel, globalScope); + return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); + } + if (isTypeDefField(ref) && !ref.type.array) { + return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); } return EMPTY_SCOPE; }) @@ -159,8 +166,8 @@ export class ZModelScopeProvider extends DefaultScopeProvider { .when(isInvocationExpr, (operand) => { // deal with member access from `auth()` and `future() if (isAuthInvocation(operand)) { - // resolve to `User` or `@@auth` model - return this.createScopeForAuthModel(node, globalScope); + // resolve to `User` or `@@auth` decl + return this.createScopeForAuth(node, globalScope); } if (isFutureInvocation(operand)) { // resolve `future()` to the containing model @@ -176,27 +183,28 @@ export class ZModelScopeProvider extends DefaultScopeProvider { const globalScope = this.getGlobalScope(referenceType, context); const collection = collectionPredicate.left; + // typedef's fields are only added to the scope if the access starts with `auth().` + const allowTypeDefScope = isAuthOrAuthMemberAccess(collection); + return match(collection) .when(isReferenceExpr, (expr) => { - // collection is a reference, it can only be a model field + // collection is a reference - model or typedef field const ref = expr.target.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - return this.createScopeForModel(targetModel, globalScope); + if (isDataModelField(ref) || isTypeDefField(ref)) { + return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); } return EMPTY_SCOPE; }) .when(isMemberAccessExpr, (expr) => { - // collection is a member access, it can only be resolved to a model field + // collection is a member access, it can only be resolved to a model or typedef field const ref = expr.member.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - return this.createScopeForModel(targetModel, globalScope); + if (isDataModelField(ref) || isTypeDefField(ref)) { + return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); } return EMPTY_SCOPE; }) .when(isAuthInvocation, (expr) => { - return this.createScopeForAuthModel(expr, globalScope); + return this.createScopeForAuth(expr, globalScope); }) .otherwise(() => EMPTY_SCOPE); } @@ -204,30 +212,32 @@ export class ZModelScopeProvider extends DefaultScopeProvider { private createScopeForContainingModel(node: AstNode, globalScope: Scope) { const model = getContainerOfType(node, isDataModel); if (model) { - return this.createScopeForModel(model, globalScope); + return this.createScopeForContainer(model, globalScope); } else { return EMPTY_SCOPE; } } - private createScopeForModel(node: AstNode | undefined, globalScope: Scope) { + private createScopeForContainer(node: AstNode | undefined, globalScope: Scope, includeTypeDefScope = false) { if (isDataModel(node)) { return this.createScopeForNodes(getModelFieldsWithBases(node), globalScope); + } else if (includeTypeDefScope && isTypeDef(node)) { + return this.createScopeForNodes(node.fields, globalScope); } else { return EMPTY_SCOPE; } } - private createScopeForAuthModel(node: AstNode, globalScope: Scope) { - // get all data models from loaded and reachable documents - const allDataModels = getAllLoadedAndReachableDataModels( + private createScopeForAuth(node: AstNode, globalScope: Scope) { + // get all data models and type defs from loaded and reachable documents + const decls = getAllLoadedAndReachableDataModelsAndTypeDefs( this.services.shared.workspace.LangiumDocuments, getContainerOfType(node, isDataModel) ); - const authModel = getAuthModel(allDataModels); - if (authModel) { - return this.createScopeForModel(authModel, globalScope); + const authDecl = getAuthDecl(decls); + if (authDecl) { + return this.createScopeForContainer(authDecl, globalScope, true); } else { return EMPTY_SCOPE; } diff --git a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts index a4e09fbb2..3736682ed 100644 --- a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts +++ b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts @@ -5,6 +5,7 @@ import { Expression, isDataModel, isMemberAccessExpr, + TypeDef, type Model, } from '@zenstackhq/sdk/ast'; import { streamAst, type AstNode } from 'langium'; @@ -14,7 +15,7 @@ import { isCollectionPredicate } from '../../../utils/ast-utils'; * Generate types for typing the `user` context object passed to the `enhance` call, based * on the fields (potentially deeply) access through `auth()`. */ -export function generateAuthType(model: Model, authModel: DataModel) { +export function generateAuthType(model: Model, authDecl: DataModel | TypeDef) { const types = new Map< string, { @@ -23,7 +24,7 @@ export function generateAuthType(model: Model, authModel: DataModel) { } >(); - types.set(authModel.name, { requiredRelations: [] }); + types.set(authDecl.name, { requiredRelations: [] }); const ensureType = (model: string) => { if (!types.has(model)) { @@ -88,9 +89,9 @@ ${Array.from(types.entries()) .map(([model, fields]) => { let result = `Partial<_P.${model}>`; - if (model === authModel.name) { + if (model === authDecl.name) { // auth model's id fields are always required - const idFields = getIdFields(authModel).map((f) => f.name); + const idFields = getIdFields(authDecl).map((f) => f.name); if (idFields.length > 0) { result = `WithRequired<${result}, ${idFields.map((f) => `'${f}'`).join('|')}>`; } diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 0b69b6a25..3f9d371a2 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -3,7 +3,8 @@ import { PluginError, getAttribute, getAttributeArg, - getAuthModel, + getAuthDecl, + getDataModelAndTypeDefs, getDataModels, getLiteral, isDelegateModel, @@ -83,9 +84,9 @@ export class EnhancerGenerator { ); await prismaDts.save(); - const authModel = getAuthModel(getDataModels(this.model)); - const authTypes = authModel ? generateAuthType(this.model, authModel) : ''; - const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser'; + const authDecl = getAuthDecl(getDataModelAndTypeDefs(this.model)); + const authTypes = authDecl ? generateAuthType(this.model, authDecl) : ''; + const authTypeParam = authDecl ? `auth.${authDecl.name}` : 'AuthUser'; const checkerTypes = this.generatePermissionChecker ? generateCheckerType(this.model) : ''; diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts index f09263dca..93212e386 100644 --- a/packages/schema/src/plugins/enhancer/policy/utils.ts +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -6,8 +6,8 @@ import { TypeScriptExpressionTransformer, TypeScriptExpressionTransformerError, getAttributeArg, - getAuthModel, - getDataModels, + getAuthDecl, + getDataModelAndTypeDefs, getEntityCheckerFunctionName, getIdFields, getLiteral, @@ -519,7 +519,7 @@ export function generateNormalizedAuthRef( const hasAuthRef = [...allows, ...denies].some((rule) => streamAst(rule).some((child) => isAuthInvocation(child))); if (hasAuthRef) { - const authModel = getAuthModel(getDataModels(model.$container, true)); + const authModel = getAuthDecl(getDataModelAndTypeDefs(model.$container, true)); if (!authModel) { throw new PluginError(name, 'Auth model not found'); } diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 62cefd36e..3316a90a9 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -209,7 +209,7 @@ attribute @@@completionHint(_ values: String[]) * @param sort: Allows you to specify in what order the entries of the ID are stored in the database. The available options are Asc and Desc. * @param clustered: Defines whether the ID is clustered or non-clustered. Defaults to true. */ -attribute @id(map: String?, length: Int?, sort: SortOrder?, clustered: Boolean?) @@@prisma +attribute @id(map: String?, length: Int?, sort: SortOrder?, clustered: Boolean?) @@@prisma @@@supportTypeDef /** * Defines a default value for a field. @@ -536,7 +536,7 @@ attribute @deny(_ operation: String @@@completionHint(["'create'", "'read'", "'u * Used to specify the model for resolving `auth()` function call in access policies. A Zmodel * can have at most one model with this attribute. By default, the model named "User" is used. */ -attribute @@auth() +attribute @@auth() @@@supportTypeDef /** * Indicates that the field is a password field and needs to be hashed before persistence. @@ -639,7 +639,7 @@ attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, Float /** * Validates the entity with a complex condition. */ -attribute @@validate(_ value: Boolean, _ message: String?, _ path: String[]?) @@@validation +attribute @@validate(_ value: Boolean, _ message: String?, _ path: String[]?) @@@validation @@@supportTypeDef /** * Validates length of a string field. diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index effd472f0..73e17d2e0 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -13,9 +13,11 @@ import { isMemberAccessExpr, isModel, isReferenceExpr, + isTypeDef, Model, ModelImport, ReferenceExpr, + TypeDef, } from '@zenstackhq/language/ast'; import { getInheritanceChain, @@ -302,21 +304,24 @@ export function findUpAst(node: AstNode, predicate: (node: AstNode) => boolean): } /** - * Gets all data models from all loaded documents + * Gets all data models and type defs from all loaded documents */ -export function getAllLoadedDataModels(langiumDocuments: LangiumDocuments) { +export function getAllLoadedDataModelsAndTypeDefs(langiumDocuments: LangiumDocuments) { return langiumDocuments.all .map((doc) => doc.parseResult.value as Model) - .flatMap((model) => model.declarations.filter(isDataModel)) + .flatMap((model) => model.declarations.filter((d): d is DataModel | TypeDef => isDataModel(d) || isTypeDef(d))) .toArray(); } /** - * Gets all data models from loaded and reachable documents + * Gets all data models and type defs from loaded and reachable documents */ -export function getAllLoadedAndReachableDataModels(langiumDocuments: LangiumDocuments, fromModel?: DataModel) { +export function getAllLoadedAndReachableDataModelsAndTypeDefs( + langiumDocuments: LangiumDocuments, + fromModel?: DataModel +) { // get all data models from loaded documents - const allDataModels = getAllLoadedDataModels(langiumDocuments); + const allDataModels = getAllLoadedDataModelsAndTypeDefs(langiumDocuments); if (fromModel) { // merge data models transitively reached from the current model diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 4d86837d0..0133f452c 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -1151,7 +1151,7 @@ describe('Attribute tests', () => { @@allow('all', auth().email != null) } `) - ).toContain(`Could not resolve reference to DataModelField named 'email'.`); + ).toContain(`Could not resolve reference to MemberAccessTarget named 'email'.`); }); it('collection predicate expression check', async () => { diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 859399673..ae76f2fe5 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -26,7 +26,7 @@ import { getAttributeArg, getAttributeArgLiteral, getAttributeArgs, - getAuthModel, + getAuthDecl, getDataModels, getInheritedFromDelegate, getLiteral, @@ -101,7 +101,7 @@ function generateModelMetadata( writeTypeDefs(sourceFile, writer, typeDefs, options); writeDeleteCascade(writer, dataModels); writeShortNameMap(options, writer); - writeAuthModel(writer, dataModels); + writeAuthModel(writer, dataModels, typeDefs); }); } @@ -162,8 +162,8 @@ function writeBaseTypes(writer: CodeBlockWriter, model: DataModel) { } } -function writeAuthModel(writer: CodeBlockWriter, dataModels: DataModel[]) { - const authModel = getAuthModel(dataModels); +function writeAuthModel(writer: CodeBlockWriter, dataModels: DataModel[], typeDefs: TypeDef[]) { + const authModel = getAuthDecl([...dataModels, ...typeDefs]); if (authModel) { writer.writeLine(`authModel: '${authModel.name}'`); } diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 6b2bfe868..8313d662c 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -27,9 +27,11 @@ import { isModel, isObjectExpr, isReferenceExpr, + isTypeDef, Model, Reference, ReferenceExpr, + TypeDef, TypeDefField, } from '@zenstackhq/language/ast'; import fs from 'node:fs'; @@ -49,6 +51,18 @@ export function getDataModels(model: Model, includeIgnored = false) { } } +/** + * Gets data models and type defs in the ZModel schema. + */ +export function getDataModelAndTypeDefs(model: Model, includeIgnored = false) { + const r = model.declarations.filter((d): d is DataModel | TypeDef => isDataModel(d) || isTypeDef(d)); + if (includeIgnored) { + return r; + } else { + return r.filter((model) => !hasAttribute(model, '@@ignore')); + } +} + export function resolved(ref: Reference): T { if (!ref.ref) { throw new Error(`Reference not resolved: ${ref.$refText}`); @@ -117,14 +131,23 @@ export function indentString(string: string, count = 4): string { } export function hasAttribute( - decl: DataModel | DataModelField | Enum | EnumField | FunctionDecl | Attribute | AttributeParam, + decl: DataModel | TypeDef | DataModelField | Enum | EnumField | FunctionDecl | Attribute | AttributeParam, name: string ) { return !!getAttribute(decl, name); } export function getAttribute( - decl: DataModel | DataModelField | TypeDefField | Enum | EnumField | FunctionDecl | Attribute | AttributeParam, + decl: + | DataModel + | TypeDef + | DataModelField + | TypeDefField + | Enum + | EnumField + | FunctionDecl + | Attribute + | AttributeParam, name: string ) { return (decl.attributes as (DataModelAttribute | DataModelFieldAttribute)[]).find( @@ -448,10 +471,10 @@ export function getPreviewFeatures(model: Model) { return [] as string[]; } -export function getAuthModel(dataModels: DataModel[]) { - let authModel = dataModels.find((m) => hasAttribute(m, '@@auth')); +export function getAuthDecl(decls: (DataModel | TypeDef)[]) { + let authModel = decls.find((m) => hasAttribute(m, '@@auth')); if (!authModel) { - authModel = dataModels.find((m) => m.name === 'User'); + authModel = decls.find((m) => m.name === 'User'); } return authModel; } @@ -473,15 +496,14 @@ export function isDiscriminatorField(field: DataModelField) { return isDataModelFieldReference(arg) && arg.target.$refText === field.name; } -export function getIdFields(dataModel: DataModel) { - const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => - f.attributes.some((attr) => attr.decl.$refText === '@id') - ); +export function getIdFields(decl: DataModel | TypeDef) { + const fields = isDataModel(decl) ? getModelFieldsWithBases(decl) : decl.fields; + const fieldLevelId = fields.find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id')); if (fieldLevelId) { return [fieldLevelId]; } else { // get model level @@id attribute - const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); + const modelIdAttr = decl.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); if (modelIdAttr) { // get fields referenced in the attribute: @@id([field1, field2]]) if (!isArrayExpr(modelIdAttr.args[0].value)) { diff --git a/tests/integration/tests/enhancements/json/validation.test.ts b/tests/integration/tests/enhancements/json/validation.test.ts index 27f5e5067..9d135e1d8 100644 --- a/tests/integration/tests/enhancements/json/validation.test.ts +++ b/tests/integration/tests/enhancements/json/validation.test.ts @@ -1,4 +1,4 @@ -import { loadModelWithError, loadSchema } from '@zenstackhq/testtools'; +import { loadModel, loadModelWithError, loadSchema } from '@zenstackhq/testtools'; describe('JSON field typing', () => { it('is only supported by postgres', async () => { @@ -36,4 +36,124 @@ describe('JSON field typing', () => { ) ).resolves.toContain('Custom-typed field must have @json attribute'); }); + + it('disallows normal member accesses in policy rules', async () => { + await expect( + loadModelWithError( + ` + type Profile { + age Int @gt(0) + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', profile.age > 18) + } + ` + ) + ).resolves.toContain(`Could not resolve reference to MemberAccessMember named 'age'.`); + }); + + it('allows auth member accesses in policy rules', async () => { + await expect( + loadModel( + ` + type Profile { + age Int @gt(0) + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', auth().profile.age > 18) + } + ` + ) + ).toResolveTruthy(); + }); + + it('disallows normal collection accesses in policy rules', async () => { + await expect( + loadModelWithError( + ` + type Profile { + roles Role[] + } + + type Role { + name String + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', profile.roles?[name == 'ADMIN']) + } + ` + ) + ).resolves.toContain(`Could not resolve reference to MemberAccessMember named 'roles'.`); + + await expect( + loadModelWithError( + ` + type Profile { + role String + } + + model User { + id Int @id @default(autoincrement()) + profiles Profile[] @json + @@allow('all', profiles?[role == 'ADMIN']) + } + ` + ) + ).resolves.toContain(`Could not resolve reference to ReferenceTarget named 'role'.`); + }); + + it('disallows auth collection accesses in policy rules', async () => { + await expect( + loadModel( + ` + type Profile { + roles Role[] + } + + type Role { + name String + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', auth().profile.roles?[name == 'ADMIN']) + } + ` + ) + ).toResolveTruthy(); + }); + + it('only allows whitelisted type-level attributes', async () => { + await expect( + loadModel( + ` + type User { + id Int @id + @@auth + } + ` + ) + ).toResolveTruthy(); + + await expect( + loadModelWithError( + ` + type User { + id Int @id + @@allow('all', true) + } + ` + ) + ).resolves.toContain('attribute "@@allow" cannot be used on type declarations'); + }); }); diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 02c3959d0..296eefee7 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -864,4 +864,60 @@ describe('auth() compile-time test', () => { } ); }); + + it('"User" type as auth', async () => { + const { enhance } = await loadSchema( + ` + type Profile { + age Int + } + + type Role { + name String + permissions String[] + } + + type User { + myId Int @id + banned Boolean + profile Profile + roles Role[] + } + + model Foo { + id Int @id @default(autoincrement()) + @@allow('read', true) + @@allow('create', auth().myId == 1 && !auth().banned) + @@allow('delete', auth().roles?['DELETE' in permissions]) + @@deny('all', auth().profile.age < 18) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { enhance } from ".zenstack/enhance"; + import { PrismaClient } from '@prisma/client'; + enhance(new PrismaClient(), { user: { myId: 1, profile: { age: 20 } } }); + `, + }, + ], + } + ); + + await expect(enhance().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(enhance({ myId: 1, banned: true }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(enhance({ myId: 1, profile: { age: 16 } }).foo.create({ data: {} })).toBeRejectedByPolicy(); + const r = await enhance({ myId: 1, profile: { age: 20 } }).foo.create({ data: {} }); + await expect( + enhance({ myId: 1, profile: { age: 20 } }).foo.delete({ where: { id: r.id } }) + ).toBeRejectedByPolicy(); + await expect( + enhance({ myId: 1, profile: { age: 20 }, roles: [{ name: 'ADMIN', permissions: ['DELETE'] }] }).foo.delete({ + where: { id: r.id }, + }) + ).toResolveTruthy(); + }); }); From 81991fb84b0536db9607a31268122b7b9af208b2 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 9 Nov 2024 15:47:18 -0800 Subject: [PATCH 2/3] fix tests --- packages/schema/src/language-server/zmodel-linker.ts | 3 ++- tests/regression/tests/issue-756.test.ts | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 5e6baf848..1e2491bda 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -35,6 +35,7 @@ import { isNumberLiteral, isReferenceExpr, isStringLiteral, + isTypeDefField, } from '@zenstackhq/language/ast'; import { getAuthDecl, getModelFieldsWithBases, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk'; import { @@ -517,7 +518,7 @@ export class ZModelLinker extends DefaultLinker { private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataModelFieldType | TypeDefFieldType) { let nullable = false; - if (isDataModelFieldType(type)) { + if (isDataModelFieldType(type) || isTypeDefField(type)) { nullable = type.optional; // referencing a field of 'Unsupported' type diff --git a/tests/regression/tests/issue-756.test.ts b/tests/regression/tests/issue-756.test.ts index 9f6750ea9..dd1a10ccf 100644 --- a/tests/regression/tests/issue-756.test.ts +++ b/tests/regression/tests/issue-756.test.ts @@ -28,6 +28,6 @@ describe('Regression: issue 756', () => { } ` ) - ).toContain(`Could not resolve reference to DataModelField named 'authorId'.`); + ).toContain(`Could not resolve reference to MemberAccessTarget named 'authorId'.`); }); }); From d3d49e41f3caab9866fac464b6e4cccd0616a58a Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 9 Nov 2024 16:10:29 -0800 Subject: [PATCH 3/3] fix tests --- tests/integration/tests/enhancements/json/validation.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/tests/enhancements/json/validation.test.ts b/tests/integration/tests/enhancements/json/validation.test.ts index 9d135e1d8..df5dfc281 100644 --- a/tests/integration/tests/enhancements/json/validation.test.ts +++ b/tests/integration/tests/enhancements/json/validation.test.ts @@ -52,7 +52,7 @@ describe('JSON field typing', () => { } ` ) - ).resolves.toContain(`Could not resolve reference to MemberAccessMember named 'age'.`); + ).resolves.toContain(`Could not resolve reference to MemberAccessTarget named 'age'.`); }); it('allows auth member accesses in policy rules', async () => { @@ -92,7 +92,7 @@ describe('JSON field typing', () => { } ` ) - ).resolves.toContain(`Could not resolve reference to MemberAccessMember named 'roles'.`); + ).resolves.toContain(`Could not resolve reference to MemberAccessTarget named 'roles'.`); await expect( loadModelWithError(