diff --git a/README.md b/README.md
index 620056619..32091aaa7 100644
--- a/README.md
+++ b/README.md
@@ -234,6 +234,7 @@ Thank you for your support!
+ CodeRabbit |
Johann Rohn |
Benjamin Zecirovic |
diff --git a/package.json b/package.json
index 2b9bdf1ea..9b0b3d158 100644
--- a/package.json
+++ b/package.json
@@ -1,6 +1,6 @@
{
"name": "zenstack-monorepo",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "",
"scripts": {
"build": "pnpm -r build",
@@ -29,7 +29,7 @@
"@typescript-eslint/parser": "^7.6.0",
"concurrently": "^7.4.0",
"copyfiles": "^2.4.1",
- "eslint": "^8.56.0",
+ "eslint": "^8.57.0",
"eslint-plugin-jest": "^28.2.0",
"jest": "^29.7.0",
"replace-in-file": "^7.0.1",
diff --git a/packages/ide/jetbrains/build.gradle.kts b/packages/ide/jetbrains/build.gradle.kts
index ad5cb1600..24b570e37 100644
--- a/packages/ide/jetbrains/build.gradle.kts
+++ b/packages/ide/jetbrains/build.gradle.kts
@@ -9,7 +9,7 @@ plugins {
}
group = "dev.zenstack"
-version = "2.0.3"
+version = "2.1.0"
repositories {
mavenCentral()
diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json
index 51621e680..296e9da55 100644
--- a/packages/ide/jetbrains/package.json
+++ b/packages/ide/jetbrains/package.json
@@ -1,6 +1,6 @@
{
"name": "jetbrains",
- "version": "2.0.3",
+ "version": "2.1.0",
"displayName": "ZenStack JetBrains IDE Plugin",
"description": "ZenStack JetBrains IDE plugin",
"homepage": "https://zenstack.dev",
diff --git a/packages/language/package.json b/packages/language/package.json
index 4de5ca0da..56121c456 100644
--- a/packages/language/package.json
+++ b/packages/language/package.json
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/language",
- "version": "2.0.3",
+ "version": "2.1.0",
"displayName": "ZenStack modeling language compiler",
"description": "ZenStack modeling language compiler",
"homepage": "https://zenstack.dev",
diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts
index a95a748d9..74fa1707b 100644
--- a/packages/language/src/generated/ast.ts
+++ b/packages/language/src/generated/ast.ts
@@ -84,6 +84,12 @@ export function isRegularID(item: unknown): item is RegularID {
return item === 'model' || item === 'enum' || item === 'attribute' || item === 'datasource' || item === 'plugin' || item === 'abstract' || item === 'in' || item === 'view' || item === 'import' || (typeof item === 'string' && (/[_a-zA-Z][\w_]*/.test(item)));
}
+export type RegularIDWithTypeNames = 'Any' | 'BigInt' | 'Boolean' | 'Bytes' | 'DateTime' | 'Decimal' | 'Float' | 'Int' | 'Json' | 'Null' | 'Object' | 'String' | 'Unsupported' | RegularID;
+
+export function isRegularIDWithTypeNames(item: unknown): item is RegularIDWithTypeNames {
+ return isRegularID(item) || item === 'String' || item === 'Boolean' || item === 'Int' || item === 'BigInt' || item === 'Float' || item === 'Decimal' || item === 'DateTime' || item === 'Json' || item === 'Bytes' || item === 'Null' || item === 'Object' || item === 'Any' || item === 'Unsupported';
+}
+
export type TypeDeclaration = DataModel | Enum;
export const TypeDeclaration = 'TypeDeclaration';
@@ -288,7 +294,7 @@ export interface DataModelField extends AstNode {
readonly $type: 'DataModelField';
attributes: Array
comments: Array
- name: RegularID
+ name: RegularIDWithTypeNames
type: DataModelFieldType
}
diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts
index 45aa3ff97..08656b422 100644
--- a/packages/language/src/generated/grammar.ts
+++ b/packages/language/src/generated/grammar.ts
@@ -70,7 +70,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@63"
+ "$ref": "#/rules@64"
},
"arguments": []
}
@@ -140,7 +140,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@47"
+ "$ref": "#/rules@48"
},
"arguments": []
}
@@ -162,7 +162,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -222,7 +222,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -282,7 +282,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -333,7 +333,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -393,7 +393,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -481,7 +481,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@64"
+ "$ref": "#/rules@65"
},
"arguments": []
}
@@ -503,7 +503,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@63"
+ "$ref": "#/rules@64"
},
"arguments": []
}
@@ -525,7 +525,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@57"
+ "$ref": "#/rules@58"
},
"arguments": []
}
@@ -649,7 +649,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@62"
+ "$ref": "#/rules@63"
},
"arguments": []
}
@@ -747,7 +747,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@62"
+ "$ref": "#/rules@63"
},
"arguments": []
}
@@ -1055,7 +1055,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@62"
+ "$ref": "#/rules@63"
},
"arguments": []
}
@@ -1176,7 +1176,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@63"
+ "$ref": "#/rules@64"
},
"arguments": []
}
@@ -1894,7 +1894,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": []
},
@@ -2032,7 +2032,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@51"
+ "$ref": "#/rules@52"
},
"arguments": []
}
@@ -2066,7 +2066,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": []
},
@@ -2079,7 +2079,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@46"
+ "$ref": "#/rules@47"
},
"arguments": []
}
@@ -2103,7 +2103,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@50"
+ "$ref": "#/rules@51"
},
"arguments": []
},
@@ -2134,7 +2134,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@56"
+ "$ref": "#/rules@57"
},
"arguments": []
}
@@ -2262,7 +2262,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": []
},
@@ -2310,7 +2310,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@51"
+ "$ref": "#/rules@52"
},
"arguments": []
}
@@ -2344,7 +2344,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": []
},
@@ -2369,7 +2369,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@50"
+ "$ref": "#/rules@51"
},
"arguments": []
},
@@ -2393,7 +2393,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -2506,7 +2506,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@52"
+ "$ref": "#/rules@53"
},
"arguments": []
},
@@ -2530,7 +2530,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -2598,7 +2598,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@55"
+ "$ref": "#/rules@56"
},
"arguments": []
}
@@ -2662,7 +2662,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@62"
+ "$ref": "#/rules@63"
},
"arguments": []
},
@@ -2711,6 +2711,81 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"parameters": [],
"wildcard": false
},
+ {
+ "$type": "ParserRule",
+ "name": "RegularIDWithTypeNames",
+ "dataType": "string",
+ "definition": {
+ "$type": "Alternatives",
+ "elements": [
+ {
+ "$type": "RuleCall",
+ "rule": {
+ "$ref": "#/rules@46"
+ },
+ "arguments": []
+ },
+ {
+ "$type": "Keyword",
+ "value": "String"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Boolean"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Int"
+ },
+ {
+ "$type": "Keyword",
+ "value": "BigInt"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Float"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Decimal"
+ },
+ {
+ "$type": "Keyword",
+ "value": "DateTime"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Json"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Bytes"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Null"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Object"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Any"
+ },
+ {
+ "$type": "Keyword",
+ "value": "Unsupported"
+ }
+ ]
+ },
+ "definesHiddenTokens": false,
+ "entry": false,
+ "fragment": false,
+ "hiddenTokens": [],
+ "parameters": [],
+ "wildcard": false
+ },
{
"$type": "ParserRule",
"name": "Attribute",
@@ -2724,7 +2799,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": []
},
@@ -2744,21 +2819,21 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@59"
+ "$ref": "#/rules@60"
},
"arguments": []
},
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@60"
+ "$ref": "#/rules@61"
},
"arguments": []
},
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@61"
+ "$ref": "#/rules@62"
},
"arguments": []
}
@@ -2779,7 +2854,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@48"
+ "$ref": "#/rules@49"
},
"arguments": []
}
@@ -2798,7 +2873,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@48"
+ "$ref": "#/rules@49"
},
"arguments": []
}
@@ -2820,7 +2895,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@52"
+ "$ref": "#/rules@53"
},
"arguments": []
},
@@ -2848,7 +2923,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": []
},
@@ -2887,7 +2962,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@49"
+ "$ref": "#/rules@50"
},
"arguments": []
}
@@ -2899,7 +2974,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@52"
+ "$ref": "#/rules@53"
},
"arguments": []
},
@@ -2933,7 +3008,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@55"
+ "$ref": "#/rules@56"
},
"arguments": []
},
@@ -3024,12 +3099,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "CrossReference",
"type": {
- "$ref": "#/rules@47"
+ "$ref": "#/rules@48"
},
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@61"
+ "$ref": "#/rules@62"
},
"arguments": []
},
@@ -3046,7 +3121,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@53"
+ "$ref": "#/rules@54"
},
"arguments": [],
"cardinality": "?"
@@ -3076,7 +3151,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@65"
+ "$ref": "#/rules@66"
},
"arguments": [],
"cardinality": "*"
@@ -3088,12 +3163,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "CrossReference",
"type": {
- "$ref": "#/rules@47"
+ "$ref": "#/rules@48"
},
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@60"
+ "$ref": "#/rules@61"
},
"arguments": []
},
@@ -3110,7 +3185,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@53"
+ "$ref": "#/rules@54"
},
"arguments": [],
"cardinality": "?"
@@ -3144,12 +3219,12 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "CrossReference",
"type": {
- "$ref": "#/rules@47"
+ "$ref": "#/rules@48"
},
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@59"
+ "$ref": "#/rules@60"
},
"arguments": []
},
@@ -3166,7 +3241,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
{
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@53"
+ "$ref": "#/rules@54"
},
"arguments": [],
"cardinality": "?"
@@ -3201,7 +3276,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@54"
+ "$ref": "#/rules@55"
},
"arguments": []
}
@@ -3220,7 +3295,7 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel
"terminal": {
"$type": "RuleCall",
"rule": {
- "$ref": "#/rules@54"
+ "$ref": "#/rules@55"
},
"arguments": []
}
diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium
index 8fcc72c34..132350bcc 100644
--- a/packages/language/src/zmodel.langium
+++ b/packages/language/src/zmodel.langium
@@ -190,7 +190,7 @@ DataModel:
DataModelField:
(comments+=TRIPLE_SLASH_COMMENT)*
- name=RegularID type=DataModelFieldType (attributes+=DataModelFieldAttribute)*;
+ name=RegularIDWithTypeNames type=DataModelFieldType (attributes+=DataModelFieldAttribute)*;
DataModelFieldType:
(type=BuiltinType | unsupported=UnsupportedFieldType | reference=[TypeDeclaration:RegularID]) (array?='[' ']')? (optional?='?')?;
@@ -226,6 +226,9 @@ RegularID returns string:
// include keywords that we'd like to work as ID in most places
ID | 'model' | 'enum' | 'attribute' | 'datasource' | 'plugin' | 'abstract' | 'in' | 'view' | 'import';
+RegularIDWithTypeNames returns string:
+ RegularID | 'String' | 'Boolean' | 'Int' | 'BigInt' | 'Float' | 'Decimal' | 'DateTime' | 'Json' | 'Bytes' | 'Null' | 'Object' | 'Any' | 'Unsupported';
+
// attribute
Attribute:
(comments+=TRIPLE_SLASH_COMMENT)* 'attribute' name=(INTERNAL_ATTRIBUTE_NAME|MODEL_ATTRIBUTE_NAME|FIELD_ATTRIBUTE_NAME) '(' (params+=AttributeParam (',' params+=AttributeParam)*)? ')' (attributes+=InternalAttribute)*;
diff --git a/packages/misc/redwood/package.json b/packages/misc/redwood/package.json
index 19f34c238..f5d7bdfaa 100644
--- a/packages/misc/redwood/package.json
+++ b/packages/misc/redwood/package.json
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/redwood",
"displayName": "ZenStack RedwoodJS Integration",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "CLI and runtime for integrating ZenStack with RedwoodJS projects.",
"repository": {
"type": "git",
diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json
index 314f33919..75103e17f 100644
--- a/packages/plugins/openapi/package.json
+++ b/packages/plugins/openapi/package.json
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/openapi",
"displayName": "ZenStack Plugin and Runtime for OpenAPI",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "ZenStack plugin and runtime supporting OpenAPI",
"main": "index.js",
"repository": {
diff --git a/packages/plugins/openapi/src/rpc-generator.ts b/packages/plugins/openapi/src/rpc-generator.ts
index cb388aae2..7339ef788 100644
--- a/packages/plugins/openapi/src/rpc-generator.ts
+++ b/packages/plugins/openapi/src/rpc-generator.ts
@@ -1,22 +1,22 @@
// Inspired by: https://github.com/omar-dulaimi/prisma-trpc-generator
-import { analyzePolicies, PluginError, requireOption, resolvePath } from '@zenstackhq/sdk';
+import { PluginError, analyzePolicies, requireOption, resolvePath } from '@zenstackhq/sdk';
import { DataModel, isDataModel } from '@zenstackhq/sdk/ast';
import {
+ AggregateOperationSupport,
addMissingInputObjectTypesForAggregate,
addMissingInputObjectTypesForInclude,
addMissingInputObjectTypesForModelArgs,
addMissingInputObjectTypesForSelect,
- AggregateOperationSupport,
resolveAggregateOperationSupport,
} from '@zenstackhq/sdk/dmmf-helpers';
-import type { DMMF } from '@zenstackhq/sdk/prisma';
+import { supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import * as fs from 'fs';
import { lowerCaseFirst } from 'lower-case-first';
import type { OpenAPIV3_1 as OAPI } from 'openapi-types';
import * as path from 'path';
import invariant from 'tiny-invariant';
-import { match, P } from 'ts-pattern';
+import { P, match } from 'ts-pattern';
import { upperCaseFirst } from 'upper-case-first';
import YAML from 'yaml';
import { name } from '.';
@@ -166,7 +166,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase {
});
}
- if (ops['createMany']) {
+ if (ops['createMany'] && supportCreateMany(zmodel.$container)) {
definitions.push({
method: 'post',
operation: 'createMany',
@@ -704,7 +704,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase {
private generateEnumComponent(_enum: DMMF.SchemaEnum): OAPI.SchemaObject {
const schema: OAPI.SchemaObject = {
type: 'string',
- enum: _enum.values,
+ enum: _enum.values as string[],
};
return schema;
}
@@ -793,17 +793,14 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase {
return result;
}
- private setInputRequired(fields: { name: string; isRequired: boolean }[], result: OAPI.NonArraySchemaObject) {
+ private setInputRequired(fields: readonly DMMF.SchemaArg[], result: OAPI.NonArraySchemaObject) {
const required = fields.filter((f) => f.isRequired).map((f) => f.name);
if (required.length > 0) {
result.required = required;
}
}
- private setOutputRequired(
- fields: { name: string; isNullable?: boolean; outputType: DMMF.OutputTypeRef }[],
- result: OAPI.NonArraySchemaObject
- ) {
+ private setOutputRequired(fields: readonly DMMF.SchemaField[], result: OAPI.NonArraySchemaObject) {
const required = fields.filter((f) => f.isNullable !== true).map((f) => f.name);
if (required.length > 0) {
result.required = required;
diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json
index a36af725d..6a6992553 100644
--- a/packages/plugins/swr/package.json
+++ b/packages/plugins/swr/package.json
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/swr",
"displayName": "ZenStack plugin for generating SWR hooks",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "ZenStack plugin for generating SWR hooks",
"main": "index.js",
"repository": {
@@ -46,6 +46,7 @@
"lower-case-first": "^2.0.2",
"semver": "^7.5.2",
"ts-morph": "^16.0.0",
+ "ts-pattern": "^4.3.0",
"upper-case-first": "^2.0.2"
},
"devDependencies": {
diff --git a/packages/plugins/swr/src/generator.ts b/packages/plugins/swr/src/generator.ts
index 4ab3fb79e..65c80fc8c 100644
--- a/packages/plugins/swr/src/generator.ts
+++ b/packages/plugins/swr/src/generator.ts
@@ -1,5 +1,6 @@
import {
PluginOptions,
+ RUNTIME_PACKAGE,
createProject,
ensureEmptyDir,
generateModelMeta,
@@ -8,11 +9,12 @@ import {
resolvePath,
saveProject,
} from '@zenstackhq/sdk';
-import { DataModel, Model } from '@zenstackhq/sdk/ast';
-import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma';
+import { DataModel, DataModelFieldType, Model, isEnum } from '@zenstackhq/sdk/ast';
+import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import { paramCase } from 'change-case';
import path from 'path';
import type { OptionalKind, ParameterDeclarationStructure, Project, SourceFile } from 'ts-morph';
+import { P, match } from 'ts-pattern';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '.';
@@ -66,6 +68,7 @@ function generateModelHooks(
});
sf.addStatements([
`import { type GetNextArgs, type QueryOptions, type InfiniteQueryOptions, type MutationOptions, type PickEnumerable } from '@zenstackhq/swr/runtime';`,
+ `import type { PolicyCrudKind } from '${RUNTIME_PACKAGE}'`,
`import metadata from './__model_meta';`,
`import * as request from '@zenstackhq/swr/runtime';`,
]);
@@ -82,7 +85,7 @@ function generateModelHooks(
}
// createMany
- if (mapping.createMany) {
+ if (mapping.createMany && supportCreateMany(model.$container)) {
const argsType = `Prisma.${model.name}CreateManyArgs`;
mutationFuncs.push(generateMutation(sf, model, 'POST', 'createMany', argsType, true));
}
@@ -239,6 +242,11 @@ function generateModelHooks(
const returnType = `T extends { select: any; } ? T['select'] extends true ? number : Prisma.GetScalarType : number`;
generateQueryHook(sf, model, 'count', argsType, inputType, returnType);
}
+
+ // extra `check` hook for ZenStack's permission checker API
+ {
+ generateCheckHook(sf, model, prismaImport);
+ }
}
function makeOptimistic(returnType: string) {
@@ -337,3 +345,47 @@ function generateMutation(
return funcName;
}
+
+function generateCheckHook(sf: SourceFile, model: DataModel, prismaImport: string) {
+ const mapFilterType = (type: DataModelFieldType) => {
+ return match(type.type)
+ .with(P.union('Int', 'BigInt'), () => 'number')
+ .with('String', () => 'string')
+ .with('Boolean', () => 'boolean')
+ .otherwise(() => undefined);
+ };
+
+ const filterFields: Array<{ name: string; type: string }> = [];
+ const enumsToImport = new Set();
+
+ // collect filterable fields and enums to import
+ model.fields.forEach((f) => {
+ if (isEnum(f.type.reference?.ref)) {
+ enumsToImport.add(f.type.reference.$refText);
+ filterFields.push({ name: f.name, type: f.type.reference.$refText });
+ }
+
+ const mappedType = mapFilterType(f.type);
+ if (mappedType) {
+ filterFields.push({ name: f.name, type: mappedType });
+ }
+ });
+
+ if (enumsToImport.size > 0) {
+ // import enums
+ sf.addStatements(`import type { ${Array.from(enumsToImport).join(', ')} } from '${prismaImport}';`);
+ }
+
+ const whereType = `{ ${filterFields.map(({ name, type }) => `${name}?: ${type}`).join('; ')} }`;
+
+ const func = sf.addFunction({
+ name: `useCheck${model.name}`,
+ isExported: true,
+ parameters: [
+ { name: 'args', type: `{ operation: PolicyCrudKind; where?: ${whereType}; }` },
+ { name: 'options?', type: `QueryOptions` },
+ ],
+ });
+
+ func.addStatements(`return request.useModelQuery('${model.name}', 'check', args, options);`);
+}
diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json
index 6f0805a52..d663d12f1 100644
--- a/packages/plugins/tanstack-query/package.json
+++ b/packages/plugins/tanstack-query/package.json
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/tanstack-query",
"displayName": "ZenStack plugin for generating tanstack-query hooks",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "ZenStack plugin for generating tanstack-query hooks",
"main": "index.js",
"exports": {
diff --git a/packages/plugins/tanstack-query/src/generator.ts b/packages/plugins/tanstack-query/src/generator.ts
index 7666cb742..6a4ba53d9 100644
--- a/packages/plugins/tanstack-query/src/generator.ts
+++ b/packages/plugins/tanstack-query/src/generator.ts
@@ -1,6 +1,7 @@
import {
PluginError,
PluginOptions,
+ RUNTIME_PACKAGE,
createProject,
ensureEmptyDir,
generateModelMeta,
@@ -9,13 +10,13 @@ import {
resolvePath,
saveProject,
} from '@zenstackhq/sdk';
-import { DataModel, Model } from '@zenstackhq/sdk/ast';
-import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma';
+import { DataModel, DataModelFieldType, Model, isEnum } from '@zenstackhq/sdk/ast';
+import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import { paramCase } from 'change-case';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
import { Project, SourceFile, VariableDeclarationKind } from 'ts-morph';
-import { match } from 'ts-pattern';
+import { P, match } from 'ts-pattern';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '.';
@@ -261,6 +262,62 @@ function generateMutationHook(
func.addStatements('return mutation;');
}
+function generateCheckHook(
+ target: string,
+ version: TanStackVersion,
+ sf: SourceFile,
+ model: DataModel,
+ prismaImport: string
+) {
+ const mapFilterType = (type: DataModelFieldType) => {
+ return match(type.type)
+ .with(P.union('Int', 'BigInt'), () => 'number')
+ .with('String', () => 'string')
+ .with('Boolean', () => 'boolean')
+ .otherwise(() => undefined);
+ };
+
+ const filterFields: Array<{ name: string; type: string }> = [];
+ const enumsToImport = new Set();
+
+ // collect filterable fields and enums to import
+ model.fields.forEach((f) => {
+ if (isEnum(f.type.reference?.ref)) {
+ enumsToImport.add(f.type.reference.$refText);
+ filterFields.push({ name: f.name, type: f.type.reference.$refText });
+ }
+
+ const mappedType = mapFilterType(f.type);
+ if (mappedType) {
+ filterFields.push({ name: f.name, type: mappedType });
+ }
+ });
+
+ if (enumsToImport.size > 0) {
+ // import enums
+ sf.addStatements(`import type { ${Array.from(enumsToImport).join(', ')} } from '${prismaImport}';`);
+ }
+
+ const whereType = `{ ${filterFields.map(({ name, type }) => `${name}?: ${type}`).join('; ')} }`;
+
+ const func = sf.addFunction({
+ name: `useCheck${model.name}`,
+ isExported: true,
+ typeParameters: ['TError = DefaultError'],
+ parameters: [
+ { name: 'args', type: `{ operation: PolicyCrudKind; where?: ${whereType}; }` },
+ { name: 'options?', type: makeQueryOptions(target, 'boolean', 'boolean', false, false, version) },
+ ],
+ });
+
+ func.addStatements([
+ makeGetContext(target),
+ `return useModelQuery('${model.name}', \`\${endpoint}/${lowerCaseFirst(
+ model.name
+ )}/check\`, args, options, fetch);`,
+ ]);
+}
+
function generateModelHooks(
target: TargetFramework,
version: TanStackVersion,
@@ -291,7 +348,7 @@ function generateModelHooks(
}
// createMany
- if (mapping.createMany) {
+ if (mapping.createMany && supportCreateMany(model.$container)) {
generateMutationHook(target, sf, model.name, 'createMany', 'post', false, 'Prisma.BatchPayload');
}
@@ -494,6 +551,11 @@ function generateModelHooks(
`TArgs extends { select: any; } ? TArgs['select'] extends true ? number : Prisma.GetScalarType : number`
);
}
+
+ {
+ // extra `check` hook for ZenStack's permission checker API
+ generateCheckHook(target, version, sf, model, prismaImport);
+ }
}
function generateIndex(
@@ -538,6 +600,7 @@ function makeBaseImports(target: TargetFramework, version: TanStackVersion) {
const shared = [
`import { useModelQuery, useInfiniteModelQuery, useModelMutation } from '${runtimeImportBase}/${target}';`,
`import type { PickEnumerable, CheckSelect, QueryError, ExtraQueryOptions, ExtraMutationOptions } from '${runtimeImportBase}';`,
+ `import type { PolicyCrudKind } from '${RUNTIME_PACKAGE}'`,
`import metadata from './__model_meta';`,
`type DefaultError = QueryError;`,
];
diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json
index 78bc59131..08538d870 100644
--- a/packages/plugins/trpc/package.json
+++ b/packages/plugins/trpc/package.json
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/trpc",
"displayName": "ZenStack plugin for tRPC",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "ZenStack plugin for tRPC",
"main": "index.js",
"repository": {
diff --git a/packages/plugins/trpc/src/generator.ts b/packages/plugins/trpc/src/generator.ts
index cf57a1baf..0180cc649 100644
--- a/packages/plugins/trpc/src/generator.ts
+++ b/packages/plugins/trpc/src/generator.ts
@@ -10,7 +10,7 @@ import {
type PluginOptions,
} from '@zenstackhq/sdk';
import { Model } from '@zenstackhq/sdk/ast';
-import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma';
+import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import fs from 'fs';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
@@ -79,11 +79,11 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
function createAppRouter(
outDir: string,
- modelOperations: DMMF.ModelMapping[],
+ modelOperations: readonly DMMF.ModelMapping[],
hiddenModels: string[],
generateModelActions: string[] | undefined,
generateClientHelpers: string[] | undefined,
- _zmodel: Model,
+ zmodel: Model,
zodSchemasImport: string,
options: PluginOptions
) {
@@ -99,19 +99,21 @@ function createAppRouter(
{
namedImports: [
'unsetMarker',
- 'type AnyRouter',
- 'type AnyRootConfig',
- 'type CreateRouterInner',
- 'type Procedure',
- 'type ProcedureBuilder',
- 'type ProcedureParams',
- 'type ProcedureRouterRecord',
- 'type ProcedureType',
+ 'AnyRouter',
+ 'AnyRootConfig',
+ 'CreateRouterInner',
+ 'Procedure',
+ 'ProcedureBuilder',
+ 'ProcedureParams',
+ 'ProcedureRouterRecord',
+ 'ProcedureType',
],
+ isTypeOnly: true,
moduleSpecifier: '@trpc/server',
},
{
- namedImports: ['type PrismaClient'],
+ namedImports: ['PrismaClient'],
+ isTypeOnly: true,
moduleSpecifier: prismaImport,
},
]);
@@ -169,7 +171,8 @@ function createAppRouter(
generateModelActions,
generateClientHelpers,
zodSchemasImport,
- options
+ options,
+ zmodel
);
appRouter.addImportDeclaration({
@@ -239,7 +242,8 @@ function generateModelCreateRouter(
generateModelActions: string[] | undefined,
generateClientHelpers: string[] | undefined,
zodSchemasImport: string,
- options: PluginOptions
+ options: PluginOptions,
+ zmodel: Model
) {
const modelRouter = project.createSourceFile(path.resolve(outputDir, 'routers', `${model}.router.ts`), undefined, {
overwrite: true,
@@ -296,6 +300,10 @@ function generateModelCreateRouter(
inputType &&
(!generateModelActions || generateModelActions.includes(generateOpName))
) {
+ if (generateOpName === 'createMany' && !supportCreateMany(zmodel)) {
+ continue;
+ }
+
generateProcedure(funcWriter, generateOpName, upperCaseFirst(inputType), model, baseOpType);
if (routerTypingStructure) {
diff --git a/packages/plugins/trpc/src/helpers.ts b/packages/plugins/trpc/src/helpers.ts
index 947b96b98..c165288f7 100644
--- a/packages/plugins/trpc/src/helpers.ts
+++ b/packages/plugins/trpc/src/helpers.ts
@@ -237,7 +237,12 @@ export function generateRouterTypingImports(sourceFile: SourceFile, options: Plu
// eslint-disable-next-line @typescript-eslint/no-unused-vars
export function generateRouterSchemaImport(sourceFile: SourceFile, zodSchemasImport: string) {
- sourceFile.addStatements(`import * as $Schema from '${zodSchemasImport}/input';`);
+ sourceFile.addStatements([
+ `import * as _Schema from '${zodSchemasImport}/input';`,
+ // temporary solution for dealing with the issue that Node.js wraps named exports under a `default`
+ // key when importing from a CJS module
+ `const $Schema: typeof _Schema = (_Schema as any).default ?? _Schema;`,
+ ]);
}
export function generateHelperImport(sourceFile: SourceFile) {
@@ -327,7 +332,7 @@ export const getProcedureTypeByOpName = (opName: string) => {
return procType;
};
-export function resolveModelsComments(models: DMMF.Model[], hiddenModels: string[]) {
+export function resolveModelsComments(models: readonly DMMF.Model[], hiddenModels: string[]) {
const modelAttributeRegex = /(@@Gen\.)+([A-z])+(\()+(.+)+(\))+/;
const attributeNameRegex = /(?:\.)+([A-Za-z])+(?:\()+/;
const attributeArgsRegex = /(?:\()+([A-Za-z])+:+(.+)+(?:\))+/;
diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts
index fbc73cf06..3edfe9c94 100644
--- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts
+++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts
@@ -1,6 +1,7 @@
/* eslint-disable */
import { type RouterFactory, type ProcBuilder, type BaseConfig, db } from ".";
-import * as $Schema from '@zenstackhq/runtime/zod/input';
+import * as _Schema from '@zenstackhq/runtime/zod/input';
+const $Schema: typeof _Schema = (_Schema as any).default ?? _Schema;
import { checkRead, checkMutate } from '../helper';
import type { Prisma } from '@prisma/client';
import type { UseTRPCMutationOptions, UseTRPCMutationResult, UseTRPCQueryOptions, UseTRPCQueryResult, UseTRPCInfiniteQueryOptions, UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared';
@@ -12,6 +13,8 @@ export default function createRouter(router: RouterFa
aggregate: procedure.input($Schema.PostInputSchema.aggregate).query(({ ctx, input }) => checkRead(db(ctx).post.aggregate(input as any))),
+ createMany: procedure.input($Schema.PostInputSchema.createMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.createMany(input as any))),
+
create: procedure.input($Schema.PostInputSchema.create).mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.create(input as any))),
deleteMany: procedure.input($Schema.PostInputSchema.deleteMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.deleteMany(input as any))),
@@ -60,6 +63,20 @@ export interface ClientType
>;
+ };
+ createMany: {
+
+ useMutation: (opts?: UseTRPCMutationOptions<
+ Prisma.PostCreateManyArgs,
+ TRPCClientErrorLike,
+ Prisma.BatchPayload,
+ Context
+ >,) =>
+ Omit, Prisma.SelectSubset, Context>, 'mutateAsync'> & {
+ mutateAsync:
+ (variables: T, opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>) => Promise
+ };
+
};
create: {
diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts
index c4bdb89de..366ccfcb0 100644
--- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts
+++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts
@@ -1,6 +1,7 @@
/* eslint-disable */
import { type RouterFactory, type ProcBuilder, type BaseConfig, db } from ".";
-import * as $Schema from '@zenstackhq/runtime/zod/input';
+import * as _Schema from '@zenstackhq/runtime/zod/input';
+const $Schema: typeof _Schema = (_Schema as any).default ?? _Schema;
import { checkRead, checkMutate } from '../helper';
import type { Prisma } from '@prisma/client';
import type { UseTRPCMutationOptions, UseTRPCMutationResult, UseTRPCQueryOptions, UseTRPCQueryResult, UseTRPCInfiniteQueryOptions, UseTRPCInfiniteQueryResult } from '@trpc/react-query/shared';
@@ -12,6 +13,8 @@ export default function createRouter(router: RouterFa
aggregate: procedure.input($Schema.UserInputSchema.aggregate).query(({ ctx, input }) => checkRead(db(ctx).user.aggregate(input as any))),
+ createMany: procedure.input($Schema.UserInputSchema.createMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.createMany(input as any))),
+
create: procedure.input($Schema.UserInputSchema.create).mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.create(input as any))),
deleteMany: procedure.input($Schema.UserInputSchema.deleteMany).mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.deleteMany(input as any))),
@@ -60,6 +63,20 @@ export interface ClientType
>;
+ };
+ createMany: {
+
+ useMutation: (opts?: UseTRPCMutationOptions<
+ Prisma.UserCreateManyArgs,
+ TRPCClientErrorLike,
+ Prisma.BatchPayload,
+ Context
+ >,) =>
+ Omit, Prisma.SelectSubset, Context>, 'mutateAsync'> & {
+ mutateAsync:
+ (variables: T, opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>) => Promise
+ };
+
};
create: {
diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts
index f474aa5b5..523f4b645 100644
--- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts
+++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/index.ts
@@ -1,6 +1,6 @@
/* eslint-disable */
-import { unsetMarker, type AnyRouter, type AnyRootConfig, type CreateRouterInner, type Procedure, type ProcedureBuilder, type ProcedureParams, type ProcedureRouterRecord, type ProcedureType } from "@trpc/server";
-import { type PrismaClient } from "@prisma/client";
+import type { unsetMarker, AnyRouter, AnyRootConfig, CreateRouterInner, Procedure, ProcedureBuilder, ProcedureParams, ProcedureRouterRecord, ProcedureType } from "@trpc/server";
+import type { PrismaClient } from "@prisma/client";
import createUserRouter from "./User.router";
import createPostRouter from "./Post.router";
import { ClientType as UserClientType } from "./User.router";
diff --git a/packages/runtime/package.json b/packages/runtime/package.json
index 10da4f5c9..94d8a587b 100644
--- a/packages/runtime/package.json
+++ b/packages/runtime/package.json
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/runtime",
"displayName": "ZenStack Runtime Library",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "Runtime of ZenStack for both client-side and server-side environments.",
"repository": {
"type": "git",
@@ -29,6 +29,10 @@
"types": "./enhancements/index.d.ts",
"default": "./enhancements/index.js"
},
+ "./constraint-solver": {
+ "types": "./constraint-solver.d.ts",
+ "default": "./constraint-solver.js"
+ },
"./zod": {
"types": "./zod/index.d.ts",
"default": "./zod/index.js"
@@ -79,12 +83,14 @@
"decimal.js": "^10.4.2",
"deepcopy": "^2.1.0",
"deepmerge": "^4.3.1",
+ "logic-solver": "^2.0.1",
"lower-case-first": "^2.0.2",
"pluralize": "^8.0.0",
"safe-json-stringify": "^1.2.0",
"semver": "^7.5.2",
"superjson": "^1.11.0",
"tiny-invariant": "^1.3.1",
+ "ts-pattern": "^4.3.0",
"tslib": "^2.4.1",
"upper-case-first": "^2.0.2",
"uuid": "^9.0.0",
diff --git a/packages/runtime/src/enhance.d.ts b/packages/runtime/src/enhance.d.ts
index 48e877878..38a519830 100644
--- a/packages/runtime/src/enhance.d.ts
+++ b/packages/runtime/src/enhance.d.ts
@@ -1,2 +1,2 @@
// @ts-expect-error stub for re-exporting generated code
-export { enhance } from '.zenstack/enhance';
+export { auth, enhance } from '.zenstack/enhance';
diff --git a/packages/runtime/src/enhancements/delegate.ts b/packages/runtime/src/enhancements/delegate.ts
index e2fdc65f2..8e4c3569a 100644
--- a/packages/runtime/src/enhancements/delegate.ts
+++ b/packages/runtime/src/enhancements/delegate.ts
@@ -99,6 +99,12 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}
Object.entries(where).forEach(([field, value]) => {
+ if (['AND', 'OR', 'NOT'].includes(field)) {
+ // recurse into logical group
+ enumerate(value).forEach((item) => this.injectWhereHierarchy(model, item));
+ return;
+ }
+
const fieldInfo = resolveField(this.options.modelMeta, model, field);
if (!fieldInfo?.inheritedFrom) {
return;
diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts
new file mode 100644
index 000000000..c87a528e7
--- /dev/null
+++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts
@@ -0,0 +1,219 @@
+import Logic from 'logic-solver';
+import { match } from 'ts-pattern';
+import type {
+ CheckerConstraint,
+ ComparisonConstraint,
+ ComparisonTerm,
+ LogicalConstraint,
+ ValueConstraint,
+ VariableConstraint,
+} from '../types';
+
+/**
+ * A boolean constraint solver based on `logic-solver`. Only boolean and integer types are supported.
+ */
+export class ConstraintSolver {
+ // a table for internalizing string literals
+ private stringTable: string[] = [];
+
+ // a map for storing variable names and their corresponding formulas
+ private variables: Map = new Map();
+
+ /**
+ * Check the satisfiability of the given constraint.
+ */
+ checkSat(constraint: CheckerConstraint): boolean {
+ // reset state
+ this.stringTable = [];
+ this.variables = new Map();
+
+ // convert the constraint to a "logic-solver" formula
+ const formula = this.buildFormula(constraint);
+
+ // solve the formula
+ const solver = new Logic.Solver();
+ solver.require(formula);
+
+ // DEBUG:
+ // const solution = solver.solve();
+ // if (solution) {
+ // console.log('Solution:');
+ // this.variables.forEach((v, k) => console.log(`\t${k}=${solution?.evaluate(v)}`));
+ // } else {
+ // console.log('No solution');
+ // }
+
+ return !!solver.solve();
+ }
+
+ private buildFormula(constraint: CheckerConstraint): Logic.Formula {
+ return match(constraint)
+ .when(
+ (c): c is ValueConstraint => c.kind === 'value',
+ (c) => this.buildValueFormula(c)
+ )
+ .when(
+ (c): c is VariableConstraint => c.kind === 'variable',
+ (c) => this.buildVariableFormula(c)
+ )
+ .when(
+ (c): c is ComparisonConstraint => ['eq', 'ne', 'gt', 'gte', 'lt', 'lte'].includes(c.kind),
+ (c) => this.buildComparisonFormula(c)
+ )
+ .when(
+ (c): c is LogicalConstraint => ['and', 'or', 'not'].includes(c.kind),
+ (c) => this.buildLogicalFormula(c)
+ )
+ .otherwise(() => {
+ throw new Error(`Unsupported constraint format: ${JSON.stringify(constraint)}`);
+ });
+ }
+
+ private buildLogicalFormula(constraint: LogicalConstraint) {
+ return match(constraint.kind)
+ .with('and', () => this.buildAndFormula(constraint))
+ .with('or', () => this.buildOrFormula(constraint))
+ .with('not', () => this.buildNotFormula(constraint))
+ .exhaustive();
+ }
+
+ private buildAndFormula(constraint: LogicalConstraint): Logic.Formula {
+ if (constraint.children.some((c) => this.isFalse(c))) {
+ // short-circuit
+ return Logic.FALSE;
+ }
+ return Logic.and(...constraint.children.map((c) => this.buildFormula(c)));
+ }
+
+ private buildOrFormula(constraint: LogicalConstraint): Logic.Formula {
+ if (constraint.children.some((c) => this.isTrue(c))) {
+ // short-circuit
+ return Logic.TRUE;
+ }
+ return Logic.or(...constraint.children.map((c) => this.buildFormula(c)));
+ }
+
+ private buildNotFormula(constraint: LogicalConstraint) {
+ if (constraint.children.length !== 1) {
+ throw new Error('"not" constraint must have exactly one child');
+ }
+ return Logic.not(this.buildFormula(constraint.children[0]));
+ }
+
+ private isTrue(constraint: CheckerConstraint): unknown {
+ return constraint.kind === 'value' && constraint.value === true;
+ }
+
+ private isFalse(constraint: CheckerConstraint): unknown {
+ return constraint.kind === 'value' && constraint.value === false;
+ }
+
+ private buildComparisonFormula(constraint: ComparisonConstraint) {
+ if (constraint.left.kind === 'value' && constraint.right.kind === 'value') {
+ // constant comparison
+ const left: ValueConstraint = constraint.left;
+ const right: ValueConstraint = constraint.right;
+ return match(constraint.kind)
+ .with('eq', () => (left.value === right.value ? Logic.TRUE : Logic.FALSE))
+ .with('ne', () => (left.value !== right.value ? Logic.TRUE : Logic.FALSE))
+ .with('gt', () => (left.value > right.value ? Logic.TRUE : Logic.FALSE))
+ .with('gte', () => (left.value >= right.value ? Logic.TRUE : Logic.FALSE))
+ .with('lt', () => (left.value < right.value ? Logic.TRUE : Logic.FALSE))
+ .with('lte', () => (left.value <= right.value ? Logic.TRUE : Logic.FALSE))
+ .exhaustive();
+ }
+
+ return match(constraint.kind)
+ .with('eq', () => this.transformEquality(constraint.left, constraint.right))
+ .with('ne', () => this.transformInequality(constraint.left, constraint.right))
+ .with('gt', () =>
+ this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThan(l, r))
+ )
+ .with('gte', () =>
+ this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.greaterThanOrEqual(l, r))
+ )
+ .with('lt', () =>
+ this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThan(l, r))
+ )
+ .with('lte', () =>
+ this.transformComparison(constraint.left, constraint.right, (l, r) => Logic.lessThanOrEqual(l, r))
+ )
+ .exhaustive();
+ }
+
+ private buildVariableFormula(constraint: VariableConstraint) {
+ return (
+ match(constraint.type)
+ .with('boolean', () => this.booleanVariable(constraint.name))
+ .with('number', () => this.intVariable(constraint.name))
+ // strings are internalized and represented by their indices
+ .with('string', () => this.intVariable(constraint.name))
+ .exhaustive()
+ );
+ }
+
+ private buildValueFormula(constraint: ValueConstraint) {
+ return match(constraint.value)
+ .when(
+ (v): v is boolean => typeof v === 'boolean',
+ (v) => (v === true ? Logic.TRUE : Logic.FALSE)
+ )
+ .when(
+ (v): v is number => typeof v === 'number',
+ (v) => Logic.constantBits(v)
+ )
+ .when(
+ (v): v is string => typeof v === 'string',
+ (v) => {
+ // internalize the string and use its index as formula representation
+ const index = this.stringTable.indexOf(v);
+ if (index === -1) {
+ this.stringTable.push(v);
+ return Logic.constantBits(this.stringTable.length - 1);
+ } else {
+ return Logic.constantBits(index);
+ }
+ }
+ )
+ .exhaustive();
+ }
+
+ private booleanVariable(name: string) {
+ this.variables.set(name, name);
+ return name;
+ }
+
+ private intVariable(name: string) {
+ const r = Logic.variableBits(name, 32);
+ this.variables.set(name, r);
+ return r;
+ }
+
+ private transformEquality(left: ComparisonTerm, right: ComparisonTerm) {
+ if (left.type !== right.type) {
+ throw new Error(`Type mismatch in equality constraint: ${JSON.stringify(left)}, ${JSON.stringify(right)}`);
+ }
+
+ const leftFormula = this.buildFormula(left);
+ const rightFormula = this.buildFormula(right);
+ if (left.type === 'boolean' && right.type === 'boolean') {
+ // logical equivalence
+ return Logic.equiv(leftFormula, rightFormula);
+ } else {
+ // integer equality
+ return Logic.equalBits(leftFormula, rightFormula);
+ }
+ }
+
+ private transformInequality(left: ComparisonTerm, right: ComparisonTerm) {
+ return Logic.not(this.transformEquality(left, right));
+ }
+
+ private transformComparison(
+ left: ComparisonTerm,
+ right: ComparisonTerm,
+ func: (left: Logic.Formula, right: Logic.Formula) => Logic.Formula
+ ) {
+ return func(this.buildFormula(left), this.buildFormula(right));
+ }
+}
diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts
index d6d893d4e..997e727d5 100644
--- a/packages/runtime/src/enhancements/policy/handler.ts
+++ b/packages/runtime/src/enhancements/policy/handler.ts
@@ -2,6 +2,7 @@
import { lowerCaseFirst } from 'lower-case-first';
import invariant from 'tiny-invariant';
+import { P, match } from 'ts-pattern';
import { upperCaseFirst } from 'upper-case-first';
import { fromZodError } from 'zod-validation-error';
import { CrudFailureReason } from '../../constants';
@@ -16,13 +17,15 @@ import {
type FieldInfo,
type ModelMeta,
} from '../../cross';
-import { PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types';
+import { PolicyCrudKind, PolicyOperationKind, type CrudContract, type DbClientContract } from '../../types';
import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement';
import { Logger } from '../logger';
import { createDeferredPromise, createFluentPromise } from '../promise';
import { PrismaProxyHandler } from '../proxy';
import { QueryUtils } from '../query-utils';
+import type { CheckerConstraint } from '../types';
import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
+import { ConstraintSolver } from './constraint-solver';
import { PolicyUtil } from './policy-utils';
// a record for post-write policy check
@@ -35,6 +38,12 @@ type PostWriteCheckRecord = {
type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFirstOrThrow' | 'findMany';
+// input arg type for `check` API
+type PermissionCheckArgs = {
+ operation: PolicyCrudKind;
+ where?: Record;
+};
+
/**
* Prisma proxy handler for injecting access policy check.
*/
@@ -1436,6 +1445,115 @@ export class PolicyProxyHandler implements Pr
//#endregion
+ //#region Check
+
+ /**
+ * Checks if the given operation is possibly allowed by the policy, without querying the database.
+ * @param operation The CRUD operation.
+ * @param fieldValues Extra field value filters to be combined with the policy constraints.
+ */
+ async check(args: PermissionCheckArgs): Promise {
+ return createDeferredPromise(() => this.doCheck(args));
+ }
+
+ private async doCheck(args: PermissionCheckArgs) {
+ if (!['create', 'read', 'update', 'delete'].includes(args.operation)) {
+ throw prismaClientValidationError(this.prisma, this.prismaModule, `Invalid "operation" ${args.operation}`);
+ }
+
+ let constraint = this.policyUtils.getCheckerConstraint(this.model, args.operation);
+ if (typeof constraint === 'boolean') {
+ return constraint;
+ }
+
+ if (args.where) {
+ // combine runtime filters with generated constraints
+
+ const extraConstraints: CheckerConstraint[] = [];
+ for (const [field, value] of Object.entries(args.where)) {
+ if (value === undefined) {
+ continue;
+ }
+
+ if (value === null) {
+ throw prismaClientValidationError(
+ this.prisma,
+ this.prismaModule,
+ `Using "null" as filter value is not supported yet`
+ );
+ }
+
+ const fieldInfo = requireField(this.modelMeta, this.model, field);
+
+ // relation and array fields are not supported
+ if (fieldInfo.isDataModel || fieldInfo.isArray) {
+ throw prismaClientValidationError(
+ this.prisma,
+ this.prismaModule,
+ `Providing filter for field "${field}" is not supported. Only scalar fields are allowed.`
+ );
+ }
+
+ // map field type to constraint type
+ const fieldType = match(fieldInfo.type)
+ .with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => 'number')
+ .with('String', () => 'string')
+ .with('Boolean', () => 'boolean')
+ .otherwise(() => {
+ throw prismaClientValidationError(
+ this.prisma,
+ this.prismaModule,
+ `Providing filter for field "${field}" is not supported. Only number, string, and boolean fields are allowed.`
+ );
+ });
+
+ // check value type
+ const valueType = typeof value;
+ if (valueType !== 'number' && valueType !== 'string' && valueType !== 'boolean') {
+ throw prismaClientValidationError(
+ this.prisma,
+ this.prismaModule,
+ `Invalid value type for field "${field}". Only number, string or boolean is allowed.`
+ );
+ }
+
+ if (fieldType !== valueType) {
+ throw prismaClientValidationError(
+ this.prisma,
+ this.prismaModule,
+ `Invalid value type for field "${field}". Expected "${fieldType}".`
+ );
+ }
+
+ // check number validity
+ if (typeof value === 'number' && (!Number.isInteger(value) || value < 0)) {
+ throw prismaClientValidationError(
+ this.prisma,
+ this.prismaModule,
+ `Invalid value for field "${field}". Only non-negative integers are allowed.`
+ );
+ }
+
+ // build a constraint
+ extraConstraints.push({
+ kind: 'eq',
+ left: { kind: 'variable', name: field, type: fieldType },
+ right: { kind: 'value', value, type: fieldType },
+ });
+ }
+
+ if (extraConstraints.length > 0) {
+ // combine the constraints
+ constraint = { kind: 'and', children: [constraint, ...extraConstraints] };
+ }
+ }
+
+ // check satisfiability
+ return new ConstraintSolver().checkSat(constraint);
+ }
+
+ //#endregion
+
//#region Utils
private get shouldLogQuery() {
diff --git a/packages/runtime/src/enhancements/policy/logic-solver.d.ts b/packages/runtime/src/enhancements/policy/logic-solver.d.ts
new file mode 100644
index 000000000..d10e688f6
--- /dev/null
+++ b/packages/runtime/src/enhancements/policy/logic-solver.d.ts
@@ -0,0 +1,109 @@
+/**
+ * Type definitions for the `logic-solver` npm package.
+ */
+declare module 'logic-solver' {
+ /**
+ * A boolean formula.
+ */
+ interface Formula {}
+
+ /**
+ * The `TRUE` formula.
+ */
+ const TRUE: Formula;
+
+ /**
+ * The `FALSE` formula.
+ */
+ const FALSE: Formula;
+
+ /**
+ * Boolean equivalence.
+ */
+ export function equiv(operand1: Formula, operand2: Formula): Formula;
+
+ /**
+ * Bits equality.
+ */
+ export function equalBits(bits1: Formula, bits2: Formula): Formula;
+
+ /**
+ * Bits greater-than.
+ */
+ export function greaterThan(bits1: Formula, bits2: Formula): Formula;
+
+ /**
+ * Bits greater-than-or-equal.
+ */
+ export function greaterThanOrEqual(bits1: Formula, bits2: Formula): Formula;
+
+ /**
+ * Bits less-than.
+ */
+ export function lessThan(bits1: Formula, bits2: Formula): Formula;
+
+ /**
+ * Bits less-than-or-equal.
+ */
+ export function lessThanOrEqual(bits1: Formula, bits2: Formula): Formula;
+
+ /**
+ * Logical AND.
+ */
+ export function and(...args: Formula[]): Formula;
+
+ /**
+ * Logical OR.
+ */
+ export function or(...args: Formula[]): Formula;
+
+ /**
+ * Logical NOT.
+ */
+ export function not(arg: Formula): Formula;
+
+ /**
+ * Creates a bits variable with the given name and bit length.
+ */
+ export function variableBits(baseName: string, N: number): Formula;
+
+ /**
+ * Creates a constant bits formula from the given whole number.
+ */
+ export function constantBits(wholeNumber: number): Formula;
+
+ /**
+ * A solution to a constraint.
+ */
+ interface Solution {
+ /**
+ * Returns a map of variable assignments.
+ */
+ getMap(): object;
+
+ /**
+ * Evaluates the given formula against the solution.
+ */
+ evaluate(formula: Formula): unknown;
+ }
+
+ /**
+ * A constraint solver.
+ */
+ class Solver {
+ /**
+ * Adds constraints to the solver.
+ */
+ require(...args: Formula[]): void;
+
+ /**
+ * Adds negated constraints from the solver.
+ */
+ forbid(...args: Formula[]): void;
+
+ /**
+ * Solves the constraints.
+ */
+ solve(): Solution;
+ }
+}
diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts
index bcb946877..b288063c0 100644
--- a/packages/runtime/src/enhancements/policy/policy-utils.ts
+++ b/packages/runtime/src/enhancements/policy/policy-utils.ts
@@ -17,12 +17,12 @@ import {
PrismaErrorCode,
} from '../../constants';
import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross';
-import { AuthUser, CrudContract, DbClientContract, PolicyOperationKind } from '../../types';
+import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types';
import { getVersion } from '../../version';
import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement';
import { Logger } from '../logger';
import { QueryUtils } from '../query-utils';
-import type { InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types';
+import type { CheckerFunc, InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types';
import { formatObject, prismaClientKnownRequestError } from '../utils';
/**
@@ -228,7 +228,7 @@ export class PolicyUtil extends QueryUtils {
//#endregion
- //# Auth guard
+ //#region Auth guard
private readonly FULLY_OPEN_AUTH_GUARD = {
create: true,
@@ -267,7 +267,7 @@ export class PolicyUtil extends QueryUtils {
}
if (!provider) {
- throw this.unknownError(`zenstack: unable to load authorization guard for ${model}`);
+ throw this.unknownError(`unable to load authorization guard for ${model}`);
}
const r = provider({ user: this.user, preValue }, db);
return this.reduce(r);
@@ -561,6 +561,46 @@ export class PolicyUtil extends QueryUtils {
return true;
}
+ //#endregion
+
+ //#region Checker
+
+ /**
+ * Gets checker constraints for the given model and operation.
+ */
+ getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean {
+ const checker = this.getModelChecker(model);
+ const provider = checker[operation];
+ if (typeof provider === 'boolean') {
+ return provider;
+ }
+
+ if (typeof provider !== 'function') {
+ throw this.unknownError(`invalid ${operation} checker function for ${model}`);
+ }
+
+ // call checker function
+ return provider({ user: this.user });
+ }
+
+ private getModelChecker(model: string) {
+ if (this.options.kinds && !this.options.kinds.includes('policy')) {
+ // policy enhancement not enabled, return a constant true checker
+ return { create: true, read: true, update: true, delete: true };
+ } else {
+ const result = this.options.policy.checker?.[lowerCaseFirst(model)];
+ if (!result) {
+ // checker generation not enabled, return constant false checker
+ throw new Error(
+ `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.`
+ );
+ }
+ return result;
+ }
+ }
+
+ //#endregion
+
/**
* Gets unique constraints for the given model.
*/
@@ -609,6 +649,10 @@ export class PolicyUtil extends QueryUtils {
const hoistedConditions: any[] = [];
for (const field of getModelFields(injectTarget)) {
+ if (injectTarget[field] === false) {
+ continue;
+ }
+
const fieldInfo = resolveField(this.modelMeta, model, field);
if (!fieldInfo || !fieldInfo.isDataModel) {
// only care about relation fields
@@ -934,6 +978,11 @@ export class PolicyUtil extends QueryUtils {
}
private doInjectReadCheckSelect(model: string, args: any, input: any) {
+ // omit should be ignored to avoid interfering with field selection
+ if (args.omit) {
+ delete args.omit;
+ }
+
if (!input?.select) {
return;
}
@@ -1130,6 +1179,12 @@ export class PolicyUtil extends QueryUtils {
continue;
}
+ if (queryArgs?.omit?.[field] === true) {
+ // respect `{ omit: { [field]: true } }`
+ delete entityData[field];
+ continue;
+ }
+
if (hasFieldLevelPolicy) {
// 1. remove fields selected for checking field-level policies but not selected by the original query args
// 2. evaluate field-level policies and remove fields that are not readable
diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts
index 9fecc375e..89d5ce9f6 100644
--- a/packages/runtime/src/enhancements/types.ts
+++ b/packages/runtime/src/enhancements/types.ts
@@ -9,7 +9,7 @@ import {
HAS_FIELD_LEVEL_POLICY_FLAG,
PRE_UPDATE_VALUE_SELECTOR,
} from '../constants';
-import type { CrudContract, PolicyOperationKind, QueryContext } from '../types';
+import type { CheckerContext, CrudContract, PolicyCrudKind, PolicyOperationKind, QueryContext } from '../types';
/**
* Common options for PrismaClient enhancements
@@ -33,6 +33,57 @@ export interface CommonEnhancementOptions {
*/
export type PolicyFunc = (context: QueryContext, db: CrudContract) => object;
+/**
+ * Function for checking if an operation is possibly allowed.
+ */
+export type CheckerFunc = (context: CheckerContext) => CheckerConstraint;
+
+/**
+ * Supported checker constraint checking value types.
+ */
+export type ConstraintValueTypes = 'boolean' | 'number' | 'string';
+
+/**
+ * Free variable constraint
+ */
+export type VariableConstraint = { kind: 'variable'; name: string; type: ConstraintValueTypes };
+
+/**
+ * Constant value constraint
+ */
+export type ValueConstraint = {
+ kind: 'value';
+ value: number | boolean | string;
+ type: ConstraintValueTypes;
+};
+
+/**
+ * Terms for comparison constraints
+ */
+export type ComparisonTerm = VariableConstraint | ValueConstraint;
+
+/**
+ * Comparison constraint
+ */
+export type ComparisonConstraint = {
+ kind: 'eq' | 'ne' | 'gt' | 'gte' | 'lt' | 'lte';
+ left: ComparisonTerm;
+ right: ComparisonTerm;
+};
+
+/**
+ * Logical constraint
+ */
+export type LogicalConstraint = {
+ kind: 'and' | 'or' | 'not';
+ children: CheckerConstraint[];
+};
+
+/**
+ * Operation allowability checking constraint
+ */
+export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint;
+
/**
* Function for getting policy guard with a given context
*/
@@ -71,6 +122,8 @@ export type PolicyDef = {
}
>;
+ checker?: Record>;
+
// tracks which models have data validation rules
validation: Record;
diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts
index 4bcab85a1..4c32480ba 100644
--- a/packages/runtime/src/types.ts
+++ b/packages/runtime/src/types.ts
@@ -22,6 +22,7 @@ export interface DbOperations {
groupBy(args: unknown): Promise;
count(args?: unknown): Promise;
subscribe(args?: unknown): Promise;
+ check(args: unknown): Promise;
fields: Record;
}
@@ -30,10 +31,12 @@ export interface DbOperations {
*/
export type PolicyKind = 'allow' | 'deny';
+export type PolicyCrudKind = 'read' | 'create' | 'update' | 'delete';
+
/**
* Kinds of operations controlled by access policies
*/
-export type PolicyOperationKind = 'create' | 'update' | 'postUpdate' | 'read' | 'delete';
+export type PolicyOperationKind = PolicyCrudKind | 'postUpdate';
/**
* Current login user info
@@ -56,6 +59,21 @@ export type QueryContext = {
preValue?: any;
};
+/**
+ * Context for checking operation allowability.
+ */
+export type CheckerContext = {
+ /**
+ * Current user
+ */
+ user?: AuthUser;
+
+ /**
+ * Extra field value filters.
+ */
+ fieldValues?: Record;
+};
+
/**
* Prisma contract for CRUD operations.
*/
diff --git a/packages/schema/package.json b/packages/schema/package.json
index 07e1ba1df..55515cbf5 100644
--- a/packages/schema/package.json
+++ b/packages/schema/package.json
@@ -3,7 +3,7 @@
"publisher": "zenstack",
"displayName": "ZenStack Language Tools",
"description": "Build scalable web apps with minimum code by defining authorization and validation rules inside the data schema that closer to the database",
- "version": "2.0.3",
+ "version": "2.1.0",
"author": {
"name": "ZenStack Team"
},
diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts
index e9db60fb6..18e2c6be3 100644
--- a/packages/schema/src/cli/cli-util.ts
+++ b/packages/schema/src/cli/cli-util.ts
@@ -12,7 +12,7 @@ import { URI } from 'vscode-uri';
import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../language-server/constants';
import { ZModelFormatter } from '../language-server/zmodel-formatter';
import { createZModelServices, ZModelServices } from '../language-server/zmodel-module';
-import { mergeBaseModel, resolveImport, resolveTransitiveImports } from '../utils/ast-utils';
+import { mergeBaseModels, resolveImport, resolveTransitiveImports } from '../utils/ast-utils';
import { findUp } from '../utils/pkg-utils';
import { getVersion } from '../utils/version-utils';
import { CliError } from './cli-error';
@@ -101,10 +101,10 @@ export async function loadDocument(fileName: string): Promise {
imported.map((m) => m.$document!.uri)
);
- validationAfterMerge(model);
+ validationAfterImportMerge(model);
// merge fields and attributes from base models
- mergeBaseModel(model, services.references.Linker);
+ mergeBaseModels(model, services.references.Linker);
// finally relink all references
const relinkedModel = await relinkAll(model, services);
@@ -113,7 +113,7 @@ export async function loadDocument(fileName: string): Promise {
}
// check global unique thing after merge imports
-function validationAfterMerge(model: Model) {
+function validationAfterImportMerge(model: Model) {
const dataSources = model.declarations.filter((d) => isDataSource(d));
if (dataSources.length == 0) {
console.error(colors.red('Validation error: Model must define a datasource'));
diff --git a/packages/schema/src/language-server/utils.ts b/packages/schema/src/language-server/utils.ts
index c57836a91..019004a62 100644
--- a/packages/schema/src/language-server/utils.ts
+++ b/packages/schema/src/language-server/utils.ts
@@ -1,4 +1,10 @@
-import { DataModel, DataModelField, isArrayExpr, isReferenceExpr, ReferenceExpr } from '@zenstackhq/language/ast';
+import {
+ isArrayExpr,
+ isReferenceExpr,
+ type DataModel,
+ type DataModelField,
+ type ReferenceExpr,
+} from '@zenstackhq/language/ast';
import { resolved } from '@zenstackhq/sdk';
/**
diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts
index 0baf5ace3..9185443f3 100644
--- a/packages/schema/src/language-server/validator/datamodel-validator.ts
+++ b/packages/schema/src/language-server/validator/datamodel-validator.ts
@@ -2,14 +2,19 @@ import {
ArrayExpr,
DataModel,
DataModelField,
- isDataModel,
- isStringLiteral,
ReferenceExpr,
+ isDataModel,
isEnum,
+ isStringLiteral,
} from '@zenstackhq/language/ast';
-import { getLiteral, getModelIdFields, getModelUniqueFields, isDelegateModel } from '@zenstackhq/sdk';
-import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium';
-import { getModelFieldsWithBases } from '../../utils/ast-utils';
+import {
+ getLiteral,
+ getModelFieldsWithBases,
+ getModelIdFields,
+ getModelUniqueFields,
+ isDelegateModel,
+} from '@zenstackhq/sdk';
+import { AstNode, DiagnosticInfo, ValidationAcceptor, getDocument } from 'langium';
import { IssueCodes, SCALAR_TYPES } from '../constants';
import { AstValidator } from '../types';
import { getUniqueFields } from '../utils';
@@ -361,6 +366,11 @@ export default class DataModelValidator implements AstValidator {
const containingModel = field.$container as DataModel;
const uniqueFieldList = getUniqueFields(containingModel);
+ // field is defined in the abstract base model
+ if (containingModel !== contextModel) {
+ uniqueFieldList.push(...getUniqueFields(contextModel));
+ }
+
thisRelation.fields?.forEach((ref) => {
const refField = ref.target.ref as DataModelField;
if (refField) {
@@ -372,7 +382,7 @@ export default class DataModelValidator implements AstValidator {
}
accept(
'error',
- `Field "${refField.name}" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute`,
+ `Field "${refField.name}" on model "${containingModel.name}" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute`,
{ node: refField }
);
}
diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts
index cb42e4cb1..d65e304dc 100644
--- a/packages/schema/src/language-server/validator/expression-validator.ts
+++ b/packages/schema/src/language-server/validator/expression-validator.ts
@@ -30,7 +30,7 @@ export default class ExpressionValidator implements AstValidator {
// check was done at link time
accept(
'error',
- 'auth() cannot be resolved because no model marked wth "@@auth()" or named "User" is found',
+ 'auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found',
{ node: expr }
);
} else {
diff --git a/packages/schema/src/language-server/zmodel-code-action.ts b/packages/schema/src/language-server/zmodel-code-action.ts
index f71e479ef..a4c99da96 100644
--- a/packages/schema/src/language-server/zmodel-code-action.ts
+++ b/packages/schema/src/language-server/zmodel-code-action.ts
@@ -10,8 +10,8 @@ import {
getDocument,
} from 'langium';
+import { getModelFieldsWithBases } from '@zenstackhq/sdk';
import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver';
-import { getModelFieldsWithBases } from '../utils/ast-utils';
import { IssueCodes } from './constants';
import { MissingOppositeRelationData } from './validator/datamodel-validator';
import { ZModelFormatter } from './zmodel-formatter';
diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts
index 5a15f9336..c2751b921 100644
--- a/packages/schema/src/language-server/zmodel-linker.ts
+++ b/packages/schema/src/language-server/zmodel-linker.ts
@@ -35,13 +35,7 @@ import {
isReferenceExpr,
isStringLiteral,
} from '@zenstackhq/language/ast';
-import {
- getAuthModel,
- getContainingModel,
- getModelFieldsWithBases,
- isAuthInvocation,
- isFutureExpr,
-} from '@zenstackhq/sdk';
+import { getAuthModel, getModelFieldsWithBases, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk';
import {
AstNode,
AstNodeDescription,
@@ -52,13 +46,14 @@ import {
LangiumServices,
LinkingError,
Reference,
+ getContainerOfType,
interruptAndCheck,
isReference,
streamContents,
} from 'langium';
import { match } from 'ts-pattern';
import { CancellationToken } from 'vscode-jsonrpc';
-import { getAllDataModelsIncludingImports, getContainingDataModel } from '../utils/ast-utils';
+import { getAllLoadedAndReachableDataModels, getContainingDataModel } from '../utils/ast-utils';
import { mapBuiltinTypeToExpressionType } from './validator/utils';
interface DefaultReference extends Reference {
@@ -283,15 +278,17 @@ export class ZModelLinker extends DefaultLinker {
// eslint-disable-next-line @typescript-eslint/ban-types
const funcDecl = node.function.ref as FunctionDecl;
if (isAuthInvocation(node)) {
- // auth() function is resolved to User model in the current document
- const model = getContainingModel(node);
-
- if (model) {
- const allDataModels = getAllDataModelsIncludingImports(this.langiumDocuments(), model);
- const authModel = getAuthModel(allDataModels);
- if (authModel) {
- node.$resolvedType = { decl: authModel, nullable: true };
- }
+ // auth() function is resolved against all loaded and reachable documents
+
+ // get all data models from loaded and reachable documents
+ const allDataModels = getAllLoadedAndReachableDataModels(
+ this.langiumDocuments(),
+ getContainerOfType(node, isDataModel)
+ );
+
+ const authModel = getAuthModel(allDataModels);
+ if (authModel) {
+ node.$resolvedType = { decl: authModel, nullable: true };
}
} else if (isFutureExpr(node)) {
// future() function is resolved to current model
diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts
index e48a17621..e8e8880b5 100644
--- a/packages/schema/src/language-server/zmodel-scope.ts
+++ b/packages/schema/src/language-server/zmodel-scope.ts
@@ -32,7 +32,7 @@ import {
import { match } from 'ts-pattern';
import { CancellationToken } from 'vscode-jsonrpc';
import {
- getAllDataModelsIncludingImports,
+ getAllLoadedAndReachableDataModels,
isCollectionPredicate,
isFutureInvocation,
resolveImportUri,
@@ -219,18 +219,18 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
}
private createScopeForAuthModel(node: AstNode, globalScope: Scope) {
- const model = getContainerOfType(node, isModel);
- if (model) {
- const allDataModels = getAllDataModelsIncludingImports(
- this.services.shared.workspace.LangiumDocuments,
- model
- );
- const authModel = getAuthModel(allDataModels);
- if (authModel) {
- return this.createScopeForModel(authModel, globalScope);
- }
+ // get all data models from loaded and reachable documents
+ const allDataModels = getAllLoadedAndReachableDataModels(
+ this.services.shared.workspace.LangiumDocuments,
+ getContainerOfType(node, isDataModel)
+ );
+
+ const authModel = getAuthModel(allDataModels);
+ if (authModel) {
+ return this.createScopeForModel(authModel, globalScope);
+ } else {
+ return EMPTY_SCOPE;
}
- return EMPTY_SCOPE;
}
}
diff --git a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts
index 18bbd8c72..bf92d1c9d 100644
--- a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts
+++ b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts
@@ -106,7 +106,7 @@ export function generateAuthType(model: Model, authModel: DataModel) {
// }
// `
- return `namespace auth {
+ return `export namespace auth {
type WithRequired = T & { [P in K]-?: T[P] };
${Array.from(types.entries())
.map(([model, fields]) => {
diff --git a/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts
new file mode 100644
index 000000000..7202f2cdf
--- /dev/null
+++ b/packages/schema/src/plugins/enhancer/enhance/checker-type-generator.ts
@@ -0,0 +1,55 @@
+import { getDataModels } from '@zenstackhq/sdk';
+import type { DataModel, DataModelField, Model } from '@zenstackhq/sdk/ast';
+import { lowerCaseFirst } from 'lower-case-first';
+import { P, match } from 'ts-pattern';
+
+/**
+ * Generates a `ModelCheckers` interface that contains a `check` method for each model in the schema.
+ *
+ * E.g.:
+ *
+ * ```ts
+ * type CheckerOperation = 'create' | 'read' | 'update' | 'delete';
+ *
+ * export interface ModelCheckers {
+ * user: { check(op: CheckerOperation, args?: { email?: string; age?: number; }): Promise },
+ * ...
+ * }
+ * ```
+ */
+export function generateCheckerType(model: Model) {
+ return `
+import type { PolicyCrudKind } from '@zenstackhq/runtime';
+
+export interface ModelCheckers {
+ ${getDataModels(model)
+ .map((dataModel) => `\t${lowerCaseFirst(dataModel.name)}: ${generateDataModelChecker(dataModel)}`)
+ .join(',\n')}
+}
+`;
+}
+
+function generateDataModelChecker(dataModel: DataModel) {
+ return `{
+ check(args: { operation: PolicyCrudKind, where?: ${generateDataModelArgs(dataModel)} }): Promise
+ }`;
+}
+
+function generateDataModelArgs(dataModel: DataModel) {
+ return `{ ${dataModel.fields
+ .filter((field) => isFieldFilterable(field))
+ .map((field) => `${field.name}?: ${mapFieldType(field)}`)
+ .join('; ')} }`;
+}
+
+function isFieldFilterable(field: DataModelField) {
+ return !!mapFieldType(field);
+}
+
+function mapFieldType(field: DataModelField) {
+ return match(field.type.type)
+ .with('Boolean', () => 'boolean')
+ .with(P.union('BigInt', 'Int', 'Float', 'Decimal'), () => 'number')
+ .with('String', () => 'string')
+ .otherwise(() => undefined);
+}
diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts
index 0425b76d0..0993f323c 100644
--- a/packages/schema/src/plugins/enhancer/enhance/index.ts
+++ b/packages/schema/src/plugins/enhancer/enhance/index.ts
@@ -41,6 +41,7 @@ import { trackPrismaSchemaError } from '../../prisma';
import { PrismaSchemaGenerator } from '../../prisma/schema-generator';
import { isDefaultWithAuth } from '../enhancer-utils';
import { generateAuthType } from './auth-type-generator';
+import { generateCheckerType } from './checker-type-generator';
// information of delegate models and their sub models
type DelegateInfo = [DataModel, DataModel[]][];
@@ -55,7 +56,7 @@ export class EnhancerGenerator {
private readonly outDir: string
) {}
- async generate() {
+ async generate(): Promise<{ dmmf: DMMF.Document | undefined }> {
let logicalPrismaClientDir: string | undefined;
let dmmf: DMMF.Document | undefined;
@@ -89,6 +90,8 @@ export class EnhancerGenerator {
const authTypes = authModel ? generateAuthType(this.model, authModel) : '';
const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser';
+ const checkerTypes = this.generatePermissionChecker ? generateCheckerType(this.model) : '';
+
const enhanceTs = this.project.createSourceFile(
path.join(this.outDir, 'enhance.ts'),
`import { type EnhancementContext, type EnhancementOptions, type ZodSchemas, type AuthUser } from '@zenstackhq/runtime';
@@ -105,6 +108,8 @@ ${
${authTypes}
+${checkerTypes}
+
${
logicalPrismaClientDir
? this.createLogicalPrismaEnhanceFunction(authTypeParam)
@@ -126,15 +131,16 @@ import type * as _P from '${prismaImport}';
}
private createSimplePrismaEnhanceFunction(authTypeParam: string) {
+ const returnType = `DbClient${this.generatePermissionChecker ? ' & ModelCheckers' : ''}`;
return `
-export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions) {
+export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): ${returnType} {
return createEnhancement(prisma, {
modelMeta,
policy,
zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined),
prismaModule: Prisma,
...options
- }, context);
+ }, context) as ${returnType};
}
`;
}
@@ -157,12 +163,16 @@ import type { Prisma, PrismaClient } from '${logicalPrismaClientDir}/index-fixed
// overload for plain PrismaClient
export function enhance & InternalArgs>(
prisma: _PrismaClient,
- context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient;
+ context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): PrismaClient${
+ this.generatePermissionChecker ? ' & ModelCheckers' : ''
+ };
// overload for extended PrismaClient
export function enhance & InternalArgs>(
prisma: DynamicClientExtensionThis,
- context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis;
+ context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): DynamicClientExtensionThis${
+ this.generatePermissionChecker ? ' & ModelCheckers' : ''
+ };
export function enhance(prisma: any, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions): any {
return createEnhancement(prisma, {
@@ -204,54 +214,59 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
// calculate a relative output path to output the logical prisma client into enhancer's output dir
const prismaClientOutDir = path.join(path.relative(zmodelDir, this.outDir), LOGICAL_CLIENT_GENERATION_PATH);
- try {
- await prismaGenerator.generate({
- provider: '@internal', // doesn't matter
- schemaPath: this.options.schemaPath,
- output: logicalPrismaFile,
- overrideClientGenerationPath: prismaClientOutDir,
- mode: 'logical',
- });
+ await prismaGenerator.generate({
+ provider: '@internal', // doesn't matter
+ schemaPath: this.options.schemaPath,
+ output: logicalPrismaFile,
+ overrideClientGenerationPath: prismaClientOutDir,
+ mode: 'logical',
+ });
- // generate the prisma client
+ // generate the prisma client
- // only run prisma client generator for the logical schema
- const prismaClientGeneratorName = this.getPrismaClientGeneratorName(this.model);
- let generateCmd = `prisma generate --schema "${logicalPrismaFile}" --generator=${prismaClientGeneratorName}`;
+ // only run prisma client generator for the logical schema
+ const prismaClientGeneratorName = this.getPrismaClientGeneratorName(this.model);
+ let generateCmd = `prisma generate --schema "${logicalPrismaFile}" --generator=${prismaClientGeneratorName}`;
- const prismaVersion = getPrismaVersion();
- if (!prismaVersion || semver.gte(prismaVersion, '5.2.0')) {
- // add --no-engine to reduce generation size if the prisma version supports
- generateCmd += ' --no-engine';
- }
+ const prismaVersion = getPrismaVersion();
+ if (!prismaVersion || semver.gte(prismaVersion, '5.2.0')) {
+ // add --no-engine to reduce generation size if the prisma version supports
+ generateCmd += ' --no-engine';
+ }
+ try {
+ // run 'prisma generate'
+ await execPackage(generateCmd, { stdio: 'ignore' });
+ } catch {
+ await trackPrismaSchemaError(logicalPrismaFile);
try {
- // run 'prisma generate'
- await execPackage(generateCmd, { stdio: 'ignore' });
+ // run 'prisma generate' again with output to the console
+ await execPackage(generateCmd);
} catch {
- await trackPrismaSchemaError(logicalPrismaFile);
- try {
- // run 'prisma generate' again with output to the console
- await execPackage(generateCmd);
- } catch {
- // noop
- }
- throw new PluginError(name, `Failed to run "prisma generate" on logical schema: ${logicalPrismaFile}`);
+ // noop
}
+ throw new PluginError(name, `Failed to run "prisma generate" on logical schema: ${logicalPrismaFile}`);
+ }
- // make a bunch of typing fixes to the generated prisma client
- await this.processClientTypes(path.join(this.outDir, LOGICAL_CLIENT_GENERATION_PATH));
+ // make a bunch of typing fixes to the generated prisma client
+ await this.processClientTypes(path.join(this.outDir, LOGICAL_CLIENT_GENERATION_PATH));
- return {
- prismaSchema: logicalPrismaFile,
- // load the dmmf of the logical prisma schema
- dmmf: await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) }),
- };
- } finally {
+ const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) });
+
+ try {
+ // clean up temp schema
if (fs.existsSync(logicalPrismaFile)) {
fs.rmSync(logicalPrismaFile);
}
+ } catch {
+ // ignore errors
}
+
+ return {
+ prismaSchema: logicalPrismaFile,
+ // load the dmmf of the logical prisma schema
+ dmmf,
+ };
}
private getPrismaClientGeneratorName(model: Model) {
@@ -277,12 +292,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
this.model.declarations
.filter((d): d is DataModel => isDelegateModel(d))
.forEach((dm) => {
- delegateInfo.push([
- dm,
- this.model.declarations.filter(
- (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
- ),
- ]);
+ const concreteModels = this.model.declarations.filter(
+ (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
+ );
+ if (concreteModels.length > 0) {
+ delegateInfo.push([dm, concreteModels]);
+ }
});
// transform index.d.ts and save it into a new file (better perf than in-line editing)
@@ -622,4 +637,8 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
await sf.save();
}
}
+
+ private get generatePermissionChecker() {
+ return this.options.generatePermissionChecker === true;
+ }
}
diff --git a/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts
new file mode 100644
index 000000000..a0b1c1dd2
--- /dev/null
+++ b/packages/schema/src/plugins/enhancer/policy/constraint-transformer.ts
@@ -0,0 +1,379 @@
+import {
+ getRelationKeyPairs,
+ isAuthInvocation,
+ isDataModelFieldReference,
+ isEnumFieldReference,
+} from '@zenstackhq/sdk';
+import {
+ BinaryExpr,
+ BooleanLiteral,
+ DataModelField,
+ Expression,
+ ExpressionType,
+ LiteralExpr,
+ MemberAccessExpr,
+ NumberLiteral,
+ ReferenceExpr,
+ StringLiteral,
+ UnaryExpr,
+ isBinaryExpr,
+ isDataModelField,
+ isEnum,
+ isLiteralExpr,
+ isMemberAccessExpr,
+ isNullExpr,
+ isReferenceExpr,
+ isThisExpr,
+ isUnaryExpr,
+} from '@zenstackhq/sdk/ast';
+import { P, match } from 'ts-pattern';
+
+/**
+ * Options for {@link ConstraintTransformer}.
+ */
+export type ConstraintTransformerOptions = {
+ authAccessor: string;
+};
+
+/**
+ * Transform a set of allow and deny rules into a single constraint expression.
+ */
+export class ConstraintTransformer {
+ // a counter for generating unique variable names
+ private varCounter = 0;
+
+ constructor(private readonly options: ConstraintTransformerOptions) {}
+
+ /**
+ * Transforms a set of allow and deny rules into a single constraint expression.
+ */
+ transformRules(allows: Expression[], denies: Expression[]): string {
+ // reset state
+ this.varCounter = 0;
+
+ if (allows.length === 0) {
+ // unconditionally deny
+ return this.value('false', 'boolean');
+ }
+
+ let result: string;
+
+ // transform allow rules
+ const allowConstraints = allows.map((allow) => this.transformExpression(allow));
+ if (allowConstraints.length > 1) {
+ result = this.or(...allowConstraints);
+ } else {
+ result = allowConstraints[0];
+ }
+
+ // transform deny rules and compose
+ if (denies.length > 0) {
+ const denyConstraints = denies.map((deny) => this.transformExpression(deny));
+ result = this.and(result, ...denyConstraints.map((c) => this.not(c)));
+ }
+
+ // DEBUG:
+ // console.log(`Constraint transformation result:\n${JSON.stringify(result, null, 2)}`);
+
+ return result;
+ }
+
+ private and(...constraints: string[]) {
+ if (constraints.length === 0) {
+ throw new Error('No expressions to combine');
+ }
+ return constraints.length === 1 ? constraints[0] : `{ kind: 'and', children: [ ${constraints.join(', ')} ] }`;
+ }
+
+ private or(...constraints: string[]) {
+ if (constraints.length === 0) {
+ throw new Error('No expressions to combine');
+ }
+ return constraints.length === 1 ? constraints[0] : `{ kind: 'or', children: [ ${constraints.join(', ')} ] }`;
+ }
+
+ private not(constraint: string) {
+ return `{ kind: 'not', children: [${constraint}] }`;
+ }
+
+ private transformExpression(expression: Expression) {
+ return (
+ match(expression)
+ .when(isBinaryExpr, (expr) => this.transformBinary(expr))
+ .when(isUnaryExpr, (expr) => this.transformUnary(expr))
+ // top-level boolean literal
+ .when(isLiteralExpr, (expr) => this.transformLiteral(expr))
+ // top-level boolean reference expr
+ .when(isReferenceExpr, (expr) => this.transformReference(expr))
+ // top-level boolean member access expr
+ .when(isMemberAccessExpr, (expr) => this.transformMemberAccess(expr))
+ .otherwise(() => this.nextVar())
+ );
+ }
+
+ private transformLiteral(expr: LiteralExpr) {
+ return match(expr.$type)
+ .with(NumberLiteral, () => {
+ const parsed = parseFloat(expr.value as string);
+ if (isNaN(parsed) || parsed < 0 || !Number.isInteger(parsed)) {
+ // only non-negative integers are supported, for other cases,
+ // transform into a free variable
+ return this.nextVar('number');
+ }
+ return this.value(expr.value.toString(), 'number');
+ })
+ .with(StringLiteral, () => this.value(`'${expr.value}'`, 'string'))
+ .with(BooleanLiteral, () => this.value(expr.value.toString(), 'boolean'))
+ .exhaustive();
+ }
+
+ private transformReference(expr: ReferenceExpr) {
+ // top-level reference is transformed into a named variable
+ return this.variable(expr.target.$refText, 'boolean');
+ }
+
+ private transformMemberAccess(expr: MemberAccessExpr) {
+ // "this.x" is transformed into a named variable
+ if (isThisExpr(expr.operand)) {
+ return this.variable(expr.member.$refText, 'boolean');
+ }
+
+ // top-level auth access
+ const authAccess = this.getAuthAccess(expr);
+ if (authAccess) {
+ return this.value(`${authAccess} ?? false`, 'boolean');
+ }
+
+ // other top-level member access expressions are not supported
+ // and thus transformed into a free variable
+ return this.nextVar();
+ }
+
+ private transformBinary(expr: BinaryExpr): string {
+ return (
+ match(expr.operator)
+ .with('&&', () => this.and(this.transformExpression(expr.left), this.transformExpression(expr.right)))
+ .with('||', () => this.or(this.transformExpression(expr.left), this.transformExpression(expr.right)))
+ .with(P.union('==', '!=', '<', '<=', '>', '>='), () => this.transformComparison(expr))
+ // unsupported operators (e.g., collection predicate) are transformed into a free variable
+ .otherwise(() => this.nextVar())
+ );
+ }
+
+ private transformUnary(expr: UnaryExpr): string {
+ return match(expr.operator)
+ .with('!', () => this.not(this.transformExpression(expr.operand)))
+ .exhaustive();
+ }
+
+ private transformComparison(expr: BinaryExpr) {
+ if (isAuthInvocation(expr.left) || isAuthInvocation(expr.right)) {
+ // handle the case if any operand is `auth()` invocation
+ const authComparison = this.transformAuthComparison(expr);
+ return authComparison ?? this.nextVar();
+ }
+
+ const leftOperand = this.getComparisonOperand(expr.left);
+ const rightOperand = this.getComparisonOperand(expr.right);
+
+ const op = this.mapOperatorToConstraintKind(expr.operator);
+ const result = `{ kind: '${op}', left: ${leftOperand}, right: ${rightOperand} }`;
+
+ // `auth()` member access can be undefined, when that happens, we assume a false condition
+ // for the comparison
+
+ const leftAuthAccess = this.getAuthAccess(expr.left);
+ const rightAuthAccess = this.getAuthAccess(expr.right);
+
+ if (leftAuthAccess && rightOperand) {
+ // `auth().f op x` => `auth().f !== undefined && auth().f op x`
+ return this.and(this.value(`${this.normalizeToNull(leftAuthAccess)} !== null`, 'boolean'), result);
+ } else if (rightAuthAccess && leftOperand) {
+ // `x op auth().f` => `auth().f !== undefined && x op auth().f`
+ return this.and(this.value(`${this.normalizeToNull(rightAuthAccess)} !== null`, 'boolean'), result);
+ }
+
+ if (leftOperand === undefined || rightOperand === undefined) {
+ // if either operand is not supported, transform into a free variable
+ return this.nextVar();
+ }
+
+ return result;
+ }
+
+ private transformAuthComparison(expr: BinaryExpr) {
+ if (this.isAuthEqualNull(expr)) {
+ // `auth() == null` => `user === null`
+ return this.value(`${this.options.authAccessor} === null`, 'boolean');
+ }
+
+ if (this.isAuthNotEqualNull(expr)) {
+ // `auth() != null` => `user !== null`
+ return this.value(`${this.options.authAccessor} !== null`, 'boolean');
+ }
+
+ // auth() equality check against a relation, translate to id-fk comparison
+ const operand = isAuthInvocation(expr.left) ? expr.right : expr.left;
+ if (!isDataModelFieldReference(operand)) {
+ return undefined;
+ }
+
+ // get id-fk field pairs from the relation field
+ const relationField = operand.target.ref as DataModelField;
+ const idFkPairs = getRelationKeyPairs(relationField);
+
+ // build id-fk field comparison constraints
+ const fieldConstraints: string[] = [];
+
+ idFkPairs.forEach(({ id, foreignKey }) => {
+ const idFieldType = this.mapType(id.type.type as ExpressionType);
+ if (!idFieldType) {
+ return;
+ }
+ const fkFieldType = this.mapType(foreignKey.type.type as ExpressionType);
+ if (!fkFieldType) {
+ return;
+ }
+
+ const op = this.mapOperatorToConstraintKind(expr.operator);
+ const authIdAccess = `${this.options.authAccessor}?.${id.name}`;
+
+ fieldConstraints.push(
+ this.and(
+ // `auth()?.id != null` guard
+ this.value(`${this.normalizeToNull(authIdAccess)} !== null`, 'boolean'),
+ // `auth()?.id [op] fkField`
+ `{ kind: '${op}', left: ${this.value(authIdAccess, idFieldType)}, right: ${this.variable(
+ foreignKey.name,
+ fkFieldType
+ )} }`
+ )
+ );
+ });
+
+ // combine field constraints
+ if (fieldConstraints.length > 0) {
+ return this.and(...fieldConstraints);
+ }
+
+ return undefined;
+ }
+
+ // normalize `auth()` access undefined value to null
+ private normalizeToNull(expr: string) {
+ return `(${expr} ?? null)`;
+ }
+
+ private isAuthEqualNull(expr: BinaryExpr) {
+ return (
+ expr.operator === '==' &&
+ ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) ||
+ (isAuthInvocation(expr.right) && isNullExpr(expr.left)))
+ );
+ }
+
+ private isAuthNotEqualNull(expr: BinaryExpr) {
+ return (
+ expr.operator === '!=' &&
+ ((isAuthInvocation(expr.left) && isNullExpr(expr.right)) ||
+ (isAuthInvocation(expr.right) && isNullExpr(expr.left)))
+ );
+ }
+
+ private getComparisonOperand(expr: Expression) {
+ if (isLiteralExpr(expr)) {
+ return this.transformLiteral(expr);
+ }
+
+ if (isEnumFieldReference(expr)) {
+ return this.value(`'${expr.target.$refText}'`, 'string');
+ }
+
+ const fieldAccess = this.getFieldAccess(expr);
+ if (fieldAccess) {
+ // model field access is transformed into a named variable
+ const mappedType = this.mapExpressionType(expr);
+ if (mappedType) {
+ return this.variable(fieldAccess.name, mappedType);
+ } else {
+ return undefined;
+ }
+ }
+
+ const authAccess = this.getAuthAccess(expr);
+ if (authAccess) {
+ const mappedType = this.mapExpressionType(expr);
+ if (mappedType) {
+ return `${this.value(authAccess, mappedType)}`;
+ } else {
+ return undefined;
+ }
+ }
+
+ return undefined;
+ }
+
+ private mapExpressionType(expression: Expression) {
+ if (isEnum(expression.$resolvedType?.decl)) {
+ return 'string';
+ } else {
+ return this.mapType(expression.$resolvedType?.decl as ExpressionType);
+ }
+ }
+
+ private mapType(type: ExpressionType) {
+ return match(type)
+ .with('Boolean', () => 'boolean')
+ .with('Int', () => 'number')
+ .with('String', () => 'string')
+ .otherwise(() => undefined);
+ }
+
+ private mapOperatorToConstraintKind(operator: BinaryExpr['operator']) {
+ return match(operator)
+ .with('==', () => 'eq')
+ .with('!=', () => 'ne')
+ .with('<', () => 'lt')
+ .with('<=', () => 'lte')
+ .with('>', () => 'gt')
+ .with('>=', () => 'gte')
+ .otherwise(() => {
+ throw new Error(`Unsupported operator: ${operator}`);
+ });
+ }
+
+ private getFieldAccess(expr: Expression) {
+ if (isReferenceExpr(expr)) {
+ return isDataModelField(expr.target.ref) ? { name: expr.target.$refText } : undefined;
+ }
+ if (isMemberAccessExpr(expr)) {
+ return isThisExpr(expr.operand) ? { name: expr.member.$refText } : undefined;
+ }
+ return undefined;
+ }
+
+ private getAuthAccess(expr: Expression): string | undefined {
+ if (!isMemberAccessExpr(expr)) {
+ return undefined;
+ }
+
+ if (isAuthInvocation(expr.operand)) {
+ return `${this.options.authAccessor}?.${expr.member.$refText}`;
+ } else {
+ const operand = this.getAuthAccess(expr.operand);
+ return operand ? `${operand}?.${expr.member.$refText}` : undefined;
+ }
+ }
+
+ private nextVar(type = 'boolean') {
+ return this.variable(`__var${this.varCounter++}`, type);
+ }
+
+ private variable(name: string, type: string) {
+ return `{ kind: 'variable', name: '${name}', type: '${type}' }`;
+ }
+
+ private value(value: string, type: string) {
+ return `{ kind: 'value', value: ${value}, type: '${type}' }`;
+ }
+}
diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts
index 753ef8f19..a36a52126 100644
--- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts
+++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts
@@ -54,6 +54,7 @@ import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, Writ
import { name } from '..';
import { isCollectionPredicate } from '../../../utils/ast-utils';
import { ALL_OPERATION_KINDS } from '../../plugin-utils';
+import { ConstraintTransformer } from './constraint-transformer';
import { ExpressionWriter, FALSE, TRUE } from './expression-writer';
/**
@@ -70,6 +71,8 @@ export class PolicyGenerator {
{ name: 'type CrudContract' },
{ name: 'allFieldsEqual' },
{ name: 'type PolicyDef' },
+ { name: 'type CheckerContext' },
+ { name: 'type CheckerConstraint' },
],
moduleSpecifier: `${RUNTIME_PACKAGE}`,
});
@@ -85,11 +88,22 @@ export class PolicyGenerator {
const models = getDataModels(model);
+ // policy guard functions
const policyMap: Record> = {};
for (const model of models) {
policyMap[model.name] = await this.generateQueryGuardForModel(model, sf);
}
+ const generatePermissionChecker = options.generatePermissionChecker === true;
+
+ // CRUD checker functions
+ const checkerMap: Record> = {};
+ if (generatePermissionChecker) {
+ for (const model of models) {
+ checkerMap[model.name] = await this.generateCheckerForModel(model, sf);
+ }
+ }
+
const authSelector = this.generateAuthSelector(models);
sf.addVariableStatement({
@@ -118,6 +132,22 @@ export class PolicyGenerator {
});
writer.writeLine(',');
+ if (generatePermissionChecker) {
+ writer.write('checker:');
+ writer.inlineBlock(() => {
+ for (const [model, map] of Object.entries(checkerMap)) {
+ writer.write(`${lowerCaseFirst(model)}:`);
+ writer.inlineBlock(() => {
+ Object.entries(map).forEach(([op, func]) => {
+ writer.write(`${op}: ${func},`);
+ });
+ });
+ writer.writeLine(',');
+ }
+ });
+ writer.writeLine(',');
+ }
+
writer.write('validation:');
writer.inlineBlock(() => {
for (const model of models) {
@@ -301,7 +331,6 @@ export class PolicyGenerator {
}
const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies);
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
result[kind] = guardFunc.getName()!;
if (kind === 'postUpdate') {
@@ -313,7 +342,6 @@ export class PolicyGenerator {
if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) {
const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies);
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
result[kind + '_input'] = inputCheckFunc.getName()!;
}
}
@@ -847,4 +875,70 @@ export class PolicyGenerator {
statements.push(`const user: any = context.user ?? null;`);
}
}
+
+ private async generateCheckerForModel(model: DataModel, sourceFile: SourceFile) {
+ const result: Record = {};
+
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ const policies = analyzePolicies(model);
+
+ for (const kind of ['create', 'read', 'update', 'delete'] as const) {
+ if (policies[kind] === true || policies[kind] === false) {
+ result[kind] = policies[kind] as boolean;
+ continue;
+ }
+
+ const denies = this.getPolicyExpressions(model, 'deny', kind);
+ const allows = this.getPolicyExpressions(model, 'allow', kind);
+
+ if (kind === 'update' && allows.length === 0) {
+ // no allow rule for 'update', policy is constant based on if there's
+ // post-update counterpart
+ if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
+ result[kind] = false;
+ continue;
+ } else {
+ result[kind] = true;
+ continue;
+ }
+ }
+
+ const guardFunc = this.generateCheckerFunction(sourceFile, model, kind, allows, denies);
+ result[kind] = guardFunc.getName()!;
+ }
+
+ return result;
+ }
+
+ private generateCheckerFunction(
+ sourceFile: SourceFile,
+ model: DataModel,
+ kind: string,
+ allows: Expression[],
+ denies: Expression[]
+ ) {
+ const statements: string[] = [];
+
+ this.generateNormalizedAuthRef(model, allows, denies, statements);
+
+ const transformed = new ConstraintTransformer({
+ authAccessor: 'user',
+ }).transformRules(allows, denies);
+
+ statements.push(`return ${transformed};`);
+
+ const func = sourceFile.addFunction({
+ name: `${model.name}$checker$${kind}`,
+ returnType: 'CheckerConstraint',
+ parameters: [
+ {
+ name: 'context',
+ type: 'CheckerContext',
+ },
+ ],
+ statements,
+ });
+
+ return func;
+ }
}
diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts
index 5440e03cf..58ecd68f4 100644
--- a/packages/schema/src/plugins/prisma/schema-generator.ts
+++ b/packages/schema/src/plugins/prisma/schema-generator.ts
@@ -374,6 +374,11 @@ export class PrismaSchemaGenerator {
// for the given model, find relation fields of delegate model type, find all concrete models
// of the delegate model and generate an auxiliary opposite relation field to each of them
decl.fields.forEach((f) => {
+ // don't process fields inherited from a delegate model
+ if (f.$inheritedFrom && isDelegateModel(f.$inheritedFrom)) {
+ return;
+ }
+
const fieldType = f.type.reference?.ref;
if (!isDataModel(fieldType)) {
return;
diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts
index a9b963d30..91f152af8 100644
--- a/packages/schema/src/plugins/zod/generator.ts
+++ b/packages/schema/src/plugins/zod/generator.ts
@@ -13,7 +13,7 @@ import {
} from '@zenstackhq/sdk';
import { DataModel, EnumField, Model, isDataModel, isEnum } from '@zenstackhq/sdk/ast';
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
-import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma';
+import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import { streamAllContents } from 'langium';
import path from 'path';
import type { SourceFile } from 'ts-morph';
@@ -106,8 +106,9 @@ export class ZodSchemaGenerator {
aggregateOperationSupport,
project: this.project,
inputObjectTypes,
+ zmodel: this.model,
});
- await transformer.generateInputSchemas(this.options);
+ await transformer.generateInputSchemas(this.options, this.model);
this.sourceFiles.push(...transformer.sourceFiles);
}
@@ -189,7 +190,10 @@ export class ZodSchemaGenerator {
);
}
- private async generateEnumSchemas(prismaSchemaEnum: DMMF.SchemaEnum[], modelSchemaEnum: DMMF.SchemaEnum[]) {
+ private async generateEnumSchemas(
+ prismaSchemaEnum: readonly DMMF.SchemaEnum[],
+ modelSchemaEnum: readonly DMMF.SchemaEnum[]
+ ) {
const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum];
const enumNames = enumTypes.map((enumItem) => upperCaseFirst(enumItem.name));
Transformer.enumNames = enumNames ?? [];
@@ -197,6 +201,7 @@ export class ZodSchemaGenerator {
enumTypes,
project: this.project,
inputObjectTypes: [],
+ zmodel: this.model,
});
await transformer.generateEnumSchemas();
this.sourceFiles.push(...transformer.sourceFiles);
@@ -210,14 +215,21 @@ export class ZodSchemaGenerator {
for (let i = 0; i < inputObjectTypes.length; i += 1) {
const fields = inputObjectTypes[i]?.fields;
const name = inputObjectTypes[i]?.name;
+
if (!generateUnchecked && name.includes('Unchecked')) {
continue;
}
+
+ if (name.includes('CreateMany') && !supportCreateMany(this.model)) {
+ continue;
+ }
+
const transformer = new Transformer({
name,
fields,
project: this.project,
inputObjectTypes,
+ zmodel: this.model,
});
const moduleName = transformer.generateObjectSchema(generateUnchecked, this.options);
moduleNames.push(moduleName);
diff --git a/packages/schema/src/plugins/zod/transformer.ts b/packages/schema/src/plugins/zod/transformer.ts
index 4fb04cdcc..09804f42b 100644
--- a/packages/schema/src/plugins/zod/transformer.ts
+++ b/packages/schema/src/plugins/zod/transformer.ts
@@ -1,7 +1,8 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { indentString, type PluginOptions } from '@zenstackhq/sdk';
+import type { Model } from '@zenstackhq/sdk/ast';
import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers';
-import { type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
+import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
import path from 'path';
import type { Project, SourceFile } from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
@@ -11,12 +12,12 @@ import { AggregateOperationSupport, TransformerParams } from './types';
export default class Transformer {
name: string;
originalName: string;
- fields: PrismaDMMF.SchemaArg[];
+ fields: readonly PrismaDMMF.SchemaArg[];
schemaImports = new Set();
- models: PrismaDMMF.Model[];
+ models: readonly PrismaDMMF.Model[];
modelOperations: PrismaDMMF.ModelMapping[];
aggregateOperationSupport: AggregateOperationSupport;
- enumTypes: PrismaDMMF.SchemaEnum[];
+ enumTypes: readonly PrismaDMMF.SchemaEnum[];
static enumNames: string[] = [];
static rawOpsMap: { [name: string]: string } = {};
@@ -26,6 +27,7 @@ export default class Transformer {
private project: Project;
private inputObjectTypes: PrismaDMMF.InputType[];
public sourceFiles: SourceFile[] = [];
+ private zmodel: Model;
constructor(params: TransformerParams) {
this.originalName = params.name ?? '';
@@ -37,6 +39,7 @@ export default class Transformer {
this.enumTypes = params.enumTypes ?? [];
this.project = params.project;
this.inputObjectTypes = params.inputObjectTypes;
+ this.zmodel = params.zmodel;
}
static setOutputPath(outPath: string) {
@@ -118,6 +121,10 @@ export default class Transformer {
return result;
}
+ if (inputType.type.includes('CreateMany') && !supportCreateMany(this.zmodel)) {
+ return result;
+ }
+
// TODO: unify the following with `schema-gen.ts`
if (inputType.type === 'String') {
@@ -389,7 +396,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
return wrapped;
}
- async generateInputSchemas(options: PluginOptions) {
+ async generateInputSchemas(options: PluginOptions, zmodel: Model) {
const globalExports: string[] = [];
// whether Prisma's Unchecked* series of input types should be generated
@@ -489,7 +496,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
operations.push(['create', origModelName]);
}
- if (createMany) {
+ if (createMany && supportCreateMany(zmodel)) {
imports.push(
`import { ${modelName}CreateManyInputObjectSchema } from '../objects/${modelName}CreateManyInput.schema'`
);
diff --git a/packages/schema/src/plugins/zod/types.ts b/packages/schema/src/plugins/zod/types.ts
index b64995448..f74d690e4 100644
--- a/packages/schema/src/plugins/zod/types.ts
+++ b/packages/schema/src/plugins/zod/types.ts
@@ -1,17 +1,19 @@
+import type { Model } from '@zenstackhq/sdk/ast';
import type { DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
import { Project } from 'ts-morph';
export type TransformerParams = {
- enumTypes?: PrismaDMMF.SchemaEnum[];
- fields?: PrismaDMMF.SchemaArg[];
+ enumTypes?: readonly PrismaDMMF.SchemaEnum[];
+ fields?: readonly PrismaDMMF.SchemaArg[];
name?: string;
- models?: PrismaDMMF.Model[];
+ models?: readonly PrismaDMMF.Model[];
modelOperations?: PrismaDMMF.ModelMapping[];
aggregateOperationSupport?: AggregateOperationSupport;
isDefaultPrismaClientOutput?: boolean;
prismaClientOutputPath?: string;
project: Project;
inputObjectTypes: PrismaDMMF.InputType[];
+ zmodel: Model;
};
export type AggregateOperationSupport = {
diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts
index e6a335221..5f3321b94 100644
--- a/packages/schema/src/plugins/zod/utils/schema-gen.ts
+++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts
@@ -141,10 +141,14 @@ export function makeFieldSchema(field: DataModelField) {
}
if (field.attributes.some(isDefaultWithAuth)) {
- // field uses `auth()` in `@default()`, this was transformed into a pseudo default
- // value, while compiling to zod we should turn it into an optional field instead
- // of `.default()`
- schema += '.nullish()';
+ if (field.type.optional) {
+ schema += '.nullish()';
+ } else {
+ // field uses `auth()` in `@default()`, this was transformed into a pseudo default
+ // value, while compiling to zod we should turn it into an optional field instead
+ // of `.default()`
+ schema += '.optional()';
+ }
} else {
const schemaDefault = getFieldSchemaDefault(field);
if (schemaDefault !== undefined) {
diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts
index 83a5a6a57..e891056c2 100644
--- a/packages/schema/src/utils/ast-utils.ts
+++ b/packages/schema/src/utils/ast-utils.ts
@@ -17,7 +17,7 @@ import {
ModelImport,
ReferenceExpr,
} from '@zenstackhq/language/ast';
-import { isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
+import { getModelFieldsWithBases, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import {
AstNode,
copyAstNode,
@@ -47,16 +47,13 @@ type BuildReference = (
refText: string
) => Reference;
-export function mergeBaseModel(model: Model, linker: Linker) {
+export function mergeBaseModels(model: Model, linker: Linker) {
const buildReference = linker.buildReference.bind(linker);
- model.declarations.filter(isDataModel).forEach((decl) => {
- const dataModel = decl as DataModel;
-
+ model.declarations.filter(isDataModel).forEach((dataModel) => {
const bases = getRecursiveBases(dataModel).reverse();
if (bases.length > 0) {
dataModel.fields = bases
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
.flatMap((base) => base.fields)
// don't inherit skip-level fields
.filter((f) => !f.$inheritedFrom)
@@ -67,16 +64,25 @@ export function mergeBaseModel(model: Model, linker: Linker) {
.flatMap((base) => base.attributes.filter((attr) => filterBaseAttribute(base, attr)))
.map((attr) => cloneAst(attr, dataModel, buildReference))
.concat(dataModel.attributes);
-
- // fix $containerIndex
- linkContentToContainer(dataModel);
}
+ // mark base merged
dataModel.$baseMerged = true;
});
// remove abstract models
model.declarations = model.declarations.filter((x) => !(isDataModel(x) && x.isAbstract));
+
+ model.declarations.filter(isDataModel).forEach((dm) => {
+ // remove abstract super types
+ dm.superTypes = dm.superTypes.filter((t) => t.ref && isDelegateModel(t.ref));
+
+ // fix $containerIndex
+ linkContentToContainer(dm);
+ });
+
+ // fix $containerIndex after deleting abstract models
+ linkContentToContainer(model);
}
function filterBaseAttribute(base: DataModel, attr: DataModelAttribute) {
@@ -244,29 +250,6 @@ export function getContainingDataModel(node: Expression): DataModel | undefined
return undefined;
}
-export function getModelFieldsWithBases(model: DataModel, includeDelegate = true) {
- if (model.$baseMerged) {
- return model.fields;
- } else {
- return [...model.fields, ...getRecursiveBases(model, includeDelegate).flatMap((base) => base.fields)];
- }
-}
-
-export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): DataModel[] {
- const result: DataModel[] = [];
- dataModel.superTypes.forEach((superType) => {
- const baseDecl = superType.ref;
- if (baseDecl) {
- if (!includeDelegate && isDelegateModel(baseDecl)) {
- return;
- }
- result.push(baseDecl);
- result.push(...getRecursiveBases(baseDecl));
- }
- });
- return result;
-}
-
/**
* Walk upward from the current AST node to find the first node that satisfies the predicate.
*/
@@ -280,3 +263,36 @@ export function findUpAst(node: AstNode, predicate: (node: AstNode) => boolean):
}
return undefined;
}
+
+/**
+ * Gets all data models from all loaded documents
+ */
+export function getAllLoadedDataModels(langiumDocuments: LangiumDocuments) {
+ return langiumDocuments.all
+ .map((doc) => doc.parseResult.value as Model)
+ .flatMap((model) => model.declarations.filter(isDataModel))
+ .toArray();
+}
+
+/**
+ * Gets all data models from loaded and reachable documents
+ */
+export function getAllLoadedAndReachableDataModels(langiumDocuments: LangiumDocuments, fromModel?: DataModel) {
+ // get all data models from loaded documents
+ const allDataModels = getAllLoadedDataModels(langiumDocuments);
+
+ if (fromModel) {
+ // merge data models transitively reached from the current model
+ const model = getContainerOfType(fromModel, isModel);
+ if (model) {
+ const transitiveDataModels = getAllDataModelsIncludingImports(langiumDocuments, model);
+ transitiveDataModels.forEach((dm) => {
+ if (!allDataModels.includes(dm)) {
+ allDataModels.push(dm);
+ }
+ });
+ }
+ }
+
+ return allDataModels;
+}
diff --git a/packages/schema/src/utils/pkg-utils.ts b/packages/schema/src/utils/pkg-utils.ts
index 82b5ef019..067810a8e 100644
--- a/packages/schema/src/utils/pkg-utils.ts
+++ b/packages/schema/src/utils/pkg-utils.ts
@@ -13,7 +13,7 @@ export type PackageManagers = 'npm' | 'yarn' | 'pnpm';
* @export
* @template e A type parameter that extends boolean
*/
-export type FindUp = e extends true ? string[] | undefined : string | undefined
+export type FindUp = e extends true ? string[] | undefined : string | undefined;
/**
* Find and return file paths by searching parent directories based on the given names list and current working directory (cwd) path.
* Optionally return a single path or multiple paths.
@@ -28,7 +28,12 @@ export type FindUp = e extends true ? string[] | undefined :
* @param [result=[]] An array of strings representing the accumulated results used in multiple results
* @returns Path(s) to a specific file or folder within the directory or parent directories
*/
-export function findUp(names: string[], cwd: string = process.cwd(), multiple: e = false as e, result: string[] = []): FindUp {
+export function findUp(
+ names: string[],
+ cwd: string = process.cwd(),
+ multiple: e = false as e,
+ result: string[] = []
+): FindUp {
if (!names.some((name) => !!name)) return undefined;
const target = names.find((name) => fs.existsSync(path.join(cwd, name)));
if (multiple == false && target) return path.join(cwd, target) as FindUp;
@@ -38,7 +43,6 @@ export function findUp(names: string[], cwd: string =
return findUp(names, up, multiple, result);
}
-
/**
* Find a Node module/file given its name in a specific directory, with a fallback to the current working directory.
* If the name is empty, return undefined.
@@ -54,11 +58,11 @@ export function findNodeModulesFile(name: string, cwd: string = process.cwd()) {
if (!name) return undefined;
try {
// Use require.resolve to find the module/file. The paths option allows specifying the directory to start from.
- const resolvedPath = require.resolve(name, { paths: [cwd] })
- return resolvedPath
+ const resolvedPath = require.resolve(name, { paths: [cwd] });
+ return resolvedPath;
} catch (error) {
// If require.resolve fails to find the module/file, it will throw an error.
- return undefined
+ return undefined;
}
}
@@ -86,7 +90,7 @@ export function installPackage(
projectPath = '.',
exactVersion = true
) {
- const manager = pkgManager ?? getPackageManager(projectPath);
+ const manager = pkgManager ?? getPackageManager(projectPath).packageManager;
console.log(`Installing package "${pkg}@${tag}" with ${manager}`);
switch (manager) {
case 'yarn':
diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts
index aca9e2674..380836e21 100644
--- a/packages/schema/tests/schema/validation/attribute-validation.test.ts
+++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts
@@ -1051,7 +1051,7 @@ describe('Attribute tests', () => {
@@allow('all', auth() != null)
}
`)
- ).toContain(`auth() cannot be resolved because no model marked wth "@@auth()" or named "User" is found`);
+ ).toContain(`auth() cannot be resolved because no model marked with "@@auth()" or named "User" is found`);
await loadModel(`
${prelude}
diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts
index 0bf12245d..e0778da51 100644
--- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts
+++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts
@@ -521,7 +521,7 @@ describe('Data Model Validation Tests', () => {
`)
).toMatchObject(
errorLike(
- `Field "aId" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute`
+ `Field "aId" on model "B" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute`
)
);
@@ -736,6 +736,26 @@ describe('Data Model Validation Tests', () => {
).toMatchObject(
errorLike(`The relation field "user" on model "A" is missing an opposite relation field on model "User"`)
);
+
+ // one-to-one relation field and @@unique defined in different models
+ await loadModel(`
+ ${prelude}
+
+ abstract model Base {
+ id String @id
+ a A @relation(fields: [aId], references: [id])
+ aId String
+ }
+
+ model A {
+ id String @id
+ b B?
+ }
+
+ model B extends Base{
+ @@unique([aId])
+ }
+ `);
});
it('delegate base type', async () => {
diff --git a/packages/schema/tests/utils.ts b/packages/schema/tests/utils.ts
index 0c92c6ee2..e72373f82 100644
--- a/packages/schema/tests/utils.ts
+++ b/packages/schema/tests/utils.ts
@@ -5,7 +5,7 @@ import * as path from 'path';
import * as tmp from 'tmp';
import { URI } from 'vscode-uri';
import { createZModelServices } from '../src/language-server/zmodel-module';
-import { mergeBaseModel } from '../src/utils/ast-utils';
+import { mergeBaseModels } from '../src/utils/ast-utils';
tmp.setGracefulCleanup();
@@ -68,7 +68,7 @@ export async function loadModel(content: string, validate = true, verbose = true
const model = (await doc.parseResult.value) as Model;
if (mergeBase) {
- mergeBaseModel(model, ZModel.references.Linker);
+ mergeBaseModels(model, ZModel.references.Linker);
}
return model;
@@ -87,13 +87,13 @@ export async function loadModelWithError(content: string, verbose = false) {
}
export async function safelyLoadModel(content: string, validate = true, verbose = false) {
- const [ result ] = await Promise.allSettled([ loadModel(content, validate, verbose) ]);
+ const [result] = await Promise.allSettled([loadModel(content, validate, verbose)]);
- return result
+ return result;
}
export const errorLike = (msg: string) => ({
reason: {
- message: expect.stringContaining(msg)
+ message: expect.stringContaining(msg),
},
-})
+});
diff --git a/packages/sdk/package.json b/packages/sdk/package.json
index 7bb3d3d2f..e0489d48f 100644
--- a/packages/sdk/package.json
+++ b/packages/sdk/package.json
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/sdk",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "ZenStack plugin development SDK",
"main": "index.js",
"scripts": {
@@ -18,8 +18,8 @@
"author": "",
"license": "MIT",
"dependencies": {
- "@prisma/generator-helper": "5.7.0",
- "@prisma/internals": "5.7.0",
+ "@prisma/generator-helper": "^5.13.0",
+ "@prisma/internals": "^5.13.0",
"@zenstackhq/language": "workspace:*",
"@zenstackhq/runtime": "workspace:*",
"langium": "1.3.1",
diff --git a/packages/sdk/src/dmmf-helpers/include-helpers.ts b/packages/sdk/src/dmmf-helpers/include-helpers.ts
index c09c72426..23c9a9e12 100644
--- a/packages/sdk/src/dmmf-helpers/include-helpers.ts
+++ b/packages/sdk/src/dmmf-helpers/include-helpers.ts
@@ -1,7 +1,10 @@
import type { DMMF } from '../prisma';
import { checkIsModelRelationField, checkModelHasManyModelRelation, checkModelHasModelRelation } from './model-helpers';
-export function addMissingInputObjectTypesForInclude(inputObjectTypes: DMMF.InputType[], models: DMMF.Model[]) {
+export function addMissingInputObjectTypesForInclude(
+ inputObjectTypes: DMMF.InputType[],
+ models: readonly DMMF.Model[]
+) {
// generate input object types necessary to support ModelInclude with relation support
const generatedIncludeInputObjectTypes = generateModelIncludeInputObjectTypes(models);
@@ -9,7 +12,7 @@ export function addMissingInputObjectTypesForInclude(inputObjectTypes: DMMF.Inpu
inputObjectTypes.push(includeInputObjectType);
}
}
-function generateModelIncludeInputObjectTypes(models: DMMF.Model[]) {
+function generateModelIncludeInputObjectTypes(models: readonly DMMF.Model[]) {
const modelIncludeInputObjectTypes: DMMF.InputType[] = [];
for (const model of models) {
const { name: modelName, fields: modelFields } = model;
diff --git a/packages/sdk/src/dmmf-helpers/model-helpers.ts b/packages/sdk/src/dmmf-helpers/model-helpers.ts
index 62bbd9980..2528eeeed 100644
--- a/packages/sdk/src/dmmf-helpers/model-helpers.ts
+++ b/packages/sdk/src/dmmf-helpers/model-helpers.ts
@@ -31,6 +31,6 @@ export function checkIsManyModelRelationField(modelField: DMMF.Field) {
return checkIsModelRelationField(modelField) && modelField.isList;
}
-export function findModelByName(models: DMMF.Model[], modelName: string) {
+export function findModelByName(models: readonly DMMF.Model[], modelName: string) {
return models.find(({ name }) => name === modelName);
}
diff --git a/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts b/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts
index 79b3a9f98..fb47b0096 100644
--- a/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts
+++ b/packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts
@@ -1,14 +1,17 @@
import type { DMMF } from '../prisma';
import { checkModelHasModelRelation } from './model-helpers';
-export function addMissingInputObjectTypesForModelArgs(inputObjectTypes: DMMF.InputType[], models: DMMF.Model[]) {
+export function addMissingInputObjectTypesForModelArgs(
+ inputObjectTypes: DMMF.InputType[],
+ models: readonly DMMF.Model[]
+) {
const modelArgsInputObjectTypes = generateModelArgsInputObjectTypes(models);
for (const modelArgsInputObjectType of modelArgsInputObjectTypes) {
inputObjectTypes.push(modelArgsInputObjectType);
}
}
-function generateModelArgsInputObjectTypes(models: DMMF.Model[]) {
+function generateModelArgsInputObjectTypes(models: readonly DMMF.Model[]) {
const modelArgsInputObjectTypes: DMMF.InputType[] = [];
for (const model of models) {
const { name: modelName } = model;
diff --git a/packages/sdk/src/dmmf-helpers/select-helpers.ts b/packages/sdk/src/dmmf-helpers/select-helpers.ts
index 6037eecd8..a5e3e68a1 100644
--- a/packages/sdk/src/dmmf-helpers/select-helpers.ts
+++ b/packages/sdk/src/dmmf-helpers/select-helpers.ts
@@ -4,7 +4,7 @@ import { checkIsModelRelationField, checkModelHasManyModelRelation } from './mod
export function addMissingInputObjectTypesForSelect(
inputObjectTypes: DMMF.InputType[],
outputObjectTypes: DMMF.OutputType[],
- models: DMMF.Model[]
+ models: readonly DMMF.Model[]
) {
// generate input object types necessary to support ModelSelect._count
const modelCountOutputTypes = getModelCountOutputTypes(outputObjectTypes);
@@ -89,7 +89,7 @@ function generateModelCountOutputTypeArgsInputObjectTypes(modelCountOutputTypes:
return modelCountOutputTypeArgsInputObjectTypes;
}
-function generateModelSelectInputObjectTypes(models: DMMF.Model[]) {
+function generateModelSelectInputObjectTypes(models: readonly DMMF.Model[]) {
const modelSelectInputObjectTypes: DMMF.InputType[] = [];
for (const model of models) {
const { name: modelName, fields: modelFields } = model;
@@ -97,16 +97,9 @@ function generateModelSelectInputObjectTypes(models: DMMF.Model[]) {
for (const modelField of modelFields) {
const { name: modelFieldName, isList, type } = modelField;
+ const inputTypes = [{ isList: false, type: 'Boolean', location: 'scalar' }];
const isRelationField = checkIsModelRelationField(modelField);
-
- const field: DMMF.SchemaArg = {
- name: modelFieldName,
- isRequired: false,
- isNullable: false,
- inputTypes: [{ isList: false, type: 'Boolean', location: 'scalar' }],
- };
-
if (isRelationField) {
const schemaArgInputType: DMMF.InputTypeRef = {
isList: false,
@@ -114,9 +107,15 @@ function generateModelSelectInputObjectTypes(models: DMMF.Model[]) {
location: 'inputObjectTypes',
namespace: 'prisma',
};
- field.inputTypes.push(schemaArgInputType);
+ inputTypes.push(schemaArgInputType);
}
+ const field: DMMF.SchemaArg = {
+ name: modelFieldName,
+ isRequired: false,
+ isNullable: false,
+ inputTypes: inputTypes as DMMF.InputTypeRef[],
+ };
fields.push(field);
}
diff --git a/packages/sdk/src/prisma.ts b/packages/sdk/src/prisma.ts
index b45dd7cfb..1edab3c3d 100644
--- a/packages/sdk/src/prisma.ts
+++ b/packages/sdk/src/prisma.ts
@@ -4,8 +4,11 @@ import type { DMMF } from '@prisma/generator-helper';
import { getDMMF as _getDMMF, type GetDMMFOptions } from '@prisma/internals';
import { DEFAULT_RUNTIME_LOAD_PATH } from '@zenstackhq/runtime';
import path from 'path';
+import semver from 'semver';
+import { Model } from './ast';
import { RUNTIME_PACKAGE } from './constants';
import type { PluginOptions } from './types';
+import { getDataSourceProvider } from './utils';
/**
* Given an import context directory and plugin options, compute the import spec for the Prisma Client.
@@ -75,4 +78,14 @@ export function getPrismaVersion(): string | undefined {
}
}
+/**
+ * Returns if the given model supports `createMany` operation.
+ */
+export function supportCreateMany(model: Model) {
+ // `createMany` is supported for sqlite since Prisma 5.12.0
+ const prismaVersion = getPrismaVersion();
+ const dsProvider = getDataSourceProvider(model);
+ return dsProvider !== 'sqlite' || (prismaVersion && semver.gte(prismaVersion, '5.12.0'));
+}
+
export type { DMMF } from '@prisma/generator-helper';
diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts
index 6617983aa..df83186b5 100644
--- a/packages/sdk/src/utils.ts
+++ b/packages/sdk/src/utils.ts
@@ -17,6 +17,7 @@ import {
isConfigArrayExpr,
isDataModel,
isDataModelField,
+ isDataSource,
isEnumField,
isExpression,
isGeneratorDecl,
@@ -298,9 +299,9 @@ export function isForeignKeyField(field: DataModelField) {
}
/**
- * Gets the foreign key fields of the given relation field.
+ * Gets the foreign key-id field pairs from the given relation field.
*/
-export function getForeignKeyFields(relationField: DataModelField) {
+export function getRelationKeyPairs(relationField: DataModelField) {
if (!isRelationshipField(relationField)) {
return [];
}
@@ -309,11 +310,31 @@ export function getForeignKeyFields(relationField: DataModelField) {
if (relAttr) {
// find "fields" arg
const fieldsArg = getAttributeArg(relAttr, 'fields');
+ let fkFields: DataModelField[];
if (fieldsArg && isArrayExpr(fieldsArg)) {
- return fieldsArg.items
+ fkFields = fieldsArg.items
.filter((item): item is ReferenceExpr => isReferenceExpr(item))
.map((item) => item.target.ref as DataModelField);
+ } else {
+ return [];
+ }
+
+ // find "references" arg
+ const referencesArg = getAttributeArg(relAttr, 'references');
+ let idFields: DataModelField[];
+ if (referencesArg && isArrayExpr(referencesArg)) {
+ idFields = referencesArg.items
+ .filter((item): item is ReferenceExpr => isReferenceExpr(item))
+ .map((item) => item.target.ref as DataModelField);
+ } else {
+ return [];
+ }
+
+ if (idFields.length !== fkFields.length) {
+ throw new Error(`Relation's references arg and fields are must have equal length`);
}
+
+ return idFields.map((idField, i) => ({ id: idField, foreignKey: fkFields[i] }));
}
return [];
@@ -347,7 +368,7 @@ export function resolvePath(_path: string, options: Pick(options: PluginDeclaredOptions, name: string, pluginName: string): T {
const value = options[name];
if (value === undefined) {
- throw new PluginError(pluginName, `Plugin "${options.name}" is missing required option: ${name}`);
+ throw new PluginError(pluginName, `required option "${name}" is not provided`);
}
return value as T;
}
@@ -481,17 +502,24 @@ export function getDataModelFieldReference(expr: Expression): DataModelField | u
}
}
-export function getModelFieldsWithBases(model: DataModel) {
- return [...model.fields, ...getRecursiveBases(model).flatMap((base) => base.fields)];
+export function getModelFieldsWithBases(model: DataModel, includeDelegate = true) {
+ if (model.$baseMerged) {
+ return model.fields;
+ } else {
+ return [...model.fields, ...getRecursiveBases(model, includeDelegate).flatMap((base) => base.fields)];
+ }
}
-export function getRecursiveBases(dataModel: DataModel): DataModel[] {
+export function getRecursiveBases(dataModel: DataModel, includeDelegate = true): DataModel[] {
const result: DataModel[] = [];
dataModel.superTypes.forEach((superType) => {
const baseDecl = superType.ref;
if (baseDecl) {
+ if (!includeDelegate && isDelegateModel(baseDecl)) {
+ return;
+ }
result.push(baseDecl);
- result.push(...getRecursiveBases(baseDecl));
+ result.push(...getRecursiveBases(baseDecl, includeDelegate));
}
});
return result;
@@ -511,3 +539,18 @@ export function ensureEmptyDir(dir: string) {
throw new Error(`Path "${dir}" already exists and is not a directory`);
}
}
+
+/**
+ * Gets the data source provider from the given model.
+ */
+export function getDataSourceProvider(model: Model) {
+ const dataSource = model.declarations.find(isDataSource);
+ if (!dataSource) {
+ return undefined;
+ }
+ const provider = dataSource?.fields.find((f) => f.name === 'provider');
+ if (!provider) {
+ return undefined;
+ }
+ return getLiteral(provider.value);
+}
diff --git a/packages/server/package.json b/packages/server/package.json
index 412979a3f..390378154 100644
--- a/packages/server/package.json
+++ b/packages/server/package.json
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/server",
- "version": "2.0.3",
+ "version": "2.1.0",
"displayName": "ZenStack Server-side Adapters",
"description": "ZenStack server-side adapters",
"homepage": "https://zenstack.dev",
diff --git a/packages/server/src/api/rpc/index.ts b/packages/server/src/api/rpc/index.ts
index a7fb44d72..a8882b8be 100644
--- a/packages/server/src/api/rpc/index.ts
+++ b/packages/server/src/api/rpc/index.ts
@@ -81,6 +81,7 @@ class RequestHandler extends APIHandlerBase {
case 'aggregate':
case 'groupBy':
case 'count':
+ case 'check':
if (method !== 'GET') {
return {
status: 400,
diff --git a/packages/server/tests/api/rpc.test.ts b/packages/server/tests/api/rpc.test.ts
index 432abec2c..56ad8c5e0 100644
--- a/packages/server/tests/api/rpc.test.ts
+++ b/packages/server/tests/api/rpc.test.ts
@@ -15,7 +15,7 @@ describe('RPC API Handler Tests', () => {
let zodSchemas: any;
beforeAll(async () => {
- const params = await loadSchema(schema, { fullZod: true });
+ const params = await loadSchema(schema, { fullZod: true, generatePermissionChecker: true });
prisma = params.prisma;
enhance = params.enhance;
modelMeta = params.modelMeta;
@@ -131,6 +131,37 @@ describe('RPC API Handler Tests', () => {
expect(r.data.count).toBe(1);
});
+ it('check', async () => {
+ const handleRequest = makeHandler();
+
+ let r = await handleRequest({
+ method: 'get',
+ path: '/post/check',
+ query: { q: JSON.stringify({ operation: 'read' }) },
+ prisma: enhance(),
+ });
+ expect(r.status).toBe(200);
+ expect(r.data).toEqual(true);
+
+ r = await handleRequest({
+ method: 'get',
+ path: '/post/check',
+ query: { q: JSON.stringify({ operation: 'read', where: { published: false } }) },
+ prisma: enhance(),
+ });
+ expect(r.status).toBe(200);
+ expect(r.data).toEqual(false);
+
+ r = await handleRequest({
+ method: 'get',
+ path: '/post/check',
+ query: { q: JSON.stringify({ operation: 'read', where: { authorId: '1', published: false } }) },
+ prisma: enhance({ id: '1' }),
+ });
+ expect(r.status).toBe(200);
+ expect(r.data).toEqual(true);
+ });
+
it('policy violation', async () => {
await prisma.user.create({
data: {
diff --git a/packages/testtools/package.json b/packages/testtools/package.json
index 7e24868d6..8cb635d5b 100644
--- a/packages/testtools/package.json
+++ b/packages/testtools/package.json
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/testtools",
- "version": "2.0.3",
+ "version": "2.1.0",
"description": "ZenStack Test Tools",
"main": "index.js",
"private": true,
diff --git a/packages/testtools/src/model.ts b/packages/testtools/src/model.ts
index adbc89453..b6e2a3b71 100644
--- a/packages/testtools/src/model.ts
+++ b/packages/testtools/src/model.ts
@@ -5,7 +5,7 @@ import * as path from 'path';
import * as tmp from 'tmp';
import { URI } from 'vscode-uri';
import { createZModelServices } from 'zenstack/language-server/zmodel-module';
-import { mergeBaseModel } from 'zenstack/utils/ast-utils';
+import { mergeBaseModels } from 'zenstack/utils/ast-utils';
tmp.setGracefulCleanup();
@@ -53,7 +53,7 @@ export async function loadModel(content: string, validate = true, verbose = true
const model = (await doc.parseResult.value) as Model;
- mergeBaseModel(model, ZModel.references.Linker);
+ mergeBaseModels(model, ZModel.references.Linker);
return model;
}
diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts
index 7249e6c4a..4495ddf14 100644
--- a/packages/testtools/src/schema.ts
+++ b/packages/testtools/src/schema.ts
@@ -97,11 +97,13 @@ datasource db {
generator js {
provider = 'prisma-client-js'
+ ${options.previewFeatures ? `previewFeatures = ${JSON.stringify(options.previewFeatures)}` : ''}
}
plugin enhancer {
provider = '@core/enhancer'
${options.preserveTsFiles ? 'preserveTsFiles = true' : ''}
+ ${options.generatePermissionChecker ? 'generatePermissionChecker = true' : ''}
}
plugin zod {
@@ -131,6 +133,8 @@ export type SchemaLoadOptions = {
extraSourceFiles?: { name: string; content: string }[];
projectDir?: string;
preserveTsFiles?: boolean;
+ generatePermissionChecker?: boolean;
+ previewFeatures?: string[];
};
const defaultOptions: SchemaLoadOptions = {
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 5ce4ac030..a8e8c7318 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -30,7 +30,7 @@ importers:
specifier: ^2.4.1
version: 2.4.1
eslint:
- specifier: ^8.56.0
+ specifier: ^8.57.0
version: 8.57.0
eslint-plugin-jest:
specifier: ^28.2.0
@@ -211,6 +211,9 @@ importers:
ts-morph:
specifier: ^16.0.0
version: 16.0.0
+ ts-pattern:
+ specifier: ^4.3.0
+ version: 4.3.0
upper-case-first:
specifier: ^2.0.2
version: 2.0.2
@@ -409,6 +412,9 @@ importers:
deepmerge:
specifier: ^4.3.1
version: 4.3.1
+ logic-solver:
+ specifier: ^2.0.1
+ version: 2.0.1
lower-case-first:
specifier: ^2.0.2
version: 2.0.2
@@ -427,6 +433,9 @@ importers:
tiny-invariant:
specifier: ^1.3.1
version: 1.3.1
+ ts-pattern:
+ specifier: ^4.3.0
+ version: 4.3.0
tslib:
specifier: ^2.4.1
version: 2.4.1
@@ -612,11 +621,11 @@ importers:
packages/sdk:
dependencies:
'@prisma/generator-helper':
- specifier: 5.7.0
- version: 5.7.0
+ specifier: ^5.13.0
+ version: 5.13.0
'@prisma/internals':
- specifier: 5.7.0
- version: 5.7.0
+ specifier: ^5.13.0
+ version: 5.13.0
'@zenstackhq/language':
specifier: workspace:*
version: link:../language/dist
@@ -3868,7 +3877,6 @@ packages:
/@prisma/debug@5.13.0:
resolution: {integrity: sha512-699iqlEvzyCj9ETrXhs8o8wQc/eVW+FigSsHpiskSFydhjVuwTJEfj/nIYqTaWFYuxiWQRfm3r01meuW97SZaQ==}
- dev: true
/@prisma/debug@5.7.0:
resolution: {integrity: sha512-tZ+MOjWlVvz1kOEhNYMa4QUGURY+kgOUBqLHYIV8jmCsMuvA1tWcn7qtIMLzYWCbDcQT4ZS8xDgK0R2gl6/0wA==}
@@ -3876,7 +3884,6 @@ packages:
/@prisma/engines-version@5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b:
resolution: {integrity: sha512-AyUuhahTINGn8auyqYdmxsN+qn0mw3eg+uhkp8zwknXYIqoT3bChG4RqNY/nfDkPvzWAPBa9mrDyBeOnWSgO6A==}
- dev: true
/@prisma/engines-version@5.7.0-41.79fb5193cf0a8fdbef536e4b4a159cad677ab1b9:
resolution: {integrity: sha512-V6tgRVi62jRwTm0Hglky3Scwjr/AKFBFtS+MdbsBr7UOuiu1TKLPc6xfPiyEN1+bYqjEtjxwGsHgahcJsd1rNg==}
@@ -3890,7 +3897,6 @@ packages:
'@prisma/engines-version': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b
'@prisma/fetch-engine': 5.13.0
'@prisma/get-platform': 5.13.0
- dev: true
/@prisma/engines@5.7.0:
resolution: {integrity: sha512-TkOMgMm60n5YgEKPn9erIvFX2/QuWnl3GBo6yTRyZKk5O5KQertXiNnrYgSLy0SpsKmhovEPQb+D4l0SzyE7XA==}
@@ -3908,7 +3914,6 @@ packages:
'@prisma/debug': 5.13.0
'@prisma/engines-version': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b
'@prisma/get-platform': 5.13.0
- dev: true
/@prisma/fetch-engine@5.7.0:
resolution: {integrity: sha512-zIn/qmO+N/3FYe7/L9o+yZseIU8ivh4NdPKSkQRIHfg2QVTVMnbhGoTcecbxfVubeTp+DjcbjS0H9fCuM4W04w==}
@@ -3918,6 +3923,12 @@ packages:
'@prisma/get-platform': 5.7.0
dev: false
+ /@prisma/generator-helper@5.13.0:
+ resolution: {integrity: sha512-i+53beJ0dxkDrkHdsXxmeMf+eVhyhOIpL0SdBga8vwe0qHPrAIJ/lpuT/Hj0y5awTmq40qiUEmhXwCEuM/Z17w==}
+ dependencies:
+ '@prisma/debug': 5.13.0
+ dev: false
+
/@prisma/generator-helper@5.7.0:
resolution: {integrity: sha512-Fn4hJHKGJ49+E8sxpfslRauB3Goa3RAENJ/W25NMR754B9KxvmbCJyE3MT/lIZxML2nGgIdXYUtoDHZHnRaKDw==}
dependencies:
@@ -3928,7 +3939,6 @@ packages:
resolution: {integrity: sha512-B/WrQwYTzwr7qCLifQzYOmQhZcFmIFhR81xC45gweInSUn2hTEbfKUPd2keAog+y5WI5xLAFNJ3wkXplvSVkSw==}
dependencies:
'@prisma/debug': 5.13.0
- dev: true
/@prisma/get-platform@5.7.0:
resolution: {integrity: sha512-ZeV/Op4bZsWXuw5Tg05WwRI8BlKiRFhsixPcAM+5BKYSiUZiMKIi713tfT3drBq8+T0E1arNZgYSA9QYcglWNA==}
@@ -3936,6 +3946,20 @@ packages:
'@prisma/debug': 5.7.0
dev: false
+ /@prisma/internals@5.13.0:
+ resolution: {integrity: sha512-OPMzS+IBPzCLT4s+IfGUbOhGFY51CFbokIFMZuoSeLKWE8UvDlitiXZ3OlVqDPUc0AlH++ysQHzDISHbZD+ZUg==}
+ dependencies:
+ '@prisma/debug': 5.13.0
+ '@prisma/engines': 5.13.0
+ '@prisma/fetch-engine': 5.13.0
+ '@prisma/generator-helper': 5.13.0
+ '@prisma/get-platform': 5.13.0
+ '@prisma/prisma-schema-wasm': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b
+ '@prisma/schema-files-loader': 5.13.0
+ arg: 5.0.2
+ prompts: 2.4.2
+ dev: false
+
/@prisma/internals@5.7.0:
resolution: {integrity: sha512-O9x47W1DECAyvNjYUx6oZHmTX10emKuBgsFHZemUbkIcJdCsp3X8Cy2JMJ5z3hqkRX6a6omMamFsWjuTARoaSw==}
dependencies:
@@ -3949,10 +3973,20 @@ packages:
prompts: 2.4.2
dev: false
+ /@prisma/prisma-schema-wasm@5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b:
+ resolution: {integrity: sha512-+IhHvuE1wKlyOpJgwAhGop1oqEt+1eixrCeikBIshRhdX6LwjmtRxVxVMlP5nS1yyughmpfkysIW4jZTa+Zjuw==}
+ dev: false
+
/@prisma/prisma-schema-wasm@5.7.0-41.79fb5193cf0a8fdbef536e4b4a159cad677ab1b9:
resolution: {integrity: sha512-w+HdQtux0dJDEn6BG3fgNn+fXErXiekj9n//uHRAgrmZghockJkhnikOmG8aSXjTb1Tu5DrGasBX+rYX6rHT1w==}
dev: false
+ /@prisma/schema-files-loader@5.13.0:
+ resolution: {integrity: sha512-6sVMoqobkWKsmzb98LfLiIt/aFRucWfkzSUBsqk7sc+h99xjynJt6aKtM2SSkyndFdWpRU0OiCHfQ9UlYUEJIw==}
+ dependencies:
+ fs-extra: 11.1.1
+ dev: false
+
/@readme/better-ajv-errors@1.6.0(ajv@8.12.0):
resolution: {integrity: sha512-9gO9rld84Jgu13kcbKRU+WHseNhaVt76wYMeRDGsUGYxwJtI3RmEJ9LY9dZCYQGI8eUZLuxb5qDja0nqklpFjQ==}
engines: {node: '>=14'}
@@ -8492,7 +8526,6 @@ packages:
graceful-fs: 4.2.11
jsonfile: 6.1.0
universalify: 2.0.0
- dev: true
/fs-extra@7.0.1:
resolution: {integrity: sha512-YJDaCJZEnBmcbw13fvdAM9AwNOJwOzrE4pqMqBq5nFiEqXUqHwlK4B+3pUw6JNvfSPtX05xFHtYy/1ni01eGCw==}
@@ -8793,7 +8826,6 @@ packages:
/graceful-fs@4.2.11:
resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==}
- dev: true
/grapheme-splitter@1.0.4:
resolution: {integrity: sha512-bzh50DW9kTPM00T8y4o8vQg89Di9oLJVLW/KaOGIXJWP/iqCN6WKYkbNOF04vFLJhwcpYUh9ydh/+5vpOqV4YQ==}
@@ -10151,7 +10183,6 @@ packages:
universalify: 2.0.0
optionalDependencies:
graceful-fs: 4.2.11
- dev: true
/jsonpointer@5.0.1:
resolution: {integrity: sha512-p/nXbhSEcu3pZRdkW1OfJhpsVtW1gd4Wa1fnQc9YLiTfAjn0312eMKimbdIQzuZl9aa9xUGaRlP9T/CJE/ditQ==}
@@ -10495,6 +10526,12 @@ packages:
wrap-ansi: 8.1.0
dev: false
+ /logic-solver@2.0.1:
+ resolution: {integrity: sha512-F1oCywXUzvAF4Z98mMyXySUCpUU3hNyc+JfYV3g2x/4BupC/xv94iPJuHh9us2XX5UrvY5lnKUXNvjcJNQBJ/g==}
+ dependencies:
+ underscore: 1.13.6
+ dev: false
+
/loose-envify@1.4.0:
resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==}
hasBin: true
@@ -14372,7 +14409,6 @@ packages:
/underscore@1.13.6:
resolution: {integrity: sha512-+A5Sja4HP1M08MaXya7p5LvjuM7K6q/2EaC0+iovj/wOcMsTzMvDFbasi/oSapiwOlt252IqsKqPjCl7huKS0A==}
- dev: true
/undici-types@5.26.5:
resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==}
@@ -14441,7 +14477,6 @@ packages:
/universalify@2.0.0:
resolution: {integrity: sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==}
engines: {node: '>= 10.0.0'}
- dev: true
/unpipe@1.0.0:
resolution: {integrity: sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==}
diff --git a/tests/integration/tests/e2e/type-coverage.test.ts b/tests/integration/tests/e2e/type-coverage.test.ts
index c8c88211c..2a41b3ffd 100644
--- a/tests/integration/tests/e2e/type-coverage.test.ts
+++ b/tests/integration/tests/e2e/type-coverage.test.ts
@@ -13,14 +13,14 @@ describe('Type Coverage Tests', () => {
model Foo {
id String @id @default(cuid())
- string String
- int Int
- bigInt BigInt
- date DateTime
- float Float
- decimal Decimal
- boolean Boolean
- bytes Bytes
+ String String
+ Int Int
+ BigInt BigInt
+ DateTime DateTime
+ Float Float
+ Decimal Decimal
+ Boolean Boolean
+ Bytes Bytes
@@allow('all', true)
}
@@ -41,14 +41,14 @@ describe('Type Coverage Tests', () => {
const date = new Date();
const data = {
id: '1',
- string: 'string',
- int: 100,
- bigInt: BigInt(9007199254740991),
- date,
- float: 1.23,
- decimal: new Decimal(1.2345),
- boolean: true,
- bytes: Buffer.from('hello'),
+ String: 'string',
+ Int: 100,
+ BigInt: BigInt(9007199254740991),
+ DateTime: date,
+ Float: 1.23,
+ Decimal: new Decimal(1.2345),
+ Boolean: true,
+ Bytes: Buffer.from('hello'),
};
await db.foo.create({
diff --git a/tests/integration/tests/enhancements/with-policy/checker.test.ts b/tests/integration/tests/enhancements/with-policy/checker.test.ts
new file mode 100644
index 000000000..e4ca61fad
--- /dev/null
+++ b/tests/integration/tests/enhancements/with-policy/checker.test.ts
@@ -0,0 +1,652 @@
+import { SchemaLoadOptions, createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools';
+
+describe('Permission checker', () => {
+ const PRELUDE = `
+ datasource db {
+ provider = 'sqlite'
+ url = 'file:./dev.db'
+ }
+
+ generator js {
+ provider = 'prisma-client-js'
+ }
+
+ plugin enhancer {
+ provider = '@core/enhancer'
+ generatePermissionChecker = true
+ }
+ `;
+
+ const load = (schema: string, options?: SchemaLoadOptions) =>
+ loadSchema(schema, {
+ ...options,
+ generatePermissionChecker: true,
+ });
+
+ it('checker generation not enabled', async () => {
+ const { enhance } = await loadSchema(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('all', true)
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).rejects.toThrow('Generated permission checkers not found');
+ });
+
+ it('empty rules', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveFalsy();
+ });
+
+ it('unconditional allow', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('all', true)
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveTruthy();
+ });
+
+ it('multiple allow rules', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('all', value == 1)
+ @@allow('all', value == 2)
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy();
+ });
+
+ it('deny rule', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('all', value > 0)
+ @@deny('all', value == 1)
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy();
+ });
+
+ it('int field condition', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('read', value == 1)
+ @@allow('create', value != 1)
+ @@allow('update', value > 1)
+ @@allow('delete', value <= 1)
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 0 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy();
+
+ await expect(db.model.check({ operation: 'create' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create', where: { value: 0 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create', where: { value: 1 } })).toResolveFalsy();
+
+ await expect(db.model.check({ operation: 'update' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'update', where: { value: 1 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'update', where: { value: 2 } })).toResolveTruthy();
+
+ await expect(db.model.check({ operation: 'delete' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'delete', where: { value: 0 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'delete', where: { value: 1 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'delete', where: { value: 2 } })).toResolveFalsy();
+ });
+
+ it('boolean field toplevel condition', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Boolean
+ @@allow('read', value)
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: false } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: true } })).toResolveTruthy();
+ });
+
+ it('boolean field condition', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Boolean
+ @@allow('read', value == true)
+ @@allow('create', value == false)
+ @@allow('update', value != true)
+ @@allow('delete', value != false)
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: false } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: true } })).toResolveTruthy();
+
+ await expect(db.model.check({ operation: 'create' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create', where: { value: true } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'create', where: { value: false } })).toResolveTruthy();
+
+ await expect(db.model.check({ operation: 'update' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'update', where: { value: true } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'update', where: { value: false } })).toResolveTruthy();
+
+ await expect(db.model.check({ operation: 'delete' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'delete', where: { value: false } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'delete', where: { value: true } })).toResolveTruthy();
+ });
+
+ it('string field condition', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value String
+ @@allow('read', value == 'admin')
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 'user' } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { value: 'admin' } })).toResolveTruthy();
+ });
+
+ it('enum', async () => {
+ const dbUrl = await createPostgresDb('permission-checker-enum');
+ let prisma: any;
+ try {
+ const r = await loadSchema(
+ `
+ datasource db {
+ provider = 'postgresql'
+ url = '${dbUrl}'
+ }
+
+ generator js {
+ provider = 'prisma-client-js'
+ }
+
+ plugin enhancer {
+ provider = '@core/enhancer'
+ generatePermissionChecker = true
+ }
+
+ enum Role {
+ USER
+ ADMIN
+ }
+ model User {
+ id Int @id @default(autoincrement())
+ role Role
+ }
+ model Model {
+ id Int @id @default(autoincrement())
+ @@allow('read', auth().role == ADMIN)
+ }
+ `,
+ { addPrelude: false, generatePermissionChecker: true }
+ );
+
+ prisma = r.prisma;
+ const enhance = r.enhance;
+
+ await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1, role: 'USER' }).model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1, role: 'ADMIN' }).model.check({ operation: 'read' })).toResolveTruthy();
+ } finally {
+ await prisma.$disconnect();
+ await dropPostgresDb('permission-checker-enum');
+ }
+ });
+
+ it('function noop', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value String
+ @@allow('read', startsWith(value, 'admin'))
+ @@allow('update', !startsWith(value, 'admin'))
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 'user' } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 'admin' } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'update' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'update', where: { value: 'user' } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'update', where: { value: 'admin' } })).toResolveTruthy();
+ });
+
+ it('relation noop', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value String
+ foo Foo?
+
+ @@allow('read', foo.x > 0)
+ }
+
+ model Foo {
+ id Int @id @default(autoincrement())
+ x Int
+ modelId Int @unique
+ model Model @relation(fields: [modelId], references: [id])
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { foo: { x: 0 } } })).rejects.toThrow(
+ 'Providing filter for field "foo"'
+ );
+ });
+
+ it('collection predicate noop', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value String
+ foo Foo[]
+
+ @@allow('read', foo?[x > 0])
+ }
+
+ model Foo {
+ id Int @id @default(autoincrement())
+ x Int
+ modelId Int
+ model Model @relation(fields: [modelId], references: [id])
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { foo: [{ x: 0 }] } })).rejects.toThrow(
+ 'Providing filter for field "foo"'
+ );
+ });
+
+ it('field complex condition', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ x Int
+ y Int
+ @@allow('read', x > 0 && x > y)
+ @@allow('create', x > 1 || x > y)
+ @@allow('update', !(x >= y))
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { x: 0 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { x: 1 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { x: 1, y: 0 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { x: 1, y: 1 } })).toResolveFalsy();
+
+ await expect(db.model.check({ operation: 'create' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create', where: { x: 0 } })).toResolveFalsy(); // numbers are non-negative
+ await expect(db.model.check({ operation: 'create', where: { x: 1 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create', where: { x: 1, y: 0 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create', where: { x: 1, y: 1 } })).toResolveFalsy();
+
+ await expect(db.model.check({ operation: 'update' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'update', where: { x: 0 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'update', where: { y: 0 } })).toResolveFalsy(); // numbers are non-negative
+ await expect(db.model.check({ operation: 'update', where: { x: 1, y: 1 } })).toResolveFalsy();
+ });
+
+ it('field condition unsolvable', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ x Int
+ y Int
+ @@allow('read', x > 0 && x < y && y <= 1)
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { x: 0 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { x: 1 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { x: 1, y: 2 } })).toResolveFalsy();
+ await expect(db.model.check({ operation: 'read', where: { x: 1, y: 1 } })).toResolveFalsy();
+ });
+
+ it('simple auth condition', async () => {
+ const { enhance } = await load(
+ `
+ model User {
+ id Int @id @default(autoincrement())
+ level Int
+ admin Boolean
+ }
+
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('read', auth().level > 0)
+ @@allow('create', auth().admin)
+ @@allow('update', !auth().admin)
+ }
+ `
+ );
+
+ await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'read' })).toResolveTruthy();
+
+ await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy();
+ await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'create' })).toResolveFalsy();
+ await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'create' })).toResolveTruthy();
+
+ await expect(enhance().model.check({ operation: 'update' })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy();
+ await expect(enhance({ id: 1, admin: true }).model.check({ operation: 'update' })).toResolveFalsy();
+ await expect(enhance({ id: 1, admin: false }).model.check({ operation: 'update' })).toResolveTruthy();
+ });
+
+ it('auth compared with relation field', async () => {
+ const { enhance } = await load(
+ `
+ model User {
+ id Int @id @default(autoincrement())
+ models Model[]
+ }
+
+ model Model {
+ id Int @id @default(autoincrement())
+ owner User @relation(fields: [ownerId], references: [id])
+ ownerId Int
+ @@allow('read', auth().id == ownerId)
+ @@allow('create', auth().id != ownerId)
+ @@allow('update', auth() == owner)
+ @@allow('delete', auth() != owner)
+ }
+ `,
+ { preserveTsFiles: true }
+ );
+
+ await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'read', where: { ownerId: 1 } })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'read', where: { ownerId: 2 } })).toResolveFalsy();
+
+ await expect(enhance().model.check({ operation: 'create' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'create', where: { ownerId: 1 } })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'create', where: { ownerId: 2 } })).toResolveTruthy();
+
+ await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'update', where: { ownerId: 1 } })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'update', where: { ownerId: 2 } })).toResolveFalsy();
+
+ await expect(enhance().model.check({ operation: 'delete' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'delete' })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'delete', where: { ownerId: 1 } })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'delete', where: { ownerId: 2 } })).toResolveTruthy();
+ });
+
+ it('auth null check', async () => {
+ const { enhance } = await load(
+ `
+ model User {
+ id Int @id @default(autoincrement())
+ level Int
+ }
+
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('read', auth() != null)
+ @@allow('create', auth() == null)
+ @@allow('update', auth().level > 0)
+ }
+ `
+ );
+
+ await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveTruthy();
+
+ await expect(enhance().model.check({ operation: 'create' })).toResolveTruthy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'create' })).toResolveFalsy();
+
+ await expect(enhance().model.check({ operation: 'update' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'update' })).toResolveFalsy();
+ await expect(enhance({ id: 1, level: 0 }).model.check({ operation: 'update' })).toResolveFalsy();
+ await expect(enhance({ id: 1, level: 1 }).model.check({ operation: 'update' })).toResolveTruthy();
+ });
+
+ it('auth with relation access', async () => {
+ const { enhance } = await load(
+ `
+ model User {
+ id Int @id @default(autoincrement())
+ profile Profile?
+ }
+
+ model Profile {
+ id Int @id @default(autoincrement())
+ level Int
+ user User @relation(fields: [userId], references: [id])
+ userId Int @unique
+ }
+
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('read', auth().profile.level > 0)
+ }
+ `
+ );
+
+ await expect(enhance().model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1 }).model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1, profile: { level: 0 } }).model.check({ operation: 'read' })).toResolveFalsy();
+ await expect(enhance({ id: 1, profile: { level: 1 } }).model.check({ operation: 'read' })).toResolveTruthy();
+ });
+
+ it('nullable field', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int?
+ @@allow('read', value != null)
+ @@allow('create', value == null)
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'create', where: { value: 1 } })).toResolveTruthy();
+ });
+
+ it('compilation', async () => {
+ await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('read', value == 1)
+ }
+ `,
+ {
+ compile: true,
+ extraSourceFiles: [
+ {
+ name: 'main.ts',
+ content: `
+ import { PrismaClient } from '@prisma/client';
+ import { enhance } from '.zenstack/enhance';
+
+ const prisma = new PrismaClient();
+ const db = enhance(prisma);
+ db.model.check({ operation: 'read' });
+ db.model.check({ operation: 'read', where: { value: 1 }});
+ `,
+ },
+ ],
+ }
+ );
+ });
+
+ it('invalid filter', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ foo Foo?
+ d DateTime
+
+ @@allow('read', value == 1)
+ }
+
+ model Foo {
+ id Int @id @default(autoincrement())
+ x Int
+ model Model @relation(fields: [modelId], references: [id])
+ modelId Int @unique
+ }
+ `
+ );
+
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read', where: { foo: { x: 1 } } })).rejects.toThrow(
+ `Providing filter for field "foo" is not supported. Only scalar fields are allowed.`
+ );
+ await expect(db.model.check({ operation: 'read', where: { d: new Date() } })).rejects.toThrow(
+ `Providing filter for field "d" is not supported. Only number, string, and boolean fields are allowed.`
+ );
+ await expect(db.model.check({ operation: 'read', where: { value: null } })).rejects.toThrow(
+ `Using "null" as filter value is not supported yet`
+ );
+ await expect(db.model.check({ operation: 'read', where: { value: {} } })).rejects.toThrow(
+ 'Invalid value type for field "value". Only number, string or boolean is allowed.'
+ );
+ await expect(db.model.check({ operation: 'read', where: { value: 'abc' } })).rejects.toThrow(
+ 'Invalid value type for field "value". Expected "number"'
+ );
+ await expect(db.model.check({ operation: 'read', where: { value: -1 } })).rejects.toThrow(
+ 'Invalid value for field "value". Only non-negative integers are allowed.'
+ );
+ });
+
+ it('float field ignored', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Float
+ @@allow('read', value == 1.1)
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy();
+ });
+
+ it('float value ignored', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('read', value > 1.1)
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy();
+ });
+
+ it('negative value ignored', async () => {
+ const { enhance } = await load(
+ `
+ model Model {
+ id Int @id @default(autoincrement())
+ value Int
+ @@allow('read', value >-1)
+ }
+ `
+ );
+ const db = enhance();
+ await expect(db.model.check({ operation: 'read' })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 1 } })).toResolveTruthy();
+ await expect(db.model.check({ operation: 'read', where: { value: 2 } })).toResolveTruthy();
+ });
+});
diff --git a/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts
new file mode 100644
index 000000000..d46c31245
--- /dev/null
+++ b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts
@@ -0,0 +1,57 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('prisma omit', () => {
+ it('test', async () => {
+ const { prisma, enhance } = await loadSchema(
+ `
+ model User {
+ id String @id @default(cuid())
+ name String
+ profile Profile?
+ age Int
+ value Int @allow('read', age > 20)
+ @@allow('all', age > 18)
+ }
+
+ model Profile {
+ id String @id @default(cuid())
+ user User @relation(fields: [userId], references: [id])
+ userId String @unique
+ level Int
+ @@allow('all', level > 1)
+ }
+ `,
+ { previewFeatures: ['omitApi'], logPrismaQuery: true }
+ );
+
+ await prisma.user.create({
+ data: {
+ name: 'John',
+ age: 25,
+ value: 10,
+ profile: {
+ create: { level: 2 },
+ },
+ },
+ });
+
+ const db = enhance();
+ let found = await db.user.findFirst({
+ include: { profile: { omit: { level: true } } },
+ omit: {
+ age: true,
+ },
+ });
+ expect(found.age).toBeUndefined();
+ expect(found.value).toEqual(10);
+ expect(found.profile.level).toBeUndefined();
+
+ found = await db.user.findFirst({
+ select: { value: true, profile: { omit: { level: true } } },
+ });
+ console.log(found);
+ expect(found.age).toBeUndefined();
+ expect(found.value).toEqual(10);
+ expect(found.profile.level).toBeUndefined();
+ });
+});
diff --git a/tests/integration/tsconfig.json b/tests/integration/tsconfig.json
index c6cc8d4a7..2771cd805 100644
--- a/tests/integration/tsconfig.json
+++ b/tests/integration/tsconfig.json
@@ -8,5 +8,5 @@
"skipLibCheck": true,
"experimentalDecorators": true
},
- "include": ["**/*.ts", "**/*.d.ts", "../regression/tests/issue-177.test.ts", "../regression/tests/issue-416.test.ts", "../regression/tests/issue-646.test.ts", "../regression/tests/issue-657.test.ts", "../regression/tests/issue-665.test.ts", "../regression/tests/issue-674.test.ts", "../regression/tests/issue-689.test.ts", "../regression/tests/issue-703.test.ts", "../regression/tests/issue-714.test.ts", "../regression/tests/issue-724.test.ts", "../regression/tests/issue-735.test.ts", "../regression/tests/issue-744.test.ts", "../regression/tests/issue-756.test.ts", "../regression/tests/issue-764.test.ts", "../regression/tests/issue-765.test.ts", "../regression/tests/issue-804.test.ts", "../regression/tests/issue-811.test.ts", "../regression/tests/issue-814.test.ts", "../regression/tests/issue-825.test.ts", "../regression/tests/issue-864.test.ts", "../regression/tests/issue-886.test.ts", "../regression/tests/issue-925.test.ts", "../regression/tests/issue-947.test.ts", "../regression/tests/issue-961.test.ts", "../regression/tests/issue-965.test.ts", "../regression/tests/issue-971.test.ts", "../regression/tests/issue-992.test.ts", "../regression/tests/issue-1014.test.ts", "../regression/tests/issue-1078.test.ts", "../regression/tests/issue-1080.test.ts", "../regression/tests/issue-1129.test.ts", "../regression/tests/issue-1167.test.ts", "../regression/tests/issue-1179.test.ts", "../regression/tests/issue-1186.test.ts", "../regression/tests/issue-1210.test.ts", "../regression/tests/issue-1235.test.ts", "../regression/tests/issue-1241.test.ts", "../regression/tests/issue-1257.test.ts", "../regression/tests/issue-1265.test.ts", "../regression/tests/issues.test.ts"]
+ "include": ["**/*.ts", "**/*.d.ts"]
}
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1058.test.ts b/tests/regression/tests/issue-1058.test.ts
similarity index 100%
rename from tests/integration/tests/enhancements/with-delegate/issue-1058.test.ts
rename to tests/regression/tests/issue-1058.test.ts
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1064.test.ts b/tests/regression/tests/issue-1064.test.ts
similarity index 100%
rename from tests/integration/tests/enhancements/with-delegate/issue-1064.test.ts
rename to tests/regression/tests/issue-1064.test.ts
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1100.test.ts b/tests/regression/tests/issue-1100.test.ts
similarity index 100%
rename from tests/integration/tests/enhancements/with-delegate/issue-1100.test.ts
rename to tests/regression/tests/issue-1100.test.ts
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1123.test.ts b/tests/regression/tests/issue-1123.test.ts
similarity index 100%
rename from tests/integration/tests/enhancements/with-delegate/issue-1123.test.ts
rename to tests/regression/tests/issue-1123.test.ts
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1135.test.ts b/tests/regression/tests/issue-1135.test.ts
similarity index 100%
rename from tests/integration/tests/enhancements/with-delegate/issue-1135.test.ts
rename to tests/regression/tests/issue-1135.test.ts
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1149.test.ts b/tests/regression/tests/issue-1149.test.ts
similarity index 100%
rename from tests/integration/tests/enhancements/with-delegate/issue-1149.test.ts
rename to tests/regression/tests/issue-1149.test.ts
diff --git a/tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts b/tests/regression/tests/issue-1243.test.ts
similarity index 100%
rename from tests/integration/tests/enhancements/with-delegate/issue-1243.test.ts
rename to tests/regression/tests/issue-1243.test.ts
diff --git a/tests/regression/tests/issue-1378.test.ts b/tests/regression/tests/issue-1378.test.ts
new file mode 100644
index 000000000..29d4b16a8
--- /dev/null
+++ b/tests/regression/tests/issue-1378.test.ts
@@ -0,0 +1,47 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1378', () => {
+ it('regression', async () => {
+ await loadSchema(
+ `
+ model User {
+ id String @id @default(cuid())
+ todos Todo[]
+ }
+
+ model Todo {
+ id String @id @default(cuid())
+ name String @length(3,255)
+ userId String @default(auth().id)
+
+ user User @relation(fields: [userId], references: [id], onDelete: Cascade)
+ @@allow("all", auth() == user)
+ }
+ `,
+ {
+ extraDependencies: ['zod'],
+ extraSourceFiles: [
+ {
+ name: 'main.ts',
+ content: `
+ import { z } from 'zod';
+ import { PrismaClient } from '@prisma/client';
+ import { enhance } from '.zenstack/enhance';
+ import { TodoCreateSchema } from '.zenstack/zod/models';
+
+ const prisma = new PrismaClient();
+ const db = enhance(prisma);
+
+ export const onSubmit = async (values: z.infer) => {
+ await db.todo.create({
+ data: values,
+ });
+ };
+ `,
+ },
+ ],
+ compile: true,
+ }
+ );
+ });
+});
diff --git a/tests/regression/tests/issue-1388.test.ts b/tests/regression/tests/issue-1388.test.ts
new file mode 100644
index 000000000..3ffbc967b
--- /dev/null
+++ b/tests/regression/tests/issue-1388.test.ts
@@ -0,0 +1,26 @@
+import { FILE_SPLITTER, loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1388', () => {
+ it('regression', async () => {
+ await loadSchema(
+ `schema.zmodel
+ import './auth'
+ import './post'
+
+ ${FILE_SPLITTER}auth.zmodel
+ model User {
+ id String @id @default(cuid())
+ role String
+ }
+
+ ${FILE_SPLITTER}post.zmodel
+ model Post {
+ id String @id @default(nanoid(6))
+ title String
+ @@deny('all', auth() == null)
+ @@allow('all', auth().id == 'user1')
+ }
+ `
+ );
+ });
+});
diff --git a/tests/regression/tests/issue-1410.test.ts b/tests/regression/tests/issue-1410.test.ts
new file mode 100644
index 000000000..488cd1bf0
--- /dev/null
+++ b/tests/regression/tests/issue-1410.test.ts
@@ -0,0 +1,146 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1410', () => {
+ it('regression', async () => {
+ const { enhance } = await loadSchema(
+ `
+ model Drink {
+ id Int @id @default(autoincrement())
+ slug String @unique
+
+ manufacturer_id Int
+ manufacturer Manufacturer @relation(fields: [manufacturer_id], references: [id])
+
+ type String
+
+ name String @unique
+ description String
+ abv Float
+ image String?
+
+ gluten Boolean
+ lactose Boolean
+ organic Boolean
+
+ containers Container[]
+
+ @@delegate(type)
+
+ @@allow('all', true)
+ }
+
+ model Beer extends Drink {
+ style_id Int
+ style BeerStyle @relation(fields: [style_id], references: [id])
+
+ ibu Float?
+
+ @@allow('all', true)
+ }
+
+ model BeerStyle {
+ id Int @id @default(autoincrement())
+
+ name String @unique
+ color String
+
+ beers Beer[]
+
+ @@allow('all', true)
+ }
+
+ model Wine extends Drink {
+ style_id Int
+ style WineStyle @relation(fields: [style_id], references: [id])
+
+ heavy_score Int?
+ tannine_score Int?
+ dry_score Int?
+ fresh_score Int?
+ notes String?
+
+ @@allow('all', true)
+ }
+
+ model WineStyle {
+ id Int @id @default(autoincrement())
+
+ name String @unique
+ color String
+
+ wines Wine[]
+
+ @@allow('all', true)
+ }
+
+ model Soda extends Drink {
+ carbonated Boolean
+
+ @@allow('all', true)
+ }
+
+ model Cocktail extends Drink {
+ mix Boolean
+
+ @@allow('all', true)
+ }
+
+ model Container {
+ barcode String @id
+
+ drink_id Int
+ drink Drink @relation(fields: [drink_id], references: [id])
+
+ type String
+ volume Int
+ portions Int?
+
+ inventory Int @default(0)
+
+ @@allow('all', true)
+ }
+
+ model Manufacturer {
+ id Int @id @default(autoincrement())
+
+ country_id String
+ country Country @relation(fields: [country_id], references: [code])
+
+ name String @unique
+ description String?
+ image String?
+
+ drinks Drink[]
+
+ @@allow('all', true)
+ }
+
+ model Country {
+ code String @id
+ name String
+
+ manufacturers Manufacturer[]
+
+ @@allow('all', true)
+ }
+ `
+ );
+
+ const db = enhance();
+
+ await db.beer.findMany({
+ include: { style: true, manufacturer: true },
+ where: { NOT: { gluten: true } },
+ });
+
+ await db.beer.findMany({
+ include: { style: true, manufacturer: true },
+ where: { AND: [{ gluten: true }, { abv: { gt: 50 } }] },
+ });
+
+ await db.beer.findMany({
+ include: { style: true, manufacturer: true },
+ where: { OR: [{ AND: [{ NOT: { gluten: true } }] }, { abv: { gt: 50 } }] },
+ });
+ });
+});
diff --git a/tests/regression/tests/issue-1415.test.ts b/tests/regression/tests/issue-1415.test.ts
new file mode 100644
index 000000000..791413557
--- /dev/null
+++ b/tests/regression/tests/issue-1415.test.ts
@@ -0,0 +1,22 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1415', () => {
+ it('regression', async () => {
+ await loadSchema(
+ `
+ model User {
+ id String @id @default(cuid())
+ prices Price[]
+ }
+
+ model Price {
+ id String @id @default(cuid())
+ owner User @relation(fields: [ownerId], references: [id])
+ ownerId String @default(auth().id)
+ priceType String
+ @@delegate(priceType)
+ }
+ `
+ );
+ });
+});
diff --git a/tests/regression/tests/issue-1416.test.ts b/tests/regression/tests/issue-1416.test.ts
new file mode 100644
index 000000000..5c18d6d4e
--- /dev/null
+++ b/tests/regression/tests/issue-1416.test.ts
@@ -0,0 +1,37 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1416', () => {
+ it('regression', async () => {
+ await loadSchema(
+ `
+ model User {
+ id String @id @default(cuid())
+ role String
+ }
+
+ model Price {
+ id String @id @default(nanoid(6))
+ entity Entity? @relation(fields: [entityId], references: [id])
+ entityId String?
+ priceType String
+ @@delegate(priceType)
+ }
+
+ model MyPrice extends Price {
+ foo String
+ }
+
+ model Entity {
+ id String @id @default(nanoid(6))
+ price Price[]
+ type String
+ @@delegate(type)
+ }
+
+ model MyEntity extends Entity {
+ foo String
+ }
+ `
+ );
+ });
+});
diff --git a/tests/regression/tests/issue-1427.test.ts b/tests/regression/tests/issue-1427.test.ts
new file mode 100644
index 000000000..0d7c7c07e
--- /dev/null
+++ b/tests/regression/tests/issue-1427.test.ts
@@ -0,0 +1,42 @@
+import { loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1427', () => {
+ it('regression', async () => {
+ const { prisma, enhance } = await loadSchema(
+ `
+ model User {
+ id String @id @default(cuid())
+ name String
+ profile Profile?
+ @@allow('all', true)
+ }
+
+ model Profile {
+ id String @id @default(cuid())
+ user User @relation(fields: [userId], references: [id])
+ userId String @unique
+ @@allow('all', true)
+ }
+ `
+ );
+
+ await prisma.user.create({
+ data: {
+ name: 'John',
+ profile: {
+ create: {},
+ },
+ },
+ });
+
+ const db = enhance();
+ const found = await db.user.findFirst({
+ select: {
+ id: true,
+ name: true,
+ profile: false,
+ },
+ });
+ expect(found.profile).toBeUndefined();
+ });
+});
diff --git a/tests/regression/tests/issue-1435.test.ts b/tests/regression/tests/issue-1435.test.ts
new file mode 100644
index 000000000..0093aff8b
--- /dev/null
+++ b/tests/regression/tests/issue-1435.test.ts
@@ -0,0 +1,119 @@
+import { createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools';
+
+describe('issue 1435', () => {
+ it('regression', async () => {
+ let prisma: any;
+ let dbUrl: string;
+
+ try {
+ dbUrl = await createPostgresDb('issue-1435');
+ const r = await loadSchema(
+ `
+ /* Interfaces */
+ abstract model IBase {
+ updatedAt DateTime @updatedAt
+ createdAt DateTime @default(now())
+ }
+
+ abstract model IAuth extends IBase {
+ user User @relation(fields: [userId], references: [id], onDelete: Cascade)
+ userId String @unique
+
+ @@allow('create', true)
+ @@allow('all', auth() == user)
+ }
+
+ abstract model IIntegration extends IBase {
+ organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade)
+ organizationId String @unique
+
+ @@allow('all', organization.members?[user == auth() && type == OWNER])
+ @@allow('read', organization.members?[user == auth()])
+ }
+
+ /* Auth Stuff */
+ model User extends IBase {
+ id String @id @default(cuid())
+ firstName String
+ lastName String
+ google GoogleAuth?
+ memberships Member[]
+
+ @@allow('create', true)
+ @@allow('all', auth() == this)
+ }
+
+ model GoogleAuth extends IAuth {
+ reference String @id
+ refreshToken String
+ }
+
+ /* Org Stuff */
+ enum MemberType {
+ OWNER
+ MEMBER
+ }
+
+ model Organization extends IBase {
+ id String @id @default(cuid())
+ name String
+ members Member[]
+ google GoogleIntegration?
+
+ @@allow('create', true)
+ @@allow('all', members?[user == auth() && type == OWNER])
+ @@allow('read', members?[user == auth()])
+ }
+
+
+ model Member extends IBase {
+ type MemberType @default(MEMBER)
+ organization Organization @relation(fields: [organizationId], references: [id], onDelete: Cascade)
+ organizationId String
+ user User @relation(fields: [userId], references: [id], onDelete: Cascade)
+ userId String
+
+ @@id([organizationId, userId])
+ @@allow('all', organization.members?[user == auth() && type == OWNER])
+ @@allow('read', user == auth())
+ }
+
+ /* Google Stuff */
+ model GoogleIntegration extends IIntegration {
+ reference String @id
+ }
+ `,
+ { provider: 'postgresql', dbUrl, logPrismaQuery: true }
+ );
+
+ prisma = r.prisma;
+ const enhance = r.enhance;
+
+ await prisma.organization.create({
+ data: {
+ name: 'My Organization',
+ members: {
+ create: {
+ type: 'OWNER',
+ user: {
+ create: {
+ id: '1',
+ firstName: 'John',
+ lastName: 'Doe',
+ },
+ },
+ },
+ },
+ },
+ });
+
+ const db = enhance({ id: '1' });
+ await expect(db.organization.findMany()).resolves.toHaveLength(1);
+ } finally {
+ if (prisma) {
+ await prisma.$disconnect();
+ }
+ await dropPostgresDb('issue-1435');
+ }
+ });
+});
diff --git a/tests/regression/tests/issue-961.test.ts b/tests/regression/tests/issue-961.test.ts
index 7bc42071b..f6dc3a135 100644
--- a/tests/regression/tests/issue-961.test.ts
+++ b/tests/regression/tests/issue-961.test.ts
@@ -123,10 +123,8 @@ describe('Regression: issue 961', () => {
await expect(db.userColumn.findMany()).resolves.toHaveLength(1);
});
- // disabled because of Prisma V4 bug: https://github.com/prisma/prisma/issues/18371
- // eslint-disable-next-line jest/no-disabled-tests
- it.skip('updateMany', async () => {
- const { prisma, enhance } = await loadSchema(schema, { logPrismaQuery: true });
+ it('updateMany', async () => {
+ const { prisma, enhance } = await loadSchema(schema);
const user = await prisma.user.create({
data: {
diff --git a/tests/regression/tsconfig.json b/tests/regression/tsconfig.json
index c6cc8d4a7..2771cd805 100644
--- a/tests/regression/tsconfig.json
+++ b/tests/regression/tsconfig.json
@@ -8,5 +8,5 @@
"skipLibCheck": true,
"experimentalDecorators": true
},
- "include": ["**/*.ts", "**/*.d.ts", "../regression/tests/issue-177.test.ts", "../regression/tests/issue-416.test.ts", "../regression/tests/issue-646.test.ts", "../regression/tests/issue-657.test.ts", "../regression/tests/issue-665.test.ts", "../regression/tests/issue-674.test.ts", "../regression/tests/issue-689.test.ts", "../regression/tests/issue-703.test.ts", "../regression/tests/issue-714.test.ts", "../regression/tests/issue-724.test.ts", "../regression/tests/issue-735.test.ts", "../regression/tests/issue-744.test.ts", "../regression/tests/issue-756.test.ts", "../regression/tests/issue-764.test.ts", "../regression/tests/issue-765.test.ts", "../regression/tests/issue-804.test.ts", "../regression/tests/issue-811.test.ts", "../regression/tests/issue-814.test.ts", "../regression/tests/issue-825.test.ts", "../regression/tests/issue-864.test.ts", "../regression/tests/issue-886.test.ts", "../regression/tests/issue-925.test.ts", "../regression/tests/issue-947.test.ts", "../regression/tests/issue-961.test.ts", "../regression/tests/issue-965.test.ts", "../regression/tests/issue-971.test.ts", "../regression/tests/issue-992.test.ts", "../regression/tests/issue-1014.test.ts", "../regression/tests/issue-1078.test.ts", "../regression/tests/issue-1080.test.ts", "../regression/tests/issue-1129.test.ts", "../regression/tests/issue-1167.test.ts", "../regression/tests/issue-1179.test.ts", "../regression/tests/issue-1186.test.ts", "../regression/tests/issue-1210.test.ts", "../regression/tests/issue-1235.test.ts", "../regression/tests/issue-1241.test.ts", "../regression/tests/issue-1257.test.ts", "../regression/tests/issue-1265.test.ts", "../regression/tests/issues.test.ts"]
+ "include": ["**/*.ts", "**/*.d.ts"]
}