diff --git a/README.md b/README.md index 620056619..32091aaa7 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ Thank you for your support! + diff --git a/package.json b/package.json index 2b9bdf1ea..9b0b3d158 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "2.0.3", + "version": "2.1.0", "description": "", "scripts": { "build": "pnpm -r build", @@ -29,7 +29,7 @@ "@typescript-eslint/parser": "^7.6.0", "concurrently": "^7.4.0", "copyfiles": "^2.4.1", - "eslint": "^8.56.0", + "eslint": "^8.57.0", "eslint-plugin-jest": "^28.2.0", "jest": "^29.7.0", "replace-in-file": "^7.0.1", diff --git a/packages/ide/jetbrains/build.gradle.kts b/packages/ide/jetbrains/build.gradle.kts index ad5cb1600..24b570e37 100644 --- a/packages/ide/jetbrains/build.gradle.kts +++ b/packages/ide/jetbrains/build.gradle.kts @@ -9,7 +9,7 @@ plugins { } group = "dev.zenstack" -version = "2.0.3" +version = "2.1.0" repositories { mavenCentral() diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index 51621e680..296e9da55 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -1,6 +1,6 @@ { "name": "jetbrains", - "version": "2.0.3", + "version": "2.1.0", "displayName": "ZenStack JetBrains IDE Plugin", "description": "ZenStack JetBrains IDE plugin", "homepage": "https://zenstack.dev", diff --git a/packages/language/package.json b/packages/language/package.json index 4de5ca0da..56121c456 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "2.0.3", + "version": "2.1.0", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index a95a748d9..74fa1707b 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -84,6 +84,12 @@ export function isRegularID(item: unknown): item is RegularID { return item === 'model' || item === 'enum' || item === 'attribute' || item === 'datasource' || item === 'plugin' || item === 'abstract' || item === 'in' || item === 'view' || item === 'import' || (typeof item === 'string' && (/[_a-zA-Z][\w_]*/.test(item))); } +export type RegularIDWithTypeNames = 'Any' | 'BigInt' | 'Boolean' | 'Bytes' | 'DateTime' | 'Decimal' | 'Float' | 'Int' | 'Json' | 'Null' | 'Object' | 'String' | 'Unsupported' | RegularID; + +export function isRegularIDWithTypeNames(item: unknown): item is RegularIDWithTypeNames { + return isRegularID(item) || item === 'String' || item === 'Boolean' || item === 'Int' || item === 'BigInt' || item === 'Float' || item === 'Decimal' || item === 'DateTime' || item === 'Json' || item === 'Bytes' || item === 'Null' || item === 'Object' || item === 'Any' || item === 'Unsupported'; +} + export type TypeDeclaration = DataModel | Enum; export const TypeDeclaration = 'TypeDeclaration'; @@ -288,7 +294,7 @@ export interface DataModelField extends AstNode { readonly $type: 'DataModelField'; attributes: Array comments: Array - name: RegularID + name: RegularIDWithTypeNames type: DataModelFieldType } diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index 45aa3ff97..08656b422 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -70,7 +70,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@63" + "$ref": "#/rules@64" }, "arguments": [] } @@ -140,7 +140,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" }, "arguments": [] } @@ -162,7 +162,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -222,7 +222,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -282,7 +282,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -333,7 +333,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -393,7 +393,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -481,7 +481,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@64" + "$ref": "#/rules@65" }, "arguments": [] } @@ -503,7 +503,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@63" + "$ref": "#/rules@64" }, "arguments": [] } @@ -525,7 +525,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@57" + "$ref": "#/rules@58" }, "arguments": [] } @@ -649,7 +649,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@62" + "$ref": "#/rules@63" }, "arguments": [] } @@ -747,7 +747,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@62" + "$ref": "#/rules@63" }, "arguments": [] } @@ -1055,7 +1055,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@62" + "$ref": "#/rules@63" }, "arguments": [] } @@ -1176,7 +1176,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@63" + "$ref": "#/rules@64" }, "arguments": [] } @@ -1894,7 +1894,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, @@ -2032,7 +2032,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2066,7 +2066,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, @@ -2079,7 +2079,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@46" + "$ref": "#/rules@47" }, "arguments": [] } @@ -2103,7 +2103,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@50" + "$ref": "#/rules@51" }, "arguments": [] }, @@ -2134,7 +2134,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@56" + "$ref": "#/rules@57" }, "arguments": [] } @@ -2262,7 +2262,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, @@ -2310,7 +2310,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@51" + "$ref": "#/rules@52" }, "arguments": [] } @@ -2344,7 +2344,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, @@ -2369,7 +2369,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@50" + "$ref": "#/rules@51" }, "arguments": [] }, @@ -2393,7 +2393,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -2506,7 +2506,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@52" + "$ref": "#/rules@53" }, "arguments": [] }, @@ -2530,7 +2530,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -2598,7 +2598,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@55" + "$ref": "#/rules@56" }, "arguments": [] } @@ -2662,7 +2662,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@62" + "$ref": "#/rules@63" }, "arguments": [] }, @@ -2711,6 +2711,81 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "parameters": [], "wildcard": false }, + { + "$type": "ParserRule", + "name": "RegularIDWithTypeNames", + "dataType": "string", + "definition": { + "$type": "Alternatives", + "elements": [ + { + "$type": "RuleCall", + "rule": { + "$ref": "#/rules@46" + }, + "arguments": [] + }, + { + "$type": "Keyword", + "value": "String" + }, + { + "$type": "Keyword", + "value": "Boolean" + }, + { + "$type": "Keyword", + "value": "Int" + }, + { + "$type": "Keyword", + "value": "BigInt" + }, + { + "$type": "Keyword", + "value": "Float" + }, + { + "$type": "Keyword", + "value": "Decimal" + }, + { + "$type": "Keyword", + "value": "DateTime" + }, + { + "$type": "Keyword", + "value": "Json" + }, + { + "$type": "Keyword", + "value": "Bytes" + }, + { + "$type": "Keyword", + "value": "Null" + }, + { + "$type": "Keyword", + "value": "Object" + }, + { + "$type": "Keyword", + "value": "Any" + }, + { + "$type": "Keyword", + "value": "Unsupported" + } + ] + }, + "definesHiddenTokens": false, + "entry": false, + "fragment": false, + "hiddenTokens": [], + "parameters": [], + "wildcard": false + }, { "$type": "ParserRule", "name": "Attribute", @@ -2724,7 +2799,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, @@ -2744,21 +2819,21 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@59" + "$ref": "#/rules@60" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@60" + "$ref": "#/rules@61" }, "arguments": [] }, { "$type": "RuleCall", "rule": { - "$ref": "#/rules@61" + "$ref": "#/rules@62" }, "arguments": [] } @@ -2779,7 +2854,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@48" + "$ref": "#/rules@49" }, "arguments": [] } @@ -2798,7 +2873,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@48" + "$ref": "#/rules@49" }, "arguments": [] } @@ -2820,7 +2895,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@52" + "$ref": "#/rules@53" }, "arguments": [] }, @@ -2848,7 +2923,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [] }, @@ -2887,7 +2962,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@49" + "$ref": "#/rules@50" }, "arguments": [] } @@ -2899,7 +2974,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@52" + "$ref": "#/rules@53" }, "arguments": [] }, @@ -2933,7 +3008,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@55" + "$ref": "#/rules@56" }, "arguments": [] }, @@ -3024,12 +3099,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" }, "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@61" + "$ref": "#/rules@62" }, "arguments": [] }, @@ -3046,7 +3121,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@53" + "$ref": "#/rules@54" }, "arguments": [], "cardinality": "?" @@ -3076,7 +3151,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@65" + "$ref": "#/rules@66" }, "arguments": [], "cardinality": "*" @@ -3088,12 +3163,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" }, "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@60" + "$ref": "#/rules@61" }, "arguments": [] }, @@ -3110,7 +3185,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@53" + "$ref": "#/rules@54" }, "arguments": [], "cardinality": "?" @@ -3144,12 +3219,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "CrossReference", "type": { - "$ref": "#/rules@47" + "$ref": "#/rules@48" }, "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@59" + "$ref": "#/rules@60" }, "arguments": [] }, @@ -3166,7 +3241,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel { "$type": "RuleCall", "rule": { - "$ref": "#/rules@53" + "$ref": "#/rules@54" }, "arguments": [], "cardinality": "?" @@ -3201,7 +3276,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@54" + "$ref": "#/rules@55" }, "arguments": [] } @@ -3220,7 +3295,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "terminal": { "$type": "RuleCall", "rule": { - "$ref": "#/rules@54" + "$ref": "#/rules@55" }, "arguments": [] } diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index 8fcc72c34..132350bcc 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -190,7 +190,7 @@ DataModel: DataModelField: (comments+=TRIPLE_SLASH_COMMENT)* - name=RegularID type=DataModelFieldType (attributes+=DataModelFieldAttribute)*; + name=RegularIDWithTypeNames type=DataModelFieldType (attributes+=DataModelFieldAttribute)*; DataModelFieldType: (type=BuiltinType | unsupported=UnsupportedFieldType | reference=[TypeDeclaration:RegularID]) (array?='[' ']')? (optional?='?')?; @@ -226,6 +226,9 @@ RegularID returns string: // include keywords that we'd like to work as ID in most places ID | 'model' | 'enum' | 'attribute' | 'datasource' | 'plugin' | 'abstract' | 'in' | 'view' | 'import'; +RegularIDWithTypeNames returns string: + RegularID | 'String' | 'Boolean' | 'Int' | 'BigInt' | 'Float' | 'Decimal' | 'DateTime' | 'Json' | 'Bytes' | 'Null' | 'Object' | 'Any' | 'Unsupported'; + // attribute Attribute: (comments+=TRIPLE_SLASH_COMMENT)* 'attribute' name=(INTERNAL_ATTRIBUTE_NAME|MODEL_ATTRIBUTE_NAME|FIELD_ATTRIBUTE_NAME) '(' (params+=AttributeParam (',' params+=AttributeParam)*)? ')' (attributes+=InternalAttribute)*; diff --git a/packages/misc/redwood/package.json b/packages/misc/redwood/package.json index 19f34c238..f5d7bdfaa 100644 --- a/packages/misc/redwood/package.json +++ b/packages/misc/redwood/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/redwood", "displayName": "ZenStack RedwoodJS Integration", - "version": "2.0.3", + "version": "2.1.0", "description": "CLI and runtime for integrating ZenStack with RedwoodJS projects.", "repository": { "type": "git", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index 314f33919..75103e17f 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "2.0.3", + "version": "2.1.0", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/openapi/src/rpc-generator.ts b/packages/plugins/openapi/src/rpc-generator.ts index cb388aae2..7339ef788 100644 --- a/packages/plugins/openapi/src/rpc-generator.ts +++ b/packages/plugins/openapi/src/rpc-generator.ts @@ -1,22 +1,22 @@ // Inspired by: https://github.com/omar-dulaimi/prisma-trpc-generator -import { analyzePolicies, PluginError, requireOption, resolvePath } from '@zenstackhq/sdk'; +import { PluginError, analyzePolicies, requireOption, resolvePath } from '@zenstackhq/sdk'; import { DataModel, isDataModel } from '@zenstackhq/sdk/ast'; import { + AggregateOperationSupport, addMissingInputObjectTypesForAggregate, addMissingInputObjectTypesForInclude, addMissingInputObjectTypesForModelArgs, addMissingInputObjectTypesForSelect, - AggregateOperationSupport, resolveAggregateOperationSupport, } from '@zenstackhq/sdk/dmmf-helpers'; -import type { DMMF } from '@zenstackhq/sdk/prisma'; +import { supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma'; import * as fs from 'fs'; import { lowerCaseFirst } from 'lower-case-first'; import type { OpenAPIV3_1 as OAPI } from 'openapi-types'; import * as path from 'path'; import invariant from 'tiny-invariant'; -import { match, P } from 'ts-pattern'; +import { P, match } from 'ts-pattern'; import { upperCaseFirst } from 'upper-case-first'; import YAML from 'yaml'; import { name } from '.'; @@ -166,7 +166,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase { }); } - if (ops['createMany']) { + if (ops['createMany'] && supportCreateMany(zmodel.$container)) { definitions.push({ method: 'post', operation: 'createMany', @@ -704,7 +704,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase { private generateEnumComponent(_enum: DMMF.SchemaEnum): OAPI.SchemaObject { const schema: OAPI.SchemaObject = { type: 'string', - enum: _enum.values, + enum: _enum.values as string[], }; return schema; } @@ -793,17 +793,14 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase { return result; } - private setInputRequired(fields: { name: string; isRequired: boolean }[], result: OAPI.NonArraySchemaObject) { + private setInputRequired(fields: readonly DMMF.SchemaArg[], result: OAPI.NonArraySchemaObject) { const required = fields.filter((f) => f.isRequired).map((f) => f.name); if (required.length > 0) { result.required = required; } } - private setOutputRequired( - fields: { name: string; isNullable?: boolean; outputType: DMMF.OutputTypeRef }[], - result: OAPI.NonArraySchemaObject - ) { + private setOutputRequired(fields: readonly DMMF.SchemaField[], result: OAPI.NonArraySchemaObject) { const required = fields.filter((f) => f.isNullable !== true).map((f) => f.name); if (required.length > 0) { result.required = required; diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index a36af725d..6a6992553 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "2.0.3", + "version": "2.1.0", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { @@ -46,6 +46,7 @@ "lower-case-first": "^2.0.2", "semver": "^7.5.2", "ts-morph": "^16.0.0", + "ts-pattern": "^4.3.0", "upper-case-first": "^2.0.2" }, "devDependencies": { diff --git a/packages/plugins/swr/src/generator.ts b/packages/plugins/swr/src/generator.ts index 4ab3fb79e..65c80fc8c 100644 --- a/packages/plugins/swr/src/generator.ts +++ b/packages/plugins/swr/src/generator.ts @@ -1,5 +1,6 @@ import { PluginOptions, + RUNTIME_PACKAGE, createProject, ensureEmptyDir, generateModelMeta, @@ -8,11 +9,12 @@ import { resolvePath, saveProject, } from '@zenstackhq/sdk'; -import { DataModel, Model } from '@zenstackhq/sdk/ast'; -import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma'; +import { DataModel, DataModelFieldType, Model, isEnum } from '@zenstackhq/sdk/ast'; +import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma'; import { paramCase } from 'change-case'; import path from 'path'; import type { OptionalKind, ParameterDeclarationStructure, Project, SourceFile } from 'ts-morph'; +import { P, match } from 'ts-pattern'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; @@ -66,6 +68,7 @@ function generateModelHooks( }); sf.addStatements([ `import { type GetNextArgs, type QueryOptions, type InfiniteQueryOptions, type MutationOptions, type PickEnumerable } from '@zenstackhq/swr/runtime';`, + `import type { PolicyCrudKind } from '${RUNTIME_PACKAGE}'`, `import metadata from './__model_meta';`, `import * as request from '@zenstackhq/swr/runtime';`, ]); @@ -82,7 +85,7 @@ function generateModelHooks( } // createMany - if (mapping.createMany) { + if (mapping.createMany && supportCreateMany(model.$container)) { const argsType = `Prisma.${model.name}CreateManyArgs`; mutationFuncs.push(generateMutation(sf, model, 'POST', 'createMany', argsType, true)); } @@ -239,6 +242,11 @@ function generateModelHooks( const returnType = `T extends { select: any; } ? T['select'] extends true ? number : Prisma.GetScalarType : number`; generateQueryHook(sf, model, 'count', argsType, inputType, returnType); } + + // extra `check` hook for ZenStack's permission checker API + { + generateCheckHook(sf, model, prismaImport); + } } function makeOptimistic(returnType: string) { @@ -337,3 +345,47 @@ function generateMutation( return funcName; } + +function generateCheckHook(sf: SourceFile, model: DataModel, prismaImport: string) { + const mapFilterType = (type: DataModelFieldType) => { + return match(type.type) + .with(P.union('Int', 'BigInt'), () => 'number') + .with('String', () => 'string') + .with('Boolean', () => 'boolean') + .otherwise(() => undefined); + }; + + const filterFields: Array<{ name: string; type: string }> = []; + const enumsToImport = new Set(); + + // collect filterable fields and enums to import + model.fields.forEach((f) => { + if (isEnum(f.type.reference?.ref)) { + enumsToImport.add(f.type.reference.$refText); + filterFields.push({ name: f.name, type: f.type.reference.$refText }); + } + + const mappedType = mapFilterType(f.type); + if (mappedType) { + filterFields.push({ name: f.name, type: mappedType }); + } + }); + + if (enumsToImport.size > 0) { + // import enums + sf.addStatements(`import type { ${Array.from(enumsToImport).join(', ')} } from '${prismaImport}';`); + } + + const whereType = `{ ${filterFields.map(({ name, type }) => `${name}?: ${type}`).join('; ')} }`; + + const func = sf.addFunction({ + name: `useCheck${model.name}`, + isExported: true, + parameters: [ + { name: 'args', type: `{ operation: PolicyCrudKind; where?: ${whereType}; }` }, + { name: 'options?', type: `QueryOptions` }, + ], + }); + + func.addStatements(`return request.useModelQuery('${model.name}', 'check', args, options);`); +} diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index 6f0805a52..d663d12f1 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "2.0.3", + "version": "2.1.0", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { diff --git a/packages/plugins/tanstack-query/src/generator.ts b/packages/plugins/tanstack-query/src/generator.ts index 7666cb742..6a4ba53d9 100644 --- a/packages/plugins/tanstack-query/src/generator.ts +++ b/packages/plugins/tanstack-query/src/generator.ts @@ -1,6 +1,7 @@ import { PluginError, PluginOptions, + RUNTIME_PACKAGE, createProject, ensureEmptyDir, generateModelMeta, @@ -9,13 +10,13 @@ import { resolvePath, saveProject, } from '@zenstackhq/sdk'; -import { DataModel, Model } from '@zenstackhq/sdk/ast'; -import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma'; +import { DataModel, DataModelFieldType, Model, isEnum } from '@zenstackhq/sdk/ast'; +import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma'; import { paramCase } from 'change-case'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import { Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; -import { match } from 'ts-pattern'; +import { P, match } from 'ts-pattern'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; @@ -261,6 +262,62 @@ function generateMutationHook( func.addStatements('return mutation;'); } +function generateCheckHook( + target: string, + version: TanStackVersion, + sf: SourceFile, + model: DataModel, + prismaImport: string +) { + const mapFilterType = (type: DataModelFieldType) => { + return match(type.type) + .with(P.union('Int', 'BigInt'), () => 'number') + .with('String', () => 'string') + .with('Boolean', () => 'boolean') + .otherwise(() => undefined); + }; + + const filterFields: Array<{ name: string; type: string }> = []; + const enumsToImport = new Set(); + + // collect filterable fields and enums to import + model.fields.forEach((f) => { + if (isEnum(f.type.reference?.ref)) { + enumsToImport.add(f.type.reference.$refText); + filterFields.push({ name: f.name, type: f.type.reference.$refText }); + } + + const mappedType = mapFilterType(f.type); + if (mappedType) { + filterFields.push({ name: f.name, type: mappedType }); + } + }); + + if (enumsToImport.size > 0) { + // import enums + sf.addStatements(`import type { ${Array.from(enumsToImport).join(', ')} } from '${prismaImport}';`); + } + + const whereType = `{ ${filterFields.map(({ name, type }) => `${name}?: ${type}`).join('; ')} }`; + + const func = sf.addFunction({ + name: `useCheck${model.name}`, + isExported: true, + typeParameters: ['TError = DefaultError'], + parameters: [ + { name: 'args', type: `{ operation: PolicyCrudKind; where?: ${whereType}; }` }, + { name: 'options?', type: makeQueryOptions(target, 'boolean', 'boolean', false, false, version) }, + ], + }); + + func.addStatements([ + makeGetContext(target), + `return useModelQuery('${model.name}', \`\${endpoint}/${lowerCaseFirst( + model.name + )}/check\`, args, options, fetch);`, + ]); +} + function generateModelHooks( target: TargetFramework, version: TanStackVersion, @@ -291,7 +348,7 @@ function generateModelHooks( } // createMany - if (mapping.createMany) { + if (mapping.createMany && supportCreateMany(model.$container)) { generateMutationHook(target, sf, model.name, 'createMany', 'post', false, 'Prisma.BatchPayload'); } @@ -494,6 +551,11 @@ function generateModelHooks( `TArgs extends { select: any; } ? TArgs['select'] extends true ? number : Prisma.GetScalarType : number` ); } + + { + // extra `check` hook for ZenStack's permission checker API + generateCheckHook(target, version, sf, model, prismaImport); + } } function generateIndex( @@ -538,6 +600,7 @@ function makeBaseImports(target: TargetFramework, version: TanStackVersion) { const shared = [ `import { useModelQuery, useInfiniteModelQuery, useModelMutation } from '${runtimeImportBase}/${target}';`, `import type { PickEnumerable, CheckSelect, QueryError, ExtraQueryOptions, ExtraMutationOptions } from '${runtimeImportBase}';`, + `import type { PolicyCrudKind } from '${RUNTIME_PACKAGE}'`, `import metadata from './__model_meta';`, `type DefaultError = QueryError;`, ]; diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 78bc59131..08538d870 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "2.0.3", + "version": "2.1.0", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/plugins/trpc/src/generator.ts b/packages/plugins/trpc/src/generator.ts index cf57a1baf..0180cc649 100644 --- a/packages/plugins/trpc/src/generator.ts +++ b/packages/plugins/trpc/src/generator.ts @@ -10,7 +10,7 @@ import { type PluginOptions, } from '@zenstackhq/sdk'; import { Model } from '@zenstackhq/sdk/ast'; -import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma'; +import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma'; import fs from 'fs'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; @@ -79,11 +79,11 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF. function createAppRouter( outDir: string, - modelOperations: DMMF.ModelMapping[], + modelOperations: readonly DMMF.ModelMapping[], hiddenModels: string[], generateModelActions: string[] | undefined, generateClientHelpers: string[] | undefined, - _zmodel: Model, + zmodel: Model, zodSchemasImport: string, options: PluginOptions ) { @@ -99,19 +99,21 @@ function createAppRouter( { namedImports: [ 'unsetMarker', - 'type AnyRouter', - 'type AnyRootConfig', - 'type CreateRouterInner', - 'type Procedure', - 'type ProcedureBuilder', - 'type ProcedureParams', - 'type ProcedureRouterRecord', - 'type ProcedureType', + 'AnyRouter', + 'AnyRootConfig', + 'CreateRouterInner', + 'Procedure', + 'ProcedureBuilder', + 'ProcedureParams', + 'ProcedureRouterRecord', + 'ProcedureType', ], + isTypeOnly: true, moduleSpecifier: '@trpc/server', }, { - namedImports: ['type PrismaClient'], + namedImports: ['PrismaClient'], + isTypeOnly: true, moduleSpecifier: prismaImport, }, ]); @@ -169,7 +171,8 @@ function createAppRouter( generateModelActions, generateClientHelpers, zodSchemasImport, - options + options, + zmodel ); appRouter.addImportDeclaration({ @@ -239,7 +242,8 @@ function generateModelCreateRouter( generateModelActions: string[] | undefined, generateClientHelpers: string[] | undefined, zodSchemasImport: string, - options: PluginOptions + options: PluginOptions, + zmodel: Model ) { const modelRouter = project.createSourceFile(path.resolve(outputDir, 'routers', `${model}.router.ts`), undefined, { overwrite: true, @@ -296,6 +300,10 @@ function generateModelCreateRouter( inputType && (!generateModelActions || generateModelActions.includes(generateOpName)) ) { + if (generateOpName === 'createMany' && !supportCreateMany(zmodel)) { + continue; + } + generateProcedure(funcWriter, generateOpName, upperCaseFirst(inputType), model, baseOpType); if (routerTypingStructure) { diff --git a/packages/plugins/trpc/src/helpers.ts b/packages/plugins/trpc/src/helpers.ts index 947b96b98..c165288f7 100644 --- a/packages/plugins/trpc/src/helpers.ts +++ b/packages/plugins/trpc/src/helpers.ts @@ -237,7 +237,12 @@ export function generateRouterTypingImports(sourceFile: SourceFile, options: Plu // eslint-disable-next-line @typescript-eslint/no-unused-vars export function generateRouterSchemaImport(sourceFile: SourceFile, zodSchemasImport: string) { - sourceFile.addStatements(`import * as $Schema from '${zodSchemasImport}/input';`); + sourceFile.addStatements([ + `import * as _Schema from '${zodSchemasImport}/input';`, + // temporary solution for dealing with the issue that Node.js wraps named exports under a `default` + // key when importing from a CJS module + `const $Schema: typeof _Schema = (_Schema as any).default ?? _Schema;`, + ]); } export function generateHelperImport(sourceFile: SourceFile) { @@ -327,7 +332,7 @@ export const getProcedureTypeByOpName = (opName: string) => { return procType; }; -export function resolveModelsComments(models: DMMF.Model[], hiddenModels: string[]) { +export function resolveModelsComments(models: readonly DMMF.Model[], hiddenModels: string[]) { const modelAttributeRegex = /(@@Gen\.)+([A-z])+(\()+(.+)+(\))+/; const attributeNameRegex = /(?:\.)+([A-Za-z])+(?:\()+/; const attributeArgsRegex = /(?:\()+([A-Za-z])+:+(.+)+(?:\))+/; diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts index fbc73cf06..3edfe9c94 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts @@ -1,6 +1,7 @@ /* eslint-disable */ import { type RouterFactory, type ProcBuilder, type BaseConfig, db } from "."; -import * as $Schema from '@zenstackhq/runtime/zod/input'; +import * as _Schema from '@zenstackhq/runtime/zod/input'; +const $Schema: typeof _Schema = (_Schema as any).default ?? _Schema; import { checkRead, checkMutate } from '../helper'; import type { Prisma } from '@prisma/client'; import type { UseTRPCMutationOptions, UseTRPCMutationResult, UseTRPCQueryOptions, UseTRPCQueryResult, UseTRPCInfiniteQueryOptions, UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared'; @@ -12,6 +13,8 @@ export default function createRouter(router: RouterFa aggregate: procedure.input($Schema.PostInputSchema.aggregate).query(({ ctx, input }) => checkRead(db(ctx).post.aggregate(input as any))), + createMany: procedure.input($Schema.PostInputSchema.createMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.createMany(input as any))), + create: procedure.input($Schema.PostInputSchema.create).mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.create(input as any))), deleteMany: procedure.input($Schema.PostInputSchema.deleteMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.deleteMany(input as any))), @@ -60,6 +63,20 @@ export interface ClientType >; + }; + createMany: { + + useMutation: (opts?: UseTRPCMutationOptions< + Prisma.PostCreateManyArgs, + TRPCClientErrorLike, + Prisma.BatchPayload, + Context + >,) => + Omit, Prisma.SelectSubset, Context>, 'mutateAsync'> & { + mutateAsync: + (variables: T, opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>) => Promise + }; + }; create: { diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts index c4bdb89de..366ccfcb0 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts @@ -1,6 +1,7 @@ /* eslint-disable */ import { type RouterFactory, type ProcBuilder, type BaseConfig, db } from "."; -import * as $Schema from '@zenstackhq/runtime/zod/input'; +import * as _Schema from '@zenstackhq/runtime/zod/input'; +const $Schema: typeof _Schema = (_Schema as any).default ?? _Schema; import { checkRead, checkMutate } from '../helper'; import type { Prisma } from '@prisma/client'; import type { UseTRPCMutationOptions, UseTRPCMutationResult, UseTRPCQueryOptions, UseTRPCQueryResult, UseTRPCInfiniteQueryOptions, UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared'; @@ -12,6 +13,8 @@ export default function createRouter(router: RouterFa aggregate: procedure.input($Schema.UserInputSchema.aggregate).query(({ ctx, input }) => checkRead(db(ctx).user.aggregate(input as any))), + createMany: procedure.input($Schema.UserInputSchema.createMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.createMany(input as any))), + create: procedure.input($Schema.UserInputSchema.create).mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.create(input as any))), deleteMany: procedure.input($Schema.UserInputSchema.deleteMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.deleteMany(input as any))), @@ -60,6 +63,20 @@ export interface ClientType >; + }; + createMany: { + + useMutation: (opts?: UseTRPCMutationOptions< + Prisma.UserCreateManyArgs, + TRPCClientErrorLike, + Prisma.BatchPayload, + Context + >,) => + Omit, Prisma.SelectSubset, Context>, 'mutateAsync'> & { + mutateAsync: + (variables: T, opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>) => Promise + }; + }; create: { diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts index f474aa5b5..523f4b645 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts @@ -1,6 +1,6 @@ /* eslint-disable */ -import { unsetMarker, type AnyRouter, type AnyRootConfig, type CreateRouterInner, type Procedure, type ProcedureBuilder, type ProcedureParams, type ProcedureRouterRecord, type ProcedureType } from "@trpc/server"; -import { type PrismaClient } from "@prisma/client"; +import type { unsetMarker, AnyRouter, AnyRootConfig, CreateRouterInner, Procedure, ProcedureBuilder, ProcedureParams, ProcedureRouterRecord, ProcedureType } from "@trpc/server"; +import type { PrismaClient } from "@prisma/client"; import createUserRouter from "./User.router"; import createPostRouter from "./Post.router"; import { ClientType as UserClientType } from "./User.router"; diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 10da4f5c9..94d8a587b 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "2.0.3", + "version": "2.1.0", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", @@ -29,6 +29,10 @@ "types": "./enhancements/index.d.ts", "default": "./enhancements/index.js" }, + "./constraint-solver": { + "types": "./constraint-solver.d.ts", + "default": "./constraint-solver.js" + }, "./zod": { "types": "./zod/index.d.ts", "default": "./zod/index.js" @@ -79,12 +83,14 @@ "decimal.js": "^10.4.2", "deepcopy": "^2.1.0", "deepmerge": "^4.3.1", + "logic-solver": "^2.0.1", "lower-case-first": "^2.0.2", "pluralize": "^8.0.0", "safe-json-stringify": "^1.2.0", "semver": "^7.5.2", "superjson": "^1.11.0", "tiny-invariant": "^1.3.1", + "ts-pattern": "^4.3.0", "tslib": "^2.4.1", "upper-case-first": "^2.0.2", "uuid": "^9.0.0", diff --git a/packages/runtime/src/enhance.d.ts b/packages/runtime/src/enhance.d.ts index 48e877878..38a519830 100644 --- a/packages/runtime/src/enhance.d.ts +++ b/packages/runtime/src/enhance.d.ts @@ -1,2 +1,2 @@ // @ts-expect-error stub for re-exporting generated code -export { enhance } from '.zenstack/enhance'; +export { auth, enhance } from '.zenstack/enhance'; diff --git a/packages/runtime/src/enhancements/delegate.ts b/packages/runtime/src/enhancements/delegate.ts index e2fdc65f2..8e4c3569a 100644 --- a/packages/runtime/src/enhancements/delegate.ts +++ b/packages/runtime/src/enhancements/delegate.ts @@ -99,6 +99,12 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { } Object.entries(where).forEach(([field, value]) => { + if (['AND', 'OR', 'NOT'].includes(field)) { + // recurse into logical group + enumerate(value).forEach((item) => this.injectWhereHierarchy(model, item)); + return; + } + const fieldInfo = resolveField(this.options.modelMeta, model, field); if (!fieldInfo?.inheritedFrom) { return; diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts new file mode 100644 index 000000000..c87a528e7 --- /dev/null +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -0,0 +1,219 @@ +import Logic from 'logic-solver'; +import { match } from 'ts-pattern'; +import type { + CheckerConstraint, + ComparisonConstraint, + ComparisonTerm, + LogicalConstraint, + ValueConstraint, + VariableConstraint, +} from '../types'; + +/** + * A boolean constraint solver based on `logic-solver`. Only boolean and integer types are supported. + */ +export class ConstraintSolver { + // a table for internalizing string literals + private stringTable: string[] = []; + + // a map for storing variable names and their corresponding formulas + private variables: Map = new Map(); + + /** + * Check the satisfiability of the given constraint. + */ + checkSat(constraint: CheckerConstraint): boolean { + // reset state + this.stringTable = []; + this.variables = new Map(); + + // convert the constraint to a "logic-solver" formula + const formula = this.buildFormula(constraint); + + // solve the formula + const solver = new Logic.Solver(); + solver.require(formula); + + // DEBUG: + // const solution = solver.solve(); + // if (solution) { + // console.log('Solution:'); + // this.variables.forEach((v, k) => console.log(`\t${k}=${solution?.evaluate(v)}`)); + // } else { + // console.log('No solution'); + // } + + return !!solver.solve(); + } + + private buildFormula(constraint: CheckerConstraint): Logic.Formula { + return match(constraint) + .when( + (c): c is ValueConstraint => c.kind === 'value', + (c) => this.buildValueFormula(c) + ) + .when( + (c): c is VariableConstraint => c.kind === 'variable', + (c) => this.buildVariableFormula(c) + ) + .when( + (c): c is ComparisonConstraint => ['eq', 'ne', 'gt', 'gte', 'lt', 'lte'].includes(c.kind), + (c) => this.buildComparisonFormula(c) + ) + .when( + (c): c is LogicalConstraint => ['and', 'or', 'not'].includes(c.kind), + (c) => this.buildLogicalFormula(c) + ) + .otherwise(() => { + throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`); + }); + } + + private buildLogicalFormula(constraint: LogicalConstraint) { + return match(constraint.kind) + .with('and', () => this.buildAndFormula(constraint)) + .with('or', () => this.buildOrFormula(constraint)) + .with('not', () => this.buildNotFormula(constraint)) + .exhaustive(); + } + + private buildAndFormula(constraint: LogicalConstraint): Logic.Formula { + if (constraint.children.some((c) => this.isFalse(c))) { + // short-circuit + return Logic.FALSE; + } + return Logic.and(...constraint.children.map((c) => this.buildFormula(c))); + } + + private buildOrFormula(constraint: LogicalConstraint): Logic.Formula { + if (constraint.children.some((c) => this.isTrue(c))) { + // short-circuit + return Logic.TRUE; + } + return Logic.or(...constraint.children.map((c) => this.buildFormula(c))); + } + + private buildNotFormula(constraint: LogicalConstraint) { + if (constraint.children.length !== 1) { + throw new Error('"not" constraint must have exactly one child'); + } + return Logic.not(this.buildFormula(constraint.children[0])); + } + + private isTrue(constraint: CheckerConstraint): unknown { + return constraint.kind === 'value' && constraint.value === true; + } + + private isFalse(constraint: CheckerConstraint): unknown { + return constraint.kind === 'value' && constraint.value === false; + } + + private buildComparisonFormula(constraint: ComparisonConstraint) { + if (constraint.left.kind === 'value' && constraint.right.kind === 'value') { + // constant comparison + const left: ValueConstraint = constraint.left; + const right: ValueConstraint = constraint.right; + return match(constraint.kind) + .with('eq', () => (left.value === right.value ? Logic.TRUE : Logic.FALSE)) + .with('ne', () => (left.value !== right.value ? Logic.TRUE : Logic.FALSE)) + .with('gt', () => (left.value > right.value ? Logic.TRUE : Logic.FALSE)) + .with('gte', () => (left.value >= right.value ? Logic.TRUE : Logic.FALSE)) + .with('lt', () => (left.value < right.value ? Logic.TRUE : Logic.FALSE)) + .with('lte', () => (left.value <= right.value ? Logic.TRUE : Logic.FALSE)) + .exhaustive(); + } + + return match(constraint.kind) + .with('eq', () => this.transformEquality(constraint.left, constraint.right)) + .with('ne', () => this.transformInequality(constraint.left, constraint.right)) + .with('gt', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThan(l, r)) + ) + .with('gte', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThanOrEqual(l, r)) + ) + .with('lt', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThan(l, r)) + ) + .with('lte', () => + this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThanOrEqual(l, r)) + ) + .exhaustive(); + } + + private buildVariableFormula(constraint: VariableConstraint) { + return ( + match(constraint.type) + .with('boolean', () => this.booleanVariable(constraint.name)) + .with('number', () => this.intVariable(constraint.name)) + // strings are internalized and represented by their indices + .with('string', () => this.intVariable(constraint.name)) + .exhaustive() + ); + } + + private buildValueFormula(constraint: ValueConstraint) { + return match(constraint.value) + .when( + (v): v is boolean => typeof v === 'boolean', + (v) => (v === true ? Logic.TRUE : Logic.FALSE) + ) + .when( + (v): v is number => typeof v === 'number', + (v) => Logic.constantBits(v) + ) + .when( + (v): v is string => typeof v === 'string', + (v) => { + // internalize the string and use its index as formula representation + const index = this.stringTable.indexOf(v); + if (index === -1) { + this.stringTable.push(v); + return Logic.constantBits(this.stringTable.length - 1); + } else { + return Logic.constantBits(index); + } + } + ) + .exhaustive(); + } + + private booleanVariable(name: string) { + this.variables.set(name, name); + return name; + } + + private intVariable(name: string) { + const r = Logic.variableBits(name, 32); + this.variables.set(name, r); + return r; + } + + private transformEquality(left: ComparisonTerm, right: ComparisonTerm) { + if (left.type !== right.type) { + throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`); + } + + const leftFormula = this.buildFormula(left); + const rightFormula = this.buildFormula(right); + if (left.type === 'boolean' && right.type === 'boolean') { + // logical equivalence + return Logic.equiv(leftFormula, rightFormula); + } else { + // integer equality + return Logic.equalBits(leftFormula, rightFormula); + } + } + + private transformInequality(left: ComparisonTerm, right: ComparisonTerm) { + return Logic.not(this.transformEquality(left, right)); + } + + private transformComparison( + left: ComparisonTerm, + right: ComparisonTerm, + func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula + ) { + return func(this.buildFormula(left), this.buildFormula(right)); + } +} diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index d6d893d4e..997e727d5 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -2,6 +2,7 @@ import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; +import { P, match } from 'ts-pattern'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; import { CrudFailureReason } from '../../constants'; @@ -16,13 +17,15 @@ import { type FieldInfo, type ModelMeta, } from '../../cross'; -import { PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; +import { PolicyCrudKind, PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; +import type { CheckerConstraint } from '../types'; import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; +import { ConstraintSolver } from './constraint-solver'; import { PolicyUtil } from './policy-utils'; // a record for post-write policy check @@ -35,6 +38,12 @@ type PostWriteCheckRecord = { type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFirstOrThrow' | 'findMany'; +// input arg type for `check` API +type PermissionCheckArgs = { + operation: PolicyCrudKind; + where?: Record; +}; + /** * Prisma proxy handler for injecting access policy check. */ @@ -1436,6 +1445,115 @@ export class PolicyProxyHandler implements Pr //#endregion + //#region Check + + /** + * Checks if the given operation is possibly allowed by the policy, without querying the database. + * @param operation The CRUD operation. + * @param fieldValues Extra field value filters to be combined with the policy constraints. + */ + async check(args: PermissionCheckArgs): Promise { + return createDeferredPromise(() => this.doCheck(args)); + } + + private async doCheck(args: PermissionCheckArgs) { + if (!['create', 'read', 'update', 'delete'].includes(args.operation)) { + throw prismaClientValidationError(this.prisma, this.prismaModule, `Invalid "operation" ${args.operation}`); + } + + let constraint = this.policyUtils.getCheckerConstraint(this.model, args.operation); + if (typeof constraint === 'boolean') { + return constraint; + } + + if (args.where) { + // combine runtime filters with generated constraints + + const extraConstraints: CheckerConstraint[] = []; + for (const [field, value] of Object.entries(args.where)) { + if (value === undefined) { + continue; + } + + if (value === null) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Using "null" as filter value is not supported yet` + ); + } + + const fieldInfo = requireField(this.modelMeta, this.model, field); + + // relation and array fields are not supported + if (fieldInfo.isDataModel || fieldInfo.isArray) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Providing filter for field "${field}" is not supported. Only scalar fields are allowed.` + ); + } + + // map field type to constraint type + const fieldType = match(fieldInfo.type) + .with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => 'number') + .with('String', () => 'string') + .with('Boolean', () => 'boolean') + .otherwise(() => { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Providing filter for field "${field}" is not supported. Only number, string, and boolean fields are allowed.` + ); + }); + + // check value type + const valueType = typeof value; + if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value type for field "${field}". Only number, string or boolean is allowed.` + ); + } + + if (fieldType !== valueType) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value type for field "${field}". Expected "${fieldType}".` + ); + } + + // check number validity + if (typeof value === 'number' && (!Number.isInteger(value) || value < 0)) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + `Invalid value for field "${field}". Only non-negative integers are allowed.` + ); + } + + // build a constraint + extraConstraints.push({ + kind: 'eq', + left: { kind: 'variable', name: field, type: fieldType }, + right: { kind: 'value', value, type: fieldType }, + }); + } + + if (extraConstraints.length > 0) { + // combine the constraints + constraint = { kind: 'and', children: [constraint, ...extraConstraints] }; + } + } + + // check satisfiability + return new ConstraintSolver().checkSat(constraint); + } + + //#endregion + //#region Utils private get shouldLogQuery() { diff --git a/packages/runtime/src/enhancements/policy/logic-solver.d.ts b/packages/runtime/src/enhancements/policy/logic-solver.d.ts new file mode 100644 index 000000000..d10e688f6 --- /dev/null +++ b/packages/runtime/src/enhancements/policy/logic-solver.d.ts @@ -0,0 +1,109 @@ +/** + * Type definitions for the `logic-solver` npm package. + */ +declare module 'logic-solver' { + /** + * A boolean formula. + */ + interface Formula {} + + /** + * The `TRUE` formula. + */ + const TRUE: Formula; + + /** + * The `FALSE` formula. + */ + const FALSE: Formula; + + /** + * Boolean equivalence. + */ + export function equiv(operand1: Formula, operand2: Formula): Formula; + + /** + * Bits equality. + */ + export function equalBits(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits greater-than. + */ + export function greaterThan(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits greater-than-or-equal. + */ + export function greaterThanOrEqual(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits less-than. + */ + export function lessThan(bits1: Formula, bits2: Formula): Formula; + + /** + * Bits less-than-or-equal. + */ + export function lessThanOrEqual(bits1: Formula, bits2: Formula): Formula; + + /** + * Logical AND. + */ + export function and(...args: Formula[]): Formula; + + /** + * Logical OR. + */ + export function or(...args: Formula[]): Formula; + + /** + * Logical NOT. + */ + export function not(arg: Formula): Formula; + + /** + * Creates a bits variable with the given name and bit length. + */ + export function variableBits(baseName: string, N: number): Formula; + + /** + * Creates a constant bits formula from the given whole number. + */ + export function constantBits(wholeNumber: number): Formula; + + /** + * A solution to a constraint. + */ + interface Solution { + /** + * Returns a map of variable assignments. + */ + getMap(): object; + + /** + * Evaluates the given formula against the solution. + */ + evaluate(formula: Formula): unknown; + } + + /** + * A constraint solver. + */ + class Solver { + /** + * Adds constraints to the solver. + */ + require(...args: Formula[]): void; + + /** + * Adds negated constraints from the solver. + */ + forbid(...args: Formula[]): void; + + /** + * Solves the constraints. + */ + solve(): Solution; + } +} diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index bcb946877..b288063c0 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -17,12 +17,12 @@ import { PrismaErrorCode, } from '../../constants'; import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; -import { AuthUser, CrudContract, DbClientContract, PolicyOperationKind } from '../../types'; +import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; -import type { InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; +import type { CheckerFunc, InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; import { formatObject, prismaClientKnownRequestError } from '../utils'; /** @@ -228,7 +228,7 @@ export class PolicyUtil extends QueryUtils { //#endregion - //# Auth guard + //#region Auth guard private readonly FULLY_OPEN_AUTH_GUARD = { create: true, @@ -267,7 +267,7 @@ export class PolicyUtil extends QueryUtils { } if (!provider) { - throw this.unknownError(`zenstack: unable to load authorization guard for ${model}`); + throw this.unknownError(`unable to load authorization guard for ${model}`); } const r = provider({ user: this.user, preValue }, db); return this.reduce(r); @@ -561,6 +561,46 @@ export class PolicyUtil extends QueryUtils { return true; } + //#endregion + + //#region Checker + + /** + * Gets checker constraints for the given model and operation. + */ + getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { + const checker = this.getModelChecker(model); + const provider = checker[operation]; + if (typeof provider === 'boolean') { + return provider; + } + + if (typeof provider !== 'function') { + throw this.unknownError(`invalid ${operation} checker function for ${model}`); + } + + // call checker function + return provider({ user: this.user }); + } + + private getModelChecker(model: string) { + if (this.options.kinds && !this.options.kinds.includes('policy')) { + // policy enhancement not enabled, return a constant true checker + return { create: true, read: true, update: true, delete: true }; + } else { + const result = this.options.policy.checker?.[lowerCaseFirst(model)]; + if (!result) { + // checker generation not enabled, return constant false checker + throw new Error( + `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.` + ); + } + return result; + } + } + + //#endregion + /** * Gets unique constraints for the given model. */ @@ -609,6 +649,10 @@ export class PolicyUtil extends QueryUtils { const hoistedConditions: any[] = []; for (const field of getModelFields(injectTarget)) { + if (injectTarget[field] === false) { + continue; + } + const fieldInfo = resolveField(this.modelMeta, model, field); if (!fieldInfo || !fieldInfo.isDataModel) { // only care about relation fields @@ -934,6 +978,11 @@ export class PolicyUtil extends QueryUtils { } private doInjectReadCheckSelect(model: string, args: any, input: any) { + // omit should be ignored to avoid interfering with field selection + if (args.omit) { + delete args.omit; + } + if (!input?.select) { return; } @@ -1130,6 +1179,12 @@ export class PolicyUtil extends QueryUtils { continue; } + if (queryArgs?.omit?.[field] === true) { + // respect `{ omit: { [field]: true } }` + delete entityData[field]; + continue; + } + if (hasFieldLevelPolicy) { // 1. remove fields selected for checking field-level policies but not selected by the original query args // 2. evaluate field-level policies and remove fields that are not readable diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 9fecc375e..89d5ce9f6 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -9,7 +9,7 @@ import { HAS_FIELD_LEVEL_POLICY_FLAG, PRE_UPDATE_VALUE_SELECTOR, } from '../constants'; -import type { CrudContract, PolicyOperationKind, QueryContext } from '../types'; +import type { CheckerContext, CrudContract, PolicyCrudKind, PolicyOperationKind, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -33,6 +33,57 @@ export interface CommonEnhancementOptions { */ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; +/** + * Function for checking if an operation is possibly allowed. + */ +export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; + +/** + * Supported checker constraint checking value types. + */ +export type ConstraintValueTypes = 'boolean' | 'number' | 'string'; + +/** + * Free variable constraint + */ +export type VariableConstraint = { kind: 'variable'; name: string; type: ConstraintValueTypes }; + +/** + * Constant value constraint + */ +export type ValueConstraint = { + kind: 'value'; + value: number | boolean | string; + type: ConstraintValueTypes; +}; + +/** + * Terms for comparison constraints + */ +export type ComparisonTerm = VariableConstraint | ValueConstraint; + +/** + * Comparison constraint + */ +export type ComparisonConstraint = { + kind: 'eq' | 'ne' | 'gt' | 'gte' | 'lt' | 'lte'; + left: ComparisonTerm; + right: ComparisonTerm; +}; + +/** + * Logical constraint + */ +export type LogicalConstraint = { + kind: 'and' | 'or' | 'not'; + children: CheckerConstraint[]; +}; + +/** + * Operation allowability checking constraint + */ +export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint; + /** * Function for getting policy guard with a given context */ @@ -71,6 +122,8 @@ export type PolicyDef = { } >; + checker?: Record>; + // tracks which models have data validation rules validation: Record; diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 4bcab85a1..4c32480ba 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -22,6 +22,7 @@ export interface DbOperations { groupBy(args: unknown): Promise; count(args?: unknown): Promise; subscribe(args?: unknown): Promise; + check(args: unknown): Promise; fields: Record; } @@ -30,10 +31,12 @@ export interface DbOperations { */ export type PolicyKind = 'allow' | 'deny'; +export type PolicyCrudKind = 'read' | 'create' | 'update' | 'delete'; + /** * Kinds of operations controlled by access policies */ -export type PolicyOperationKind = 'create' | 'update' | 'postUpdate' | 'read' | 'delete'; +export type PolicyOperationKind = PolicyCrudKind | 'postUpdate'; /** * Current login user info @@ -56,6 +59,21 @@ export type QueryContext = { preValue?: any; }; +/** + * Context for checking operation allowability. + */ +export type CheckerContext = { + /** + * Current user + */ + user?: AuthUser; + + /** + * Extra field value filters. + */ + fieldValues?: Record; +}; + /** * Prisma contract for CRUD operations. */ diff --git a/packages/schema/package.json b/packages/schema/package.json index 07e1ba1df..55515cbf5 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "Build scalable web apps with minimum code by defining authorization and validation rules inside the data schema that closer to the database", - "version": "2.0.3", + "version": "2.1.0", "author": { "name": "ZenStack Team" }, diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index e9db60fb6..18e2c6be3 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -12,7 +12,7 @@ import { URI } from 'vscode-uri'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../language-server/constants'; import { ZModelFormatter } from '../language-server/zmodel-formatter'; import { createZModelServices, ZModelServices } from '../language-server/zmodel-module'; -import { mergeBaseModel, resolveImport, resolveTransitiveImports } from '../utils/ast-utils'; +import { mergeBaseModels, resolveImport, resolveTransitiveImports } from '../utils/ast-utils'; import { findUp } from '../utils/pkg-utils'; import { getVersion } from '../utils/version-utils'; import { CliError } from './cli-error'; @@ -101,10 +101,10 @@ export async function loadDocument(fileName: string): Promise { imported.map((m) => m.$document!.uri) ); - validationAfterMerge(model); + validationAfterImportMerge(model); // merge fields and attributes from base models - mergeBaseModel(model, services.references.Linker); + mergeBaseModels(model, services.references.Linker); // finally relink all references const relinkedModel = await relinkAll(model, services); @@ -113,7 +113,7 @@ export async function loadDocument(fileName: string): Promise { } // check global unique thing after merge imports -function validationAfterMerge(model: Model) { +function validationAfterImportMerge(model: Model) { const dataSources = model.declarations.filter((d) => isDataSource(d)); if (dataSources.length == 0) { console.error(colors.red('Validation error: Model must define a datasource')); diff --git a/packages/schema/src/language-server/utils.ts b/packages/schema/src/language-server/utils.ts index c57836a91..019004a62 100644 --- a/packages/schema/src/language-server/utils.ts +++ b/packages/schema/src/language-server/utils.ts @@ -1,4 +1,10 @@ -import { DataModel, DataModelField, isArrayExpr, isReferenceExpr, ReferenceExpr } from '@zenstackhq/language/ast'; +import { + isArrayExpr, + isReferenceExpr, + type DataModel, + type DataModelField, + type ReferenceExpr, +} from '@zenstackhq/language/ast'; import { resolved } from '@zenstackhq/sdk'; /** diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index 0baf5ace3..9185443f3 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -2,14 +2,19 @@ import { ArrayExpr, DataModel, DataModelField, - isDataModel, - isStringLiteral, ReferenceExpr, + isDataModel, isEnum, + isStringLiteral, } from '@zenstackhq/language/ast'; -import { getLiteral, getModelIdFields, getModelUniqueFields, isDelegateModel } from '@zenstackhq/sdk'; -import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium'; -import { getModelFieldsWithBases } from '../../utils/ast-utils'; +import { + getLiteral, + getModelFieldsWithBases, + getModelIdFields, + getModelUniqueFields, + isDelegateModel, +} from '@zenstackhq/sdk'; +import { AstNode, DiagnosticInfo, ValidationAcceptor, getDocument } from 'langium'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getUniqueFields } from '../utils'; @@ -361,6 +366,11 @@ export default class DataModelValidator implements AstValidator { const containingModel = field.$container as DataModel; const uniqueFieldList = getUniqueFields(containingModel); + // field is defined in the abstract base model + if (containingModel !== contextModel) { + uniqueFieldList.push(...getUniqueFields(contextModel)); + } + thisRelation.fields?.forEach((ref) => { const refField = ref.target.ref as DataModelField; if (refField) { @@ -372,7 +382,7 @@ export default class DataModelValidator implements AstValidator { } accept( 'error', - `Field "${refField.name}" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute`, + `Field "${refField.name}" on model "${containingModel.name}" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute`, { node: refField } ); } diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index cb42e4cb1..d65e304dc 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -30,7 +30,7 @@ export default class ExpressionValidator implements AstValidator { // check was done at link time accept( 'error', - 'auth() cannot be resolved because no model marked wth "@@auth()" or named "User" is found', + 'auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found', { node: expr } ); } else { diff --git a/packages/schema/src/language-server/zmodel-code-action.ts b/packages/schema/src/language-server/zmodel-code-action.ts index f71e479ef..a4c99da96 100644 --- a/packages/schema/src/language-server/zmodel-code-action.ts +++ b/packages/schema/src/language-server/zmodel-code-action.ts @@ -10,8 +10,8 @@ import { getDocument, } from 'langium'; +import { getModelFieldsWithBases } from '@zenstackhq/sdk'; import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver'; -import { getModelFieldsWithBases } from '../utils/ast-utils'; import { IssueCodes } from './constants'; import { MissingOppositeRelationData } from './validator/datamodel-validator'; import { ZModelFormatter } from './zmodel-formatter'; diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 5a15f9336..c2751b921 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -35,13 +35,7 @@ import { isReferenceExpr, isStringLiteral, } from '@zenstackhq/language/ast'; -import { - getAuthModel, - getContainingModel, - getModelFieldsWithBases, - isAuthInvocation, - isFutureExpr, -} from '@zenstackhq/sdk'; +import { getAuthModel, getModelFieldsWithBases, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -52,13 +46,14 @@ import { LangiumServices, LinkingError, Reference, + getContainerOfType, interruptAndCheck, isReference, streamContents, } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { getAllDataModelsIncludingImports, getContainingDataModel } from '../utils/ast-utils'; +import { getAllLoadedAndReachableDataModels, getContainingDataModel } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; interface DefaultReference extends Reference { @@ -283,15 +278,17 @@ export class ZModelLinker extends DefaultLinker { // eslint-disable-next-line @typescript-eslint/ban-types const funcDecl = node.function.ref as FunctionDecl; if (isAuthInvocation(node)) { - // auth() function is resolved to User model in the current document - const model = getContainingModel(node); - - if (model) { - const allDataModels = getAllDataModelsIncludingImports(this.langiumDocuments(), model); - const authModel = getAuthModel(allDataModels); - if (authModel) { - node.$resolvedType = { decl: authModel, nullable: true }; - } + // auth() function is resolved against all loaded and reachable documents + + // get all data models from loaded and reachable documents + const allDataModels = getAllLoadedAndReachableDataModels( + this.langiumDocuments(), + getContainerOfType(node, isDataModel) + ); + + const authModel = getAuthModel(allDataModels); + if (authModel) { + node.$resolvedType = { decl: authModel, nullable: true }; } } else if (isFutureExpr(node)) { // future() function is resolved to current model diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index e48a17621..e8e8880b5 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -32,7 +32,7 @@ import { import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; import { - getAllDataModelsIncludingImports, + getAllLoadedAndReachableDataModels, isCollectionPredicate, isFutureInvocation, resolveImportUri, @@ -219,18 +219,18 @@ export class ZModelScopeProvider extends DefaultScopeProvider { } private createScopeForAuthModel(node: AstNode, globalScope: Scope) { - const model = getContainerOfType(node, isModel); - if (model) { - const allDataModels = getAllDataModelsIncludingImports( - this.services.shared.workspace.LangiumDocuments, - model - ); - const authModel = getAuthModel(allDataModels); - if (authModel) { - return this.createScopeForModel(authModel, globalScope); - } + // get all data models from loaded and reachable documents + const allDataModels = getAllLoadedAndReachableDataModels( + this.services.shared.workspace.LangiumDocuments, + getContainerOfType(node, isDataModel) + ); + + const authModel = getAuthModel(allDataModels); + if (authModel) { + return this.createScopeForModel(authModel, globalScope); + } else { + return EMPTY_SCOPE; } - return EMPTY_SCOPE; } } diff --git a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts index 18bbd8c72..bf92d1c9d 100644 --- a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts +++ b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts @@ -106,7 +106,7 @@ export function generateAuthType(model: Model, authModel: DataModel) { // } // ` - return `namespace auth { + return `export namespace auth { type WithRequired = T & { [P in K]-?: T[P] }; ${Array.from(types.entries()) .map(([model, fields]) => { diff --git a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts new file mode 100644 index 000000000..7202f2cdf --- /dev/null +++ b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts @@ -0,0 +1,55 @@ +import { getDataModels } from '@zenstackhq/sdk'; +import type { DataModel, DataModelField, Model } from '@zenstackhq/sdk/ast'; +import { lowerCaseFirst } from 'lower-case-first'; +import { P, match } from 'ts-pattern'; + +/** + * Generates a `ModelCheckers` interface that contains a `check` method for each model in the schema. + * + * E.g.: + * + * ```ts + * type CheckerOperation = 'create' | 'read' | 'update' | 'delete'; + * + * export interface ModelCheckers { + * user: { check(op: CheckerOperation, args?: { email?: string; age?: number; }): Promise }, + * ... + * } + * ``` + */ +export function generateCheckerType(model: Model) { + return ` +import type { PolicyCrudKind } from '@zenstackhq/runtime'; + +export interface ModelCheckers { + ${getDataModels(model) + .map((dataModel) => `\t${lowerCaseFirst(dataModel.name)}: ${generateDataModelChecker(dataModel)}`) + .join(',\n')} +} +`; +} + +function generateDataModelChecker(dataModel: DataModel) { + return `{ + check(args: { operation: PolicyCrudKind, where?: ${generateDataModelArgs(dataModel)} }): Promise + }`; +} + +function generateDataModelArgs(dataModel: DataModel) { + return `{ ${dataModel.fields + .filter((field) => isFieldFilterable(field)) + .map((field) => `${field.name}?: ${mapFieldType(field)}`) + .join('; ')} }`; +} + +function isFieldFilterable(field: DataModelField) { + return !!mapFieldType(field); +} + +function mapFieldType(field: DataModelField) { + return match(field.type.type) + .with('Boolean', () => 'boolean') + .with(P.union('BigInt', 'Int', 'Float', 'Decimal'), () => 'number') + .with('String', () => 'string') + .otherwise(() => undefined); +} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 0425b76d0..0993f323c 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -41,6 +41,7 @@ import { trackPrismaSchemaError } from '../../prisma'; import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; import { isDefaultWithAuth } from '../enhancer-utils'; import { generateAuthType } from './auth-type-generator'; +import { generateCheckerType } from './checker-type-generator'; // information of delegate models and their sub models type DelegateInfo = [DataModel, DataModel[]][]; @@ -55,7 +56,7 @@ export class EnhancerGenerator { private readonly outDir: string ) {} - async generate() { + async generate(): Promise<{ dmmf: DMMF.Document | undefined }> { let logicalPrismaClientDir: string | undefined; let dmmf: DMMF.Document | undefined; @@ -89,6 +90,8 @@ export class EnhancerGenerator { const authTypes = authModel ? generateAuthType(this.model, authModel) : ''; const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser'; + const checkerTypes = this.generatePermissionChecker ? generateCheckerType(this.model) : ''; + const enhanceTs = this.project.createSourceFile( path.join(this.outDir, 'enhance.ts'), `import { type EnhancementContext, type EnhancementOptions, type ZodSchemas, type AuthUser } from '@zenstackhq/runtime'; @@ -105,6 +108,8 @@ ${ ${authTypes} +${checkerTypes} + ${ logicalPrismaClientDir ? this.createLogicalPrismaEnhanceFunction(authTypeParam) @@ -126,15 +131,16 @@ import type * as _P from '${prismaImport}'; } private createSimplePrismaEnhanceFunction(authTypeParam: string) { + const returnType = `DbClient${this.generatePermissionChecker ? ' & ModelCheckers' : ''}`; return ` -export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions) { +export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): ${returnType} { return createEnhancement(prisma, { modelMeta, policy, zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), prismaModule: Prisma, ...options - }, context); + }, context) as ${returnType}; } `; } @@ -157,12 +163,16 @@ import type { Prisma, PrismaClient } from '${logicalPrismaClientDir}/index-fixed // overload for plain PrismaClient export function enhance & InternalArgs>( prisma: _PrismaClient, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient${ + this.generatePermissionChecker ? ' & ModelCheckers' : '' + }; // overload for extended PrismaClient export function enhance & InternalArgs>( prisma: DynamicClientExtensionThis, - context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis; + context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis${ + this.generatePermissionChecker ? ' & ModelCheckers' : '' + }; export function enhance(prisma: any, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): any { return createEnhancement(prisma, { @@ -204,54 +214,59 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara // calculate a relative output path to output the logical prisma client into enhancer's output dir const prismaClientOutDir = path.join(path.relative(zmodelDir, this.outDir), LOGICAL_CLIENT_GENERATION_PATH); - try { - await prismaGenerator.generate({ - provider: '@internal', // doesn't matter - schemaPath: this.options.schemaPath, - output: logicalPrismaFile, - overrideClientGenerationPath: prismaClientOutDir, - mode: 'logical', - }); + await prismaGenerator.generate({ + provider: '@internal', // doesn't matter + schemaPath: this.options.schemaPath, + output: logicalPrismaFile, + overrideClientGenerationPath: prismaClientOutDir, + mode: 'logical', + }); - // generate the prisma client + // generate the prisma client - // only run prisma client generator for the logical schema - const prismaClientGeneratorName = this.getPrismaClientGeneratorName(this.model); - let generateCmd = `prisma generate --schema "${logicalPrismaFile}" --generator=${prismaClientGeneratorName}`; + // only run prisma client generator for the logical schema + const prismaClientGeneratorName = this.getPrismaClientGeneratorName(this.model); + let generateCmd = `prisma generate --schema "${logicalPrismaFile}" --generator=${prismaClientGeneratorName}`; - const prismaVersion = getPrismaVersion(); - if (!prismaVersion || semver.gte(prismaVersion, '5.2.0')) { - // add --no-engine to reduce generation size if the prisma version supports - generateCmd += ' --no-engine'; - } + const prismaVersion = getPrismaVersion(); + if (!prismaVersion || semver.gte(prismaVersion, '5.2.0')) { + // add --no-engine to reduce generation size if the prisma version supports + generateCmd += ' --no-engine'; + } + try { + // run 'prisma generate' + await execPackage(generateCmd, { stdio: 'ignore' }); + } catch { + await trackPrismaSchemaError(logicalPrismaFile); try { - // run 'prisma generate' - await execPackage(generateCmd, { stdio: 'ignore' }); + // run 'prisma generate' again with output to the console + await execPackage(generateCmd); } catch { - await trackPrismaSchemaError(logicalPrismaFile); - try { - // run 'prisma generate' again with output to the console - await execPackage(generateCmd); - } catch { - // noop - } - throw new PluginError(name, `Failed to run "prisma generate" on logical schema: ${logicalPrismaFile}`); + // noop } + throw new PluginError(name, `Failed to run "prisma generate" on logical schema: ${logicalPrismaFile}`); + } - // make a bunch of typing fixes to the generated prisma client - await this.processClientTypes(path.join(this.outDir, LOGICAL_CLIENT_GENERATION_PATH)); + // make a bunch of typing fixes to the generated prisma client + await this.processClientTypes(path.join(this.outDir, LOGICAL_CLIENT_GENERATION_PATH)); - return { - prismaSchema: logicalPrismaFile, - // load the dmmf of the logical prisma schema - dmmf: await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) }), - }; - } finally { + const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) }); + + try { + // clean up temp schema if (fs.existsSync(logicalPrismaFile)) { fs.rmSync(logicalPrismaFile); } + } catch { + // ignore errors } + + return { + prismaSchema: logicalPrismaFile, + // load the dmmf of the logical prisma schema + dmmf, + }; } private getPrismaClientGeneratorName(model: Model) { @@ -277,12 +292,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara this.model.declarations .filter((d): d is DataModel => isDelegateModel(d)) .forEach((dm) => { - delegateInfo.push([ - dm, - this.model.declarations.filter( - (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) - ), - ]); + const concreteModels = this.model.declarations.filter( + (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) + ); + if (concreteModels.length > 0) { + delegateInfo.push([dm, concreteModels]); + } }); // transform index.d.ts and save it into a new file (better perf than in-line editing) @@ -622,4 +637,8 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara await sf.save(); } } + + private get generatePermissionChecker() { + return this.options.generatePermissionChecker === true; + } } diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts new file mode 100644 index 000000000..a0b1c1dd2 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts @@ -0,0 +1,379 @@ +import { + getRelationKeyPairs, + isAuthInvocation, + isDataModelFieldReference, + isEnumFieldReference, +} from '@zenstackhq/sdk'; +import { + BinaryExpr, + BooleanLiteral, + DataModelField, + Expression, + ExpressionType, + LiteralExpr, + MemberAccessExpr, + NumberLiteral, + ReferenceExpr, + StringLiteral, + UnaryExpr, + isBinaryExpr, + isDataModelField, + isEnum, + isLiteralExpr, + isMemberAccessExpr, + isNullExpr, + isReferenceExpr, + isThisExpr, + isUnaryExpr, +} from '@zenstackhq/sdk/ast'; +import { P, match } from 'ts-pattern'; + +/** + * Options for {@link ConstraintTransformer}. + */ +export type ConstraintTransformerOptions = { + authAccessor: string; +}; + +/** + * Transform a set of allow and deny rules into a single constraint expression. + */ +export class ConstraintTransformer { + // a counter for generating unique variable names + private varCounter = 0; + + constructor(private readonly options: ConstraintTransformerOptions) {} + + /** + * Transforms a set of allow and deny rules into a single constraint expression. + */ + transformRules(allows: Expression[], denies: Expression[]): string { + // reset state + this.varCounter = 0; + + if (allows.length === 0) { + // unconditionally deny + return this.value('false', 'boolean'); + } + + let result: string; + + // transform allow rules + const allowConstraints = allows.map((allow) => this.transformExpression(allow)); + if (allowConstraints.length > 1) { + result = this.or(...allowConstraints); + } else { + result = allowConstraints[0]; + } + + // transform deny rules and compose + if (denies.length > 0) { + const denyConstraints = denies.map((deny) => this.transformExpression(deny)); + result = this.and(result, ...denyConstraints.map((c) => this.not(c))); + } + + // DEBUG: + // console.log(`Constraint transformation result:\n${JSON.stringify(result, null, 2)}`); + + return result; + } + + private and(...constraints: string[]) { + if (constraints.length === 0) { + throw new Error('No expressions to combine'); + } + return constraints.length === 1 ? constraints[0] : `{ kind: 'and', children: [ ${constraints.join(', ')} ] }`; + } + + private or(...constraints: string[]) { + if (constraints.length === 0) { + throw new Error('No expressions to combine'); + } + return constraints.length === 1 ? constraints[0] : `{ kind: 'or', children: [ ${constraints.join(', ')} ] }`; + } + + private not(constraint: string) { + return `{ kind: 'not', children: [${constraint}] }`; + } + + private transformExpression(expression: Expression) { + return ( + match(expression) + .when(isBinaryExpr, (expr) => this.transformBinary(expr)) + .when(isUnaryExpr, (expr) => this.transformUnary(expr)) + // top-level boolean literal + .when(isLiteralExpr, (expr) => this.transformLiteral(expr)) + // top-level boolean reference expr + .when(isReferenceExpr, (expr) => this.transformReference(expr)) + // top-level boolean member access expr + .when(isMemberAccessExpr, (expr) => this.transformMemberAccess(expr)) + .otherwise(() => this.nextVar()) + ); + } + + private transformLiteral(expr: LiteralExpr) { + return match(expr.$type) + .with(NumberLiteral, () => { + const parsed = parseFloat(expr.value as string); + if (isNaN(parsed) || parsed < 0 || !Number.isInteger(parsed)) { + // only non-negative integers are supported, for other cases, + // transform into a free variable + return this.nextVar('number'); + } + return this.value(expr.value.toString(), 'number'); + }) + .with(StringLiteral, () => this.value(`'${expr.value}'`, 'string')) + .with(BooleanLiteral, () => this.value(expr.value.toString(), 'boolean')) + .exhaustive(); + } + + private transformReference(expr: ReferenceExpr) { + // top-level reference is transformed into a named variable + return this.variable(expr.target.$refText, 'boolean'); + } + + private transformMemberAccess(expr: MemberAccessExpr) { + // "this.x" is transformed into a named variable + if (isThisExpr(expr.operand)) { + return this.variable(expr.member.$refText, 'boolean'); + } + + // top-level auth access + const authAccess = this.getAuthAccess(expr); + if (authAccess) { + return this.value(`${authAccess} ?? false`, 'boolean'); + } + + // other top-level member access expressions are not supported + // and thus transformed into a free variable + return this.nextVar(); + } + + private transformBinary(expr: BinaryExpr): string { + return ( + match(expr.operator) + .with('&&', () => this.and(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with('||', () => this.or(this.transformExpression(expr.left), this.transformExpression(expr.right))) + .with(P.union('==', '!=', '<', '<=', '>', '>='), () => this.transformComparison(expr)) + // unsupported operators (e.g., collection predicate) are transformed into a free variable + .otherwise(() => this.nextVar()) + ); + } + + private transformUnary(expr: UnaryExpr): string { + return match(expr.operator) + .with('!', () => this.not(this.transformExpression(expr.operand))) + .exhaustive(); + } + + private transformComparison(expr: BinaryExpr) { + if (isAuthInvocation(expr.left) || isAuthInvocation(expr.right)) { + // handle the case if any operand is `auth()` invocation + const authComparison = this.transformAuthComparison(expr); + return authComparison ?? this.nextVar(); + } + + const leftOperand = this.getComparisonOperand(expr.left); + const rightOperand = this.getComparisonOperand(expr.right); + + const op = this.mapOperatorToConstraintKind(expr.operator); + const result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`; + + // `auth()` member access can be undefined, when that happens, we assume a false condition + // for the comparison + + const leftAuthAccess = this.getAuthAccess(expr.left); + const rightAuthAccess = this.getAuthAccess(expr.right); + + if (leftAuthAccess && rightOperand) { + // `auth().f op x` => `auth().f !== undefined && auth().f op x` + return this.and(this.value(`${this.normalizeToNull(leftAuthAccess)} !== null`, 'boolean'), result); + } else if (rightAuthAccess && leftOperand) { + // `x op auth().f` => `auth().f !== undefined && x op auth().f` + return this.and(this.value(`${this.normalizeToNull(rightAuthAccess)} !== null`, 'boolean'), result); + } + + if (leftOperand === undefined || rightOperand === undefined) { + // if either operand is not supported, transform into a free variable + return this.nextVar(); + } + + return result; + } + + private transformAuthComparison(expr: BinaryExpr) { + if (this.isAuthEqualNull(expr)) { + // `auth() == null` => `user === null` + return this.value(`${this.options.authAccessor} === null`, 'boolean'); + } + + if (this.isAuthNotEqualNull(expr)) { + // `auth() != null` => `user !== null` + return this.value(`${this.options.authAccessor} !== null`, 'boolean'); + } + + // auth() equality check against a relation, translate to id-fk comparison + const operand = isAuthInvocation(expr.left) ? expr.right : expr.left; + if (!isDataModelFieldReference(operand)) { + return undefined; + } + + // get id-fk field pairs from the relation field + const relationField = operand.target.ref as DataModelField; + const idFkPairs = getRelationKeyPairs(relationField); + + // build id-fk field comparison constraints + const fieldConstraints: string[] = []; + + idFkPairs.forEach(({ id, foreignKey }) => { + const idFieldType = this.mapType(id.type.type as ExpressionType); + if (!idFieldType) { + return; + } + const fkFieldType = this.mapType(foreignKey.type.type as ExpressionType); + if (!fkFieldType) { + return; + } + + const op = this.mapOperatorToConstraintKind(expr.operator); + const authIdAccess = `${this.options.authAccessor}?.${id.name}`; + + fieldConstraints.push( + this.and( + // `auth()?.id != null` guard + this.value(`${this.normalizeToNull(authIdAccess)} !== null`, 'boolean'), + // `auth()?.id [op] fkField` + `{ kind: '${op}', left: ${this.value(authIdAccess, idFieldType)}, right: ${this.variable( + foreignKey.name, + fkFieldType + )} }` + ) + ); + }); + + // combine field constraints + if (fieldConstraints.length > 0) { + return this.and(...fieldConstraints); + } + + return undefined; + } + + // normalize `auth()` access undefined value to null + private normalizeToNull(expr: string) { + return `(${expr} ?? null)`; + } + + private isAuthEqualNull(expr: BinaryExpr) { + return ( + expr.operator === '==' && + ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) || + (isAuthInvocation(expr.right) && isNullExpr(expr.left))) + ); + } + + private isAuthNotEqualNull(expr: BinaryExpr) { + return ( + expr.operator === '!=' && + ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) || + (isAuthInvocation(expr.right) && isNullExpr(expr.left))) + ); + } + + private getComparisonOperand(expr: Expression) { + if (isLiteralExpr(expr)) { + return this.transformLiteral(expr); + } + + if (isEnumFieldReference(expr)) { + return this.value(`'${expr.target.$refText}'`, 'string'); + } + + const fieldAccess = this.getFieldAccess(expr); + if (fieldAccess) { + // model field access is transformed into a named variable + const mappedType = this.mapExpressionType(expr); + if (mappedType) { + return this.variable(fieldAccess.name, mappedType); + } else { + return undefined; + } + } + + const authAccess = this.getAuthAccess(expr); + if (authAccess) { + const mappedType = this.mapExpressionType(expr); + if (mappedType) { + return `${this.value(authAccess, mappedType)}`; + } else { + return undefined; + } + } + + return undefined; + } + + private mapExpressionType(expression: Expression) { + if (isEnum(expression.$resolvedType?.decl)) { + return 'string'; + } else { + return this.mapType(expression.$resolvedType?.decl as ExpressionType); + } + } + + private mapType(type: ExpressionType) { + return match(type) + .with('Boolean', () => 'boolean') + .with('Int', () => 'number') + .with('String', () => 'string') + .otherwise(() => undefined); + } + + private mapOperatorToConstraintKind(operator: BinaryExpr['operator']) { + return match(operator) + .with('==', () => 'eq') + .with('!=', () => 'ne') + .with('<', () => 'lt') + .with('<=', () => 'lte') + .with('>', () => 'gt') + .with('>=', () => 'gte') + .otherwise(() => { + throw new Error(`Unsupported operator: ${operator}`); + }); + } + + private getFieldAccess(expr: Expression) { + if (isReferenceExpr(expr)) { + return isDataModelField(expr.target.ref) ? { name: expr.target.$refText } : undefined; + } + if (isMemberAccessExpr(expr)) { + return isThisExpr(expr.operand) ? { name: expr.member.$refText } : undefined; + } + return undefined; + } + + private getAuthAccess(expr: Expression): string | undefined { + if (!isMemberAccessExpr(expr)) { + return undefined; + } + + if (isAuthInvocation(expr.operand)) { + return `${this.options.authAccessor}?.${expr.member.$refText}`; + } else { + const operand = this.getAuthAccess(expr.operand); + return operand ? `${operand}?.${expr.member.$refText}` : undefined; + } + } + + private nextVar(type = 'boolean') { + return this.variable(`__var${this.varCounter++}`, type); + } + + private variable(name: string, type: string) { + return `{ kind: 'variable', name: '${name}', type: '${type}' }`; + } + + private value(value: string, type: string) { + return `{ kind: 'value', value: ${value}, type: '${type}' }`; + } +} diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 753ef8f19..a36a52126 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -54,6 +54,7 @@ import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, Writ import { name } from '..'; import { isCollectionPredicate } from '../../../utils/ast-utils'; import { ALL_OPERATION_KINDS } from '../../plugin-utils'; +import { ConstraintTransformer } from './constraint-transformer'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; /** @@ -70,6 +71,8 @@ export class PolicyGenerator { { name: 'type CrudContract' }, { name: 'allFieldsEqual' }, { name: 'type PolicyDef' }, + { name: 'type CheckerContext' }, + { name: 'type CheckerConstraint' }, ], moduleSpecifier: `${RUNTIME_PACKAGE}`, }); @@ -85,11 +88,22 @@ export class PolicyGenerator { const models = getDataModels(model); + // policy guard functions const policyMap: Record> = {}; for (const model of models) { policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); } + const generatePermissionChecker = options.generatePermissionChecker === true; + + // CRUD checker functions + const checkerMap: Record> = {}; + if (generatePermissionChecker) { + for (const model of models) { + checkerMap[model.name] = await this.generateCheckerForModel(model, sf); + } + } + const authSelector = this.generateAuthSelector(models); sf.addVariableStatement({ @@ -118,6 +132,22 @@ export class PolicyGenerator { }); writer.writeLine(','); + if (generatePermissionChecker) { + writer.write('checker:'); + writer.inlineBlock(() => { + for (const [model, map] of Object.entries(checkerMap)) { + writer.write(`${lowerCaseFirst(model)}:`); + writer.inlineBlock(() => { + Object.entries(map).forEach(([op, func]) => { + writer.write(`${op}: ${func},`); + }); + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); + } + writer.write('validation:'); writer.inlineBlock(() => { for (const model of models) { @@ -301,7 +331,6 @@ export class PolicyGenerator { } const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind] = guardFunc.getName()!; if (kind === 'postUpdate') { @@ -313,7 +342,6 @@ export class PolicyGenerator { if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) { const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind + '_input'] = inputCheckFunc.getName()!; } } @@ -847,4 +875,70 @@ export class PolicyGenerator { statements.push(`const user: any = context.user ?? null;`); } } + + private async generateCheckerForModel(model: DataModel, sourceFile: SourceFile) { + const result: Record = {}; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const policies = analyzePolicies(model); + + for (const kind of ['create', 'read', 'update', 'delete'] as const) { + if (policies[kind] === true || policies[kind] === false) { + result[kind] = policies[kind] as boolean; + continue; + } + + const denies = this.getPolicyExpressions(model, 'deny', kind); + const allows = this.getPolicyExpressions(model, 'allow', kind); + + if (kind === 'update' && allows.length === 0) { + // no allow rule for 'update', policy is constant based on if there's + // post-update counterpart + if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { + result[kind] = false; + continue; + } else { + result[kind] = true; + continue; + } + } + + const guardFunc = this.generateCheckerFunction(sourceFile, model, kind, allows, denies); + result[kind] = guardFunc.getName()!; + } + + return result; + } + + private generateCheckerFunction( + sourceFile: SourceFile, + model: DataModel, + kind: string, + allows: Expression[], + denies: Expression[] + ) { + const statements: string[] = []; + + this.generateNormalizedAuthRef(model, allows, denies, statements); + + const transformed = new ConstraintTransformer({ + authAccessor: 'user', + }).transformRules(allows, denies); + + statements.push(`return ${transformed};`); + + const func = sourceFile.addFunction({ + name: `${model.name}$checker$${kind}`, + returnType: 'CheckerConstraint', + parameters: [ + { + name: 'context', + type: 'CheckerContext', + }, + ], + statements, + }); + + return func; + } } diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 5440e03cf..58ecd68f4 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -374,6 +374,11 @@ export class PrismaSchemaGenerator { // for the given model, find relation fields of delegate model type, find all concrete models // of the delegate model and generate an auxiliary opposite relation field to each of them decl.fields.forEach((f) => { + // don't process fields inherited from a delegate model + if (f.$inheritedFrom && isDelegateModel(f.$inheritedFrom)) { + return; + } + const fieldType = f.type.reference?.ref; if (!isDataModel(fieldType)) { return; diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index a9b963d30..91f152af8 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -13,7 +13,7 @@ import { } from '@zenstackhq/sdk'; import { DataModel, EnumField, Model, isDataModel, isEnum } from '@zenstackhq/sdk/ast'; import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers'; -import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma'; +import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma'; import { streamAllContents } from 'langium'; import path from 'path'; import type { SourceFile } from 'ts-morph'; @@ -106,8 +106,9 @@ export class ZodSchemaGenerator { aggregateOperationSupport, project: this.project, inputObjectTypes, + zmodel: this.model, }); - await transformer.generateInputSchemas(this.options); + await transformer.generateInputSchemas(this.options, this.model); this.sourceFiles.push(...transformer.sourceFiles); } @@ -189,7 +190,10 @@ export class ZodSchemaGenerator { ); } - private async generateEnumSchemas(prismaSchemaEnum: DMMF.SchemaEnum[], modelSchemaEnum: DMMF.SchemaEnum[]) { + private async generateEnumSchemas( + prismaSchemaEnum: readonly DMMF.SchemaEnum[], + modelSchemaEnum: readonly DMMF.SchemaEnum[] + ) { const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum]; const enumNames = enumTypes.map((enumItem) => upperCaseFirst(enumItem.name)); Transformer.enumNames = enumNames ?? []; @@ -197,6 +201,7 @@ export class ZodSchemaGenerator { enumTypes, project: this.project, inputObjectTypes: [], + zmodel: this.model, }); await transformer.generateEnumSchemas(); this.sourceFiles.push(...transformer.sourceFiles); @@ -210,14 +215,21 @@ export class ZodSchemaGenerator { for (let i = 0; i < inputObjectTypes.length; i += 1) { const fields = inputObjectTypes[i]?.fields; const name = inputObjectTypes[i]?.name; + if (!generateUnchecked && name.includes('Unchecked')) { continue; } + + if (name.includes('CreateMany') && !supportCreateMany(this.model)) { + continue; + } + const transformer = new Transformer({ name, fields, project: this.project, inputObjectTypes, + zmodel: this.model, }); const moduleName = transformer.generateObjectSchema(generateUnchecked, this.options); moduleNames.push(moduleName); diff --git a/packages/schema/src/plugins/zod/transformer.ts b/packages/schema/src/plugins/zod/transformer.ts index 4fb04cdcc..09804f42b 100644 --- a/packages/schema/src/plugins/zod/transformer.ts +++ b/packages/schema/src/plugins/zod/transformer.ts @@ -1,7 +1,8 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ import { indentString, type PluginOptions } from '@zenstackhq/sdk'; +import type { Model } from '@zenstackhq/sdk/ast'; import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers'; -import { type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma'; +import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma'; import path from 'path'; import type { Project, SourceFile } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; @@ -11,12 +12,12 @@ import { AggregateOperationSupport, TransformerParams } from './types'; export default class Transformer { name: string; originalName: string; - fields: PrismaDMMF.SchemaArg[]; + fields: readonly PrismaDMMF.SchemaArg[]; schemaImports = new Set(); - models: PrismaDMMF.Model[]; + models: readonly PrismaDMMF.Model[]; modelOperations: PrismaDMMF.ModelMapping[]; aggregateOperationSupport: AggregateOperationSupport; - enumTypes: PrismaDMMF.SchemaEnum[]; + enumTypes: readonly PrismaDMMF.SchemaEnum[]; static enumNames: string[] = []; static rawOpsMap: { [name: string]: string } = {}; @@ -26,6 +27,7 @@ export default class Transformer { private project: Project; private inputObjectTypes: PrismaDMMF.InputType[]; public sourceFiles: SourceFile[] = []; + private zmodel: Model; constructor(params: TransformerParams) { this.originalName = params.name ?? ''; @@ -37,6 +39,7 @@ export default class Transformer { this.enumTypes = params.enumTypes ?? []; this.project = params.project; this.inputObjectTypes = params.inputObjectTypes; + this.zmodel = params.zmodel; } static setOutputPath(outPath: string) { @@ -118,6 +121,10 @@ export default class Transformer { return result; } + if (inputType.type.includes('CreateMany') && !supportCreateMany(this.zmodel)) { + return result; + } + // TODO: unify the following with `schema-gen.ts` if (inputType.type === 'String') { @@ -389,7 +396,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`; return wrapped; } - async generateInputSchemas(options: PluginOptions) { + async generateInputSchemas(options: PluginOptions, zmodel: Model) { const globalExports: string[] = []; // whether Prisma's Unchecked* series of input types should be generated @@ -489,7 +496,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`; operations.push(['create', origModelName]); } - if (createMany) { + if (createMany && supportCreateMany(zmodel)) { imports.push( `import { ${modelName}CreateManyInputObjectSchema } from '../objects/${modelName}CreateManyInput.schema'` ); diff --git a/packages/schema/src/plugins/zod/types.ts b/packages/schema/src/plugins/zod/types.ts index b64995448..f74d690e4 100644 --- a/packages/schema/src/plugins/zod/types.ts +++ b/packages/schema/src/plugins/zod/types.ts @@ -1,17 +1,19 @@ +import type { Model } from '@zenstackhq/sdk/ast'; import type { DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma'; import { Project } from 'ts-morph'; export type TransformerParams = { - enumTypes?: PrismaDMMF.SchemaEnum[]; - fields?: PrismaDMMF.SchemaArg[]; + enumTypes?: readonly PrismaDMMF.SchemaEnum[]; + fields?: readonly PrismaDMMF.SchemaArg[]; name?: string; - models?: PrismaDMMF.Model[]; + models?: readonly PrismaDMMF.Model[]; modelOperations?: PrismaDMMF.ModelMapping[]; aggregateOperationSupport?: AggregateOperationSupport; isDefaultPrismaClientOutput?: boolean; prismaClientOutputPath?: string; project: Project; inputObjectTypes: PrismaDMMF.InputType[]; + zmodel: Model; }; export type AggregateOperationSupport = { diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index e6a335221..5f3321b94 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -141,10 +141,14 @@ export function makeFieldSchema(field: DataModelField) { } if (field.attributes.some(isDefaultWithAuth)) { - // field uses `auth()` in `@default()`, this was transformed into a pseudo default - // value, while compiling to zod we should turn it into an optional field instead - // of `.default()` - schema += '.nullish()'; + if (field.type.optional) { + schema += '.nullish()'; + } else { + // field uses `auth()` in `@default()`, this was transformed into a pseudo default + // value, while compiling to zod we should turn it into an optional field instead + // of `.default()` + schema += '.optional()'; + } } else { const schemaDefault = getFieldSchemaDefault(field); if (schemaDefault !== undefined) { diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 83a5a6a57..e891056c2 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -17,7 +17,7 @@ import { ModelImport, ReferenceExpr, } from '@zenstackhq/language/ast'; -import { isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; +import { getModelFieldsWithBases, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, copyAstNode, @@ -47,16 +47,13 @@ type BuildReference = ( refText: string ) => Reference; -export function mergeBaseModel(model: Model, linker: Linker) { +export function mergeBaseModels(model: Model, linker: Linker) { const buildReference = linker.buildReference.bind(linker); - model.declarations.filter(isDataModel).forEach((decl) => { - const dataModel = decl as DataModel; - + model.declarations.filter(isDataModel).forEach((dataModel) => { const bases = getRecursiveBases(dataModel).reverse(); if (bases.length > 0) { dataModel.fields = bases - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion .flatMap((base) => base.fields) // don't inherit skip-level fields .filter((f) => !f.$inheritedFrom) @@ -67,16 +64,25 @@ export function mergeBaseModel(model: Model, linker: Linker) { .flatMap((base) => base.attributes.filter((attr) => filterBaseAttribute(base, attr))) .map((attr) => cloneAst(attr, dataModel, buildReference)) .concat(dataModel.attributes); - - // fix $containerIndex - linkContentToContainer(dataModel); } + // mark base merged dataModel.$baseMerged = true; }); // remove abstract models model.declarations = model.declarations.filter((x) => !(isDataModel(x) && x.isAbstract)); + + model.declarations.filter(isDataModel).forEach((dm) => { + // remove abstract super types + dm.superTypes = dm.superTypes.filter((t) => t.ref && isDelegateModel(t.ref)); + + // fix $containerIndex + linkContentToContainer(dm); + }); + + // fix $containerIndex after deleting abstract models + linkContentToContainer(model); } function filterBaseAttribute(base: DataModel, attr: DataModelAttribute) { @@ -244,29 +250,6 @@ export function getContainingDataModel(node: Expression): DataModel | undefined return undefined; } -export function getModelFieldsWithBases(model: DataModel, includeDelegate = true) { - if (model.$baseMerged) { - return model.fields; - } else { - return [...model.fields, ...getRecursiveBases(model, includeDelegate).flatMap((base) => base.fields)]; - } -} - -export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): DataModel[] { - const result: DataModel[] = []; - dataModel.superTypes.forEach((superType) => { - const baseDecl = superType.ref; - if (baseDecl) { - if (!includeDelegate && isDelegateModel(baseDecl)) { - return; - } - result.push(baseDecl); - result.push(...getRecursiveBases(baseDecl)); - } - }); - return result; -} - /** * Walk upward from the current AST node to find the first node that satisfies the predicate. */ @@ -280,3 +263,36 @@ export function findUpAst(node: AstNode, predicate: (node: AstNode) => boolean): } return undefined; } + +/** + * Gets all data models from all loaded documents + */ +export function getAllLoadedDataModels(langiumDocuments: LangiumDocuments) { + return langiumDocuments.all + .map((doc) => doc.parseResult.value as Model) + .flatMap((model) => model.declarations.filter(isDataModel)) + .toArray(); +} + +/** + * Gets all data models from loaded and reachable documents + */ +export function getAllLoadedAndReachableDataModels(langiumDocuments: LangiumDocuments, fromModel?: DataModel) { + // get all data models from loaded documents + const allDataModels = getAllLoadedDataModels(langiumDocuments); + + if (fromModel) { + // merge data models transitively reached from the current model + const model = getContainerOfType(fromModel, isModel); + if (model) { + const transitiveDataModels = getAllDataModelsIncludingImports(langiumDocuments, model); + transitiveDataModels.forEach((dm) => { + if (!allDataModels.includes(dm)) { + allDataModels.push(dm); + } + }); + } + } + + return allDataModels; +} diff --git a/packages/schema/src/utils/pkg-utils.ts b/packages/schema/src/utils/pkg-utils.ts index 82b5ef019..067810a8e 100644 --- a/packages/schema/src/utils/pkg-utils.ts +++ b/packages/schema/src/utils/pkg-utils.ts @@ -13,7 +13,7 @@ export type PackageManagers = 'npm' | 'yarn' | 'pnpm'; * @export * @template e A type parameter that extends boolean */ -export type FindUp = e extends true ? string[] | undefined : string | undefined +export type FindUp = e extends true ? string[] | undefined : string | undefined; /** * Find and return file paths by searching parent directories based on the given names list and current working directory (cwd) path. * Optionally return a single path or multiple paths. @@ -28,7 +28,12 @@ export type FindUp = e extends true ? string[] | undefined : * @param [result=[]] An array of strings representing the accumulated results used in multiple results * @returns Path(s) to a specific file or folder within the directory or parent directories */ -export function findUp(names: string[], cwd: string = process.cwd(), multiple: e = false as e, result: string[] = []): FindUp { +export function findUp( + names: string[], + cwd: string = process.cwd(), + multiple: e = false as e, + result: string[] = [] +): FindUp { if (!names.some((name) => !!name)) return undefined; const target = names.find((name) => fs.existsSync(path.join(cwd, name))); if (multiple == false && target) return path.join(cwd, target) as FindUp; @@ -38,7 +43,6 @@ export function findUp(names: string[], cwd: string = return findUp(names, up, multiple, result); } - /** * Find a Node module/file given its name in a specific directory, with a fallback to the current working directory. * If the name is empty, return undefined. @@ -54,11 +58,11 @@ export function findNodeModulesFile(name: string, cwd: string = process.cwd()) { if (!name) return undefined; try { // Use require.resolve to find the module/file. The paths option allows specifying the directory to start from. - const resolvedPath = require.resolve(name, { paths: [cwd] }) - return resolvedPath + const resolvedPath = require.resolve(name, { paths: [cwd] }); + return resolvedPath; } catch (error) { // If require.resolve fails to find the module/file, it will throw an error. - return undefined + return undefined; } } @@ -86,7 +90,7 @@ export function installPackage( projectPath = '.', exactVersion = true ) { - const manager = pkgManager ?? getPackageManager(projectPath); + const manager = pkgManager ?? getPackageManager(projectPath).packageManager; console.log(`Installing package "${pkg}@${tag}" with ${manager}`); switch (manager) { case 'yarn': diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index aca9e2674..380836e21 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -1051,7 +1051,7 @@ describe('Attribute tests', () => { @@allow('all', auth() != null) } `) - ).toContain(`auth() cannot be resolved because no model marked wth "@@auth()" or named "User" is found`); + ).toContain(`auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found`); await loadModel(` ${prelude} diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index 0bf12245d..e0778da51 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -521,7 +521,7 @@ describe('Data Model Validation Tests', () => { `) ).toMatchObject( errorLike( - `Field "aId" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute` + `Field "aId" on model "B" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute` ) ); @@ -736,6 +736,26 @@ describe('Data Model Validation Tests', () => { ).toMatchObject( errorLike(`The relation field "user" on model "A" is missing an opposite relation field on model "User"`) ); + + // one-to-one relation field and @@unique defined in different models + await loadModel(` + ${prelude} + + abstract model Base { + id String @id + a A @relation(fields: [aId], references: [id]) + aId String + } + + model A { + id String @id + b B? + } + + model B extends Base{ + @@unique([aId]) + } + `); }); it('delegate base type', async () => { diff --git a/packages/schema/tests/utils.ts b/packages/schema/tests/utils.ts index 0c92c6ee2..e72373f82 100644 --- a/packages/schema/tests/utils.ts +++ b/packages/schema/tests/utils.ts @@ -5,7 +5,7 @@ import * as path from 'path'; import * as tmp from 'tmp'; import { URI } from 'vscode-uri'; import { createZModelServices } from '../src/language-server/zmodel-module'; -import { mergeBaseModel } from '../src/utils/ast-utils'; +import { mergeBaseModels } from '../src/utils/ast-utils'; tmp.setGracefulCleanup(); @@ -68,7 +68,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; if (mergeBase) { - mergeBaseModel(model, ZModel.references.Linker); + mergeBaseModels(model, ZModel.references.Linker); } return model; @@ -87,13 +87,13 @@ export async function loadModelWithError(content: string, verbose = false) { } export async function safelyLoadModel(content: string, validate = true, verbose = false) { - const [ result ] = await Promise.allSettled([ loadModel(content, validate, verbose) ]); + const [result] = await Promise.allSettled([loadModel(content, validate, verbose)]); - return result + return result; } export const errorLike = (msg: string) => ({ reason: { - message: expect.stringContaining(msg) + message: expect.stringContaining(msg), }, -}) +}); diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 7bb3d3d2f..e0489d48f 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "2.0.3", + "version": "2.1.0", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { @@ -18,8 +18,8 @@ "author": "", "license": "MIT", "dependencies": { - "@prisma/generator-helper": "5.7.0", - "@prisma/internals": "5.7.0", + "@prisma/generator-helper": "^5.13.0", + "@prisma/internals": "^5.13.0", "@zenstackhq/language": "workspace:*", "@zenstackhq/runtime": "workspace:*", "langium": "1.3.1", diff --git a/packages/sdk/src/dmmf-helpers/include-helpers.ts b/packages/sdk/src/dmmf-helpers/include-helpers.ts index c09c72426..23c9a9e12 100644 --- a/packages/sdk/src/dmmf-helpers/include-helpers.ts +++ b/packages/sdk/src/dmmf-helpers/include-helpers.ts @@ -1,7 +1,10 @@ import type { DMMF } from '../prisma'; import { checkIsModelRelationField, checkModelHasManyModelRelation, checkModelHasModelRelation } from './model-helpers'; -export function addMissingInputObjectTypesForInclude(inputObjectTypes: DMMF.InputType[], models: DMMF.Model[]) { +export function addMissingInputObjectTypesForInclude( + inputObjectTypes: DMMF.InputType[], + models: readonly DMMF.Model[] +) { // generate input object types necessary to support ModelInclude with relation support const generatedIncludeInputObjectTypes = generateModelIncludeInputObjectTypes(models); @@ -9,7 +12,7 @@ export function addMissingInputObjectTypesForInclude(inputObjectTypes: DMMF.Inpu inputObjectTypes.push(includeInputObjectType); } } -function generateModelIncludeInputObjectTypes(models: DMMF.Model[]) { +function generateModelIncludeInputObjectTypes(models: readonly DMMF.Model[]) { const modelIncludeInputObjectTypes: DMMF.InputType[] = []; for (const model of models) { const { name: modelName, fields: modelFields } = model; diff --git a/packages/sdk/src/dmmf-helpers/model-helpers.ts b/packages/sdk/src/dmmf-helpers/model-helpers.ts index 62bbd9980..2528eeeed 100644 --- a/packages/sdk/src/dmmf-helpers/model-helpers.ts +++ b/packages/sdk/src/dmmf-helpers/model-helpers.ts @@ -31,6 +31,6 @@ export function checkIsManyModelRelationField(modelField: DMMF.Field) { return checkIsModelRelationField(modelField) && modelField.isList; } -export function findModelByName(models: DMMF.Model[], modelName: string) { +export function findModelByName(models: readonly DMMF.Model[], modelName: string) { return models.find(({ name }) => name === modelName); } diff --git a/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts b/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts index 79b3a9f98..fb47b0096 100644 --- a/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts +++ b/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts @@ -1,14 +1,17 @@ import type { DMMF } from '../prisma'; import { checkModelHasModelRelation } from './model-helpers'; -export function addMissingInputObjectTypesForModelArgs(inputObjectTypes: DMMF.InputType[], models: DMMF.Model[]) { +export function addMissingInputObjectTypesForModelArgs( + inputObjectTypes: DMMF.InputType[], + models: readonly DMMF.Model[] +) { const modelArgsInputObjectTypes = generateModelArgsInputObjectTypes(models); for (const modelArgsInputObjectType of modelArgsInputObjectTypes) { inputObjectTypes.push(modelArgsInputObjectType); } } -function generateModelArgsInputObjectTypes(models: DMMF.Model[]) { +function generateModelArgsInputObjectTypes(models: readonly DMMF.Model[]) { const modelArgsInputObjectTypes: DMMF.InputType[] = []; for (const model of models) { const { name: modelName } = model; diff --git a/packages/sdk/src/dmmf-helpers/select-helpers.ts b/packages/sdk/src/dmmf-helpers/select-helpers.ts index 6037eecd8..a5e3e68a1 100644 --- a/packages/sdk/src/dmmf-helpers/select-helpers.ts +++ b/packages/sdk/src/dmmf-helpers/select-helpers.ts @@ -4,7 +4,7 @@ import { checkIsModelRelationField, checkModelHasManyModelRelation } from './mod export function addMissingInputObjectTypesForSelect( inputObjectTypes: DMMF.InputType[], outputObjectTypes: DMMF.OutputType[], - models: DMMF.Model[] + models: readonly DMMF.Model[] ) { // generate input object types necessary to support ModelSelect._count const modelCountOutputTypes = getModelCountOutputTypes(outputObjectTypes); @@ -89,7 +89,7 @@ function generateModelCountOutputTypeArgsInputObjectTypes(modelCountOutputTypes: return modelCountOutputTypeArgsInputObjectTypes; } -function generateModelSelectInputObjectTypes(models: DMMF.Model[]) { +function generateModelSelectInputObjectTypes(models: readonly DMMF.Model[]) { const modelSelectInputObjectTypes: DMMF.InputType[] = []; for (const model of models) { const { name: modelName, fields: modelFields } = model; @@ -97,16 +97,9 @@ function generateModelSelectInputObjectTypes(models: DMMF.Model[]) { for (const modelField of modelFields) { const { name: modelFieldName, isList, type } = modelField; + const inputTypes = [{ isList: false, type: 'Boolean', location: 'scalar' }]; const isRelationField = checkIsModelRelationField(modelField); - - const field: DMMF.SchemaArg = { - name: modelFieldName, - isRequired: false, - isNullable: false, - inputTypes: [{ isList: false, type: 'Boolean', location: 'scalar' }], - }; - if (isRelationField) { const schemaArgInputType: DMMF.InputTypeRef = { isList: false, @@ -114,9 +107,15 @@ function generateModelSelectInputObjectTypes(models: DMMF.Model[]) { location: 'inputObjectTypes', namespace: 'prisma', }; - field.inputTypes.push(schemaArgInputType); + inputTypes.push(schemaArgInputType); } + const field: DMMF.SchemaArg = { + name: modelFieldName, + isRequired: false, + isNullable: false, + inputTypes: inputTypes as DMMF.InputTypeRef[], + }; fields.push(field); } diff --git a/packages/sdk/src/prisma.ts b/packages/sdk/src/prisma.ts index b45dd7cfb..1edab3c3d 100644 --- a/packages/sdk/src/prisma.ts +++ b/packages/sdk/src/prisma.ts @@ -4,8 +4,11 @@ import type { DMMF } from '@prisma/generator-helper'; import { getDMMF as _getDMMF, type GetDMMFOptions } from '@prisma/internals'; import { DEFAULT_RUNTIME_LOAD_PATH } from '@zenstackhq/runtime'; import path from 'path'; +import semver from 'semver'; +import { Model } from './ast'; import { RUNTIME_PACKAGE } from './constants'; import type { PluginOptions } from './types'; +import { getDataSourceProvider } from './utils'; /** * Given an import context directory and plugin options, compute the import spec for the Prisma Client. @@ -75,4 +78,14 @@ export function getPrismaVersion(): string | undefined { } } +/** + * Returns if the given model supports `createMany` operation. + */ +export function supportCreateMany(model: Model) { + // `createMany` is supported for sqlite since Prisma 5.12.0 + const prismaVersion = getPrismaVersion(); + const dsProvider = getDataSourceProvider(model); + return dsProvider !== 'sqlite' || (prismaVersion && semver.gte(prismaVersion, '5.12.0')); +} + export type { DMMF } from '@prisma/generator-helper'; diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 6617983aa..df83186b5 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -17,6 +17,7 @@ import { isConfigArrayExpr, isDataModel, isDataModelField, + isDataSource, isEnumField, isExpression, isGeneratorDecl, @@ -298,9 +299,9 @@ export function isForeignKeyField(field: DataModelField) { } /** - * Gets the foreign key fields of the given relation field. + * Gets the foreign key-id field pairs from the given relation field. */ -export function getForeignKeyFields(relationField: DataModelField) { +export function getRelationKeyPairs(relationField: DataModelField) { if (!isRelationshipField(relationField)) { return []; } @@ -309,11 +310,31 @@ export function getForeignKeyFields(relationField: DataModelField) { if (relAttr) { // find "fields" arg const fieldsArg = getAttributeArg(relAttr, 'fields'); + let fkFields: DataModelField[]; if (fieldsArg && isArrayExpr(fieldsArg)) { - return fieldsArg.items + fkFields = fieldsArg.items .filter((item): item is ReferenceExpr => isReferenceExpr(item)) .map((item) => item.target.ref as DataModelField); + } else { + return []; + } + + // find "references" arg + const referencesArg = getAttributeArg(relAttr, 'references'); + let idFields: DataModelField[]; + if (referencesArg && isArrayExpr(referencesArg)) { + idFields = referencesArg.items + .filter((item): item is ReferenceExpr => isReferenceExpr(item)) + .map((item) => item.target.ref as DataModelField); + } else { + return []; + } + + if (idFields.length !== fkFields.length) { + throw new Error(`Relation's references arg and fields are must have equal length`); } + + return idFields.map((idField, i) => ({ id: idField, foreignKey: fkFields[i] })); } return []; @@ -347,7 +368,7 @@ export function resolvePath(_path: string, options: Pick(options: PluginDeclaredOptions, name: string, pluginName: string): T { const value = options[name]; if (value === undefined) { - throw new PluginError(pluginName, `Plugin "${options.name}" is missing required option: ${name}`); + throw new PluginError(pluginName, `required option "${name}" is not provided`); } return value as T; } @@ -481,17 +502,24 @@ export function getDataModelFieldReference(expr: Expression): DataModelField | u } } -export function getModelFieldsWithBases(model: DataModel) { - return [...model.fields, ...getRecursiveBases(model).flatMap((base) => base.fields)]; +export function getModelFieldsWithBases(model: DataModel, includeDelegate = true) { + if (model.$baseMerged) { + return model.fields; + } else { + return [...model.fields, ...getRecursiveBases(model, includeDelegate).flatMap((base) => base.fields)]; + } } -export function getRecursiveBases(dataModel: DataModel): DataModel[] { +export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): DataModel[] { const result: DataModel[] = []; dataModel.superTypes.forEach((superType) => { const baseDecl = superType.ref; if (baseDecl) { + if (!includeDelegate && isDelegateModel(baseDecl)) { + return; + } result.push(baseDecl); - result.push(...getRecursiveBases(baseDecl)); + result.push(...getRecursiveBases(baseDecl, includeDelegate)); } }); return result; @@ -511,3 +539,18 @@ export function ensureEmptyDir(dir: string) { throw new Error(`Path "${dir}" already exists and is not a directory`); } } + +/** + * Gets the data source provider from the given model. + */ +export function getDataSourceProvider(model: Model) { + const dataSource = model.declarations.find(isDataSource); + if (!dataSource) { + return undefined; + } + const provider = dataSource?.fields.find((f) => f.name === 'provider'); + if (!provider) { + return undefined; + } + return getLiteral(provider.value); +} diff --git a/packages/server/package.json b/packages/server/package.json index 412979a3f..390378154 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "2.0.3", + "version": "2.1.0", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/server/src/api/rpc/index.ts b/packages/server/src/api/rpc/index.ts index a7fb44d72..a8882b8be 100644 --- a/packages/server/src/api/rpc/index.ts +++ b/packages/server/src/api/rpc/index.ts @@ -81,6 +81,7 @@ class RequestHandler extends APIHandlerBase { case 'aggregate': case 'groupBy': case 'count': + case 'check': if (method !== 'GET') { return { status: 400, diff --git a/packages/server/tests/api/rpc.test.ts b/packages/server/tests/api/rpc.test.ts index 432abec2c..56ad8c5e0 100644 --- a/packages/server/tests/api/rpc.test.ts +++ b/packages/server/tests/api/rpc.test.ts @@ -15,7 +15,7 @@ describe('RPC API Handler Tests', () => { let zodSchemas: any; beforeAll(async () => { - const params = await loadSchema(schema, { fullZod: true }); + const params = await loadSchema(schema, { fullZod: true, generatePermissionChecker: true }); prisma = params.prisma; enhance = params.enhance; modelMeta = params.modelMeta; @@ -131,6 +131,37 @@ describe('RPC API Handler Tests', () => { expect(r.data.count).toBe(1); }); + it('check', async () => { + const handleRequest = makeHandler(); + + let r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read' }) }, + prisma: enhance(), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(true); + + r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read', where: { published: false } }) }, + prisma: enhance(), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(false); + + r = await handleRequest({ + method: 'get', + path: '/post/check', + query: { q: JSON.stringify({ operation: 'read', where: { authorId: '1', published: false } }) }, + prisma: enhance({ id: '1' }), + }); + expect(r.status).toBe(200); + expect(r.data).toEqual(true); + }); + it('policy violation', async () => { await prisma.user.create({ data: { diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 7e24868d6..8cb635d5b 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "2.0.3", + "version": "2.1.0", "description": "ZenStack Test Tools", "main": "index.js", "private": true, diff --git a/packages/testtools/src/model.ts b/packages/testtools/src/model.ts index adbc89453..b6e2a3b71 100644 --- a/packages/testtools/src/model.ts +++ b/packages/testtools/src/model.ts @@ -5,7 +5,7 @@ import * as path from 'path'; import * as tmp from 'tmp'; import { URI } from 'vscode-uri'; import { createZModelServices } from 'zenstack/language-server/zmodel-module'; -import { mergeBaseModel } from 'zenstack/utils/ast-utils'; +import { mergeBaseModels } from 'zenstack/utils/ast-utils'; tmp.setGracefulCleanup(); @@ -53,7 +53,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; - mergeBaseModel(model, ZModel.references.Linker); + mergeBaseModels(model, ZModel.references.Linker); return model; } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 7249e6c4a..4495ddf14 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -97,11 +97,13 @@ datasource db { generator js { provider = 'prisma-client-js' + ${options.previewFeatures ? `previewFeatures = ${JSON.stringify(options.previewFeatures)}` : ''} } plugin enhancer { provider = '@core/enhancer' ${options.preserveTsFiles ? 'preserveTsFiles = true' : ''} + ${options.generatePermissionChecker ? 'generatePermissionChecker = true' : ''} } plugin zod { @@ -131,6 +133,8 @@ export type SchemaLoadOptions = { extraSourceFiles?: { name: string; content: string }[]; projectDir?: string; preserveTsFiles?: boolean; + generatePermissionChecker?: boolean; + previewFeatures?: string[]; }; const defaultOptions: SchemaLoadOptions = { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 5ce4ac030..a8e8c7318 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -30,7 +30,7 @@ importers: specifier: ^2.4.1 version: 2.4.1 eslint: - specifier: ^8.56.0 + specifier: ^8.57.0 version: 8.57.0 eslint-plugin-jest: specifier: ^28.2.0 @@ -211,6 +211,9 @@ importers: ts-morph: specifier: ^16.0.0 version: 16.0.0 + ts-pattern: + specifier: ^4.3.0 + version: 4.3.0 upper-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -409,6 +412,9 @@ importers: deepmerge: specifier: ^4.3.1 version: 4.3.1 + logic-solver: + specifier: ^2.0.1 + version: 2.0.1 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -427,6 +433,9 @@ importers: tiny-invariant: specifier: ^1.3.1 version: 1.3.1 + ts-pattern: + specifier: ^4.3.0 + version: 4.3.0 tslib: specifier: ^2.4.1 version: 2.4.1 @@ -612,11 +621,11 @@ importers: packages/sdk: dependencies: '@prisma/generator-helper': - specifier: 5.7.0 - version: 5.7.0 + specifier: ^5.13.0 + version: 5.13.0 '@prisma/internals': - specifier: 5.7.0 - version: 5.7.0 + specifier: ^5.13.0 + version: 5.13.0 '@zenstackhq/language': specifier: workspace:* version: link:../language/dist @@ -3868,7 +3877,6 @@ packages: /@prisma/debug@5.13.0: resolution: {integrity: sha512-699iqlEvzyCj9ETrXhs8o8wQc/eVW+FigSsHpiskSFydhjVuwTJEfj/nIYqTaWFYuxiWQRfm3r01meuW97SZaQ==} - dev: true /@prisma/debug@5.7.0: resolution: {integrity: sha512-tZ+MOjWlVvz1kOEhNYMa4QUGURY+kgOUBqLHYIV8jmCsMuvA1tWcn7qtIMLzYWCbDcQT4ZS8xDgK0R2gl6/0wA==} @@ -3876,7 +3884,6 @@ packages: /@prisma/engines-version@5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b: resolution: {integrity: sha512-AyUuhahTINGn8auyqYdmxsN+qn0mw3eg+uhkp8zwknXYIqoT3bChG4RqNY/nfDkPvzWAPBa9mrDyBeOnWSgO6A==} - dev: true /@prisma/engines-version@5.7.0-41.79fb5193cf0a8fdbef536e4b4a159cad677ab1b9: resolution: {integrity: sha512-V6tgRVi62jRwTm0Hglky3Scwjr/AKFBFtS+MdbsBr7UOuiu1TKLPc6xfPiyEN1+bYqjEtjxwGsHgahcJsd1rNg==} @@ -3890,7 +3897,6 @@ packages: '@prisma/engines-version': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b '@prisma/fetch-engine': 5.13.0 '@prisma/get-platform': 5.13.0 - dev: true /@prisma/engines@5.7.0: resolution: {integrity: sha512-TkOMgMm60n5YgEKPn9erIvFX2/QuWnl3GBo6yTRyZKk5O5KQertXiNnrYgSLy0SpsKmhovEPQb+D4l0SzyE7XA==} @@ -3908,7 +3914,6 @@ packages: '@prisma/debug': 5.13.0 '@prisma/engines-version': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b '@prisma/get-platform': 5.13.0 - dev: true /@prisma/fetch-engine@5.7.0: resolution: {integrity: sha512-zIn/qmO+N/3FYe7/L9o+yZseIU8ivh4NdPKSkQRIHfg2QVTVMnbhGoTcecbxfVubeTp+DjcbjS0H9fCuM4W04w==} @@ -3918,6 +3923,12 @@ packages: '@prisma/get-platform': 5.7.0 dev: false + /@prisma/generator-helper@5.13.0: + resolution: {integrity: sha512-i+53beJ0dxkDrkHdsXxmeMf+eVhyhOIpL0SdBga8vwe0qHPrAIJ/lpuT/Hj0y5awTmq40qiUEmhXwCEuM/Z17w==} + dependencies: + '@prisma/debug': 5.13.0 + dev: false + /@prisma/generator-helper@5.7.0: resolution: {integrity: sha512-Fn4hJHKGJ49+E8sxpfslRauB3Goa3RAENJ/W25NMR754B9KxvmbCJyE3MT/lIZxML2nGgIdXYUtoDHZHnRaKDw==} dependencies: @@ -3928,7 +3939,6 @@ packages: resolution: {integrity: sha512-B/WrQwYTzwr7qCLifQzYOmQhZcFmIFhR81xC45gweInSUn2hTEbfKUPd2keAog+y5WI5xLAFNJ3wkXplvSVkSw==} dependencies: '@prisma/debug': 5.13.0 - dev: true /@prisma/get-platform@5.7.0: resolution: {integrity: sha512-ZeV/Op4bZsWXuw5Tg05WwRI8BlKiRFhsixPcAM+5BKYSiUZiMKIi713tfT3drBq8+T0E1arNZgYSA9QYcglWNA==} @@ -3936,6 +3946,20 @@ packages: '@prisma/debug': 5.7.0 dev: false + /@prisma/internals@5.13.0: + resolution: {integrity: sha512-OPMzS+IBPzCLT4s+IfGUbOhGFY51CFbokIFMZuoSeLKWE8UvDlitiXZ3OlVqDPUc0AlH++ysQHzDISHbZD+ZUg==} + dependencies: + '@prisma/debug': 5.13.0 + '@prisma/engines': 5.13.0 + '@prisma/fetch-engine': 5.13.0 + '@prisma/generator-helper': 5.13.0 + '@prisma/get-platform': 5.13.0 + '@prisma/prisma-schema-wasm': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b + '@prisma/schema-files-loader': 5.13.0 + arg: 5.0.2 + prompts: 2.4.2 + dev: false + /@prisma/internals@5.7.0: resolution: {integrity: sha512-O9x47W1DECAyvNjYUx6oZHmTX10emKuBgsFHZemUbkIcJdCsp3X8Cy2JMJ5z3hqkRX6a6omMamFsWjuTARoaSw==} dependencies: @@ -3949,10 +3973,20 @@ packages: prompts: 2.4.2 dev: false + /@prisma/prisma-schema-wasm@5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b: + resolution: {integrity: sha512-+IhHvuE1wKlyOpJgwAhGop1oqEt+1eixrCeikBIshRhdX6LwjmtRxVxVMlP5nS1yyughmpfkysIW4jZTa+Zjuw==} + dev: false + /@prisma/prisma-schema-wasm@5.7.0-41.79fb5193cf0a8fdbef536e4b4a159cad677ab1b9: resolution: {integrity: sha512-w+HdQtux0dJDEn6BG3fgNn+fXErXiekj9n//uHRAgrmZghockJkhnikOmG8aSXjTb1Tu5DrGasBX+rYX6rHT1w==} dev: false + /@prisma/schema-files-loader@5.13.0: + resolution: {integrity: sha512-6sVMoqobkWKsmzb98LfLiIt/aFRucWfkzSUBsqk7sc+h99xjynJt6aKtM2SSkyndFdWpRU0OiCHfQ9UlYUEJIw==} + dependencies: + fs-extra: 11.1.1 + dev: false + /@readme/better-ajv-errors@1.6.0(ajv@8.12.0): resolution: {integrity: sha512-9gO9rld84Jgu13kcbKRU+WHseNhaVt76wYMeRDGsUGYxwJtI3RmEJ9LY9dZCYQGI8eUZLuxb5qDja0nqklpFjQ==} engines: {node: '>=14'} @@ -8492,7 +8526,6 @@ packages: graceful-fs: 4.2.11 jsonfile: 6.1.0 universalify: 2.0.0 - dev: true /fs-extra@7.0.1: resolution: {integrity: sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw==} @@ -8793,7 +8826,6 @@ packages: /graceful-fs@4.2.11: resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} - dev: true /grapheme-splitter@1.0.4: resolution: {integrity: sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==} @@ -10151,7 +10183,6 @@ packages: universalify: 2.0.0 optionalDependencies: graceful-fs: 4.2.11 - dev: true /jsonpointer@5.0.1: resolution: {integrity: sha512-p/nXbhSEcu3pZRdkW1OfJhpsVtW1gd4Wa1fnQc9YLiTfAjn0312eMKimbdIQzuZl9aa9xUGaRlP9T/CJE/ditQ==} @@ -10495,6 +10526,12 @@ packages: wrap-ansi: 8.1.0 dev: false + /logic-solver@2.0.1: + resolution: {integrity: sha512-F1oCywXUzvAF4Z98mMyXySUCpUU3hNyc+JfYV3g2x/4BupC/xv94iPJuHh9us2XX5UrvY5lnKUXNvjcJNQBJ/g==} + dependencies: + underscore: 1.13.6 + dev: false + /loose-envify@1.4.0: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true @@ -14372,7 +14409,6 @@ packages: /underscore@1.13.6: resolution: {integrity: sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A==} - dev: true /undici-types@5.26.5: resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} @@ -14441,7 +14477,6 @@ packages: /universalify@2.0.0: resolution: {integrity: sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==} engines: {node: '>= 10.0.0'} - dev: true /unpipe@1.0.0: resolution: {integrity: sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==} diff --git a/tests/integration/tests/e2e/type-coverage.test.ts b/tests/integration/tests/e2e/type-coverage.test.ts index c8c88211c..2a41b3ffd 100644 --- a/tests/integration/tests/e2e/type-coverage.test.ts +++ b/tests/integration/tests/e2e/type-coverage.test.ts @@ -13,14 +13,14 @@ describe('Type Coverage Tests', () => { model Foo { id String @id @default(cuid()) - string String - int Int - bigInt BigInt - date DateTime - float Float - decimal Decimal - boolean Boolean - bytes Bytes + String String + Int Int + BigInt BigInt + DateTime DateTime + Float Float + Decimal Decimal + Boolean Boolean + Bytes Bytes @@allow('all', true) } @@ -41,14 +41,14 @@ describe('Type Coverage Tests', () => { const date = new Date(); const data = { id: '1', - string: 'string', - int: 100, - bigInt: BigInt(9007199254740991), - date, - float: 1.23, - decimal: new Decimal(1.2345), - boolean: true, - bytes: Buffer.from('hello'), + String: 'string', + Int: 100, + BigInt: BigInt(9007199254740991), + DateTime: date, + Float: 1.23, + Decimal: new Decimal(1.2345), + Boolean: true, + Bytes: Buffer.from('hello'), }; await db.foo.create({ diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts new file mode 100644 index 000000000..e4ca61fad --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts @@ -0,0 +1,652 @@ +import { SchemaLoadOptions, createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools'; + +describe('Permission checker', () => { + const PRELUDE = ` + datasource db { + provider = 'sqlite' + url = 'file:./dev.db' + } + + generator js { + provider = 'prisma-client-js' + } + + plugin enhancer { + provider = '@core/enhancer' + generatePermissionChecker = true + } + `; + + const load = (schema: string, options?: SchemaLoadOptions) => + loadSchema(schema, { + ...options, + generatePermissionChecker: true, + }); + + it('checker generation not enabled', async () => { + const { enhance } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', true) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).rejects.toThrow('Generated permission checkers not found'); + }); + + it('empty rules', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveFalsy(); + }); + + it('unconditional allow', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', true) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveTruthy(); + }); + + it('multiple allow rules', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', value == 1) + @@allow('all', value == 2) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy(); + }); + + it('deny rule', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('all', value > 0) + @@deny('all', value == 1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy(); + }); + + it('int field condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value == 1) + @@allow('create', value != 1) + @@allow('update', value > 1) + @@allow('delete', value <= 1) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', where: { value: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', where: { value: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', where: { value: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'update', where: { value: 2 } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', where: { value: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', where: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', where: { value: 2 } })).toResolveFalsy(); + }); + + it('boolean field toplevel condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Boolean + @@allow('read', value) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: true } })).toResolveTruthy(); + }); + + it('boolean field condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Boolean + @@allow('read', value == true) + @@allow('create', value == false) + @@allow('update', value != true) + @@allow('delete', value != false) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: true } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', where: { value: true } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'create', where: { value: false } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', where: { value: true } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'update', where: { value: false } })).toResolveTruthy(); + + await expect(db.model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'delete', where: { value: false } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'delete', where: { value: true } })).toResolveTruthy(); + }); + + it('string field condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + @@allow('read', value == 'admin') + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 'user' } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { value: 'admin' } })).toResolveTruthy(); + }); + + it('enum', async () => { + const dbUrl = await createPostgresDb('permission-checker-enum'); + let prisma: any; + try { + const r = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = '${dbUrl}' + } + + generator js { + provider = 'prisma-client-js' + } + + plugin enhancer { + provider = '@core/enhancer' + generatePermissionChecker = true + } + + enum Role { + USER + ADMIN + } + model User { + id Int @id @default(autoincrement()) + role Role + } + model Model { + id Int @id @default(autoincrement()) + @@allow('read', auth().role == ADMIN) + } + `, + { addPrelude: false, generatePermissionChecker: true } + ); + + prisma = r.prisma; + const enhance = r.enhance; + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, role: 'USER' }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, role: 'ADMIN' }).model.check({ operation: 'read' })).toResolveTruthy(); + } finally { + await prisma.$disconnect(); + await dropPostgresDb('permission-checker-enum'); + } + }); + + it('function noop', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + @@allow('read', startsWith(value, 'admin')) + @@allow('update', !startsWith(value, 'admin')) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 'user' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 'admin' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', where: { value: 'user' } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', where: { value: 'admin' } })).toResolveTruthy(); + }); + + it('relation noop', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + foo Foo? + + @@allow('read', foo.x > 0) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + modelId Int @unique + model Model @relation(fields: [modelId], references: [id]) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { foo: { x: 0 } } })).rejects.toThrow( + 'Providing filter for field "foo"' + ); + }); + + it('collection predicate noop', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value String + foo Foo[] + + @@allow('read', foo?[x > 0]) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + modelId Int + model Model @relation(fields: [modelId], references: [id]) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { foo: [{ x: 0 }] } })).rejects.toThrow( + 'Providing filter for field "foo"' + ); + }); + + it('field complex condition', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int + @@allow('read', x > 0 && x > y) + @@allow('create', x > 1 || x > y) + @@allow('update', !(x >= y)) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { x: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { x: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { x: 1, y: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { x: 1, y: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', where: { x: 0 } })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check({ operation: 'create', where: { x: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', where: { x: 1, y: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', where: { x: 1, y: 1 } })).toResolveFalsy(); + + await expect(db.model.check({ operation: 'update' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', where: { x: 0 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'update', where: { y: 0 } })).toResolveFalsy(); // numbers are non-negative + await expect(db.model.check({ operation: 'update', where: { x: 1, y: 1 } })).toResolveFalsy(); + }); + + it('field condition unsolvable', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int + @@allow('read', x > 0 && x < y && y <= 1) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { x: 0 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { x: 1 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { x: 1, y: 2 } })).toResolveFalsy(); + await expect(db.model.check({ operation: 'read', where: { x: 1, y: 1 } })).toResolveFalsy(); + }); + + it('simple auth condition', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + level Int + admin Boolean + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth().level > 0) + @@allow('create', auth().admin) + @@allow('update', !auth().admin) + } + ` + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'create' })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'update' })).toResolveTruthy(); + }); + + it('auth compared with relation field', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + } + + model Model { + id Int @id @default(autoincrement()) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + @@allow('read', auth().id == ownerId) + @@allow('create', auth().id != ownerId) + @@allow('update', auth() == owner) + @@allow('delete', auth() != owner) + } + `, + { preserveTsFiles: true } + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read', where: { ownerId: 1 } })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read', where: { ownerId: 2 } })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create', where: { ownerId: 1 } })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create', where: { ownerId: 2 } })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update', where: { ownerId: 1 } })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update', where: { ownerId: 2 } })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'delete' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete', where: { ownerId: 1 } })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'delete', where: { ownerId: 2 } })).toResolveTruthy(); + }); + + it('auth null check', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + level Int + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth() != null) + @@allow('create', auth() == null) + @@allow('update', auth().level > 0) + } + ` + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy(); + + await expect(enhance().model.check({ operation: 'create' })).toResolveTruthy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy(); + + await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'update' })).toResolveFalsy(); + await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'update' })).toResolveTruthy(); + }); + + it('auth with relation access', async () => { + const { enhance } = await load( + ` + model User { + id Int @id @default(autoincrement()) + profile Profile? + } + + model Profile { + id Int @id @default(autoincrement()) + level Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', auth().profile.level > 0) + } + ` + ); + + await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, profile: { level: 0 } }).model.check({ operation: 'read' })).toResolveFalsy(); + await expect(enhance({ id: 1, profile: { level: 1 } }).model.check({ operation: 'read' })).toResolveTruthy(); + }); + + it('nullable field', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int? + @@allow('read', value != null) + @@allow('create', value == null) + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'create', where: { value: 1 } })).toResolveTruthy(); + }); + + it('compilation', async () => { + await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value == 1) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { PrismaClient } from '@prisma/client'; + import { enhance } from '.zenstack/enhance'; + + const prisma = new PrismaClient(); + const db = enhance(prisma); + db.model.check({ operation: 'read' }); + db.model.check({ operation: 'read', where: { value: 1 }}); + `, + }, + ], + } + ); + }); + + it('invalid filter', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + foo Foo? + d DateTime + + @@allow('read', value == 1) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + model Model @relation(fields: [modelId], references: [id]) + modelId Int @unique + } + ` + ); + + const db = enhance(); + await expect(db.model.check({ operation: 'read', where: { foo: { x: 1 } } })).rejects.toThrow( + `Providing filter for field "foo" is not supported. Only scalar fields are allowed.` + ); + await expect(db.model.check({ operation: 'read', where: { d: new Date() } })).rejects.toThrow( + `Providing filter for field "d" is not supported. Only number, string, and boolean fields are allowed.` + ); + await expect(db.model.check({ operation: 'read', where: { value: null } })).rejects.toThrow( + `Using "null" as filter value is not supported yet` + ); + await expect(db.model.check({ operation: 'read', where: { value: {} } })).rejects.toThrow( + 'Invalid value type for field "value". Only number, string or boolean is allowed.' + ); + await expect(db.model.check({ operation: 'read', where: { value: 'abc' } })).rejects.toThrow( + 'Invalid value type for field "value". Expected "number"' + ); + await expect(db.model.check({ operation: 'read', where: { value: -1 } })).rejects.toThrow( + 'Invalid value for field "value". Only non-negative integers are allowed.' + ); + }); + + it('float field ignored', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Float + @@allow('read', value == 1.1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy(); + }); + + it('float value ignored', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value > 1.1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy(); + }); + + it('negative value ignored', async () => { + const { enhance } = await load( + ` + model Model { + id Int @id @default(autoincrement()) + value Int + @@allow('read', value >-1) + } + ` + ); + const db = enhance(); + await expect(db.model.check({ operation: 'read' })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy(); + await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts new file mode 100644 index 000000000..d46c31245 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts @@ -0,0 +1,57 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('prisma omit', () => { + it('test', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + name String + profile Profile? + age Int + value Int @allow('read', age > 20) + @@allow('all', age > 18) + } + + model Profile { + id String @id @default(cuid()) + user User @relation(fields: [userId], references: [id]) + userId String @unique + level Int + @@allow('all', level > 1) + } + `, + { previewFeatures: ['omitApi'], logPrismaQuery: true } + ); + + await prisma.user.create({ + data: { + name: 'John', + age: 25, + value: 10, + profile: { + create: { level: 2 }, + }, + }, + }); + + const db = enhance(); + let found = await db.user.findFirst({ + include: { profile: { omit: { level: true } } }, + omit: { + age: true, + }, + }); + expect(found.age).toBeUndefined(); + expect(found.value).toEqual(10); + expect(found.profile.level).toBeUndefined(); + + found = await db.user.findFirst({ + select: { value: true, profile: { omit: { level: true } } }, + }); + console.log(found); + expect(found.age).toBeUndefined(); + expect(found.value).toEqual(10); + expect(found.profile.level).toBeUndefined(); + }); +}); diff --git a/tests/integration/tsconfig.json b/tests/integration/tsconfig.json index c6cc8d4a7..2771cd805 100644 --- a/tests/integration/tsconfig.json +++ b/tests/integration/tsconfig.json @@ -8,5 +8,5 @@ "skipLibCheck": true, "experimentalDecorators": true }, - "include": ["**/*.ts", "**/*.d.ts", "../regression/tests/issue-177.test.ts", "../regression/tests/issue-416.test.ts", "../regression/tests/issue-646.test.ts", "../regression/tests/issue-657.test.ts", "../regression/tests/issue-665.test.ts", "../regression/tests/issue-674.test.ts", "../regression/tests/issue-689.test.ts", "../regression/tests/issue-703.test.ts", "../regression/tests/issue-714.test.ts", "../regression/tests/issue-724.test.ts", "../regression/tests/issue-735.test.ts", "../regression/tests/issue-744.test.ts", "../regression/tests/issue-756.test.ts", "../regression/tests/issue-764.test.ts", "../regression/tests/issue-765.test.ts", "../regression/tests/issue-804.test.ts", "../regression/tests/issue-811.test.ts", "../regression/tests/issue-814.test.ts", "../regression/tests/issue-825.test.ts", "../regression/tests/issue-864.test.ts", "../regression/tests/issue-886.test.ts", "../regression/tests/issue-925.test.ts", "../regression/tests/issue-947.test.ts", "../regression/tests/issue-961.test.ts", "../regression/tests/issue-965.test.ts", "../regression/tests/issue-971.test.ts", "../regression/tests/issue-992.test.ts", "../regression/tests/issue-1014.test.ts", "../regression/tests/issue-1078.test.ts", "../regression/tests/issue-1080.test.ts", "../regression/tests/issue-1129.test.ts", "../regression/tests/issue-1167.test.ts", "../regression/tests/issue-1179.test.ts", "../regression/tests/issue-1186.test.ts", "../regression/tests/issue-1210.test.ts", "../regression/tests/issue-1235.test.ts", "../regression/tests/issue-1241.test.ts", "../regression/tests/issue-1257.test.ts", "../regression/tests/issue-1265.test.ts", "../regression/tests/issues.test.ts"] + "include": ["**/*.ts", "**/*.d.ts"] } diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1058.test.ts b/tests/regression/tests/issue-1058.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/issue-1058.test.ts rename to tests/regression/tests/issue-1058.test.ts diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1064.test.ts b/tests/regression/tests/issue-1064.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/issue-1064.test.ts rename to tests/regression/tests/issue-1064.test.ts diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1100.test.ts b/tests/regression/tests/issue-1100.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/issue-1100.test.ts rename to tests/regression/tests/issue-1100.test.ts diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1123.test.ts b/tests/regression/tests/issue-1123.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/issue-1123.test.ts rename to tests/regression/tests/issue-1123.test.ts diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1135.test.ts b/tests/regression/tests/issue-1135.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/issue-1135.test.ts rename to tests/regression/tests/issue-1135.test.ts diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1149.test.ts b/tests/regression/tests/issue-1149.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/issue-1149.test.ts rename to tests/regression/tests/issue-1149.test.ts diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts b/tests/regression/tests/issue-1243.test.ts similarity index 100% rename from tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts rename to tests/regression/tests/issue-1243.test.ts diff --git a/tests/regression/tests/issue-1378.test.ts b/tests/regression/tests/issue-1378.test.ts new file mode 100644 index 000000000..29d4b16a8 --- /dev/null +++ b/tests/regression/tests/issue-1378.test.ts @@ -0,0 +1,47 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1378', () => { + it('regression', async () => { + await loadSchema( + ` + model User { + id String @id @default(cuid()) + todos Todo[] + } + + model Todo { + id String @id @default(cuid()) + name String @length(3,255) + userId String @default(auth().id) + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + @@allow("all", auth() == user) + } + `, + { + extraDependencies: ['zod'], + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { z } from 'zod'; + import { PrismaClient } from '@prisma/client'; + import { enhance } from '.zenstack/enhance'; + import { TodoCreateSchema } from '.zenstack/zod/models'; + + const prisma = new PrismaClient(); + const db = enhance(prisma); + + export const onSubmit = async (values: z.infer) => { + await db.todo.create({ + data: values, + }); + }; + `, + }, + ], + compile: true, + } + ); + }); +}); diff --git a/tests/regression/tests/issue-1388.test.ts b/tests/regression/tests/issue-1388.test.ts new file mode 100644 index 000000000..3ffbc967b --- /dev/null +++ b/tests/regression/tests/issue-1388.test.ts @@ -0,0 +1,26 @@ +import { FILE_SPLITTER, loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1388', () => { + it('regression', async () => { + await loadSchema( + `schema.zmodel + import './auth' + import './post' + + ${FILE_SPLITTER}auth.zmodel + model User { + id String @id @default(cuid()) + role String + } + + ${FILE_SPLITTER}post.zmodel + model Post { + id String @id @default(nanoid(6)) + title String + @@deny('all', auth() == null) + @@allow('all', auth().id == 'user1') + } + ` + ); + }); +}); diff --git a/tests/regression/tests/issue-1410.test.ts b/tests/regression/tests/issue-1410.test.ts new file mode 100644 index 000000000..488cd1bf0 --- /dev/null +++ b/tests/regression/tests/issue-1410.test.ts @@ -0,0 +1,146 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1410', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` + model Drink { + id Int @id @default(autoincrement()) + slug String @unique + + manufacturer_id Int + manufacturer Manufacturer @relation(fields: [manufacturer_id], references: [id]) + + type String + + name String @unique + description String + abv Float + image String? + + gluten Boolean + lactose Boolean + organic Boolean + + containers Container[] + + @@delegate(type) + + @@allow('all', true) + } + + model Beer extends Drink { + style_id Int + style BeerStyle @relation(fields: [style_id], references: [id]) + + ibu Float? + + @@allow('all', true) + } + + model BeerStyle { + id Int @id @default(autoincrement()) + + name String @unique + color String + + beers Beer[] + + @@allow('all', true) + } + + model Wine extends Drink { + style_id Int + style WineStyle @relation(fields: [style_id], references: [id]) + + heavy_score Int? + tannine_score Int? + dry_score Int? + fresh_score Int? + notes String? + + @@allow('all', true) + } + + model WineStyle { + id Int @id @default(autoincrement()) + + name String @unique + color String + + wines Wine[] + + @@allow('all', true) + } + + model Soda extends Drink { + carbonated Boolean + + @@allow('all', true) + } + + model Cocktail extends Drink { + mix Boolean + + @@allow('all', true) + } + + model Container { + barcode String @id + + drink_id Int + drink Drink @relation(fields: [drink_id], references: [id]) + + type String + volume Int + portions Int? + + inventory Int @default(0) + + @@allow('all', true) + } + + model Manufacturer { + id Int @id @default(autoincrement()) + + country_id String + country Country @relation(fields: [country_id], references: [code]) + + name String @unique + description String? + image String? + + drinks Drink[] + + @@allow('all', true) + } + + model Country { + code String @id + name String + + manufacturers Manufacturer[] + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await db.beer.findMany({ + include: { style: true, manufacturer: true }, + where: { NOT: { gluten: true } }, + }); + + await db.beer.findMany({ + include: { style: true, manufacturer: true }, + where: { AND: [{ gluten: true }, { abv: { gt: 50 } }] }, + }); + + await db.beer.findMany({ + include: { style: true, manufacturer: true }, + where: { OR: [{ AND: [{ NOT: { gluten: true } }] }, { abv: { gt: 50 } }] }, + }); + }); +}); diff --git a/tests/regression/tests/issue-1415.test.ts b/tests/regression/tests/issue-1415.test.ts new file mode 100644 index 000000000..791413557 --- /dev/null +++ b/tests/regression/tests/issue-1415.test.ts @@ -0,0 +1,22 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1415', () => { + it('regression', async () => { + await loadSchema( + ` + model User { + id String @id @default(cuid()) + prices Price[] + } + + model Price { + id String @id @default(cuid()) + owner User @relation(fields: [ownerId], references: [id]) + ownerId String @default(auth().id) + priceType String + @@delegate(priceType) + } + ` + ); + }); +}); diff --git a/tests/regression/tests/issue-1416.test.ts b/tests/regression/tests/issue-1416.test.ts new file mode 100644 index 000000000..5c18d6d4e --- /dev/null +++ b/tests/regression/tests/issue-1416.test.ts @@ -0,0 +1,37 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1416', () => { + it('regression', async () => { + await loadSchema( + ` + model User { + id String @id @default(cuid()) + role String + } + + model Price { + id String @id @default(nanoid(6)) + entity Entity? @relation(fields: [entityId], references: [id]) + entityId String? + priceType String + @@delegate(priceType) + } + + model MyPrice extends Price { + foo String + } + + model Entity { + id String @id @default(nanoid(6)) + price Price[] + type String + @@delegate(type) + } + + model MyEntity extends Entity { + foo String + } + ` + ); + }); +}); diff --git a/tests/regression/tests/issue-1427.test.ts b/tests/regression/tests/issue-1427.test.ts new file mode 100644 index 000000000..0d7c7c07e --- /dev/null +++ b/tests/regression/tests/issue-1427.test.ts @@ -0,0 +1,42 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1427', () => { + it('regression', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + name String + profile Profile? + @@allow('all', true) + } + + model Profile { + id String @id @default(cuid()) + user User @relation(fields: [userId], references: [id]) + userId String @unique + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ + data: { + name: 'John', + profile: { + create: {}, + }, + }, + }); + + const db = enhance(); + const found = await db.user.findFirst({ + select: { + id: true, + name: true, + profile: false, + }, + }); + expect(found.profile).toBeUndefined(); + }); +}); diff --git a/tests/regression/tests/issue-1435.test.ts b/tests/regression/tests/issue-1435.test.ts new file mode 100644 index 000000000..0093aff8b --- /dev/null +++ b/tests/regression/tests/issue-1435.test.ts @@ -0,0 +1,119 @@ +import { createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1435', () => { + it('regression', async () => { + let prisma: any; + let dbUrl: string; + + try { + dbUrl = await createPostgresDb('issue-1435'); + const r = await loadSchema( + ` + /* Interfaces */ + abstract model IBase { + updatedAt DateTime @updatedAt + createdAt DateTime @default(now()) + } + + abstract model IAuth extends IBase { + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String @unique + + @@allow('create', true) + @@allow('all', auth() == user) + } + + abstract model IIntegration extends IBase { + organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) + organizationId String @unique + + @@allow('all', organization.members?[user == auth() && type == OWNER]) + @@allow('read', organization.members?[user == auth()]) + } + + /* Auth Stuff */ + model User extends IBase { + id String @id @default(cuid()) + firstName String + lastName String + google GoogleAuth? + memberships Member[] + + @@allow('create', true) + @@allow('all', auth() == this) + } + + model GoogleAuth extends IAuth { + reference String @id + refreshToken String + } + + /* Org Stuff */ + enum MemberType { + OWNER + MEMBER + } + + model Organization extends IBase { + id String @id @default(cuid()) + name String + members Member[] + google GoogleIntegration? + + @@allow('create', true) + @@allow('all', members?[user == auth() && type == OWNER]) + @@allow('read', members?[user == auth()]) + } + + + model Member extends IBase { + type MemberType @default(MEMBER) + organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade) + organizationId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + + @@id([organizationId, userId]) + @@allow('all', organization.members?[user == auth() && type == OWNER]) + @@allow('read', user == auth()) + } + + /* Google Stuff */ + model GoogleIntegration extends IIntegration { + reference String @id + } + `, + { provider: 'postgresql', dbUrl, logPrismaQuery: true } + ); + + prisma = r.prisma; + const enhance = r.enhance; + + await prisma.organization.create({ + data: { + name: 'My Organization', + members: { + create: { + type: 'OWNER', + user: { + create: { + id: '1', + firstName: 'John', + lastName: 'Doe', + }, + }, + }, + }, + }, + }); + + const db = enhance({ id: '1' }); + await expect(db.organization.findMany()).resolves.toHaveLength(1); + } finally { + if (prisma) { + await prisma.$disconnect(); + } + await dropPostgresDb('issue-1435'); + } + }); +}); diff --git a/tests/regression/tests/issue-961.test.ts b/tests/regression/tests/issue-961.test.ts index 7bc42071b..f6dc3a135 100644 --- a/tests/regression/tests/issue-961.test.ts +++ b/tests/regression/tests/issue-961.test.ts @@ -123,10 +123,8 @@ describe('Regression: issue 961', () => { await expect(db.userColumn.findMany()).resolves.toHaveLength(1); }); - // disabled because of Prisma V4 bug: https://github.com/prisma/prisma/issues/18371 - // eslint-disable-next-line jest/no-disabled-tests - it.skip('updateMany', async () => { - const { prisma, enhance } = await loadSchema(schema, { logPrismaQuery: true }); + it('updateMany', async () => { + const { prisma, enhance } = await loadSchema(schema); const user = await prisma.user.create({ data: { diff --git a/tests/regression/tsconfig.json b/tests/regression/tsconfig.json index c6cc8d4a7..2771cd805 100644 --- a/tests/regression/tsconfig.json +++ b/tests/regression/tsconfig.json @@ -8,5 +8,5 @@ "skipLibCheck": true, "experimentalDecorators": true }, - "include": ["**/*.ts", "**/*.d.ts", "../regression/tests/issue-177.test.ts", "../regression/tests/issue-416.test.ts", "../regression/tests/issue-646.test.ts", "../regression/tests/issue-657.test.ts", "../regression/tests/issue-665.test.ts", "../regression/tests/issue-674.test.ts", "../regression/tests/issue-689.test.ts", "../regression/tests/issue-703.test.ts", "../regression/tests/issue-714.test.ts", "../regression/tests/issue-724.test.ts", "../regression/tests/issue-735.test.ts", "../regression/tests/issue-744.test.ts", "../regression/tests/issue-756.test.ts", "../regression/tests/issue-764.test.ts", "../regression/tests/issue-765.test.ts", "../regression/tests/issue-804.test.ts", "../regression/tests/issue-811.test.ts", "../regression/tests/issue-814.test.ts", "../regression/tests/issue-825.test.ts", "../regression/tests/issue-864.test.ts", "../regression/tests/issue-886.test.ts", "../regression/tests/issue-925.test.ts", "../regression/tests/issue-947.test.ts", "../regression/tests/issue-961.test.ts", "../regression/tests/issue-965.test.ts", "../regression/tests/issue-971.test.ts", "../regression/tests/issue-992.test.ts", "../regression/tests/issue-1014.test.ts", "../regression/tests/issue-1078.test.ts", "../regression/tests/issue-1080.test.ts", "../regression/tests/issue-1129.test.ts", "../regression/tests/issue-1167.test.ts", "../regression/tests/issue-1179.test.ts", "../regression/tests/issue-1186.test.ts", "../regression/tests/issue-1210.test.ts", "../regression/tests/issue-1235.test.ts", "../regression/tests/issue-1241.test.ts", "../regression/tests/issue-1257.test.ts", "../regression/tests/issue-1265.test.ts", "../regression/tests/issues.test.ts"] + "include": ["**/*.ts", "**/*.d.ts"] }
CodeRabbit
CodeRabbit
Johann Rohn
Johann Rohn
Benjamin Zecirovic
Benjamin Zecirovic