diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..2766c310e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ +* text=auto eol=lf + +*.bat text=auto eol=crlf \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 659fe2184..e31698560 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,10 +5,24 @@ I want to think you first for considering contributing to ZenStack 🙏🏻. It' ## Prerequisites - [Node.js](https://nodejs.org/): v18 or above -- [pnpm](https://pnpm.io/): latest version +- [pnpm](https://pnpm.io/): v8.x + +If you want to run the tests, you should be aware that some of the integration tests run against postgres. These tests will attempt to set up and subsequently their own database, so you'll need to provide a connection details for a postgres user with at least those permissions. To provide connection details, you can configure the following environment variables or provide them when executing `pnpm test` commands. + +- `ZENSTACK_TEST_DB_USER`: The postgres username, for a user with permission to create/drop databases. Default: `postgres`. +- `ZENSTACK_TEST_DB_PASS`: Password for said user. Default: `abc123`. +- `ZENSTACK_TEST_DB_NAME`: Default database to connect onto. This database isn't used any further, so it's recommended to just use the default `postgres` database. Default: `postgres`. +- `ZENSTACK_TEST_DB_HOST`: Hostname or IP to connect onto. Default: `localhost`. +- `ZENSTACK_TEST_DB_PORT`: Port number to connect onto. Default: `5432`. ## Get started +1. (Windows only) Your environment should support symlinks, by enabling "Developer mode" in `Settings => System => For developers` (Windows 10/11 only) and setting the `core.symlinks` setting in git to `true`. For more info [refer to this StackOverflow answer](https://stackoverflow.com/questions/5917249/git-symbolic-links-in-windows/59761201#59761201). + + ```pwsh + git config --global core.symlinks true + ``` + 1. Make a fork of the repository Make sure all branches are included. diff --git a/package.json b/package.json index 7c62b9b36..a8eb656c3 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "2.1.2", + "version": "2.2.0", "description": "", "scripts": { "build": "pnpm -r build", diff --git a/packages/ide/jetbrains/build.gradle.kts b/packages/ide/jetbrains/build.gradle.kts index 7647afcf1..e1547b907 100644 --- a/packages/ide/jetbrains/build.gradle.kts +++ b/packages/ide/jetbrains/build.gradle.kts @@ -9,7 +9,7 @@ plugins { } group = "dev.zenstack" -version = "2.1.2" +version = "2.2.0" repositories { mavenCentral() diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index 7f18e5e4a..d94e641c4 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -1,17 +1,20 @@ { "name": "jetbrains", - "version": "2.1.2", + "version": "2.2.0", "displayName": "ZenStack JetBrains IDE Plugin", "description": "ZenStack JetBrains IDE plugin", "homepage": "https://zenstack.dev", "private": true, "scripts": { - "build": "./gradlew buildPlugin" + "build": "run-script-os", + "build:win32": "gradlew.bat buildPlugin", + "build:default": "./gradlew buildPlugin" }, "author": "ZenStack Team", "license": "MIT", "devDependencies": { - "zenstack": "workspace:*", - "@zenstackhq/language": "workspace:*" + "@zenstackhq/language": "workspace:*", + "run-script-os": "^1.1.6", + "zenstack": "workspace:*" } } diff --git a/packages/language/package.json b/packages/language/package.json index 01ae8ec10..5703ef590 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "2.1.2", + "version": "2.2.0", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", @@ -9,7 +9,7 @@ "generate": "langium generate && npx ts-node script/generate-plist.ts", "watch": "concurrently \"langium generate --watch\" \"tsc --watch\"", "lint": "eslint src --ext ts", - "build": "pnpm lint --max-warnings=0 && pnpm clean && pnpm generate && tsc && copyfiles -F ./README.md ./LICENSE ./package.json 'syntaxes/**/*' dist && pnpm pack dist --pack-destination ../../../.build", + "build": "pnpm lint --max-warnings=0 && pnpm clean && pnpm generate && tsc && copyfiles -F ./README.md ./LICENSE ./package.json \"syntaxes/**/*\" dist && pnpm pack dist --pack-destination ../../../.build", "prepublishOnly": "pnpm build" }, "publishConfig": { diff --git a/packages/misc/redwood/package.json b/packages/misc/redwood/package.json index a6184a83a..5e5253ac1 100644 --- a/packages/misc/redwood/package.json +++ b/packages/misc/redwood/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/redwood", "displayName": "ZenStack RedwoodJS Integration", - "version": "2.1.2", + "version": "2.2.0", "description": "CLI and runtime for integrating ZenStack with RedwoodJS projects.", "repository": { "type": "git", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index 0e764cbce..cff7fcd0d 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "2.1.2", + "version": "2.2.0", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index abe9910b7..9483c87da 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "2.1.2", + "version": "2.2.0", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index fe86c1419..6611d6789 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "2.1.2", + "version": "2.2.0", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { @@ -110,6 +110,7 @@ "replace-in-file": "^7.0.1", "svelte": "^4.2.1", "swr": "^2.0.3", + "tmp": "^0.2.3", "vue": "^3.3.4" } } diff --git a/packages/plugins/tanstack-query/src/generator.ts b/packages/plugins/tanstack-query/src/generator.ts index 6a4ba53d9..e29a9ef0e 100644 --- a/packages/plugins/tanstack-query/src/generator.ts +++ b/packages/plugins/tanstack-query/src/generator.ts @@ -136,8 +136,8 @@ function generateQueryHook( }); if (version === 'v5' && infinite && ['react', 'svelte'].includes(target)) { - // initialPageParam and getNextPageParam options are required in v5 - func.addStatements([`options = options ?? { initialPageParam: undefined, getNextPageParam: () => null };`]); + // getNextPageParam option is required in v5 + func.addStatements([`options = options ?? { getNextPageParam: () => null };`]); } func.addStatements([ @@ -668,20 +668,20 @@ function makeQueryOptions( ? `Omit, 'queryKey'>` : `Omit>, 'queryKey'>` + }InfiniteQueryOptions<${returnType}, TError, InfiniteData<${dataType}>>, 'queryKey' | 'initialPageParam'>` : `Omit, 'queryKey'>` ) .with('vue', () => { - const baseOption = `Omit, 'queryKey'>`; + const baseOption = infinite + ? `Omit>, 'queryKey' | 'initialPageParam'>` + : `Omit, 'queryKey'>`; return `MaybeRefOrGetter<${baseOption}> | ComputedRef<${baseOption}>`; }) .with('svelte', () => infinite ? version === 'v4' ? `Omit, 'queryKey'>` - : `StoreOrVal>, 'queryKey'>>` + : `StoreOrVal>, 'queryKey' | 'initialPageParam'>>` : version === 'v4' ? `Omit, 'queryKey'>` : `StoreOrVal, 'queryKey'>>` diff --git a/packages/plugins/tanstack-query/src/runtime-v5/react.ts b/packages/plugins/tanstack-query/src/runtime-v5/react.ts index b169ce4fa..e8017befa 100644 --- a/packages/plugins/tanstack-query/src/runtime-v5/react.ts +++ b/packages/plugins/tanstack-query/src/runtime-v5/react.ts @@ -120,7 +120,7 @@ export function useInfiniteModelQuery( model: string, url: string, args: unknown, - options: Omit>, 'queryKey'>, + options: Omit>, 'queryKey' | 'initialPageParam'>, fetch?: FetchFn ) { return useInfiniteQuery({ @@ -128,6 +128,7 @@ export function useInfiniteModelQuery( queryFn: ({ pageParam }) => { return fetcher(makeUrl(url, pageParam ?? args), undefined, fetch, false); }, + initialPageParam: args, ...options, }); } @@ -146,7 +147,10 @@ export function useSuspenseInfiniteModelQuery( model: string, url: string, args: unknown, - options: Omit>, 'queryKey'>, + options: Omit< + UseSuspenseInfiniteQueryOptions>, + 'queryKey' | 'initialPageParam' + >, fetch?: FetchFn ) { return useSuspenseInfiniteQuery({ @@ -154,6 +158,7 @@ export function useSuspenseInfiniteModelQuery( queryFn: ({ pageParam }) => { return fetcher(makeUrl(url, pageParam ?? args), undefined, fetch, false); }, + initialPageParam: args, ...options, }); } diff --git a/packages/plugins/tanstack-query/src/runtime-v5/svelte.ts b/packages/plugins/tanstack-query/src/runtime-v5/svelte.ts index 1c58f83be..2c554aec9 100644 --- a/packages/plugins/tanstack-query/src/runtime-v5/svelte.ts +++ b/packages/plugins/tanstack-query/src/runtime-v5/svelte.ts @@ -91,7 +91,7 @@ export function useModelQuery( ...options, }; } - return createQuery(mergedOpt); + return createQuery(mergedOpt); } /** @@ -107,7 +107,9 @@ export function useInfiniteModelQuery( model: string, url: string, args: unknown, - options: StoreOrVal>, 'queryKey'>>, + options: StoreOrVal< + Omit>, 'queryKey' | 'initialPageParam'> + >, fetch?: FetchFn ) { const queryKey = getQueryKey(model, url, args, { infinite: true, optimisticUpdate: false }); @@ -115,12 +117,17 @@ export function useInfiniteModelQuery( fetcher(makeUrl(url, pageParam ?? args), undefined, fetch, false); let mergedOpt: StoreOrVal>>; - if (isStore>>(options)) { + if ( + isStore< + Omit>, 'queryKey' | 'initialPageParam'> + >(options) + ) { // options is store mergedOpt = derived([options], ([$opt]) => { return { queryKey, queryFn, + initialPageParam: args, ...$opt, }; }); @@ -129,6 +136,7 @@ export function useInfiniteModelQuery( mergedOpt = { queryKey, queryFn, + initialPageParam: args, ...options, }; } diff --git a/packages/plugins/tanstack-query/src/runtime/vue.ts b/packages/plugins/tanstack-query/src/runtime/vue.ts index 016414722..b024d8940 100644 --- a/packages/plugins/tanstack-query/src/runtime/vue.ts +++ b/packages/plugins/tanstack-query/src/runtime/vue.ts @@ -1,5 +1,6 @@ /* eslint-disable @typescript-eslint/ban-types */ /* eslint-disable @typescript-eslint/no-explicit-any */ +import type { InfiniteData } from '@tanstack/react-query-v5'; import { useInfiniteQuery, useMutation, @@ -103,8 +104,12 @@ export function useInfiniteModelQuery( url: string, args?: MaybeRefOrGetter | ComputedRef, options?: - | MaybeRefOrGetter, 'queryKey'>> - | ComputedRef, 'queryKey'>>, + | MaybeRefOrGetter< + Omit>, 'queryKey' | 'initialPageParam'> + > + | ComputedRef< + Omit>, 'queryKey' | 'initialPageParam'> + >, fetch?: FetchFn ) { // CHECKME: vue-query's `useInfiniteQuery`'s input typing seems wrong @@ -115,10 +120,11 @@ export function useInfiniteModelQuery( const reqUrl = makeUrl(url, pageParam ?? toValue(args)); return fetcher(reqUrl, undefined, fetch, false); }, + initialPageParam: toValue(args), ...toValue(options), })); - return useInfiniteQuery(queryOptions); + return useInfiniteQuery>(queryOptions); } /** diff --git a/packages/plugins/tanstack-query/tests/plugin.test.ts b/packages/plugins/tanstack-query/tests/plugin.test.ts index 3dfb2dec0..88c8fd7e8 100644 --- a/packages/plugins/tanstack-query/tests/plugin.test.ts +++ b/packages/plugins/tanstack-query/tests/plugin.test.ts @@ -18,7 +18,7 @@ describe('Tanstack Query Plugin Tests', () => { const sharedModel = ` model User { - id String @id + id String @id @default(cuid()) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt email String @unique @@ -32,7 +32,7 @@ enum role { } model post_Item { - id String @id + id String @id @default(cuid()) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt title String @@ -48,6 +48,34 @@ model Foo { } `; + const reactAppSource = { + name: 'main.ts', + content: ` + import { useFindFirstpost_Item, useInfiniteFindManypost_Item, useCreatepost_Item } from './hooks'; + + function query() { + const { data } = useFindFirstpost_Item({include: { author: true }}); + console.log(data?.viewCount); + console.log(data?.author?.email); + } + + function infiniteQuery() { + const { data, fetchNextPage, hasNextPage } = useInfiniteFindManypost_Item(); + useInfiniteFindManypost_Item({ where: { published: true } }); + useInfiniteFindManypost_Item(undefined, { getNextPageParam: () => null }); + console.log(data?.pages[0][0].published); + console.log(data?.pageParams[0]); + } + + async function mutation() { + const { mutateAsync } = useCreatepost_Item(); + const data = await mutateAsync({ data: { title: 'hello' }, include: { author: true } }); + console.log(data?.viewCount); + console.log(data?.author?.email); + } + `, + }; + it('react-query run plugin v4', async () => { await loadSchema( ` @@ -66,6 +94,7 @@ ${sharedModel} extraDependencies: ['react@18.2.0', '@types/react@18.2.0', '@tanstack/react-query@4.29.7'], copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, + extraSourceFiles: [reactAppSource], } ); }); @@ -87,10 +116,55 @@ ${sharedModel} extraDependencies: ['react@18.2.0', '@types/react@18.2.0', '@tanstack/react-query@^5.0.0'], copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, + extraSourceFiles: [ + reactAppSource, + { + name: 'suspense.ts', + content: ` + import { useSuspenseInfiniteFindManypost_Item } from './hooks'; + + function suspenseInfiniteQuery() { + const { data, fetchNextPage, hasNextPage } = useSuspenseInfiniteFindManypost_Item(); + useSuspenseInfiniteFindManypost_Item({ where: { published: true } }); + useSuspenseInfiniteFindManypost_Item(undefined, { getNextPageParam: () => null }); + console.log(data?.pages[0][0].published); + console.log(data?.pageParams[0]); + } + `, + }, + ], } ); }); + const vueAppSource = { + name: 'main.ts', + content: ` + import { useFindFirstpost_Item, useInfiniteFindManypost_Item, useCreatepost_Item } from './hooks'; + + function query() { + const { data } = useFindFirstpost_Item({include: { author: true }}); + console.log(data.value?.viewCount); + console.log(data.value?.author?.email); + } + + function infiniteQuery() { + const { data, fetchNextPage, hasNextPage } = useInfiniteFindManypost_Item(); + useInfiniteFindManypost_Item({ where: { published: true } }); + useInfiniteFindManypost_Item(undefined, { getNextPageParam: () => null }); + console.log(data.value?.pages[0][0].published); + console.log(data.value?.pageParams[0]); + } + + async function mutation() { + const { mutateAsync } = useCreatepost_Item(); + const data = await mutateAsync({ data: { title: 'hello' }, include: { author: true } }); + console.log(data?.viewCount); + console.log(data?.author?.email); + } + `, + }; + it('vue-query run plugin v4', async () => { await loadSchema( ` @@ -109,6 +183,7 @@ ${sharedModel} extraDependencies: ['vue@^3.3.4', '@tanstack/vue-query@4.37.0'], copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, + extraSourceFiles: [vueAppSource], } ); }); @@ -130,10 +205,40 @@ ${sharedModel} extraDependencies: ['vue@^3.3.4', '@tanstack/vue-query@latest'], copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, + extraSourceFiles: [vueAppSource], } ); }); + const svelteAppSource = { + name: 'main.ts', + content: ` + import { get } from 'svelte/store'; + import { useFindFirstpost_Item, useInfiniteFindManypost_Item, useCreatepost_Item } from './hooks'; + + function query() { + const { data } = get(useFindFirstpost_Item({include: { author: true }})); + console.log(data?.viewCount); + console.log(data?.author?.email); + } + + function infiniteQuery() { + const { data, fetchNextPage, hasNextPage } = get(useInfiniteFindManypost_Item()); + useInfiniteFindManypost_Item({ where: { published: true } }); + useInfiniteFindManypost_Item(undefined, { getNextPageParam: () => null }); + console.log(data?.pages[0][0].published); + console.log(data?.pageParams[0]); + } + + async function mutation() { + const { mutateAsync } = get(useCreatepost_Item()); + const data = await mutateAsync({ data: { title: 'hello' }, include: { author: true } }); + console.log(data?.viewCount); + console.log(data?.author?.email); + } + `, + }; + it('svelte-query run plugin v4', async () => { await loadSchema( ` @@ -152,6 +257,7 @@ ${sharedModel} extraDependencies: ['svelte@^3.0.0', '@tanstack/svelte-query@4.29.7'], copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, + extraSourceFiles: [svelteAppSource], } ); }); @@ -173,6 +279,7 @@ ${sharedModel} extraDependencies: ['svelte@^3.0.0', '@tanstack/svelte-query@^5.0.0'], copyDependencies: [path.resolve(__dirname, '../dist')], compile: true, + extraSourceFiles: [svelteAppSource], } ); }); diff --git a/packages/plugins/tanstack-query/tests/react-hooks-v5.test.tsx b/packages/plugins/tanstack-query/tests/react-hooks-v5.test.tsx index d5e23c374..e772c12e1 100644 --- a/packages/plugins/tanstack-query/tests/react-hooks-v5.test.tsx +++ b/packages/plugins/tanstack-query/tests/react-hooks-v5.test.tsx @@ -10,7 +10,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query-v5'; import { act, renderHook, waitFor } from '@testing-library/react'; import nock from 'nock'; import React from 'react'; -import { RequestHandlerContext, useModelMutation, useModelQuery } from '../src/runtime-v5/react'; +import { RequestHandlerContext, useInfiniteModelQuery, useModelMutation, useModelQuery } from '../src/runtime-v5/react'; import { getQueryKey } from '../src/runtime/common'; import { modelMeta } from './test-model-meta'; @@ -60,6 +60,45 @@ describe('Tanstack Query React Hooks V5 Test', () => { }); }); + it('infinite query', async () => { + const { queryClient, wrapper } = createWrapper(); + + const queryArgs = { where: { id: '1' } }; + const data = [{ id: '1', name: 'foo' }]; + + nock(makeUrl('User', 'findMany', queryArgs)) + .get(/.*/) + .reply(200, () => { + console.log('Query findMany:', queryArgs); + return { + data: data, + }; + }); + + const { result } = renderHook( + () => + useInfiniteModelQuery('User', makeUrl('User', 'findMany'), queryArgs, { + getNextPageParam: () => null, + }), + { + wrapper, + } + ); + await waitFor(() => { + expect(result.current.isSuccess).toBe(true); + const resultData = result.current.data!; + expect(resultData.pages).toHaveLength(1); + expect(resultData.pages[0]).toMatchObject(data); + expect(resultData?.pageParams).toHaveLength(1); + expect(resultData?.pageParams[0]).toMatchObject(queryArgs); + expect(result.current.hasNextPage).toBe(false); + const cacheData: any = queryClient.getQueryData( + getQueryKey('User', 'findMany', queryArgs, { infinite: true, optimisticUpdate: false }) + ); + expect(cacheData.pages[0]).toMatchObject(data); + }); + }); + it('independent mutation and query', async () => { const { wrapper } = createWrapper(); diff --git a/packages/plugins/tanstack-query/tests/react-hooks.test.tsx b/packages/plugins/tanstack-query/tests/react-hooks.test.tsx index 7bd952fad..df913da7a 100644 --- a/packages/plugins/tanstack-query/tests/react-hooks.test.tsx +++ b/packages/plugins/tanstack-query/tests/react-hooks.test.tsx @@ -11,7 +11,7 @@ import { act, renderHook, waitFor } from '@testing-library/react'; import nock from 'nock'; import React from 'react'; import { getQueryKey } from '../src/runtime/common'; -import { RequestHandlerContext, useModelMutation, useModelQuery } from '../src/runtime/react'; +import { RequestHandlerContext, useInfiniteModelQuery, useModelMutation, useModelQuery } from '../src/runtime/react'; import { modelMeta } from './test-model-meta'; describe('Tanstack Query React Hooks V4 Test', () => { @@ -60,6 +60,45 @@ describe('Tanstack Query React Hooks V4 Test', () => { }); }); + it('infinite query', async () => { + const { queryClient, wrapper } = createWrapper(); + + const queryArgs = { where: { id: '1' } }; + const data = [{ id: '1', name: 'foo' }]; + + nock(makeUrl('User', 'findMany', queryArgs)) + .get(/.*/) + .reply(200, () => { + console.log('Query findMany:', queryArgs); + return { + data: data, + }; + }); + + const { result } = renderHook( + () => + useInfiniteModelQuery('User', makeUrl('User', 'findMany'), queryArgs, { + getNextPageParam: () => null, + }), + { + wrapper, + } + ); + await waitFor(() => { + expect(result.current.isSuccess).toBe(true); + const resultData = result.current.data!; + expect(resultData.pages).toHaveLength(1); + expect(resultData.pages[0]).toMatchObject(data); + expect(resultData?.pageParams).toHaveLength(1); + expect(resultData?.pageParams[0]).toBeUndefined(); + expect(result.current.hasNextPage).toBe(false); + const cacheData: any = queryClient.getQueryData( + getQueryKey('User', 'findMany', queryArgs, { infinite: true, optimisticUpdate: false }) + ); + expect(cacheData.pages[0]).toMatchObject(data); + }); + }); + it('independent mutation and query', async () => { const { wrapper } = createWrapper(); diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 7027f2d0a..a91ea4e45 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "2.1.2", + "version": "2.2.0", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { @@ -10,7 +10,7 @@ }, "scripts": { "clean": "rimraf dist", - "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE 'res/**/*' dist && pnpm pack dist --pack-destination ../../../../.build", + "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE \"res/**/*\" dist && pnpm pack dist --pack-destination ../../../../.build", "watch": "tsc --watch", "lint": "eslint src --ext ts", "test": "jest", diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 12aa58df1..9a6bfb923 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "2.1.2", + "version": "2.2.0", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", @@ -9,7 +9,7 @@ }, "scripts": { "clean": "rimraf dist", - "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && tsup-node --config ./tsup-browser.config.ts && tsup-node --config ./tsup-cross.config.ts && copyfiles ./package.json ./README.md ../../LICENSE dist && copyfiles -u1 'res/**/*' dist && pnpm pack dist --pack-destination '../../../.build'", + "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && tsup-node --config ./tsup-browser.config.ts && tsup-node --config ./tsup-cross.config.ts && copyfiles ./package.json ./README.md ../../LICENSE dist && copyfiles -u1 \"res/**/*\" dist && pnpm pack dist --pack-destination ../../../.build", "watch": "concurrently \"tsc --watch\" \"tsup-node --config ./tsup-browser.config.ts --watch\" \"tsup-node --config ./tsup-cross.config.ts --watch\"", "lint": "eslint src --ext ts", "prepublishOnly": "pnpm build" @@ -83,6 +83,7 @@ "decimal.js": "^10.4.2", "deepcopy": "^2.1.0", "deepmerge": "^4.3.1", + "is-plain-object": "^5.0.0", "logic-solver": "^2.0.1", "lower-case-first": "^2.0.2", "pluralize": "^8.0.0", @@ -98,7 +99,7 @@ "zod-validation-error": "^1.5.0" }, "peerDependencies": { - "@prisma/client": "5.0.0 - 5.13.x" + "@prisma/client": "5.0.0 - 5.15.x" }, "author": { "name": "ZenStack Team" diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index a85392887..5fd8c2901 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -63,41 +63,6 @@ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer'; */ export const PRISMA_MINIMUM_VERSION = '5.0.0'; -/** - * Selector function name for fetching pre-update entity values. - */ -export const PRE_UPDATE_VALUE_SELECTOR = 'preValueSelect'; - -/** - * Prefix for field-level read checker function name - */ -export const FIELD_LEVEL_READ_CHECKER_PREFIX = 'readFieldCheck$'; - -/** - * Field-level access control evaluation selector function name - */ -export const FIELD_LEVEL_READ_CHECKER_SELECTOR = 'readFieldSelect'; - -/** - * Prefix for field-level override read guard function name - */ -export const FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX = 'readFieldGuardOverride$'; - -/** - * Prefix for field-level update guard function name - */ -export const FIELD_LEVEL_UPDATE_GUARD_PREFIX = 'updateFieldGuard$'; - -/** - * Prefix for field-level override update guard function name - */ -export const FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX = 'updateFieldGuardOverride$'; - -/** - * Flag that indicates if the model has field-level access control - */ -export const HAS_FIELD_LEVEL_POLICY_FLAG = 'hasFieldLevelPolicy'; - /** * Prefix for auxiliary relation field generated for delegated models */ diff --git a/packages/runtime/src/enhancements/delegate.ts b/packages/runtime/src/enhancements/delegate.ts index 8e4c3569a..edbaed78d 100644 --- a/packages/runtime/src/enhancements/delegate.ts +++ b/packages/runtime/src/enhancements/delegate.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import deepcopy from 'deepcopy'; import deepmerge, { type ArrayMergeOptions } from 'deepmerge'; +import { isPlainObject } from 'is-plain-object'; import { lowerCaseFirst } from 'lower-case-first'; import { DELEGATE_AUX_RELATION_PREFIX } from '../constants'; import { @@ -1094,11 +1095,16 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { const result = deepmerge(upMerged, downMerged, { arrayMerge: combineMerge, + isMergeableObject: (v) => isPlainObject(v) || Array.isArray(v), // avoid messing with Decimal, Date, etc. }); return result; } private assembleUp(model: string, entity: any) { + if (!entity) { + return entity; + } + const result: any = {}; const base = this.getBaseModel(model); @@ -1146,6 +1152,10 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { } private assembleDown(model: string, entity: any) { + if (!entity) { + return entity; + } + const result: any = {}; const modelInfo = getModelInfo(this.options.modelMeta, model, true); diff --git a/packages/runtime/src/enhancements/policy/constraint-solver.ts b/packages/runtime/src/enhancements/policy/constraint-solver.ts index c87a528e7..9b792a0fa 100644 --- a/packages/runtime/src/enhancements/policy/constraint-solver.ts +++ b/packages/runtime/src/enhancements/policy/constraint-solver.ts @@ -1,10 +1,10 @@ import Logic from 'logic-solver'; import { match } from 'ts-pattern'; import type { - CheckerConstraint, ComparisonConstraint, ComparisonTerm, LogicalConstraint, + PermissionCheckerConstraint, ValueConstraint, VariableConstraint, } from '../types'; @@ -22,7 +22,7 @@ export class ConstraintSolver { /** * Check the satisfiability of the given constraint. */ - checkSat(constraint: CheckerConstraint): boolean { + checkSat(constraint: PermissionCheckerConstraint): boolean { // reset state this.stringTable = []; this.variables = new Map(); @@ -46,7 +46,7 @@ export class ConstraintSolver { return !!solver.solve(); } - private buildFormula(constraint: CheckerConstraint): Logic.Formula { + private buildFormula(constraint: PermissionCheckerConstraint): Logic.Formula { return match(constraint) .when( (c): c is ValueConstraint => c.kind === 'value', @@ -100,11 +100,11 @@ export class ConstraintSolver { return Logic.not(this.buildFormula(constraint.children[0])); } - private isTrue(constraint: CheckerConstraint): unknown { + private isTrue(constraint: PermissionCheckerConstraint): unknown { return constraint.kind === 'value' && constraint.value === true; } - private isFalse(constraint: CheckerConstraint): unknown { + private isFalse(constraint: PermissionCheckerConstraint): unknown { return constraint.kind === 'value' && constraint.value === false; } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 997e727d5..b6088ed25 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1,5 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +import deepmerge from 'deepmerge'; import { lowerCaseFirst } from 'lower-case-first'; import invariant from 'tiny-invariant'; import { P, match } from 'ts-pattern'; @@ -23,7 +24,7 @@ import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; -import type { CheckerConstraint } from '../types'; +import type { EntityCheckerFunc, PermissionCheckerConstraint } from '../types'; import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; import { ConstraintSolver } from './constraint-solver'; import { PolicyUtil } from './policy-utils'; @@ -152,8 +153,7 @@ export class PolicyProxyHandler implements Pr } const result = await this.modelClient[actionName](_args); - this.policyUtils.postProcessForRead(result, this.model, origArgs); - return result; + return this.policyUtils.postProcessForRead(result, this.model, origArgs); } //#endregion @@ -447,31 +447,10 @@ export class PolicyProxyHandler implements Pr // go through create items, statically check input to determine if post-create // check is needed, and also validate zod schema - let needPostCreateCheck = false; - for (const item of enumerate(args.data)) { - const validationResult = this.validateCreateInputSchema(this.model, item); - if (validationResult !== item) { - this.policyUtils.replace(item, validationResult); - } - - const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); - if (inputCheck === false) { - // unconditionally deny - throw this.policyUtils.deniedByPolicy( - this.model, - 'create', - undefined, - CrudFailureReason.ACCESS_POLICY_VIOLATION - ); - } else if (inputCheck === true) { - // unconditionally allow - } else if (inputCheck === undefined) { - // static policy check is not possible, need to do post-create check - needPostCreateCheck = true; - } - } + const needPostCreateCheck = this.validateCreateInput(args); if (!needPostCreateCheck) { + // direct create return this.modelClient.createMany(args); } else { // create entities in a transaction with post-create checks @@ -479,12 +458,95 @@ export class PolicyProxyHandler implements Pr const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx); // post-create check await this.runPostWriteChecks(postWriteChecks, tx); - return result; + return { count: result.length }; }); } }); } + createManyAndReturn(args: { select: any; include: any; data: any; skipDuplicates?: boolean }) { + if (!args) { + throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.prismaModule, + 'data field is required in query argument' + ); + } + + return createDeferredPromise(async () => { + this.policyUtils.tryReject(this.prisma, this.model, 'create'); + + const origArgs = args; + args = clone(args); + + // go through create items, statically check input to determine if post-create + // check is needed, and also validate zod schema + const needPostCreateCheck = this.validateCreateInput(args); + + let result: { result: unknown; error?: Error }[]; + + if (!needPostCreateCheck) { + // direct create + const created = await this.modelClient.createManyAndReturn(args); + + // process read-back + result = await Promise.all( + created.map((item) => this.policyUtils.readBack(this.prisma, this.model, 'create', origArgs, item)) + ); + } else { + // create entities in a transaction with post-create checks + result = await this.queryUtils.transaction(this.prisma, async (tx) => { + const { result: created, postWriteChecks } = await this.doCreateMany(this.model, args, tx); + // post-create check + await this.runPostWriteChecks(postWriteChecks, tx); + + // process read-back + return Promise.all( + created.map((item) => this.policyUtils.readBack(tx, this.model, 'create', origArgs, item)) + ); + }); + } + + // throw read-back error if any of create result read-back fails + const error = result.find((r) => !!r.error)?.error; + if (error) { + throw error; + } else { + return result.map((r) => r.result); + } + }); + } + + private validateCreateInput(args: { data: any; skipDuplicates?: boolean | undefined }) { + let needPostCreateCheck = false; + for (const item of enumerate(args.data)) { + const validationResult = this.validateCreateInputSchema(this.model, item); + if (validationResult !== item) { + this.policyUtils.replace(item, validationResult); + } + + const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); + if (inputCheck === false) { + // unconditionally deny + throw this.policyUtils.deniedByPolicy( + this.model, + 'create', + undefined, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); + } else if (inputCheck === true) { + // unconditionally allow + } else if (inputCheck === undefined) { + // static policy check is not possible, need to do post-create check + needPostCreateCheck = true; + } + } + return needPostCreateCheck; + } + private async doCreateMany(model: string, args: { data: any; skipDuplicates?: boolean }, db: CrudContract) { // We can't call the native "createMany" because we can't get back what was created // for post-create checks. Instead, do a "create" for each item and collect the results. @@ -511,7 +573,7 @@ export class PolicyProxyHandler implements Pr createResult = createResult.filter((p) => !!p); return { - result: { count: createResult.length }, + result: createResult, postWriteChecks: createResult.map((item) => ({ model, operation: 'create' as PolicyOperationKind, @@ -779,10 +841,27 @@ export class PolicyProxyHandler implements Pr } }; - const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => { + const _connectDisconnect = async ( + model: string, + args: any, + context: NestedWriteVisitorContext, + operation: 'connect' | 'disconnect' + ) => { if (context.field?.backLink) { const backLinkField = this.policyUtils.getModelField(model, context.field.backLink); if (backLinkField?.isRelationOwner) { + let uniqueFilter = args; + if (operation === 'disconnect') { + // disconnect filter is not unique, need to build a reversed query to + // locate the entity and use its id fields as unique filter + const reversedQuery = this.policyUtils.buildReversedQuery(context); + const found = await db[model].findUnique({ + where: reversedQuery, + select: this.policyUtils.makeIdSelection(model), + }); + uniqueFilter = found && this.policyUtils.getIdFieldValues(model, found); + } + // update happens on the related model, require updatable, // translate args to foreign keys so field-level policies can be checked const checkArgs: any = {}; @@ -794,10 +873,15 @@ export class PolicyProxyHandler implements Pr } } } - await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs); - // register post-update check - await _registerPostUpdateCheck(model, args, args); + // `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist + if (uniqueFilter) { + // check for update + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, checkArgs); + + // register post-update check + await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); + } } } }; @@ -970,14 +1054,14 @@ export class PolicyProxyHandler implements Pr } }, - connect: async (model, args, context) => _connectDisconnect(model, args, context), + connect: async (model, args, context) => _connectDisconnect(model, args, context, 'connect'), connectOrCreate: async (model, args, context) => { // the where condition is already unique, so we can use it to check if the target exists const existing = await this.policyUtils.checkExistence(db, model, args.where); if (existing) { // connect - await _connectDisconnect(model, args.where, context); + await _connectDisconnect(model, args.where, context, 'connect'); return true; } else { // create @@ -997,7 +1081,7 @@ export class PolicyProxyHandler implements Pr } }, - disconnect: async (model, args, context) => _connectDisconnect(model, args, context), + disconnect: async (model, args, context) => _connectDisconnect(model, args, context, 'disconnect'), set: async (model, args, context) => { // find the set of items to be replaced @@ -1012,10 +1096,10 @@ export class PolicyProxyHandler implements Pr const currentSet = await db[model].findMany(findCurrSetArgs); // register current set for update (foreign key) - await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context))); + await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, 'disconnect'))); // proceed with connecting the new set - await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context))); + await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, 'connect'))); }, delete: async (model, args, context) => { @@ -1160,48 +1244,78 @@ export class PolicyProxyHandler implements Pr args.data = this.validateUpdateInputSchema(this.model, args.data); - if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { - // use a transaction to do post-update checks - const postWriteChecks: PostWriteCheckRecord[] = []; - return this.queryUtils.transaction(this.prisma, async (tx) => { - // collect pre-update values - let select = this.policyUtils.makeIdSelection(this.model); - const preValueSelect = this.policyUtils.getPreValueSelect(this.model); - if (preValueSelect) { - select = { ...select, ...preValueSelect }; - } - const currentSetQuery = { select, where: args.where }; - this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update'); - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); - } - const currentSet = await tx[this.model].findMany(currentSetQuery); + const canProceedWithoutTransaction = + // no post-update rules + !this.policyUtils.hasAuthGuard(this.model, 'postUpdate') && + // no Zod schema + !this.policyUtils.getZodSchema(this.model) && + // no entity checker + !entityChecker; - postWriteChecks.push( - ...currentSet.map((preValue) => ({ - model: this.model, - operation: 'postUpdate' as PolicyOperationKind, - uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), - preValue: preValueSelect ? preValue : undefined, - })) - ); - - // proceed with the update - const result = await tx[this.model].updateMany(args); - - // run post-write checks - await this.runPostWriteChecks(postWriteChecks, tx); - - return result; - }); - } else { + if (canProceedWithoutTransaction) { // proceed without a transaction if (this.shouldLogQuery) { this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`); } return this.modelClient.updateMany(args); } + + // collect post-update checks + const postWriteChecks: PostWriteCheckRecord[] = []; + + return this.queryUtils.transaction(this.prisma, async (tx) => { + // collect pre-update values + let select = this.policyUtils.makeIdSelection(this.model); + const preValueSelect = this.policyUtils.getPreValueSelect(this.model); + if (preValueSelect) { + select = { ...select, ...preValueSelect }; + } + + // merge selection required for running additional checker + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'update'); + if (entityChecker?.selector) { + select = deepmerge(select, entityChecker.selector); + } + + const currentSetQuery = { select, where: args.where }; + this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'update'); + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); + } + let candidates = await tx[this.model].findMany(currentSetQuery); + + if (entityChecker) { + // filter candidates with additional checker and build an id filter + const r = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func); + candidates = r.filteredCandidates; + + // merge id filter into update's where clause + args.where = args.where ? { AND: [args.where, r.idFilter] } : r.idFilter; + } + + postWriteChecks.push( + ...candidates.map((preValue) => ({ + model: this.model, + operation: 'postUpdate' as PolicyOperationKind, + uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), + preValue: preValueSelect ? preValue : undefined, + })) + ); + + // proceed with the update + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`updateMany\` in tx for ${this.model}: ${formatObject(args)}`); + } + const result = await tx[this.model].updateMany(args); + + // run post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + + return result; + }); }); } @@ -1328,14 +1442,49 @@ export class PolicyProxyHandler implements Pr this.policyUtils.tryReject(this.prisma, this.model, 'delete'); // inject policy conditions - args = args ?? {}; + args = clone(args); this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); - // conduct the deletion - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + const entityChecker = this.policyUtils.getEntityChecker(this.model, 'delete'); + if (entityChecker) { + // additional checker exists, need to run deletion inside a transaction + return this.queryUtils.transaction(this.prisma, async (tx) => { + // find the delete candidates, selecting id fields and fields needed for + // running the additional checker + let candidateSelect = this.policyUtils.makeIdSelection(this.model); + if (entityChecker.selector) { + candidateSelect = deepmerge(candidateSelect, entityChecker.selector); + } + + if (this.shouldLogQuery) { + this.logger.info( + `[policy] \`findMany\` ${this.model}: ${formatObject({ + where: args.where, + select: candidateSelect, + })}` + ); + } + const candidates = await tx[this.model].findMany({ where: args.where, select: candidateSelect }); + + // build a ID filter based on id values filtered by the additional checker + const { idFilter } = this.buildIdFilterWithEntityChecker(candidates, entityChecker.func); + + // merge the ID filter into the where clause + args.where = args.where ? { AND: [args.where, idFilter] } : idFilter; + + // finally, conduct the deletion with the combined where clause + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`deleteMany\` in tx for ${this.model}:\n${formatObject(args)}`); + } + return tx[this.model].deleteMany(args); + }); + } else { + // conduct the deletion directly + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.deleteMany(args); } - return this.modelClient.deleteMany(args); }); } @@ -1469,7 +1618,7 @@ export class PolicyProxyHandler implements Pr if (args.where) { // combine runtime filters with generated constraints - const extraConstraints: CheckerConstraint[] = []; + const extraConstraints: PermissionCheckerConstraint[] = []; for (const [field, value] of Object.entries(args.where)) { if (value === undefined) { continue; @@ -1599,5 +1748,17 @@ export class PolicyProxyHandler implements Pr } } + private buildIdFilterWithEntityChecker(candidates: any[], entityChecker: EntityCheckerFunc) { + const filteredCandidates = candidates.filter((value) => entityChecker(value, { user: this.context?.user })); + const idFields = this.policyUtils.getIdFields(this.model); + let idFilter: any; + if (idFields.length === 1) { + idFilter = { [idFields[0].name]: { in: filteredCandidates.map((x) => x[idFields[0].name]) } }; + } else { + idFilter = { AND: filteredCandidates.map((x) => this.policyUtils.getIdFieldValues(this.model, x)) }; + } + return { filteredCandidates, idFilter }; + } + //#endregion } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 1af05b03e..f5551b309 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -1,28 +1,26 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import deepcopy from 'deepcopy'; +import deepmerge from 'deepmerge'; import { lowerCaseFirst } from 'lower-case-first'; import { upperCaseFirst } from 'upper-case-first'; import { ZodError } from 'zod'; import { fromZodError } from 'zod-validation-error'; -import { - CrudFailureReason, - FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, - FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX, - FIELD_LEVEL_READ_CHECKER_PREFIX, - FIELD_LEVEL_READ_CHECKER_SELECTOR, - FIELD_LEVEL_UPDATE_GUARD_PREFIX, - HAS_FIELD_LEVEL_POLICY_FLAG, - PRE_UPDATE_VALUE_SELECTOR, - PrismaErrorCode, -} from '../../constants'; +import { CrudFailureReason, PrismaErrorCode } from '../../constants'; import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; -import { AuthUser, CrudContract, DbClientContract, PolicyCrudKind, PolicyOperationKind } from '../../types'; +import { + AuthUser, + CrudContract, + DbClientContract, + PolicyCrudKind, + PolicyOperationKind, + QueryContext, +} from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; import { Logger } from '../logger'; import { QueryUtils } from '../query-utils'; -import type { CheckerFunc, InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; +import type { EntityChecker, ModelPolicyDef, PermissionCheckerFunc, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; import { formatObject, prismaClientKnownRequestError } from '../utils'; /** @@ -230,23 +228,32 @@ export class PolicyUtil extends QueryUtils { //#region Auth guard - private readonly FULLY_OPEN_AUTH_GUARD = { - create: true, - read: true, - update: true, - delete: true, - postUpdate: true, - create_input: true, - update_input: true, + private readonly FULL_OPEN_MODEL_POLICY: ModelPolicyDef = { + modelLevel: { + read: { guard: true }, + create: { guard: true, inputChecker: true }, + update: { guard: true }, + delete: { guard: true }, + postUpdate: { guard: true }, + }, }; - private getModelAuthGuard(model: string): PolicyDef['guard']['string'] { + private getModelPolicyDef(model: string): ModelPolicyDef { if (this.options.kinds && !this.options.kinds.includes('policy')) { // policy enhancement not enabled, return an fully open guard - return this.FULLY_OPEN_AUTH_GUARD; - } else { - return this.policy.guard[lowerCaseFirst(model)]; + return this.FULL_OPEN_MODEL_POLICY; + } + + const def = this.policy.policy[lowerCaseFirst(model)]; + if (!def) { + throw this.unknownError(`unable to load policy guard for ${model}`); } + return def; + } + + private getModelGuardForOperation(model: string, operation: PolicyOperationKind): PolicyFunc | boolean { + const def = this.getModelPolicyDef(model); + return def.modelLevel[operation].guard ?? true; } /** @@ -256,20 +263,35 @@ export class PolicyUtil extends QueryUtils { * otherwise returns a guard object */ getAuthGuard(db: CrudContract, model: string, operation: PolicyOperationKind, preValue?: any) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); + const guard = this.getModelGuardForOperation(model, operation); + + // constant guard + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const provider = guard[operation]; - if (typeof provider === 'boolean') { - return this.reduce(provider); + // invoke guard function + const r = guard({ user: this.user, preValue }, db); + return this.reduce(r); + } + + /** + * Get field-level read auth guard + */ + getFieldReadAuthGuard(db: CrudContract, model: string, field: string) { + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.read?.[field]?.guard; + + if (guard === undefined) { + // field access is allowed by default + return this.makeTrue(); } - if (!provider) { - throw this.unknownError(`unable to load authorization guard for ${model}`); + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const r = provider({ user: this.user, preValue }, db); + + const r = guard({ user: this.user }, db); return this.reduce(r); } @@ -277,19 +299,19 @@ export class PolicyUtil extends QueryUtils { * Get field-level read auth guard that overrides the model-level */ getFieldOverrideReadAuthGuard(db: CrudContract, model: string, field: string) { - const guard = this.requireGuard(model); + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.read?.[field]?.overrideGuard; - const provider = guard[`${FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${field}`]; - if (provider === undefined) { + if (guard === undefined) { // field access is denied by default in override mode return this.makeFalse(); } - if (typeof provider === 'boolean') { - return this.reduce(provider); + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const r = provider({ user: this.user }, db); + const r = guard({ user: this.user }, db); return this.reduce(r); } @@ -297,19 +319,19 @@ export class PolicyUtil extends QueryUtils { * Get field-level update auth guard */ getFieldUpdateAuthGuard(db: CrudContract, model: string, field: string) { - const guard = this.requireGuard(model); + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.update?.[field]?.guard; - const provider = guard[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field}`]; - if (provider === undefined) { + if (guard === undefined) { // field access is allowed by default return this.makeTrue(); } - if (typeof provider === 'boolean') { - return this.reduce(provider); + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const r = provider({ user: this.user }, db); + const r = guard({ user: this.user }, db); return this.reduce(r); } @@ -317,19 +339,19 @@ export class PolicyUtil extends QueryUtils { * Get field-level update auth guard that overrides the model-level */ getFieldOverrideUpdateAuthGuard(db: CrudContract, model: string, field: string) { - const guard = this.requireGuard(model); + const def = this.getModelPolicyDef(model); + const guard = def.fieldLevel?.update?.[field]?.overrideGuard; - const provider = guard[`${FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${field}`]; - if (provider === undefined) { + if (guard === undefined) { // field access is denied by default in override mode return this.makeFalse(); } - if (typeof provider === 'boolean') { - return this.reduce(provider); + if (typeof guard === 'boolean') { + return this.reduce(guard); } - const r = provider({ user: this.user }, db); + const r = guard({ user: this.user }, db); return this.reduce(r); } @@ -337,26 +359,24 @@ export class PolicyUtil extends QueryUtils { * Checks if the given model has a policy guard for the given operation. */ hasAuthGuard(model: string, operation: PolicyOperationKind) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - return false; - } - const provider = guard[operation]; - return typeof provider !== 'boolean' || provider !== true; + const guard = this.getModelGuardForOperation(model, operation); + return typeof guard !== 'boolean' || guard !== true; } /** * Checks if the given model has any field-level override policy guard for the given operation. */ hasOverrideAuthGuard(model: string, operation: PolicyOperationKind) { - const guard = this.requireGuard(model); - switch (operation) { - case 'read': - return Object.keys(guard).some((k) => k.startsWith(FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX)); - case 'update': - return Object.keys(guard).some((k) => k.startsWith(FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX)); - default: - return false; + if (operation !== 'read' && operation !== 'update') { + return false; + } + const def = this.getModelPolicyDef(model); + if (def.fieldLevel?.[operation]) { + return Object.values(def.fieldLevel[operation]).some( + (f) => f.overrideGuard !== undefined || f.overrideEntityChecker !== undefined + ); + } else { + return false; } } @@ -366,22 +386,18 @@ export class PolicyUtil extends QueryUtils { * @returns boolean if static analysis is enough to determine the result, undefined if not */ checkInputGuard(model: string, args: any, operation: 'create'): boolean | undefined { - const guard = this.getModelAuthGuard(model); - if (!guard) { - return undefined; - } - - const provider: InputCheckFunc | boolean | undefined = guard[`${operation}_input` as const]; + const def = this.getModelPolicyDef(model); - if (typeof provider === 'boolean') { - return provider; + const guard = def.modelLevel[operation].inputChecker; + if (guard === undefined) { + return undefined; } - if (!provider) { - return undefined; + if (typeof guard === 'boolean') { + return guard; } - return provider(args, { user: this.user }); + return guard(args, { user: this.user }); } /** @@ -423,57 +439,73 @@ export class PolicyUtil extends QueryUtils { return false; } + let mergedGuard = guard; if (args.where) { // inject into relation fields: // to-many: some/none/every // to-one: direct-conditions/is/isNot - this.injectGuardForRelationFields(db, model, args.where, operation); + mergedGuard = this.injectReadGuardForRelationFields(db, model, args.where, guard); } - args.where = this.and(args.where, guard); + args.where = this.and(args.where, mergedGuard); return true; } - private injectGuardForRelationFields( - db: CrudContract, - model: string, - payload: any, - operation: PolicyOperationKind - ) { + // Injects guard for relation fields nested in `payload`. The `modelGuard` parameter represents the model-level guard for `model`. + // The function returns a modified copy of `modelGuard` with field-level policies combined. + private injectReadGuardForRelationFields(db: CrudContract, model: string, payload: any, modelGuard: any) { + if (!payload || typeof payload !== 'object' || Object.keys(payload).length === 0) { + return modelGuard; + } + + const allFieldGuards: object[] = []; + const allFieldOverrideGuards: object[] = []; + for (const [field, subPayload] of Object.entries(payload)) { if (!subPayload) { continue; } - const fieldInfo = resolveField(this.modelMeta, model, field); - if (!fieldInfo || !fieldInfo.isDataModel) { - continue; - } + allFieldGuards.push(this.getFieldReadAuthGuard(db, model, field)); + allFieldOverrideGuards.push(this.getFieldOverrideReadAuthGuard(db, model, field)); - if (fieldInfo.isArray) { - this.injectGuardForToManyField(db, fieldInfo, subPayload, operation); - } else { - this.injectGuardForToOneField(db, fieldInfo, subPayload, operation); + const fieldInfo = resolveField(this.modelMeta, model, field); + if (fieldInfo?.isDataModel) { + if (fieldInfo.isArray) { + this.injectReadGuardForToManyField(db, fieldInfo, subPayload); + } else { + this.injectReadGuardForToOneField(db, fieldInfo, subPayload); + } } } + + // all existing field-level guards must be true + const mergedGuard: object = this.and(...allFieldGuards); + + // all existing field-level override guards must be true for override to take effect; override is disabled by default + const mergedOverrideGuard: object = + allFieldOverrideGuards.length === 0 ? this.makeFalse() : this.and(...allFieldOverrideGuards); + + // (original-guard && field-level-guard) || field-level-override-guard + const updatedGuard = this.or(this.and(modelGuard, mergedGuard), mergedOverrideGuard); + return updatedGuard; } - private injectGuardForToManyField( + private injectReadGuardForToManyField( db: CrudContract, fieldInfo: FieldInfo, - payload: { some?: any; every?: any; none?: any }, - operation: PolicyOperationKind + payload: { some?: any; every?: any; none?: any } ) { - const guard = this.getAuthGuard(db, fieldInfo.type, operation); + const guard = this.getAuthGuard(db, fieldInfo.type, 'read'); if (payload.some) { - this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation); + const mergedGuard = this.injectReadGuardForRelationFields(db, fieldInfo.type, payload.some, guard); // turn "some" into: { some: { AND: [guard, payload.some] } } - payload.some = this.and(payload.some, guard); + payload.some = this.and(payload.some, mergedGuard); } if (payload.none) { - this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation); + const mergedGuard = this.injectReadGuardForRelationFields(db, fieldInfo.type, payload.none, guard); // turn none into: { none: { AND: [guard, payload.none] } } - payload.none = this.and(payload.none, guard); + payload.none = this.and(payload.none, mergedGuard); } if ( payload.every && @@ -481,40 +513,44 @@ export class PolicyUtil extends QueryUtils { // ignore empty every clause Object.keys(payload.every).length > 0 ) { - this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation); + const mergedGuard = this.injectReadGuardForRelationFields(db, fieldInfo.type, payload.every, guard); // turn "every" into: { none: { AND: [guard, { NOT: payload.every }] } } if (!payload.none) { payload.none = {}; } - payload.none = this.and(payload.none, guard, this.not(payload.every)); + payload.none = this.and(payload.none, mergedGuard, this.not(payload.every)); delete payload.every; } } - private injectGuardForToOneField( + private injectReadGuardForToOneField( db: CrudContract, fieldInfo: FieldInfo, - payload: { is?: any; isNot?: any } & Record, - operation: PolicyOperationKind + payload: { is?: any; isNot?: any } & Record ) { - const guard = this.getAuthGuard(db, fieldInfo.type, operation); + const guard = this.getAuthGuard(db, fieldInfo.type, 'read'); // is|isNot and flat fields conditions are mutually exclusive - if (payload.is || payload.isNot) { + // is and isNot can be null value + + if (payload.is !== undefined || payload.isNot !== undefined) { if (payload.is) { - this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation); + const mergedGuard = this.injectReadGuardForRelationFields(db, fieldInfo.type, payload.is, guard); + // merge guard with existing "is": { is: { AND: [originalIs, guard] } } + payload.is = this.and(payload.is, mergedGuard); } + if (payload.isNot) { - this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation); + const mergedGuard = this.injectReadGuardForRelationFields(db, fieldInfo.type, payload.isNot, guard); + // merge guard with existing "isNot": { isNot: { AND: [originalIsNot, guard] } } + payload.isNot = this.and(payload.isNot, mergedGuard); } - // merge guard with existing "is": { is: [originalIs, guard] } - payload.is = this.and(payload.is, guard); } else { - this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation); + const mergedGuard = this.injectReadGuardForRelationFields(db, fieldInfo.type, payload, guard); // turn direct conditions into: { is: { AND: [ originalConditions, guard ] } } - const combined = this.and(deepcopy(payload), guard); + const combined = this.and(deepcopy(payload), mergedGuard); Object.keys(payload).forEach((key) => delete payload[key]); payload.is = combined; } @@ -534,7 +570,7 @@ export class PolicyUtil extends QueryUtils { // inject into relation fields: // to-many: some/none/every // to-one: direct-conditions/is/isNot - this.injectGuardForRelationFields(db, model, args.where, 'read'); + this.injectReadGuardForRelationFields(db, model, args.where, {}); } if (injected.where && Object.keys(injected.where).length > 0 && !this.isTrue(injected.where)) { @@ -568,35 +604,30 @@ export class PolicyUtil extends QueryUtils { /** * Gets checker constraints for the given model and operation. */ - getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { - const checker = this.getModelChecker(model); - const provider = checker[operation]; - if (typeof provider === 'boolean') { - return provider; + getCheckerConstraint(model: string, operation: PolicyCrudKind): ReturnType | boolean { + if (this.options.kinds && !this.options.kinds.includes('policy')) { + // policy enhancement not enabled, return a constant true checker result + return true; } - if (typeof provider !== 'function') { - throw this.unknownError(`invalid ${operation} checker function for ${model}`); + const def = this.getModelPolicyDef(model); + const checker = def.modelLevel[operation].permissionChecker; + if (checker === undefined) { + throw new Error( + `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.` + ); } - // call checker function - return provider({ user: this.user }); - } + if (typeof checker === 'boolean') { + return checker; + } - private getModelChecker(model: string) { - if (this.options.kinds && !this.options.kinds.includes('policy')) { - // policy enhancement not enabled, return a constant true checker - return { create: true, read: true, update: true, delete: true }; - } else { - const result = this.options.policy.checker?.[lowerCaseFirst(model)]; - if (!result) { - // checker generation not enabled, return constant false checker - throw new Error( - `Generated permission checkers not found. Please make sure the "generatePermissionChecker" option is set to true in the "@core/enhancer" plugin.` - ); - } - return result; + if (typeof checker !== 'function') { + throw this.unknownError(`invalid ${operation} checker function for ${model}`); } + + // call checker function + return checker({ user: this.user }); } //#endregion @@ -719,6 +750,8 @@ export class PolicyUtil extends QueryUtils { ); } + let entityChecker: EntityChecker | undefined; + if (operation === 'update' && args) { // merge field-level policy guards const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); @@ -732,33 +765,47 @@ export class PolicyUtil extends QueryUtils { }"`, CrudFailureReason.ACCESS_POLICY_VIOLATION ); - } else { - if (fieldUpdateGuard.guard) { - // merge field-level guard - guard = this.and(guard, fieldUpdateGuard.guard); - } + } - if (fieldUpdateGuard.overrideGuard) { - // merge field-level override guard - guard = this.or(guard, fieldUpdateGuard.overrideGuard); - } + if (fieldUpdateGuard.guard) { + // merge field-level guard with AND + guard = this.and(guard, fieldUpdateGuard.guard); } + + if (fieldUpdateGuard.overrideGuard) { + // merge field-level override guard with OR + guard = this.or(guard, fieldUpdateGuard.overrideGuard); + } + + // field-level entity checker + entityChecker = fieldUpdateGuard.entityChecker; } // Zod schema is to be checked for "create" and "postUpdate" const schema = ['create', 'postUpdate'].includes(operation) ? this.getZodSchema(model) : undefined; - if (this.isTrue(guard) && !schema) { + // combine field-level entity checker with model-level + const modelEntityChecker = this.getEntityChecker(model, operation); + entityChecker = this.combineEntityChecker(entityChecker, modelEntityChecker, 'and'); + + if (this.isTrue(guard) && !schema && !entityChecker) { // unconditionally allowed return; } - const select = schema + let select = schema ? // need to validate against schema, need to fetch all fields undefined : // only fetch id fields this.makeIdSelection(model); + if (entityChecker?.selector) { + if (!select) { + select = this.makeAllScalarFieldSelect(model); + } + select = { ...select, ...entityChecker.selector }; + } + let where = this.clone(uniqueFilter); // query args may have be of combined-id form, need to flatten it to call findFirst this.flattenGeneratedUniqueField(model, where); @@ -780,6 +827,20 @@ export class PolicyUtil extends QueryUtils { ); } + if (entityChecker) { + if (this.logger.enabled('info')) { + this.logger.info(`[policy] running entity checker on ${model} for ${operation}`); + } + if (!entityChecker.func(result, { user: this.user, preValue })) { + throw this.deniedByPolicy( + model, + operation, + `entity ${formatObject(uniqueFilter, false)} failed policy check`, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); + } + } + if (schema) { // TODO: push down schema check to the database const parseResult = schema.safeParse(result); @@ -799,6 +860,20 @@ export class PolicyUtil extends QueryUtils { } } + getEntityChecker(model: string, operation: PolicyOperationKind, field?: string) { + const def = this.getModelPolicyDef(model); + if (field) { + return def.fieldLevel?.[operation as 'read' | 'update']?.[field]?.entityChecker; + } else { + return def.modelLevel[operation].entityChecker; + } + } + + getUpdateOverrideEntityCheckerForField(model: string, field: string) { + const def = this.getModelPolicyDef(model); + return def.fieldLevel?.update?.[field]?.overrideEntityChecker; + } + private getFieldReadGuards(db: CrudContract, model: string, args: { select?: any; include?: any }) { const allFields = Object.values(getFields(this.modelMeta, model)); @@ -825,19 +900,20 @@ export class PolicyUtil extends QueryUtils { private getFieldUpdateGuards(db: CrudContract, model: string, args: any) { const allFieldGuards = []; const allOverrideFieldGuards = []; + let entityChecker: EntityChecker | undefined; - for (const [k, v] of Object.entries(args.data ?? args)) { - if (typeof v === 'undefined') { + for (const [field, value] of Object.entries(args.data ?? args)) { + if (typeof value === 'undefined') { continue; } - const field = resolveField(this.modelMeta, model, k); + const fieldInfo = resolveField(this.modelMeta, model, field); - if (field?.isDataModel) { + if (fieldInfo?.isDataModel) { // relation field update should be treated as foreign key update, // fetch and merge all foreign key guards - if (field.isRelationOwner && field.foreignKeyMapping) { - const foreignKeys = Object.values(field.foreignKeyMapping); + if (fieldInfo.isRelationOwner && fieldInfo.foreignKeyMapping) { + const foreignKeys = Object.values(fieldInfo.foreignKeyMapping); for (const fk of foreignKeys) { const fieldGuard = this.getFieldUpdateAuthGuard(db, model, fk); if (this.isFalse(fieldGuard)) { @@ -853,18 +929,26 @@ export class PolicyUtil extends QueryUtils { } } } else { - const fieldGuard = this.getFieldUpdateAuthGuard(db, model, k); + const fieldGuard = this.getFieldUpdateAuthGuard(db, model, field); if (this.isFalse(fieldGuard)) { - return { guard: fieldGuard, rejectedByField: k }; + return { guard: fieldGuard, rejectedByField: field }; } // add field guard allFieldGuards.push(fieldGuard); // add field override guard - const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, k); + const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, field); allOverrideFieldGuards.push(overrideFieldGuard); } + + // merge regular and override entity checkers with OR + let checker = this.getEntityChecker(model, 'update', field); + const overrideChecker = this.getUpdateOverrideEntityCheckerForField(model, field); + checker = this.combineEntityChecker(checker, overrideChecker, 'or'); + + // accumulate entity checker across fields + entityChecker = this.combineEntityChecker(entityChecker, checker, 'and'); } const allFieldsCombined = this.and(...allFieldGuards); @@ -875,6 +959,31 @@ export class PolicyUtil extends QueryUtils { guard: allFieldsCombined, overrideGuard: allOverrideFieldsCombined, rejectedByField: undefined, + entityChecker, + }; + } + + private combineEntityChecker( + left: EntityChecker | undefined, + right: EntityChecker | undefined, + combiner: 'and' | 'or' + ): EntityChecker | undefined { + if (!left) { + return right; + } + + if (!right) { + return left; + } + + const func = + combiner === 'and' + ? (entity: any, context: QueryContext) => left.func(entity, context) && right.func(entity, context) + : (entity: any, context: QueryContext) => left.func(entity, context) || right.func(entity, context); + + return { + func, + selector: deepmerge(left.selector ?? {}, right.selector ?? {}), }; } @@ -956,8 +1065,8 @@ export class PolicyUtil extends QueryUtils { } /** - * Injects field selection needed for checking field-level read policy into query args. - * @returns + * Injects field selection needed for checking field-level read policy check and evaluating + * entity checker into query args. */ injectReadCheckSelect(model: string, args: any) { // we need to recurse into relation fields before injecting the current level, because @@ -974,11 +1083,16 @@ export class PolicyUtil extends QueryUtils { if (this.hasFieldLevelPolicy(model)) { // recursively inject selection for fields needed for field-level read checks - const readFieldSelect = this.getReadFieldSelect(model); + const readFieldSelect = this.getFieldReadCheckSelector(model); if (readFieldSelect) { this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); } } + + const entityChecker = this.getEntityChecker(model, 'read'); + if (entityChecker?.selector) { + this.doInjectReadCheckSelect(model, args, { select: entityChecker.selector }); + } } private doInjectReadCheckSelect(model: string, args: any, input: any) { @@ -1091,32 +1205,41 @@ export class PolicyUtil extends QueryUtils { /** * Gets field selection for fetching pre-update entity values for the given model. */ - getPreValueSelect(model: string): object | undefined { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - return guard[PRE_UPDATE_VALUE_SELECTOR]; + getPreValueSelect(model: string) { + const def = this.getModelPolicyDef(model); + return def.modelLevel.postUpdate.preUpdateSelector; } - private getReadFieldSelect(model: string): object | undefined { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); + // get a merged selector object for all field-level read policies + private getFieldReadCheckSelector(model: string) { + const def = this.getModelPolicyDef(model); + let result: any = {}; + const fieldLevel = def.fieldLevel?.read; + if (fieldLevel) { + for (const def of Object.values(fieldLevel)) { + if (def.entityChecker?.selector) { + result = deepmerge(result, def.entityChecker.selector); + } + if (def.overrideEntityChecker?.selector) { + result = deepmerge(result, def.overrideEntityChecker.selector); + } + } } - return guard[FIELD_LEVEL_READ_CHECKER_SELECTOR]; + return Object.keys(result).length > 0 ? result : undefined; } private checkReadField(model: string, field: string, entity: any) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - const func = guard[`${FIELD_LEVEL_READ_CHECKER_PREFIX}${field}`] as ReadFieldCheckFunc | undefined; - if (!func) { + const def = this.getModelPolicyDef(model); + + // combine regular and override field-level entity checkers with OR + const checker = def.fieldLevel?.read?.[field]?.entityChecker; + const overrideChecker = def.fieldLevel?.read?.[field]?.overrideEntityChecker; + const combinedChecker = this.combineEntityChecker(checker, overrideChecker, 'or'); + + if (combinedChecker === undefined) { return true; } else { - return func(entity, { user: this.user }); + return combinedChecker.func(entity, { user: this.user }); } } @@ -1125,11 +1248,8 @@ export class PolicyUtil extends QueryUtils { } private hasFieldLevelPolicy(model: string) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - return !!guard[HAS_FIELD_LEVEL_POLICY_FLAG]; + const def = this.getModelPolicyDef(model); + return Object.keys(def.fieldLevel?.read ?? {}).length > 0; } /** @@ -1152,7 +1272,7 @@ export class PolicyUtil extends QueryUtils { // preserve the original data as it may be needed for checking field-level readability, // while the "data" will be manipulated during traversal (deleting unreadable fields) const origData = this.clone(data); - this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); + return this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); } private doPostProcessForRead( @@ -1164,12 +1284,44 @@ export class PolicyUtil extends QueryUtils { path = '' ) { if (data === null || data === undefined) { - return; + return data; + } + + let filteredData = data; + let filteredFullData = fullData; + + const entityChecker = this.getEntityChecker(model, 'read'); + if (entityChecker) { + if (Array.isArray(data)) { + filteredData = []; + filteredFullData = []; + for (const [entityData, entityFullData] of zip(data, fullData)) { + if (!entityChecker.func(entityData, { user: this.user })) { + if (this.shouldLogQuery) { + this.logger.info( + `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to entity checker` + ); + } + } else { + filteredData.push(entityData); + filteredFullData.push(entityFullData); + } + } + } else { + if (!entityChecker.func(data, { user: this.user })) { + if (this.shouldLogQuery) { + this.logger.info( + `[policy] dropping ${model} entity${path ? ' at ' + path : ''} due to entity checker` + ); + } + return null; + } + } } - for (const [entityData, entityFullData] of zip(data, fullData)) { + for (const [entityData, entityFullData] of zip(filteredData, filteredFullData)) { if (typeof entityData !== 'object' || !entityData) { - return; + continue; } for (const [field, fieldData] of Object.entries(entityData)) { @@ -1225,7 +1377,7 @@ export class PolicyUtil extends QueryUtils { if (fieldInfo.isDataModel) { // recurse into nested fields const nextArgs = (queryArgs?.select ?? queryArgs?.include)?.[field]; - this.doPostProcessForRead( + const nestedResult = this.doPostProcessForRead( fieldData, fieldInfo.type, entityFullData[field], @@ -1233,9 +1385,16 @@ export class PolicyUtil extends QueryUtils { this.hasFieldLevelPolicy(fieldInfo.type), path ? path + '.' + field : field ); + if (nestedResult === undefined) { + delete entityData[field]; + } else { + entityData[field] = nestedResult; + } } } } + + return filteredData; } /** @@ -1305,14 +1464,6 @@ export class PolicyUtil extends QueryUtils { } } - private requireGuard(model: string) { - const guard = this.getModelAuthGuard(model); - if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); - } - return guard; - } - /** * Given an entity data, returns an object only containing id fields. */ diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index e7f55a88c..70d8f27e9 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -35,6 +35,8 @@ export interface PrismaProxyHandler { createMany(args: { data: any; skipDuplicates?: boolean }): Promise; + createManyAndReturn(args: { data: any; select: any; include: any; skipDuplicates?: boolean }): Promise; + update(args: any): Promise; updateMany(args: any): Promise; @@ -122,6 +124,10 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return this.deferred<{ count: number }>('createMany', args, false); } + createManyAndReturn(args: { data: any; select: any; include: any; skipDuplicates?: boolean }) { + return this.deferred('createManyAndReturn', args); + } + update(args: any) { return this.deferred('update', args); } diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 89d5ce9f6..a5cd85314 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,15 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; -import { - FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, - FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX, - FIELD_LEVEL_READ_CHECKER_PREFIX, - FIELD_LEVEL_READ_CHECKER_SELECTOR, - FIELD_LEVEL_UPDATE_GUARD_PREFIX, - HAS_FIELD_LEVEL_POLICY_FLAG, - PRE_UPDATE_VALUE_SELECTOR, -} from '../constants'; -import type { CheckerContext, CrudContract, PolicyCrudKind, PolicyOperationKind, QueryContext } from '../types'; +import type { CrudContract, PermissionCheckerContext, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -33,10 +24,15 @@ export interface CommonEnhancementOptions { */ export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; +/** + * Function for checking an entity's data for permission + */ +export type EntityCheckerFunc = (input: any, context: QueryContext) => boolean; + /** * Function for checking if an operation is possibly allowed. */ -export type CheckerFunc = (context: CheckerContext) => CheckerConstraint; +export type PermissionCheckerFunc = (context: PermissionCheckerContext) => PermissionCheckerConstraint; /** * Supported checker constraint checking value types. @@ -76,53 +72,24 @@ export type ComparisonConstraint = { */ export type LogicalConstraint = { kind: 'and' | 'or' | 'not'; - children: CheckerConstraint[]; + children: PermissionCheckerConstraint[]; }; /** * Operation allowability checking constraint */ -export type CheckerConstraint = ValueConstraint | VariableConstraint | ComparisonConstraint | LogicalConstraint; - -/** - * Function for getting policy guard with a given context - */ -export type InputCheckFunc = (args: any, context: QueryContext) => boolean; - -/** - * Function for getting policy guard with a given context - */ -export type ReadFieldCheckFunc = (input: any, context: QueryContext) => boolean; +export type PermissionCheckerConstraint = + | ValueConstraint + | VariableConstraint + | ComparisonConstraint + | LogicalConstraint; /** * Policy definition */ export type PolicyDef = { - // Prisma query guards - guard: Record< - string, - // policy operation guard functions - Partial> & - // 'create_input' checker function - Partial> & - // field-level read checker functions or update guard functions - Record<`${typeof FIELD_LEVEL_READ_CHECKER_PREFIX}${string}`, ReadFieldCheckFunc> & - Record< - | `${typeof FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${string}` - | `${typeof FIELD_LEVEL_UPDATE_GUARD_PREFIX}${string}` - | `${typeof FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${string}`, - PolicyFunc - > & { - // pre-update value selector - [PRE_UPDATE_VALUE_SELECTOR]?: object; - // field-level read checker selector - [FIELD_LEVEL_READ_CHECKER_SELECTOR]?: object; - // flag that indicates if the model has field-level access control - [HAS_FIELD_LEVEL_POLICY_FLAG]?: boolean; - } - >; - - checker?: Record>; + // policy definitions for each model + policy: Record; // tracks which models have data validation rules validation: Record; @@ -131,10 +98,176 @@ export type PolicyDef = { authSelector?: object; }; +type ModelName = string; +type FieldName = string; + +/** + * Policy definition for a model + */ +export type ModelPolicyDef = { + /** + * Model-level CRUD policies + */ + modelLevel: ModelCrudDef; + + /** + * Field-level CRUD policies + */ + fieldLevel?: FieldCrudDef; +}; + +/** + * CRUD policy definitions for a model + */ +export type ModelCrudDef = { + read: ModelReadDef; + create: ModelCreateDef; + update: ModelUpdateDef; + delete: ModelDeleteDef; + postUpdate: ModelPostUpdateDef; +}; + +/** + * Information for checking entity data outside of Prisma + */ +export type EntityChecker = { + /** + * Checker function + */ + func: EntityCheckerFunc; + + /** + * Selector for fetching entity data + */ + selector?: object; +}; + +/** + * Common policy definition for a CRUD operation + */ +type ModelCrudCommon = { + /** + * Prisma query guard or a constant condition + */ + guard: PolicyFunc | boolean; + + /** + * Additional checker function for checking policies outside of Prisma + */ + /** + * Additional checker function for checking policies outside of Prisma + */ + entityChecker?: EntityChecker; + + /** + * Permission checker function or a constant condition + */ + permissionChecker?: PermissionCheckerFunc | boolean; +}; + +/** + * Policy definition for reading a model + */ +type ModelReadDef = ModelCrudCommon; + +/** + * Policy definition for creating a model + */ +type ModelCreateDef = ModelCrudCommon & { + /** + * Create input validation function. Only generated when a create + * can be approved or denied based on input values. + */ + inputChecker?: EntityCheckerFunc | boolean; +}; + +/** + * Policy definition for updating a model + */ +type ModelUpdateDef = ModelCrudCommon; + +/** + * Policy definition for deleting a model + */ +type ModelDeleteDef = ModelCrudCommon; + +/** + * Policy definition for post-update checking a model + */ +type ModelPostUpdateDef = Exclude & { + preUpdateSelector?: object; +}; + +/** + * CRUD policy definitions for a field + */ +type FieldCrudDef = { + /** + * Field-level read policy + */ + read: Record; + + /** + * Field-level update policy + */ + update: Record; +}; + +type FieldReadDef = { + /** + * Field-level Prisma query guard + */ + guard?: PolicyFunc; + + /** + * Entity checker + */ + entityChecker?: EntityChecker; + + /** + * Field-level read override Prisma query guard + */ + overrideGuard?: PolicyFunc; + + /** + * Entity checker for override policies + */ + overrideEntityChecker?: EntityChecker; +}; + +type FieldUpdateDef = { + /** + * Field-level update Prisma query guard + */ + guard?: PolicyFunc; + + /** + * Additional entity checker + */ + entityChecker?: EntityChecker; + + /** + * Field-level update override Prisma query guard + */ + overrideGuard?: PolicyFunc; + + /** + * Additional entity checker for override policies + */ + overrideEntityChecker?: EntityChecker; +}; + /** * Zod schemas for validation */ export type ZodSchemas = { + /** + * Zod schema for each model + */ models: Record; + + /** + * Zod schema for Prisma input types for each model + */ input?: Record>; }; diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 4c32480ba..bf1660090 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -12,7 +12,8 @@ export interface DbOperations { findUnique(args: unknown): PrismaPromise; findUniqueOrThrow(args: unknown): PrismaPromise; create(args: unknown): Promise; - createMany(args: unknown, skipDuplicates?: boolean): Promise<{ count: number }>; + createMany(args: unknown): Promise<{ count: number }>; + createManyAndReturn(args: unknown): Promise; update(args: unknown): Promise; updateMany(args: unknown): Promise<{ count: number }>; upsert(args: unknown): Promise; @@ -62,7 +63,7 @@ export type QueryContext = { /** * Context for checking operation allowability. */ -export type CheckerContext = { +export type PermissionCheckerContext = { /** * Current user */ diff --git a/packages/schema/README.md b/packages/schema/README.md index 2c24ab102..60b1cf5bf 100644 --- a/packages/schema/README.md +++ b/packages/schema/README.md @@ -6,8 +6,31 @@ This VS Code extension provides code editing helpers for authoring ZenStack's sc ## Features -- Syntax highlighting +- Syntax highlighting of `*.zmodel` files + + - In case the schema file is not recognized automatically, add the following to your settings.json file: + + ```json + "files.associations": { + "*.zmodel": "zmodel" + }, + ``` + - Auto formatting + + - To automatically format on save, add the following to your settings.json file: + + ```json + "editor.formatOnSave": true + ``` + + - To enable formatting in combination with prettier, add the following to your settings.json file: + ```json + "[zmodel]": { + "editor.defaultFormatter": "zenstack.zenstack" + }, + ``` + - Inline error reporting - Go-to definition - Hover documentation diff --git a/packages/schema/package.json b/packages/schema/package.json index a27f41f3b..b78d49bfe 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "Build scalable web apps with minimum code by defining authorization and validation rules inside the data schema that closer to the database", - "version": "2.1.2", + "version": "2.2.0", "author": { "name": "ZenStack Team" }, @@ -80,7 +80,7 @@ "vscode:prepublish": "pnpm bundle", "vscode:package": "pnpm bundle && vsce package --no-dependencies", "clean": "rimraf dist", - "build": "pnpm clean && pnpm lint --max-warnings=0 && tsc && copyfiles -F \"bin/*\" dist && copyfiles ./README-global.md ./LICENSE ./package.json dist && renamer --replace \"README.md\" dist/README-global.md && copyfiles -u 1 \"src/res/*\" dist && node build/post-build.js && pnpm pack dist --pack-destination '../../../.build'", + "build": "pnpm clean && pnpm lint --max-warnings=0 && tsc && copyfiles -F \"bin/*\" dist && copyfiles ./README-global.md ./LICENSE ./package.json dist && renamer --replace \"README.md\" dist/README-global.md && copyfiles -u 1 \"src/res/*\" dist && node build/post-build.js && pnpm pack dist --pack-destination ../../../.build", "bundle": "rimraf bundle && pnpm lint --max-warnings=0 && node build/bundle.js --minify", "watch": "tsc --watch", "lint": "eslint src tests --ext ts", @@ -122,10 +122,10 @@ "zod-validation-error": "^1.5.0" }, "peerDependencies": { - "prisma": "5.0.0 - 5.13.x" + "prisma": "5.0.0 - 5.15.x" }, "devDependencies": { - "@prisma/client": "^5.13.0", + "@prisma/client": "^5.15.0", "@types/async-exit-hook": "^2.0.0", "@types/pluralize": "^0.0.29", "@types/semver": "^7.3.13", @@ -137,7 +137,7 @@ "@zenstackhq/runtime": "workspace:*", "dotenv": "^16.0.3", "esbuild": "^0.15.12", - "prisma": "^5.13.0", + "prisma": "^5.15.0", "renamer": "^4.0.0", "tmp": "^0.2.1", "tsc-alias": "^1.7.0", diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index 9185443f3..eb6d06400 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -15,6 +15,7 @@ import { isDelegateModel, } from '@zenstackhq/sdk'; import { AstNode, DiagnosticInfo, ValidationAcceptor, getDocument } from 'langium'; +import { findUpInheritance } from '../../utils/ast-utils'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getUniqueFields } from '../utils'; @@ -238,7 +239,7 @@ export default class DataModelValidator implements AstValidator { return; } - if (field.$container !== contextModel && isDelegateModel(field.$container as DataModel)) { + if (this.isFieldInheritedFromDelegateModel(field, contextModel)) { // relation fields inherited from delegate model don't need opposite relation return; } @@ -390,6 +391,16 @@ export default class DataModelValidator implements AstValidator { } } + // checks if the given field is inherited directly or indirectly from a delegate model + private isFieldInheritedFromDelegateModel(field: DataModelField, contextModel: DataModel) { + const basePath = findUpInheritance(contextModel, field.$container as DataModel); + if (basePath && basePath.some(isDelegateModel)) { + return true; + } else { + return false; + } + } + private validateBaseAbstractModel(model: DataModel, accept: ValidationAcceptor) { model.superTypes.forEach((superType, index) => { if ( diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index d65e304dc..478db5ff7 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -1,6 +1,7 @@ import { AstNode, BinaryExpr, + DataModelAttribute, Expression, ExpressionType, isDataModel, @@ -13,7 +14,12 @@ import { isReferenceExpr, isThisExpr, } from '@zenstackhq/language/ast'; -import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; +import { + getAttributeArgLiteral, + isAuthInvocation, + isDataModelFieldReference, + isEnumFieldReference, +} from '@zenstackhq/sdk'; import { ValidationAcceptor, streamAst } from 'langium'; import { findUpAst, getContainingDataModel } from '../../utils/ast-utils'; import { AstValidator } from '../types'; @@ -151,6 +157,7 @@ export default class ExpressionValidator implements AstValidator { accept('error', 'incompatible operand types', { node: expr }); break; } + // not supported: // - foo.a == bar // - foo.user.id == userId @@ -169,10 +176,24 @@ export default class ExpressionValidator implements AstValidator { // foo.user.id == null // foo.user.id == EnumValue if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { - accept('error', 'comparison between fields of different models are not supported', { - node: expr, - }); - break; + const containingPolicyAttr = findUpAst( + expr, + (node) => isDataModelAttribute(node) && ['@@allow', '@@deny'].includes(node.decl.$refText) + ) as DataModelAttribute | undefined; + + if (containingPolicyAttr) { + const operation = getAttributeArgLiteral(containingPolicyAttr, 'operation'); + if (operation?.split(',').includes('all') || operation?.split(',').includes('read')) { + accept( + 'error', + 'comparison between fields of different models is not supported in model-level "read" rules', + { + node: expr, + } + ); + break; + } + } } } @@ -246,16 +267,6 @@ export default class ExpressionValidator implements AstValidator { accept('error', 'collection predicate can only be used on an array of model type', { node: expr }); return; } - - // TODO: revisit this when we implement lambda inside collection predicate - const thisExpr = streamAst(expr).find(isThisExpr); - if (thisExpr) { - accept( - 'error', - 'using `this` in collection predicate is not supported. To compare entity identity, use id field comparison instead.', - { node: thisExpr } - ); - } } private isInValidationContext(node: AstNode) { diff --git a/packages/schema/src/plugins/enhancer/policy/index.ts b/packages/schema/src/plugins/enhancer/policy/index.ts index 8eaf1d00b..918bfba8c 100644 --- a/packages/schema/src/plugins/enhancer/policy/index.ts +++ b/packages/schema/src/plugins/enhancer/policy/index.ts @@ -4,5 +4,5 @@ import type { Project } from 'ts-morph'; import { PolicyGenerator } from './policy-guard-generator'; export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { - return new PolicyGenerator().generate(project, model, options, outDir); + return new PolicyGenerator(options).generate(project, model, outDir); } diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index a36a52126..ce672dcc7 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -1,111 +1,58 @@ import { DataModel, DataModelField, - Enum, Expression, Model, isDataModel, isDataModelField, isEnum, - isExpression, - isInvocationExpr, isMemberAccessExpr, isReferenceExpr, isThisExpr, } from '@zenstackhq/language/ast'; -import { - FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, - FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX, - FIELD_LEVEL_READ_CHECKER_PREFIX, - FIELD_LEVEL_READ_CHECKER_SELECTOR, - FIELD_LEVEL_UPDATE_GUARD_PREFIX, - HAS_FIELD_LEVEL_POLICY_FLAG, - PRE_UPDATE_VALUE_SELECTOR, - type PolicyKind, - type PolicyOperationKind, -} from '@zenstackhq/runtime'; +import { PolicyCrudKind, type PolicyOperationKind } from '@zenstackhq/runtime'; import { ExpressionContext, - PluginError, PluginOptions, + PolicyAnalysisResult, RUNTIME_PACKAGE, TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, analyzePolicies, - getAttributeArg, - getAuthModel, getDataModels, - getIdFields, - getLiteral, hasAttribute, hasValidationAttributes, isAuthInvocation, - isEnumFieldReference, isForeignKeyField, - isFromStdlib, - isFutureExpr, - resolved, } from '@zenstackhq/sdk'; import { getPrismaClientImportSpec } from '@zenstackhq/sdk/prisma'; -import { streamAllContents, streamAst, streamContents } from 'langium'; +import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; -import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; -import { name } from '..'; -import { isCollectionPredicate } from '../../../utils/ast-utils'; -import { ALL_OPERATION_KINDS } from '../../plugin-utils'; +import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { ConstraintTransformer } from './constraint-transformer'; -import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; +import { + generateEntityCheckerFunction, + generateNormalizedAuthRef, + generateQueryGuardFunction, + generateSelectForRules, + getPolicyExpressions, + isEnumReferenced, +} from './utils'; /** * Generates source file that contains Prisma query guard objects used for injecting database queries */ export class PolicyGenerator { - async generate(project: Project, model: Model, options: PluginOptions, output: string) { + constructor(private options: PluginOptions) {} + + async generate(project: Project, model: Model, output: string) { const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true }); sf.addStatements('/* eslint-disable */'); - sf.addImportDeclaration({ - namedImports: [ - { name: 'type QueryContext' }, - { name: 'type CrudContract' }, - { name: 'allFieldsEqual' }, - { name: 'type PolicyDef' }, - { name: 'type CheckerContext' }, - { name: 'type CheckerConstraint' }, - ], - moduleSpecifier: `${RUNTIME_PACKAGE}`, - }); - - // import enums - const prismaImport = getPrismaClientImportSpec(output, options); - for (const e of model.declarations.filter((d) => isEnum(d) && this.isEnumReferenced(model, d))) { - sf.addImportDeclaration({ - namedImports: [{ name: e.name }], - moduleSpecifier: prismaImport, - }); - } + this.writeImports(model, output, sf); const models = getDataModels(model); - // policy guard functions - const policyMap: Record> = {}; - for (const model of models) { - policyMap[model.name] = await this.generateQueryGuardForModel(model, sf); - } - - const generatePermissionChecker = options.generatePermissionChecker === true; - - // CRUD checker functions - const checkerMap: Record> = {}; - if (generatePermissionChecker) { - for (const model of models) { - checkerMap[model.name] = await this.generateCheckerForModel(model, sf); - } - } - - const authSelector = this.generateAuthSelector(models); - sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, declarations: [ @@ -114,55 +61,9 @@ export class PolicyGenerator { type: 'PolicyDef', initializer: (writer) => { writer.block(() => { - writer.write('guard:'); - writer.inlineBlock(() => { - for (const [model, map] of Object.entries(policyMap)) { - writer.write(`${lowerCaseFirst(model)}:`); - writer.inlineBlock(() => { - for (const [op, func] of Object.entries(map)) { - if (typeof func === 'object') { - writer.write(`${op}: ${JSON.stringify(func)},`); - } else { - writer.write(`${op}: ${func},`); - } - } - }); - writer.write(','); - } - }); - writer.writeLine(','); - - if (generatePermissionChecker) { - writer.write('checker:'); - writer.inlineBlock(() => { - for (const [model, map] of Object.entries(checkerMap)) { - writer.write(`${lowerCaseFirst(model)}:`); - writer.inlineBlock(() => { - Object.entries(map).forEach(([op, func]) => { - writer.write(`${op}: ${func},`); - }); - }); - writer.writeLine(','); - } - }); - writer.writeLine(','); - } - - writer.write('validation:'); - writer.inlineBlock(() => { - for (const model of models) { - writer.write(`${lowerCaseFirst(model.name)}:`); - writer.inlineBlock(() => { - writer.write(`hasValidation: ${hasValidationAttributes(model)}`); - }); - writer.writeLine(','); - } - }); - - if (authSelector) { - writer.writeLine(','); - writer.write(`authSelector: ${JSON.stringify(authSelector)}`); - } + this.writePolicy(writer, models, sf); + this.writeValidationMeta(writer, models); + this.writeAuthSelector(models, writer); }); }, }, @@ -172,242 +73,170 @@ export class PolicyGenerator { sf.addStatements('export default policy'); // save ts files if requested explicitly or the user provided - const preserveTsFiles = options.preserveTsFiles === true || !!options.output; + const preserveTsFiles = this.options.preserveTsFiles === true || !!this.options.output; if (preserveTsFiles) { await sf.save(); } } - // Generates a { select: ... } object to select `auth()` fields used in policy rules - private generateAuthSelector(models: DataModel[]) { - const authRules: Expression[] = []; + private writeImports(model: Model, output: string, sf: SourceFile) { + sf.addImportDeclaration({ + namedImports: [ + { name: 'type QueryContext' }, + { name: 'type CrudContract' }, + { name: 'allFieldsEqual' }, + { name: 'type PolicyDef' }, + { name: 'type PermissionCheckerContext' }, + { name: 'type PermissionCheckerConstraint' }, + ], + moduleSpecifier: `${RUNTIME_PACKAGE}`, + }); - models.forEach((model) => { - // model-level rules - const modelPolicyAttrs = model.attributes.filter((attr) => - ['@@allow', '@@deny'].includes(attr.decl.$refText) - ); + // import enums + const prismaImport = getPrismaClientImportSpec(output, this.options); + for (const e of model.declarations.filter((d) => isEnum(d) && isEnumReferenced(model, d))) { + sf.addImportDeclaration({ + namedImports: [{ name: e.name }], + moduleSpecifier: prismaImport, + }); + } + } - // field-level rules - const fieldPolicyAttrs = model.fields - .flatMap((f) => f.attributes) - .filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText)); + private writePolicy(writer: CodeBlockWriter, models: DataModel[], sourceFile: SourceFile) { + writer.write('policy:'); + writer.inlineBlock(() => { + for (const model of models) { + writer.write(`${lowerCaseFirst(model.name)}:`); - // all rule expression - const allExpressions = [...modelPolicyAttrs, ...fieldPolicyAttrs] - .filter((attr) => attr.args.length > 1) - .map((attr) => attr.args[1].value); + writer.block(() => { + // model-level guards + this.writeModelLevelDefs(model, writer, sourceFile); - // collect `auth()` member access - allExpressions.forEach((rule) => { - streamAst(rule).forEach((node) => { - if (isMemberAccessExpr(node) && isAuthInvocation(node.operand)) { - authRules.push(node); - } + // field-level guards + this.writeFieldLevelDefs(model, writer, sourceFile); }); - }); - }); - - if (authRules.length > 0) { - return this.generateSelectForRules(authRules, true); - } else { - return undefined; - } - } - private isEnumReferenced(model: Model, decl: Enum): unknown { - return streamAllContents(model).some((node) => { - if (isDataModelField(node) && node.type.reference?.ref === decl) { - // referenced as field type - return true; - } - if (isEnumFieldReference(node) && node.target.ref?.$container === decl) { - // enum field is referenced - return true; + writer.writeLine(','); } - return false; }); + writer.writeLine(','); } - private getPolicyExpressions( - target: DataModel | DataModelField, - kind: PolicyKind, - operation: PolicyOperationKind, - override = false - ) { - const attributes = target.attributes; - const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; - const attrs = attributes.filter((attr) => { - if (attr.decl.ref?.name !== attrName) { - return false; - } + // #region Model-level definitions - if (override) { - const overrideArg = getAttributeArg(attr, 'override'); - return overrideArg && getLiteral(overrideArg) === true; - } else { - return true; - } + // writes model-level policy def for each operation kind for a model + // `[modelName]: { [operationKind]: [funcName] },` + private writeModelLevelDefs(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + const policies = analyzePolicies(model); + writer.write('modelLevel:'); + writer.inlineBlock(() => { + this.writeModelReadDef(model, policies, writer, sourceFile); + this.writeModelCreateDef(model, policies, writer, sourceFile); + this.writeModelUpdateDef(model, policies, writer, sourceFile); + this.writeModelPostUpdateDef(model, policies, writer, sourceFile); + this.writeModelDeleteDef(model, policies, writer, sourceFile); }); + writer.writeLine(','); + } - const checkOperation = operation === 'postUpdate' ? 'update' : operation; - - let result = attrs - .filter((attr) => { - const opsValue = getLiteral(attr.args[0].value); - if (!opsValue) { - return false; - } - const ops = opsValue.split(',').map((s) => s.trim()); - return ops.includes(checkOperation) || ops.includes('all'); - }) - .map((attr) => attr.args[1].value); - - if (operation === 'update') { - result = this.processUpdatePolicies(result, false); - } else if (operation === 'postUpdate') { - result = this.processUpdatePolicies(result, true); - } - - return result; + // writes `read: ...` for a given model + private writeModelReadDef( + model: DataModel, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + writer.write(`read:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'read', policies, writer, sourceFile); + }); + writer.writeLine(','); } - private processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { - const hasFutureReference = expressions.some((expr) => this.hasFutureReference(expr)); - if (postUpdate) { - // when compiling post-update rules, if any rule contains `future()` reference, - // we include all as post-update rules - return hasFutureReference ? expressions : []; - } else { - // when compiling pre-update rules, if any rule contains `future()` reference, - // we completely skip pre-update check and defer them to post-update - return hasFutureReference ? [] : expressions; - } + // writes `create: ...` for a given model + private writeModelCreateDef( + model: DataModel, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + writer.write(`create:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'create', policies, writer, sourceFile); + + // create policy has an additional input checker for validating the payload + this.writeCreateInputChecker(model, writer, sourceFile); + }); + writer.writeLine(','); } - private hasFutureReference(expr: Expression) { - for (const node of streamAst(expr)) { - if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { - return true; - } + // writes `inputChecker: [funcName]` for a given model + private writeCreateInputChecker(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + if (this.canCheckCreateBasedOnInput(model)) { + const inputCheckFunc = this.generateCreateInputCheckerFunction(model, sourceFile); + writer.write(`inputChecker: ${inputCheckFunc.getName()!},`); } - return false; } - private async generateQueryGuardForModel(model: DataModel, sourceFile: SourceFile) { - const result: Record = {}; - - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const policies: any = analyzePolicies(model); - - for (const kind of ALL_OPERATION_KINDS) { - if (policies[kind] === true || policies[kind] === false) { - result[kind] = policies[kind]; - if (kind === 'create') { - result[kind + '_input'] = policies[kind]; - } - continue; - } - - const denies = this.getPolicyExpressions(model, 'deny', kind); - const allows = this.getPolicyExpressions(model, 'allow', kind); + private canCheckCreateBasedOnInput(model: DataModel) { + const allows = getPolicyExpressions(model, 'allow', 'create', false, 'all'); + const denies = getPolicyExpressions(model, 'deny', 'create', false, 'all'); - if (kind === 'update' && allows.length === 0) { - // no allow rule for 'update', policy is constant based on if there's - // post-update counterpart - if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { - result[kind] = false; - continue; - } else { - result[kind] = true; - continue; + return [...allows, ...denies].every((rule) => { + return streamAst(rule).every((expr) => { + if (isThisExpr(expr)) { + return false; } - } - - if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) { - // no rule 'postUpdate', always allow - result[kind] = true; - continue; - } + if (isReferenceExpr(expr)) { + if (isDataModel(expr.$resolvedType?.decl)) { + // if policy rules uses relation fields, + // we can't check based on create input + return false; + } - const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies); - result[kind] = guardFunc.getName()!; + if ( + isDataModelField(expr.target.ref) && + expr.target.ref.$container === model && + hasAttribute(expr.target.ref, '@default') + ) { + // reference to field of current model + // if it has default value, we can't check + // based on create input + return false; + } - if (kind === 'postUpdate') { - const preValueSelect = this.generateSelectForRules([...allows, ...denies]); - if (preValueSelect) { - result[PRE_UPDATE_VALUE_SELECTOR] = preValueSelect; + if (isDataModelField(expr.target.ref) && isForeignKeyField(expr.target.ref)) { + // reference to foreign key field + // we can't check based on create input + return false; + } } - } - - if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) { - const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies); - result[kind + '_input'] = inputCheckFunc.getName()!; - } - } - - // generate field read checkers - this.generateReadFieldsCheckers(model, sourceFile, result); - // generate field read override guards - this.generateReadFieldsOverrideGuards(model, sourceFile, result); - - // generate field update guards - this.generateUpdateFieldsGuards(model, sourceFile, result); - - return result; - } - - private generateReadFieldsCheckers( - model: DataModel, - sourceFile: SourceFile, - result: Record - ) { - const allFieldsAllows: Expression[] = []; - const allFieldsDenies: Expression[] = []; - - for (const field of model.fields) { - const allows = this.getPolicyExpressions(field, 'allow', 'read'); - const denies = this.getPolicyExpressions(field, 'deny', 'read'); - if (denies.length === 0 && allows.length === 0) { - continue; - } - - allFieldsAllows.push(...allows); - allFieldsDenies.push(...denies); - - const guardFunc = this.generateReadFieldCheckerFunction(sourceFile, field, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`${FIELD_LEVEL_READ_CHECKER_PREFIX}${field.name}`] = guardFunc.getName()!; - } - - if (allFieldsAllows.length > 0 || allFieldsDenies.length > 0) { - result[HAS_FIELD_LEVEL_POLICY_FLAG] = true; - const readFieldCheckSelect = this.generateSelectForRules([...allFieldsAllows, ...allFieldsDenies]); - if (readFieldCheckSelect) { - result[FIELD_LEVEL_READ_CHECKER_SELECTOR] = readFieldCheckSelect; - } - } + return true; + }); + }); } - private generateReadFieldCheckerFunction( - sourceFile: SourceFile, - field: DataModelField, - allows: Expression[], - denies: Expression[] - ) { + // generates a function for checking "create" input + private generateCreateInputCheckerFunction(model: DataModel, sourceFile: SourceFile) { const statements: (string | WriterFunction)[] = []; + const allows = getPolicyExpressions(model, 'allow', 'create'); + const denies = getPolicyExpressions(model, 'deny', 'create'); - this.generateNormalizedAuthRef(field.$container as DataModel, allows, denies, statements); + generateNormalizedAuthRef(model, allows, denies, statements); - // compile rules down to typescript expressions statements.push((writer) => { + if (allows.length === 0) { + writer.write('return false;'); + return; + } + const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', }); - const denyStmt = + let expr = denies.length > 0 ? '!(' + denies @@ -418,34 +247,18 @@ export class PolicyGenerator { ')' : undefined; - const allowStmt = - allows.length > 0 - ? '(' + - allows - .map((allow) => { - return transformer.transform(allow); - }) - .join(' || ') + - ')' - : undefined; - - let expr: string | undefined; - - if (denyStmt && allowStmt) { - expr = `${denyStmt} && ${allowStmt}`; - } else if (denyStmt) { - expr = denyStmt; - } else if (allowStmt) { - expr = allowStmt; - } else { - throw new Error('should not happen'); - } + const allowStmt = allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(' || '); + expr = expr ? `${expr} && (${allowStmt})` : allowStmt; writer.write('return ' + expr); }); const func = sourceFile.addFunction({ - name: `${field.$container.name}$${field.name}_read`, + name: model.name + '_create_input', returnType: 'boolean', parameters: [ { @@ -463,323 +276,222 @@ export class PolicyGenerator { return func; } - private generateReadFieldsOverrideGuards( + // writes `update: ...` for a given model + private writeModelUpdateDef( model: DataModel, - sourceFile: SourceFile, - result: Record + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile ) { - for (const field of model.fields) { - const overrideAllows = this.getPolicyExpressions(field, 'allow', 'read', true); - if (overrideAllows.length > 0) { - const denies = this.getPolicyExpressions(field, 'deny', 'read'); - const overrideGuardFunc = this.generateQueryGuardFunction( - sourceFile, - model, - 'read', - overrideAllows, - denies, - field, - true - ); - result[`${FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${field.name}`] = overrideGuardFunc.getName()!; - } - } + writer.write(`update:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'update', policies, writer, sourceFile); + }); + writer.writeLine(','); } - private generateUpdateFieldsGuards( + // writes `postUpdate: ...` for a given model + private writeModelPostUpdateDef( model: DataModel, - sourceFile: SourceFile, - result: Record + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile ) { - for (const field of model.fields) { - const allows = this.getPolicyExpressions(field, 'allow', 'update'); - const denies = this.getPolicyExpressions(field, 'deny', 'update'); + writer.write(`postUpdate:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'postUpdate', policies, writer, sourceFile); - if (denies.length === 0 && allows.length === 0) { - continue; - } + // post-update policy has an additional selector for reading the pre-update entity data + this.writePostUpdatePreValueSelector(model, writer); + }); + writer.writeLine(','); + } - const guardFunc = this.generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field.name}`] = guardFunc.getName()!; - - const overrideAllows = this.getPolicyExpressions(field, 'allow', 'update', true); - if (overrideAllows.length > 0) { - const overrideGuardFunc = this.generateQueryGuardFunction( - sourceFile, - model, - 'update', - overrideAllows, - denies, - field, - true - ); - result[`${FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${field.name}`] = overrideGuardFunc.getName()!; - } + private writePostUpdatePreValueSelector(model: DataModel, writer: CodeBlockWriter) { + const allows = getPolicyExpressions(model, 'allow', 'postUpdate'); + const denies = getPolicyExpressions(model, 'deny', 'postUpdate'); + const preValueSelect = generateSelectForRules([...allows, ...denies]); + if (preValueSelect) { + writer.writeLine(`preUpdateSelector: ${JSON.stringify(preValueSelect)},`); } } - private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { - return [...allows, ...denies].every((rule) => { - return streamAst(rule).every((expr) => { - if (isThisExpr(expr)) { - return false; - } - if (isReferenceExpr(expr)) { - if (isDataModel(expr.$resolvedType?.decl)) { - // if policy rules uses relation fields, - // we can't check based on create input - return false; - } + // writes `delete: ...` for a given model + private writeModelDeleteDef( + model: DataModel, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + writer.write(`delete:`); + writer.inlineBlock(() => { + this.writeCommonModelDef(model, 'delete', policies, writer, sourceFile); + }); + } - if ( - isDataModelField(expr.target.ref) && - expr.target.ref.$container === model && - hasAttribute(expr.target.ref, '@default') - ) { - // reference to field of current model - // if it has default value, we can't check - // based on create input - return false; - } + // writes `[kind]: ...` for a given model + private writeCommonModelDef( + model: DataModel, + kind: PolicyOperationKind, + policies: PolicyAnalysisResult, + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + const allows = getPolicyExpressions(model, 'allow', kind); + const denies = getPolicyExpressions(model, 'deny', kind); - if (isDataModelField(expr.target.ref) && isForeignKeyField(expr.target.ref)) { - // reference to foreign key field - // we can't check based on create input - return false; - } - } + // policy guard + this.writePolicyGuard(model, kind, policies, allows, denies, writer, sourceFile); - return true; - }); - }); + // permission checker + if (kind !== 'postUpdate') { + this.writePermissionChecker(model, kind, policies, allows, denies, writer, sourceFile); + } + + // write cross-model comparison rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(model, kind, writer, sourceFile, true); } - // generates a "select" object that contains (recursively) fields referenced by the - // given policy rules - private generateSelectForRules(rules: Expression[], forAuthContext = false): object { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const result: any = {}; - const addPath = (path: string[]) => { - let curr = result; - path.forEach((seg, i) => { - if (i === path.length - 1) { - curr[seg] = true; - } else { - if (!curr[seg]) { - curr[seg] = { select: {} }; - } - curr = curr[seg].select; - } - }); - }; + private writeEntityChecker( + target: DataModel | DataModelField, + kind: PolicyOperationKind, + writer: CodeBlockWriter, + sourceFile: SourceFile, + onlyCrossModelComparison = false, + forOverride = false + ) { + const allows = getPolicyExpressions( + target, + 'allow', + kind, + forOverride, + onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all' + ); + const denies = getPolicyExpressions( + target, + 'deny', + kind, + forOverride, + onlyCrossModelComparison ? 'onlyCrossModelComparison' : 'all' + ); - // visit a reference or member access expression to build a - // selection path - const visit = (node: Expression): string[] | undefined => { - if (isThisExpr(node)) { - return []; - } + if (allows.length === 0 && denies.length === 0) { + return; + } - if (isReferenceExpr(node)) { - const target = resolved(node.target); - if (isDataModelField(target)) { - // a field selection, it's a terminal - return [target.name]; - } + const model = isDataModel(target) ? target : (target.$container as DataModel); + const func = generateEntityCheckerFunction( + sourceFile, + model, + kind, + allows, + denies, + isDataModelField(target) ? target : undefined, + forOverride + ); + const selector = generateSelectForRules([...allows, ...denies], false, kind !== 'postUpdate') ?? {}; + const key = forOverride ? 'overrideEntityChecker' : 'entityChecker'; + writer.write(`${key}: { func: ${func.getName()!}, selector: ${JSON.stringify(selector)} },`); + } + + // writes `guard: ...` for a given policy operation kind + private writePolicyGuard( + model: DataModel, + kind: PolicyOperationKind, + policies: ReturnType, + allows: Expression[], + denies: Expression[], + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + if (kind === 'update' && allows.length === 0) { + // no allow rule for 'update', policy is constant based on if there's + // post-update counterpart + if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { + writer.write(`guard: false,`); + } else { + writer.write(`guard: true,`); } + return; + } - if (isMemberAccessExpr(node)) { - if (forAuthContext && isAuthInvocation(node.operand)) { - return [node.member.$refText]; - } + if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) { + // no 'postUpdate' rule, always allow + writer.write(`guard: true,`); + return; + } - if (isFutureExpr(node.operand)) { - // future().field is not subject to pre-update select - return undefined; - } + if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') { + // constant policy + writer.write(`guard: ${policies[kind as keyof typeof policies]},`); + return; + } - // build a selection path inside-out for chained member access - const inner = visit(node.operand); - if (inner) { - return [...inner, node.member.$refText]; - } - } + // generate a policy function that evaluates a partial prisma query + const guardFunc = generateQueryGuardFunction(sourceFile, model, kind, allows, denies); + writer.write(`guard: ${guardFunc.getName()!},`); + } - return undefined; - }; - - // collect selection paths from the given expression - const collectReferencePaths = (expr: Expression): string[][] => { - if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) { - // a standalone `this` expression, include all id fields - const model = expr.$resolvedType?.decl as DataModel; - const idFields = getIdFields(model); - return idFields.map((field) => [field.name]); - } + // writes `permissionChecker: ...` for a given policy operation kind + private writePermissionChecker( + model: DataModel, + kind: PolicyCrudKind, + policies: PolicyAnalysisResult, + allows: Expression[], + denies: Expression[], + writer: CodeBlockWriter, + sourceFile: SourceFile + ) { + if (this.options.generatePermissionChecker !== true) { + return; + } - if (isMemberAccessExpr(expr) || isReferenceExpr(expr)) { - const path = visit(expr); - if (path) { - if (isDataModel(expr.$resolvedType?.decl)) { - // member selection ended at a data model field, include its id fields - const idFields = getIdFields(expr.$resolvedType?.decl as DataModel); - return idFields.map((field) => [...path, field.name]); - } else { - return [path]; - } - } else { - return []; - } - } else if (isCollectionPredicate(expr)) { - const path = visit(expr.left); - if (path) { - // recurse into RHS - const rhs = collectReferencePaths(expr.right); - // combine path of LHS and RHS - return rhs.map((r) => [...path, ...r]); - } else { - return []; - } - } else if (isInvocationExpr(expr)) { - // recurse into function arguments - return expr.args.flatMap((arg) => collectReferencePaths(arg.value)); + if (policies[kind] === true || policies[kind] === false) { + // constant policy + writer.write(`permissionChecker: ${policies[kind]},`); + return; + } + + if (kind === 'update' && allows.length === 0) { + // no allow rule for 'update', policy is constant based on if there's + // post-update counterpart + if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { + writer.write(`permissionChecker: false,`); } else { - // recurse - const children = streamContents(expr) - .filter((child): child is Expression => isExpression(child)) - .toArray(); - return children.flatMap((child) => collectReferencePaths(child)); + writer.write(`permissionChecker: true,`); } - }; - - for (const rule of rules) { - const paths = collectReferencePaths(rule); - paths.forEach((p) => addPath(p)); + return; } - return Object.keys(result).length === 0 ? undefined : result; + const guardFunc = this.generatePermissionCheckerFunction(model, kind, allows, denies, sourceFile); + writer.write(`permissionChecker: ${guardFunc.getName()!},`); } - private generateQueryGuardFunction( - sourceFile: SourceFile, + private generatePermissionCheckerFunction( model: DataModel, - kind: PolicyOperationKind, + kind: string, allows: Expression[], denies: Expression[], - forField?: DataModelField, - fieldOverride = false + sourceFile: SourceFile ) { - const statements: (string | WriterFunction)[] = []; + const statements: string[] = []; - this.generateNormalizedAuthRef(model, allows, denies, statements); - - const hasFieldAccess = [...denies, ...allows].some((rule) => - streamAst(rule).some( - (child) => - // this.??? - isThisExpr(child) || - // future().??? - isFutureExpr(child) || - // field reference - (isReferenceExpr(child) && isDataModelField(child.target.ref)) - ) - ); + generateNormalizedAuthRef(model, allows, denies, statements); - if (!hasFieldAccess) { - // none of the rules reference model fields, we can compile down to a plain boolean - // function in this case (so we can skip doing SQL queries when validating) - statements.push((writer) => { - const transformer = new TypeScriptExpressionTransformer({ - context: ExpressionContext.AccessPolicy, - isPostGuard: kind === 'postUpdate', - }); - try { - denies.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); - }); - allows.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); - }); - } catch (err) { - if (err instanceof TypeScriptExpressionTransformerError) { - throw new PluginError(name, err.message); - } else { - throw err; - } - } + const transformed = new ConstraintTransformer({ + authAccessor: 'user', + }).transformRules(allows, denies); - if (forField) { - if (allows.length === 0) { - // if there's no allow rule, for field-level rules, by default we allow - writer.write(`return ${TRUE};`); - } else { - // if there's any allow rule, we deny unless any allow rule evaluates to true - writer.write(`return ${FALSE};`); - } - } else { - // for model-level rules, the default is always deny - writer.write(`return ${FALSE};`); - } - }); - } else { - statements.push((writer) => { - writer.write('return '); - const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); - const writeDenies = () => { - writer.conditionalWrite(denies.length > 1, '{ AND: ['); - denies.forEach((expr, i) => { - writer.inlineBlock(() => { - writer.write('NOT: '); - exprWriter.write(expr); - }); - writer.conditionalWrite(i !== denies.length - 1, ','); - }); - writer.conditionalWrite(denies.length > 1, ']}'); - }; - - const writeAllows = () => { - writer.conditionalWrite(allows.length > 1, '{ OR: ['); - allows.forEach((expr, i) => { - exprWriter.write(expr); - writer.conditionalWrite(i !== allows.length - 1, ','); - }); - writer.conditionalWrite(allows.length > 1, ']}'); - }; - - if (allows.length > 0 && denies.length > 0) { - // include both allow and deny rules - writer.write('{ AND: ['); - writeDenies(); - writer.write(','); - writeAllows(); - writer.write(']}'); - } else if (denies.length > 0) { - // only deny rules - writeDenies(); - } else if (allows.length > 0) { - // only allow rules - writeAllows(); - } else { - // disallow any operation - writer.write(`{ OR: [] }`); - } - writer.write(';'); - }); - } + statements.push(`return ${transformed};`); const func = sourceFile.addFunction({ - name: `${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, - returnType: 'any', + name: `${model.name}$checker$${kind}`, + returnType: 'PermissionCheckerConstraint', parameters: [ { name: 'context', - type: 'QueryContext', - }, - { - // for generating field references used by field comparison in the same model - name: 'db', - type: 'CrudContract', + type: 'PermissionCheckerContext', }, ], statements, @@ -788,29 +500,131 @@ export class PolicyGenerator { return func; } - private generateInputCheckFunction( + // #endregion + + // #region Field-level definitions + + private writeFieldLevelDefs(model: DataModel, writer: CodeBlockWriter, sf: SourceFile) { + writer.write('fieldLevel:'); + writer.inlineBlock(() => { + this.writeFieldReadDef(model, writer, sf); + this.writeFieldUpdateDef(model, writer, sf); + }); + writer.writeLine(','); + } + + private writeFieldReadDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + writer.writeLine('read:'); + writer.block(() => { + for (const field of model.fields) { + const allows = getPolicyExpressions(field, 'allow', 'read'); + const denies = getPolicyExpressions(field, 'deny', 'read'); + const overrideAllows = getPolicyExpressions(field, 'allow', 'read', true); + + if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) { + continue; + } + + writer.write(`${field.name}:`); + + writer.block(() => { + // guard + const guardFunc = generateQueryGuardFunction(sourceFile, model, 'read', allows, denies, field); + writer.write(`guard: ${guardFunc.getName()},`); + + // checker function + // write all field-level rules as entity checker function + this.writeEntityChecker(field, 'read', writer, sourceFile, false, false); + + if (overrideAllows.length > 0) { + // override guard function + const denies = getPolicyExpressions(field, 'deny', 'read'); + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'read', + overrideAllows, + denies, + field, + true + ); + writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + + // additional entity checker for override + this.writeEntityChecker(field, 'read', writer, sourceFile, false, true); + } + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); + } + + private writeFieldUpdateDef(model: DataModel, writer: CodeBlockWriter, sourceFile: SourceFile) { + writer.writeLine('update:'); + writer.block(() => { + for (const field of model.fields) { + const allows = getPolicyExpressions(field, 'allow', 'update'); + const denies = getPolicyExpressions(field, 'deny', 'update'); + const overrideAllows = getPolicyExpressions(field, 'allow', 'update', true); + + if (allows.length === 0 && denies.length === 0 && overrideAllows.length === 0) { + continue; + } + + writer.write(`${field.name}:`); + + writer.block(() => { + // guard + const guardFunc = generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); + writer.write(`guard: ${guardFunc.getName()},`); + + // write cross-model comparison rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(field, 'update', writer, sourceFile, true, false); + + if (overrideAllows.length > 0) { + // override guard + const overrideGuardFunc = generateQueryGuardFunction( + sourceFile, + model, + 'update', + overrideAllows, + denies, + field, + true + ); + writer.write(`overrideGuard: ${overrideGuardFunc.getName()},`); + + // write cross-model comparison override rules as entity checker functions + // because they cannot be checked inside Prisma + this.writeEntityChecker(field, 'update', writer, sourceFile, true, true); + } + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); + } + + private generateFieldReadCheckerFunction( sourceFile: SourceFile, - model: DataModel, - kind: 'create' | 'update', + field: DataModelField, allows: Expression[], denies: Expression[] - ): FunctionDeclaration { + ) { const statements: (string | WriterFunction)[] = []; - this.generateNormalizedAuthRef(model, allows, denies, statements); + generateNormalizedAuthRef(field.$container as DataModel, allows, denies, statements); + // compile rules down to typescript expressions statements.push((writer) => { - if (allows.length === 0) { - writer.write('return false;'); - return; - } - const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', }); - let expr = + const denyStmt = denies.length > 0 ? '!(' + denies @@ -821,18 +635,34 @@ export class PolicyGenerator { ')' : undefined; - const allowStmt = allows - .map((allow) => { - return transformer.transform(allow); - }) - .join(' || '); + const allowStmt = + allows.length > 0 + ? '(' + + allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(' || ') + + ')' + : undefined; + + let expr: string | undefined; + + if (denyStmt && allowStmt) { + expr = `${denyStmt} && ${allowStmt}`; + } else if (denyStmt) { + expr = denyStmt; + } else if (allowStmt) { + expr = allowStmt; + } else { + throw new Error('should not happen'); + } - expr = expr ? `${expr} && (${allowStmt})` : allowStmt; writer.write('return ' + expr); }); const func = sourceFile.addFunction({ - name: model.name + '_' + kind + '_input', + name: `${field.$container.name}$${field.name}_read`, returnType: 'boolean', parameters: [ { @@ -850,95 +680,71 @@ export class PolicyGenerator { return func; } - private generateNormalizedAuthRef( - model: DataModel, - allows: Expression[], - denies: Expression[], - statements: (string | WriterFunction)[] - ) { - // check if any allow or deny rule contains 'auth()' invocation - const hasAuthRef = [...allows, ...denies].some((rule) => - streamAst(rule).some((child) => isAuthInvocation(child)) - ); + // #endregion - if (hasAuthRef) { - const authModel = getAuthModel(getDataModels(model.$container, true)); - if (!authModel) { - throw new PluginError(name, 'Auth model not found'); - } - const userIdFields = getIdFields(authModel); - if (!userIdFields || userIdFields.length === 0) { - throw new PluginError(name, 'User model does not have an id field'); - } + //#region Auth selector - // normalize user to null to avoid accidentally use undefined in filter - statements.push(`const user: any = context.user ?? null;`); + private writeAuthSelector(models: DataModel[], writer: CodeBlockWriter) { + const authSelector = this.generateAuthSelector(models); + if (authSelector) { + writer.write(`authSelector: ${JSON.stringify(authSelector)},`); } } - private async generateCheckerForModel(model: DataModel, sourceFile: SourceFile) { - const result: Record = {}; + // Generates a { select: ... } object to select `auth()` fields used in policy rules + private generateAuthSelector(models: DataModel[]) { + const authRules: Expression[] = []; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const policies = analyzePolicies(model); + models.forEach((model) => { + // model-level rules + const modelPolicyAttrs = model.attributes.filter((attr) => + ['@@allow', '@@deny'].includes(attr.decl.$refText) + ); - for (const kind of ['create', 'read', 'update', 'delete'] as const) { - if (policies[kind] === true || policies[kind] === false) { - result[kind] = policies[kind] as boolean; - continue; - } + // field-level rules + const fieldPolicyAttrs = model.fields + .flatMap((f) => f.attributes) + .filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText)); - const denies = this.getPolicyExpressions(model, 'deny', kind); - const allows = this.getPolicyExpressions(model, 'allow', kind); + // all rule expression + const allExpressions = [...modelPolicyAttrs, ...fieldPolicyAttrs] + .filter((attr) => attr.args.length > 1) + .map((attr) => attr.args[1].value); - if (kind === 'update' && allows.length === 0) { - // no allow rule for 'update', policy is constant based on if there's - // post-update counterpart - if (this.getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) { - result[kind] = false; - continue; - } else { - result[kind] = true; - continue; - } - } + // collect `auth()` member access + allExpressions.forEach((rule) => { + streamAst(rule).forEach((node) => { + if (isMemberAccessExpr(node) && isAuthInvocation(node.operand)) { + authRules.push(node); + } + }); + }); + }); - const guardFunc = this.generateCheckerFunction(sourceFile, model, kind, allows, denies); - result[kind] = guardFunc.getName()!; + if (authRules.length > 0) { + return generateSelectForRules(authRules, true); + } else { + return undefined; } - - return result; } - private generateCheckerFunction( - sourceFile: SourceFile, - model: DataModel, - kind: string, - allows: Expression[], - denies: Expression[] - ) { - const statements: string[] = []; - - this.generateNormalizedAuthRef(model, allows, denies, statements); + // #endregion - const transformed = new ConstraintTransformer({ - authAccessor: 'user', - }).transformRules(allows, denies); - - statements.push(`return ${transformed};`); + // #region Validation meta - const func = sourceFile.addFunction({ - name: `${model.name}$checker$${kind}`, - returnType: 'CheckerConstraint', - parameters: [ - { - name: 'context', - type: 'CheckerContext', - }, - ], - statements, + private writeValidationMeta(writer: CodeBlockWriter, models: DataModel[]) { + writer.write('validation:'); + writer.inlineBlock(() => { + for (const model of models) { + writer.write(`${lowerCaseFirst(model.name)}:`); + writer.inlineBlock(() => { + writer.write(`hasValidation: ${hasValidationAttributes(model)}`); + }); + writer.writeLine(','); + } }); - - return func; + writer.writeLine(','); } + + // #endregion } diff --git a/packages/schema/src/plugins/enhancer/policy/utils.ts b/packages/schema/src/plugins/enhancer/policy/utils.ts new file mode 100644 index 000000000..1085a6e88 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/policy/utils.ts @@ -0,0 +1,513 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime'; +import { + ExpressionContext, + PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, + getAttributeArg, + getAuthModel, + getDataModels, + getIdFields, + getLiteral, + isAuthInvocation, + isDataModelFieldReference, + isEnumFieldReference, + isFromStdlib, + isFutureExpr, + resolved, +} from '@zenstackhq/sdk'; +import { + Enum, + Model, + isBinaryExpr, + isDataModel, + isDataModelField, + isExpression, + isInvocationExpr, + isMemberAccessExpr, + isReferenceExpr, + isThisExpr, + type DataModel, + type DataModelField, + type Expression, +} from '@zenstackhq/sdk/ast'; +import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium'; +import { SourceFile, WriterFunction } from 'ts-morph'; +import { name } from '..'; +import { isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils'; +import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; + +/** + * Get policy expressions for the given model or field and operation kind + */ +export function getPolicyExpressions( + target: DataModel | DataModelField, + kind: PolicyKind, + operation: PolicyOperationKind, + forOverride = false, + filter: 'all' | 'withoutCrossModelComparison' | 'onlyCrossModelComparison' = 'all' +) { + const attributes = target.attributes; + const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; + const attrs = attributes.filter((attr) => { + if (attr.decl.ref?.name !== attrName) { + return false; + } + + const overrideArg = getAttributeArg(attr, 'override'); + const isOverride = !!overrideArg && getLiteral(overrideArg) === true; + + return (forOverride && isOverride) || (!forOverride && !isOverride); + }); + + const checkOperation = operation === 'postUpdate' ? 'update' : operation; + + let result = attrs + .filter((attr) => { + const opsValue = getLiteral(attr.args[0].value); + if (!opsValue) { + return false; + } + const ops = opsValue.split(',').map((s) => s.trim()); + return ops.includes(checkOperation) || ops.includes('all'); + }) + .map((attr) => attr.args[1].value); + + if (filter === 'onlyCrossModelComparison') { + result = result.filter((expr) => hasCrossModelComparison(expr)); + } else if (filter === 'withoutCrossModelComparison') { + result = result.filter((expr) => !hasCrossModelComparison(expr)); + } + + if (operation === 'update') { + result = processUpdatePolicies(result, false); + } else if (operation === 'postUpdate') { + result = processUpdatePolicies(result, true); + } + + return result; +} + +function hasFutureReference(expr: Expression) { + for (const node of streamAst(expr)) { + if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { + return true; + } + } + return false; +} + +function processUpdatePolicies(expressions: Expression[], postUpdate: boolean) { + const hasFutureRef = expressions.some(hasFutureReference); + if (postUpdate) { + // when compiling post-update rules, if any rule contains `future()` reference, + // we include all as post-update rules + return hasFutureRef ? expressions : []; + } else { + // when compiling pre-update rules, if any rule contains `future()` reference, + // we completely skip pre-update check and defer them to post-update + return hasFutureRef ? [] : expressions; + } +} + +/** + * Generates a "select" object that contains (recursively) fields referenced by the + * given policy rules + */ +export function generateSelectForRules(rules: Expression[], forAuthContext = false, ignoreFutureReference = true) { + const result: any = {}; + const addPath = (path: string[]) => { + const thisIndex = path.lastIndexOf('$this'); + if (thisIndex >= 0) { + // drop everything before $this + path = path.slice(thisIndex + 1); + } + let curr = result; + path.forEach((seg, i) => { + if (i === path.length - 1) { + curr[seg] = true; + } else { + if (!curr[seg]) { + curr[seg] = { select: {} }; + } + curr = curr[seg].select; + } + }); + }; + + // visit a reference or member access expression to build a + // selection path + const visit = (node: Expression): string[] | undefined => { + if (isThisExpr(node)) { + return ['$this']; + } + + if (isFutureExpr(node)) { + return []; + } + + if (isReferenceExpr(node)) { + const target = resolved(node.target); + if (isDataModelField(target)) { + // a field selection, it's a terminal + return [target.name]; + } + } + + if (isMemberAccessExpr(node)) { + if (forAuthContext && isAuthInvocation(node.operand)) { + return [node.member.$refText]; + } + + if (isFutureExpr(node.operand) && ignoreFutureReference) { + // future().field is not subject to pre-update select + return undefined; + } + + // build a selection path inside-out for chained member access + const inner = visit(node.operand); + if (inner) { + return [...inner, node.member.$refText]; + } + } + + return undefined; + }; + + // collect selection paths from the given expression + const collectReferencePaths = (expr: Expression): string[][] => { + if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) { + // a standalone `this` expression, include all id fields + const model = expr.$resolvedType?.decl as DataModel; + const idFields = getIdFields(model); + return idFields.map((field) => [field.name]); + } + + if (isMemberAccessExpr(expr) || isReferenceExpr(expr)) { + const path = visit(expr); + if (path) { + if (isDataModel(expr.$resolvedType?.decl)) { + // member selection ended at a data model field, include its id fields + const idFields = getIdFields(expr.$resolvedType?.decl as DataModel); + return idFields.map((field) => [...path, field.name]); + } else { + return [path]; + } + } else { + return []; + } + } else if (isCollectionPredicate(expr)) { + const path = visit(expr.left); + // recurse into RHS + const rhs = collectReferencePaths(expr.right); + if (path) { + // combine path of LHS and RHS + return rhs.map((r) => [...path, ...r]); + } else { + // LHS is not rooted from the current model, + // only keep RHS items that contains '$this' + return rhs.filter((r) => r.includes('$this')); + } + } else if (isInvocationExpr(expr)) { + // recurse into function arguments + return expr.args.flatMap((arg) => collectReferencePaths(arg.value)); + } else { + // recurse + const children = streamContents(expr) + .filter((child): child is Expression => isExpression(child)) + .toArray(); + return children.flatMap((child) => collectReferencePaths(child)); + } + }; + + for (const rule of rules) { + const paths = collectReferencePaths(rule); + paths.forEach((p) => addPath(p)); + } + + return Object.keys(result).length === 0 ? undefined : result; +} + +/** + * Generates a query guard function that returns a partial Prisma query for the given model or field + */ +export function generateQueryGuardFunction( + sourceFile: SourceFile, + model: DataModel, + kind: PolicyOperationKind, + allows: Expression[], + denies: Expression[], + forField?: DataModelField, + fieldOverride = false +) { + const statements: (string | WriterFunction)[] = []; + + const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule)); + const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule)); + + generateNormalizedAuthRef(model, allowRules, denyRules, statements); + + const hasFieldAccess = [...denyRules, ...allowRules].some((rule) => + streamAst(rule).some( + (child) => + // this.??? + isThisExpr(child) || + // future().??? + isFutureExpr(child) || + // field reference + (isReferenceExpr(child) && isDataModelField(child.target.ref)) + ) + ); + + if (!hasFieldAccess) { + // none of the rules reference model fields, we can compile down to a plain boolean + // function in this case (so we can skip doing SQL queries when validating) + statements.push((writer) => { + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + isPostGuard: kind === 'postUpdate', + }); + try { + denyRules.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); + }); + allowRules.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); + }); + } catch (err) { + if (err instanceof TypeScriptExpressionTransformerError) { + throw new PluginError(name, err.message); + } else { + throw err; + } + } + + if (forField) { + if (allows.length === 0) { + // if there's no allow rule, for field-level rules, by default we allow + writer.write(`return ${TRUE};`); + } else { + if (allowRules.length < allows.length) { + writer.write(`return ${TRUE};`); + } else { + // if there's any allow rule, we deny unless any allow rule evaluates to true + writer.write(`return ${FALSE};`); + } + } + } else { + if (allowRules.length < allows.length) { + // some rules are filtered out here and will be generated as additional + // checker functions, so we allow here to avoid a premature denial + writer.write(`return ${TRUE};`); + } else { + // for model-level rules, the default is always deny unless for 'postUpdate' + writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); + } + } + }); + } else { + statements.push((writer) => { + writer.write('return '); + const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); + const writeDenies = () => { + writer.conditionalWrite(denyRules.length > 1, '{ AND: ['); + denyRules.forEach((expr, i) => { + writer.inlineBlock(() => { + writer.write('NOT: '); + exprWriter.write(expr); + }); + writer.conditionalWrite(i !== denyRules.length - 1, ','); + }); + writer.conditionalWrite(denyRules.length > 1, ']}'); + }; + + const writeAllows = () => { + writer.conditionalWrite(allowRules.length > 1, '{ OR: ['); + allowRules.forEach((expr, i) => { + exprWriter.write(expr); + writer.conditionalWrite(i !== allowRules.length - 1, ','); + }); + writer.conditionalWrite(allowRules.length > 1, ']}'); + }; + + if (allowRules.length > 0 && denyRules.length > 0) { + // include both allow and deny rules + writer.write('{ AND: ['); + writeDenies(); + writer.write(','); + writeAllows(); + writer.write(']}'); + } else if (denyRules.length > 0) { + // only deny rules + writeDenies(); + } else if (allowRules.length > 0) { + // only allow rules + writeAllows(); + } else { + // disallow any operation unless for 'postUpdate' + writer.write(`return ${kind === 'postUpdate' ? TRUE : FALSE};`); + } + writer.write(';'); + }); + } + + const func = sourceFile.addFunction({ + name: `${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, + returnType: 'any', + parameters: [ + { + name: 'context', + type: 'QueryContext', + }, + { + // for generating field references used by field comparison in the same model + name: 'db', + type: 'CrudContract', + }, + ], + statements, + }); + + return func; +} + +export function generateEntityCheckerFunction( + sourceFile: SourceFile, + model: DataModel, + kind: PolicyOperationKind, + allows: Expression[], + denies: Expression[], + forField?: DataModelField, + fieldOverride = false +) { + const statements: (string | WriterFunction)[] = []; + + generateNormalizedAuthRef(model, allows, denies, statements); + + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + thisExprContext: 'input', + fieldReferenceContext: 'input', + isPostGuard: kind === 'postUpdate', + futureRefContext: 'input', + }); + + denies.forEach((rule) => { + const compiled = transformer.transform(rule); + statements.push(`if (${compiled}) { return false; }`); + }); + + allows.forEach((rule) => { + const compiled = transformer.transform(rule); + statements.push(`if (${compiled}) { return true; }`); + }); + + // default: deny unless for 'postUpdate' + statements.push(kind === 'postUpdate' ? 'return true;' : 'return false;'); + + const func = sourceFile.addFunction({ + name: `$check_${model.name}${forField ? '$' + forField.name : ''}${fieldOverride ? '$override' : ''}_${kind}`, + returnType: 'any', + parameters: [ + { + name: 'input', + type: 'any', + }, + { + name: 'context', + type: 'QueryContext', + }, + ], + statements, + }); + + return func; +} + +/** + * Generates a normalized auth reference for the given policy rules + */ +export function generateNormalizedAuthRef( + model: DataModel, + allows: Expression[], + denies: Expression[], + statements: (string | WriterFunction)[] +) { + // check if any allow or deny rule contains 'auth()' invocation + const hasAuthRef = [...allows, ...denies].some((rule) => streamAst(rule).some((child) => isAuthInvocation(child))); + + if (hasAuthRef) { + const authModel = getAuthModel(getDataModels(model.$container, true)); + if (!authModel) { + throw new PluginError(name, 'Auth model not found'); + } + const userIdFields = getIdFields(authModel); + if (!userIdFields || userIdFields.length === 0) { + throw new PluginError(name, 'User model does not have an id field'); + } + + // normalize user to null to avoid accidentally use undefined in filter + statements.push(`const user: any = context.user ?? null;`); + } +} + +/** + * Check if the given enum is referenced in the model + */ +export function isEnumReferenced(model: Model, decl: Enum): unknown { + return streamAllContents(model).some((node) => { + if (isDataModelField(node) && node.type.reference?.ref === decl) { + // referenced as field type + return true; + } + if (isEnumFieldReference(node) && node.target.ref?.$container === decl) { + // enum field is referenced + return true; + } + return false; + }); +} + +function hasCrossModelComparison(expr: Expression) { + return streamAst(expr).some((node) => { + if (isBinaryExpr(node) && ['==', '!=', '>', '<', '>=', '<=', 'in'].includes(node.operator)) { + const leftRoot = getSourceModelOfFieldAccess(node.left); + const rightRoot = getSourceModelOfFieldAccess(node.right); + if (leftRoot && rightRoot && leftRoot !== rightRoot) { + return true; + } + } + return false; + }); +} + +function getSourceModelOfFieldAccess(expr: Expression) { + // an expression that resolves to a data model and is part of a member access, return the model + // e.g.: profile.age => Profile + if (isDataModel(expr.$resolvedType?.decl) && isMemberAccessExpr(expr.$container)) { + return expr.$resolvedType?.decl; + } + + // `this` reference + if (isThisExpr(expr)) { + return getContainerOfType(expr, isDataModel); + } + + // `future()` + if (isFutureInvocation(expr)) { + return getContainerOfType(expr, isDataModel); + } + + // direct field reference, return the model + if (isDataModelFieldReference(expr)) { + return (expr.target.ref as DataModelField).$container; + } + + // member access + if (isMemberAccessExpr(expr)) { + return getSourceModelOfFieldAccess(expr.operand); + } + + return undefined; +} diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 58ecd68f4..e73d8203f 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -1,4 +1,5 @@ import { + AbstractDeclaration, AttributeArg, BooleanLiteral, ConfigArrayExpr, @@ -26,6 +27,7 @@ import { LiteralExpr, Model, NumberLiteral, + ReferenceExpr, StringLiteral, } from '@zenstackhq/language/ast'; import { getPrismaVersion } from '@zenstackhq/sdk/prisma'; @@ -38,6 +40,7 @@ import { getAttributeArg, getAttributeArgLiteral, getLiteral, + getRelationKeyPairs, isDelegateModel, isIdField, PluginError, @@ -50,7 +53,6 @@ import { writeFile } from 'fs/promises'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import semver from 'semver'; -import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; import { execPackage } from '../../utils/exec-utils'; @@ -79,6 +81,10 @@ const MODEL_PASSTHROUGH_ATTR = '@@prisma.passthrough'; const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough'; const PROVIDERS_SUPPORTING_NAMED_CONSTRAINTS = ['postgresql', 'mysql', 'cockroachdb']; +// Some database providers like postgres and mysql have default limit to the length of identifiers +// Here we use a conservative value that should work for most cases, and truncate names if needed +const IDENTIFIER_NAME_MAX_LENGTH = 50 - DELEGATE_AUX_RELATION_PREFIX.length; + /** * Generates Prisma schema file */ @@ -94,6 +100,9 @@ export class PrismaSchemaGenerator { private mode: 'logical' | 'physical' = 'physical'; + // a mapping from shortened names to their original full names + private shortNameMap = new Map(); + constructor(private readonly zmodel: Model) {} async generate(options: PluginOptions) { @@ -307,7 +316,7 @@ export class PrismaSchemaGenerator { // generate an optional relation field in delegate base model to each concrete model concreteModels.forEach((concrete) => { - const auxName = `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}`; + const auxName = `${DELEGATE_AUX_RELATION_PREFIX}_${this.truncate(lowerCaseFirst(concrete.name))}`; model.addField(auxName, new ModelFieldType(concrete.name, false, true)); }); } @@ -328,7 +337,7 @@ export class PrismaSchemaGenerator { const idFields = getIdFields(base); // add relation fields - const relationField = `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(base.name)}`; + const relationField = `${DELEGATE_AUX_RELATION_PREFIX}_${this.truncate(lowerCaseFirst(base.name))}`; model.addField(relationField, base.name, [ new PrismaFieldAttribute('@relation', [ new PrismaAttributeArg( @@ -364,7 +373,7 @@ export class PrismaSchemaGenerator { }); } - private expandPolymorphicRelations(model: PrismaDataModel, decl: DataModel) { + private expandPolymorphicRelations(model: PrismaDataModel, dataModel: DataModel) { if (this.mode !== 'logical') { return; } @@ -373,58 +382,64 @@ export class PrismaSchemaGenerator { // for the given model, find relation fields of delegate model type, find all concrete models // of the delegate model and generate an auxiliary opposite relation field to each of them - decl.fields.forEach((f) => { + dataModel.fields.forEach((field) => { // don't process fields inherited from a delegate model - if (f.$inheritedFrom && isDelegateModel(f.$inheritedFrom)) { + if (field.$inheritedFrom && isDelegateModel(field.$inheritedFrom)) { return; } - const fieldType = f.type.reference?.ref; + const fieldType = field.type.reference?.ref; if (!isDataModel(fieldType)) { return; } // find concrete models that inherit from this field's model type - const concreteModels = decl.$container.declarations.filter( + const concreteModels = dataModel.$container.declarations.filter( (d) => isDataModel(d) && isDescendantOf(d, fieldType) ); - // aux relation name format: delegate_aux_[model]_[relationField]_[concrete] - // e.g., delegate_aux_User_myAsset_Video - concreteModels.forEach((concrete) => { - const relationField = model.addField( - `${DELEGATE_AUX_RELATION_PREFIX}_${decl.name}_${f.name}_${concrete.name}`, - new ModelFieldType(concrete.name, f.type.array, f.type.optional) + // aux relation name format: delegate_aux_[model]_[relationField]_[concrete] + // e.g., delegate_aux_User_myAsset_Video + const auxRelationName = `${dataModel.name}_${field.name}_${concrete.name}`; + const auxRelationField = model.addField( + `${DELEGATE_AUX_RELATION_PREFIX}_${this.truncate(auxRelationName)}`, + new ModelFieldType(concrete.name, field.type.array, field.type.optional) ); - const relAttr = getAttribute(f, '@relation'); + + const relAttr = getAttribute(field, '@relation'); if (relAttr) { - const fieldsArg = relAttr.args.find((arg) => arg.name === 'fields'); + const fieldsArg = getAttributeArg(relAttr, 'fields'); if (fieldsArg) { - const idFields = getIdFields(fieldType); - - // add fk fields, e.g., delegate_aux_User_myAsset_VideoId - const addedIdFields = idFields.map((idField) => - model.addField(`${relationField.name}${upperCaseFirst(idField.name)}`, idField.type.type!) - ); + // for reach foreign key field pointing to the delegate model, we need to create an aux foreign key + // to point to the concrete model + const relationFieldPairs = getRelationKeyPairs(field); + const addedFkFields: ModelField[] = []; + for (const { foreignKey } of relationFieldPairs) { + const addedFkField = this.replicateForeignKey(model, dataModel, concrete, foreignKey); + addedFkFields.push(addedFkField); + } + // the `@relation(..., fields: [...])` attribute argument const fieldsArg = new AttributeArgValue( 'Array', - addedIdFields.map( - (f) => new AttributeArgValue('FieldReference', new PrismaFieldReference(f.name)) + addedFkFields.map( + (addedFk) => + new AttributeArgValue('FieldReference', new PrismaFieldReference(addedFk.name)) ) ); + // the `@relation(..., references: [...])` attribute argument const referencesArg = new AttributeArgValue( 'Array', - idFields.map( - (f) => new AttributeArgValue('FieldReference', new PrismaFieldReference(f.name)) + relationFieldPairs.map( + ({ id }) => new AttributeArgValue('FieldReference', new PrismaFieldReference(id.name)) ) ); const addedRel = new PrismaFieldAttribute('@relation', [ // use field name as relation name for disambiguation - new PrismaAttributeArg(undefined, new AttributeArgValue('String', relationField.name)), + new PrismaAttributeArg(undefined, new AttributeArgValue('String', auxRelationField.name)), new PrismaAttributeArg('fields', fieldsArg), new PrismaAttributeArg('references', referencesArg), ]); @@ -434,20 +449,20 @@ export class PrismaSchemaGenerator { // generate a `map` argument for foreign key constraint disambiguation new PrismaAttributeArg( 'map', - new PrismaAttributeArgValue('String', `${relationField.name}_fk`) + new PrismaAttributeArgValue('String', `${auxRelationField.name}_fk`) ) ); } - relationField.attributes.push(addedRel); + auxRelationField.attributes.push(addedRel); } else { - relationField.attributes.push(this.makeFieldAttribute(relAttr as DataModelFieldAttribute)); + auxRelationField.attributes.push(this.makeFieldAttribute(relAttr as DataModelFieldAttribute)); } } else { - relationField.attributes.push( + auxRelationField.attributes.push( new PrismaFieldAttribute('@relation', [ // use field name as relation name for disambiguation - new PrismaAttributeArg(undefined, new AttributeArgValue('String', relationField.name)), + new PrismaAttributeArg(undefined, new AttributeArgValue('String', auxRelationField.name)), ]) ); } @@ -455,6 +470,107 @@ export class PrismaSchemaGenerator { }); } + private replicateForeignKey( + model: PrismaDataModel, + dataModel: DataModel, + concreteModel: AbstractDeclaration, + origForeignKey: DataModelField + ) { + // aux fk name format: delegate_aux_[model]_[fkField]_[concrete] + // e.g., delegate_aux_User_myAssetId_Video + + // generate a fk field based on the original fk field + const addedFkField = this.generateModelField(model, origForeignKey); + + // fix its name + const addedFkFieldName = `${dataModel.name}_${origForeignKey.name}_${concreteModel.name}`; + addedFkField.name = `${DELEGATE_AUX_RELATION_PREFIX}_${this.truncate(addedFkFieldName)}`; + + // we also need to make sure `@unique` constraint's `map` parameter is fixed to avoid conflict + const uniqueAttr = addedFkField.attributes.find( + (attr) => (attr as PrismaFieldAttribute).name === '@unique' + ) as PrismaFieldAttribute; + if (uniqueAttr) { + const mapArg = uniqueAttr.args.find((arg) => arg.name === 'map'); + const constraintName = `${addedFkField.name}_unique`; + if (mapArg) { + mapArg.value = new AttributeArgValue('String', constraintName); + } else { + uniqueAttr.args.push(new PrismaAttributeArg('map', new AttributeArgValue('String', constraintName))); + } + } + + // we also need to go through model-level `@@unique` and replicate those involving fk fields + this.replicateForeignKeyModelLevelUnique(model, dataModel, origForeignKey, addedFkField); + + return addedFkField; + } + + private replicateForeignKeyModelLevelUnique( + model: PrismaDataModel, + dataModel: DataModel, + origForeignKey: DataModelField, + addedFkField: ModelField + ) { + for (const uniqueAttr of dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@unique')) { + const fields = getAttributeArg(uniqueAttr, 'fields'); + if (fields && isArrayExpr(fields)) { + const found = fields.items.find( + (fieldRef) => isReferenceExpr(fieldRef) && fieldRef.target.ref === origForeignKey + ); + if (found) { + // replicate the attribute and replace the field reference with the new FK field + const args: PrismaAttributeArgValue[] = []; + for (const arg of fields.items) { + if (isReferenceExpr(arg) && arg.target.ref === origForeignKey) { + // replace + args.push( + new PrismaAttributeArgValue( + 'FieldReference', + new PrismaFieldReference(addedFkField.name) + ) + ); + } else { + // copy + args.push( + new PrismaAttributeArgValue( + 'FieldReference', + new PrismaFieldReference((arg as ReferenceExpr).target.$refText) + ) + ); + } + } + + model.addAttribute('@@unique', [ + new PrismaAttributeArg(undefined, new PrismaAttributeArgValue('Array', args)), + ]); + } + } + } + } + + private truncate(name: string) { + if (name.length <= IDENTIFIER_NAME_MAX_LENGTH) { + return name; + } + + const shortName = name.slice(0, IDENTIFIER_NAME_MAX_LENGTH); + const entry = this.shortNameMap.get(shortName); + if (!entry) { + this.shortNameMap.set(shortName, [name]); + return `${shortName}_0`; + } else { + const index = entry.findIndex((n) => n === name); + if (index >= 0) { + return `${shortName}_${index}`; + } else { + const newIndex = entry.length; + entry.push(name); + return `${shortName}_${newIndex}`; + } + } + } + private nameRelationsInheritedFromDelegate(model: PrismaDataModel, decl: DataModel) { if (this.mode !== 'logical') { return; @@ -463,18 +579,30 @@ export class PrismaSchemaGenerator { // the logical schema needs to name relations inherited from delegate base models for disambiguation decl.fields.forEach((f) => { - if (!f.$inheritedFrom || !isDelegateModel(f.$inheritedFrom) || !isDataModel(f.type.reference?.ref)) { + if (!isDataModel(f.type.reference?.ref)) { + // only process relation fields return; } - const prismaField = model.fields.find((field) => field.name === f.name); - if (!prismaField) { + if (!f.$inheritedFrom) { + // only process inherited fields return; } - // find the base field that this field is inherited from - const baseField = f.$inheritedFrom.fields.find((field) => field.name === f.name); + // Walk up the inheritance chain to find a field with matching name + // which is where this field is inherited from. + // + // Note that we can't walk all the way up to the $inheritedFrom model + // because it may have been eliminated because of being abstract. + + const baseField = this.findUpMatchingFieldFromDelegate(decl, f); if (!baseField) { + // only process fields inherited from delegate models + return; + } + + const prismaField = model.fields.find((field) => field.name === f.name); + if (!prismaField) { return; } @@ -488,7 +616,8 @@ export class PrismaSchemaGenerator { // relation name format: delegate_aux_[relationType]_[oppositeRelationField]_[concrete] const relAttr = getAttribute(f, '@relation'); - const relName = `${DELEGATE_AUX_RELATION_PREFIX}_${fieldType.name}_${oppositeRelationField.name}_${decl.name}`; + const name = `${fieldType.name}_${oppositeRelationField.name}_${decl.name}`; + const relName = `${DELEGATE_AUX_RELATION_PREFIX}_${this.truncate(name)}`; if (relAttr) { const nameArg = getAttributeArg(relAttr, 'name'); @@ -512,6 +641,28 @@ export class PrismaSchemaGenerator { }); } + private findUpMatchingFieldFromDelegate(start: DataModel, target: DataModelField): DataModelField | undefined { + for (const base of start.superTypes) { + if (isDataModel(base.ref)) { + if (isDelegateModel(base.ref)) { + const field = base.ref.fields.find((f) => f.name === target.name); + if (field) { + if (!field.$inheritedFrom || !isDelegateModel(field.$inheritedFrom)) { + // if this field is not inherited from an upper delegate, we're done + return field; + } + } + } + + const upper = this.findUpMatchingFieldFromDelegate(base.ref, target); + if (upper) { + return upper; + } + } + } + return undefined; + } + private getOppositeRelationField(oppositeModel: DataModel, relationField: DataModelField) { const relName = this.getRelationName(relationField); return oppositeModel.fields.find( @@ -609,6 +760,7 @@ export class PrismaSchemaGenerator { // user defined comments pass-through field.comments.forEach((c) => result.addComment(c)); + return result; } private setDummyDefault(result: ModelField, field: DataModelField) { diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index 5f3321b94..3df90bd95 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -152,7 +152,13 @@ export function makeFieldSchema(field: DataModelField) { } else { const schemaDefault = getFieldSchemaDefault(field); if (schemaDefault !== undefined) { - schema += `.default(${schemaDefault})`; + if (field.type.type === 'BigInt') { + // we can't use the `n` BigInt literal notation, since it needs + // ES2020 or later, which TypeScript doesn't use by default + schema += `.default(BigInt("${schemaDefault}"))`; + } else { + schema += `.default(${schemaDefault})`; + } } if (field.type.optional) { diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index e891056c2..24af8862d 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -296,3 +296,19 @@ export function getAllLoadedAndReachableDataModels(langiumDocuments: LangiumDocu return allDataModels; } + +/** + * Walk up the inheritance chain to find the path from the start model to the target model + */ +export function findUpInheritance(start: DataModel, target: DataModel): DataModel[] | undefined { + for (const base of start.superTypes) { + if (base.ref === target) { + return [base.ref]; + } + const path = findUpInheritance(base.ref as DataModel, target); + if (path) { + return [base.ref as DataModel, ...path]; + } + } + return undefined; +} diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index 6eaf06399..5affcec77 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -34,7 +34,6 @@ describe('Prisma generator test', () => { provider = 'postgresql' url = env("DATABASE_URL") directUrl = env("DATABASE_URL") - shadowDatabaseUrl = env("DATABASE_URL") extensions = [pg_trgm, postgis(version: "3.3.2"), uuid_ossp(map: "uuid-ossp", schema: "extensions")] schemas = ["auth", "public"] } @@ -67,7 +66,6 @@ describe('Prisma generator test', () => { expect(content).toContain('provider = "postgresql"'); expect(content).toContain('url = env("DATABASE_URL")'); expect(content).toContain('directUrl = env("DATABASE_URL")'); - expect(content).toContain('shadowDatabaseUrl = env("DATABASE_URL")'); expect(content).toContain( 'extensions = [pg_trgm, postgis(version: "3.3.2"), uuid_ossp(map: "uuid-ossp", schema: "extensions")]' ); @@ -253,9 +251,9 @@ describe('Prisma generator test', () => { expect(content).toContain(`@@map("_Role")`); expect(content).toContain(`@map("admin")`); expect(content).toContain(`@map("customer")`); - expect(content).toContain('/// Admin role documentation line 1\n' + - ' /// Admin role documentation line 2\n' + - ' ADMIN'); + expect(content).toContain( + '/// Admin role documentation line 1\n' + ' /// Admin role documentation line 2\n' + ' ADMIN' + ); }); it('attribute passthrough', async () => { diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 380836e21..b2ac1544b 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -699,7 +699,36 @@ describe('Attribute tests', () => { } `) - ).toContain('comparison between fields of different models are not supported'); + ).toContain('comparison between fields of different models is not supported in model-level "read" rules'); + + expect( + await loadModel(` + ${prelude} + model User { + id Int @id + lists List[] + todos Todo[] + } + + model List { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + todos Todo[] + } + + model Todo { + id Int @id + user User @relation(fields: [userId], references: [id]) + userId Int + list List @relation(fields: [listId], references: [id]) + listId Int + + @@allow('create', list.user.id == userId) + } + + `) + ).toBeTruthy(); expect( await loadModelWithError(` diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index e0778da51..e7dd6bf84 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -88,7 +88,7 @@ describe('Data Model Validation Tests', () => { @@allow('all', members?[this == auth()]) } `) - ).toMatchObject(errorLike('using `this` in collection predicate is not supported')); + ).toBeTruthy(); expect( await loadModel(` diff --git a/packages/schema/tests/schema/validation/datasource-validation.test.ts b/packages/schema/tests/schema/validation/datasource-validation.test.ts index 469ba5ac1..0c90da4db 100644 --- a/packages/schema/tests/schema/validation/datasource-validation.test.ts +++ b/packages/schema/tests/schema/validation/datasource-validation.test.ts @@ -13,9 +13,9 @@ describe('Datasource Validation Tests', () => { cause: [ { message: 'datasource must include a "provider" field' }, { message: 'datasource must include a "url" field' }, - ] - } - }) + ], + }, + }); }); it('dup fields', async () => { @@ -63,14 +63,6 @@ describe('Datasource Validation Tests', () => { } `) ).toContain('"url" must be set to a string literal or an invocation of "env" function'); - - expect( - await loadModelWithError(` - datasource db { - shadowDatabaseUrl = 123 - } - `) - ).toContain('"shadowDatabaseUrl" must be set to a string literal or an invocation of "env" function'); }); it('invalid relationMode value', async () => { @@ -96,7 +88,6 @@ describe('Datasource Validation Tests', () => { datasource db { provider = "postgresql" url = "url" - shadowDatabaseUrl = "shadow" relationMode = "prisma" } `); @@ -105,7 +96,6 @@ describe('Datasource Validation Tests', () => { datasource db { provider = "postgresql" url = env("url") - shadowDatabaseUrl = env("shadowUrl") relationMode = "foreignKeys" } `); diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 436a715a6..4d1dedde7 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,12 +1,12 @@ { "name": "@zenstackhq/sdk", - "version": "2.1.2", + "version": "2.2.0", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { "clean": "rimraf dist", "lint": "eslint src --ext ts", - "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./LICENSE ./README.md dist && pnpm pack dist --pack-destination '../../../.build'", + "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./LICENSE ./README.md dist && pnpm pack dist --pack-destination ../../../.build", "watch": "tsc --watch", "prepublishOnly": "pnpm build" }, @@ -18,8 +18,8 @@ "author": "", "license": "MIT", "dependencies": { - "@prisma/generator-helper": "^5.13.0", - "@prisma/internals": "^5.13.0", + "@prisma/generator-helper": "^5.15.0", + "@prisma/internals": "^5.15.0", "@zenstackhq/language": "workspace:*", "@zenstackhq/runtime": "workspace:*", "langium": "1.3.1", diff --git a/packages/sdk/src/policy.ts b/packages/sdk/src/policy.ts index ccd3e851f..c9eea9865 100644 --- a/packages/sdk/src/policy.ts +++ b/packages/sdk/src/policy.ts @@ -2,6 +2,8 @@ import type { DataModel, DataModelAttribute } from './ast'; import { getLiteral } from './utils'; import { hasValidationAttributes } from './validation'; +export type PolicyAnalysisResult = ReturnType; + export function analyzePolicies(dataModel: DataModel) { const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow'); const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny'); diff --git a/packages/sdk/src/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts index 8e33eb4a7..28ce1d345 100644 --- a/packages/sdk/src/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -34,6 +34,7 @@ type Options = { isPostGuard?: boolean; fieldReferenceContext?: string; thisExprContext?: string; + futureRefContext?: string; context: ExpressionContext; }; @@ -116,7 +117,9 @@ export class TypeScriptExpressionTransformer { if (this.options?.isPostGuard !== true) { throw new TypeScriptExpressionTransformerError(`future() is only supported in postUpdate rules`); } - return expr.member.ref.name; + return this.options.futureRefContext + ? `${this.options.futureRefContext}.${expr.member.ref.name}` + : expr.member.ref.name; } else { if (normalizeUndefined) { // normalize field access to null instead of undefined to avoid accidentally use undefined in filter @@ -449,7 +452,6 @@ export class TypeScriptExpressionTransformer { ...this.options, isPostGuard: false, fieldReferenceContext: '_item', - thisExprContext: '_item', }); const predicate = innerTransformer.transform(expr.right, normalizeUndefined); diff --git a/packages/server/package.json b/packages/server/package.json index 1802f0771..11c41bae1 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,12 +1,12 @@ { "name": "@zenstackhq/server", - "version": "2.1.2", + "version": "2.2.0", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", "scripts": { "clean": "rimraf dist", - "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE dist && pnpm pack dist --pack-destination '../../../.build'", + "build": "pnpm lint --max-warnings=0 && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE dist && pnpm pack dist --pack-destination ../../../.build", "watch": "tsc --watch", "lint": "eslint src --ext ts", "test": "jest", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 2c3bc1ae1..8c1a9dc5e 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "2.1.2", + "version": "2.2.0", "description": "ZenStack Test Tools", "main": "index.js", "private": true, @@ -11,7 +11,7 @@ "scripts": { "clean": "rimraf dist", "lint": "eslint src --ext ts", - "build": "pnpm lint && pnpm clean && tsc && copyfiles ./package.json ./LICENSE ./README.md dist && pnpm pack dist --pack-destination '../../../.build'", + "build": "pnpm lint && pnpm clean && tsc && copyfiles ./package.json ./LICENSE ./README.md dist && pnpm pack dist --pack-destination ../../../.build", "watch": "tsc --watch", "prepublishOnly": "pnpm build" }, diff --git a/packages/testtools/src/db.ts b/packages/testtools/src/db.ts index 8de49a7c1..16142d527 100644 --- a/packages/testtools/src/db.ts +++ b/packages/testtools/src/db.ts @@ -1,16 +1,31 @@ import { Pool } from 'pg'; -const USERNAME = 'postgres'; -const PASSWORD = 'abc123'; +const USERNAME = process.env.ZENSTACK_TEST_DB_USER || 'postgres'; +const PASSWORD = process.env.ZENSTACK_TEST_DB_PASS || 'abc123'; +const CONNECTION_DB = process.env.ZENSTACK_TEST_DB_NAME || 'postgres'; +const HOST = process.env.ZENSTACK_TEST_DB_HOST || 'localhost'; +const PORT = (process.env.ZENSTACK_TEST_DB_PORT ? parseInt(process.env.ZENSTACK_TEST_DB_PORT) : null) || 5432; + +function connect() { + return new Pool({ + user: USERNAME, + password: PASSWORD, + database: CONNECTION_DB, + host: HOST, + port: PORT + }); +} export async function createPostgresDb(db: string) { - const pool = new Pool({ user: USERNAME, password: PASSWORD }); + const pool = connect(); await pool.query(`DROP DATABASE IF EXISTS "${db}";`); await pool.query(`CREATE DATABASE "${db}";`); - return `postgresql://${USERNAME}:${PASSWORD}@localhost:5432/${db}`; + await pool.end(); + return `postgresql://${USERNAME}:${PASSWORD}@${HOST}:${PORT}/${db}`; } export async function dropPostgresDb(db: string) { - const pool = new Pool({ user: USERNAME, password: PASSWORD }); + const pool = connect(); await pool.query(`DROP DATABASE IF EXISTS "${db}";`); + await pool.end(); } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 4495ddf14..fb90fac4b 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -3,6 +3,7 @@ import type { Model } from '@zenstackhq/language/ast'; import { DEFAULT_RUNTIME_LOAD_PATH, + PolicyDef, type AuthUser, type CrudContract, type EnhancementKind, @@ -43,14 +44,12 @@ export type FullDbClientContract = CrudContract & { export function run(cmd: string, env?: Record, cwd?: string) { try { - const start = Date.now(); execSync(cmd, { stdio: 'pipe', encoding: 'utf-8', env: { ...process.env, DO_NOT_TRACK: '1', ...env }, cwd, }); - console.log('Execution took', Date.now() - start, 'ms', '-', cmd); } catch (err) { console.error('Command failed:', cmd, err); throw err; @@ -299,7 +298,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { projectDir, enhance: undefined as any, enhanceRaw: undefined as any, - policy: undefined as any, + policy: undefined as unknown as PolicyDef, modelMeta: undefined as any, zodSchemas: undefined as any, }; @@ -311,7 +310,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { : path.join(projectDir, opt.output) : path.join(projectDir, 'node_modules', DEFAULT_RUNTIME_LOAD_PATH); - const policy = require(path.join(outputPath, 'policy')).default; + const policy: PolicyDef = require(path.join(outputPath, 'policy')).default; const modelMeta = require(path.join(outputPath, 'model-meta')).default; let zodSchemas: any; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a8e8c7318..dfeda8f4e 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -65,6 +65,9 @@ importers: '@zenstackhq/language': specifier: workspace:* version: link:../../language/dist + run-script-os: + specifier: ^1.1.6 + version: 1.1.6 zenstack: specifier: workspace:* version: link:../../schema/dist @@ -334,6 +337,9 @@ importers: swr: specifier: ^2.0.3 version: 2.0.3(react@18.2.0) + tmp: + specifier: ^0.2.3 + version: 0.2.3 vue: specifier: ^3.3.4 version: 3.3.4 @@ -392,8 +398,8 @@ importers: packages/runtime: dependencies: '@prisma/client': - specifier: 5.0.0 - 5.13.x - version: 5.12.0 + specifier: 5.0.0 - 5.15.x + version: 5.15.0(prisma@5.15.0) bcryptjs: specifier: ^2.4.3 version: 2.4.3 @@ -412,6 +418,9 @@ importers: deepmerge: specifier: ^4.3.1 version: 4.3.1 + is-plain-object: + specifier: ^5.0.0 + version: 5.0.0 logic-solver: specifier: ^2.0.1 version: 2.0.1 @@ -566,8 +575,8 @@ importers: version: 1.5.0(zod@3.22.4) devDependencies: '@prisma/client': - specifier: ^5.13.0 - version: 5.13.0(prisma@5.13.0) + specifier: ^5.15.0 + version: 5.15.0(prisma@5.15.0) '@types/async-exit-hook': specifier: ^2.0.0 version: 2.0.0 @@ -602,8 +611,8 @@ importers: specifier: ^0.15.12 version: 0.15.12 prisma: - specifier: ^5.13.0 - version: 5.13.0 + specifier: ^5.15.0 + version: 5.15.0 renamer: specifier: ^4.0.0 version: 4.0.0 @@ -621,11 +630,11 @@ importers: packages/sdk: dependencies: '@prisma/generator-helper': - specifier: ^5.13.0 - version: 5.13.0 + specifier: ^5.15.0 + version: 5.15.0 '@prisma/internals': - specifier: ^5.13.0 - version: 5.13.0 + specifier: ^5.15.0 + version: 5.15.0 '@zenstackhq/language': specifier: workspace:* version: link:../language/dist @@ -3840,19 +3849,8 @@ packages: resolution: {integrity: sha512-a5Sab1C4/icpTZVzZc5Ghpz88yQtGOyNqYXcZgOssB2uuAr+wF/MvN6bgtW32q7HHrvBki+BsZ0OuNv6EV3K9g==} dev: true - /@prisma/client@5.12.0: - resolution: {integrity: sha512-bk/+KPpRm0+IzqFCtAxrj+/TNiHzulspnO+OkysaYY/atc/eX0Gx8V3tTLxbHKVX0LKD4Hi8KKCcSbU1U72n7Q==} - engines: {node: '>=16.13'} - requiresBuild: true - peerDependencies: - prisma: '*' - peerDependenciesMeta: - prisma: - optional: true - dev: false - - /@prisma/client@5.13.0(prisma@5.13.0): - resolution: {integrity: sha512-uYdfpPncbZ/syJyiYBwGZS8Gt1PTNoErNYMuqHDa2r30rNSFtgTA/LXsSk55R7pdRTMi5pHkeP9B14K6nHmwkg==} + /@prisma/client@5.15.0(prisma@5.15.0): + resolution: {integrity: sha512-wPTeTjbd2Q0abOeffN7zCDCbkp9C9cF+e9HPiI64lmpehyq2TepgXE+sY7FXr7Rhbb21prLMnhXX27/E11V09w==} engines: {node: '>=16.13'} requiresBuild: true peerDependencies: @@ -3861,8 +3859,7 @@ packages: prisma: optional: true dependencies: - prisma: 5.13.0 - dev: true + prisma: 5.15.0 /@prisma/client@5.7.0: resolution: {integrity: sha512-cZmglCrfNbYpzUtz7HscVHl38e9CrUs31nrVoGUK1nIPXGgt8hT4jj2s657UXcNdQ/jBUxDgGmHyu2Nyrq1txg==} @@ -3875,28 +3872,28 @@ packages: optional: true dev: true - /@prisma/debug@5.13.0: - resolution: {integrity: sha512-699iqlEvzyCj9ETrXhs8o8wQc/eVW+FigSsHpiskSFydhjVuwTJEfj/nIYqTaWFYuxiWQRfm3r01meuW97SZaQ==} + /@prisma/debug@5.15.0: + resolution: {integrity: sha512-QpEAOjieLPc/4sMny/WrWqtpIAmBYsgqwWlWwIctqZO0AbhQ9QcT6x2Ut3ojbDo/pFRCCA1Z1+xm2MUy7fAkZA==} /@prisma/debug@5.7.0: resolution: {integrity: sha512-tZ+MOjWlVvz1kOEhNYMa4QUGURY+kgOUBqLHYIV8jmCsMuvA1tWcn7qtIMLzYWCbDcQT4ZS8xDgK0R2gl6/0wA==} dev: false - /@prisma/engines-version@5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b: - resolution: {integrity: sha512-AyUuhahTINGn8auyqYdmxsN+qn0mw3eg+uhkp8zwknXYIqoT3bChG4RqNY/nfDkPvzWAPBa9mrDyBeOnWSgO6A==} + /@prisma/engines-version@5.15.0-29.12e25d8d06f6ea5a0252864dd9a03b1bb51f3022: + resolution: {integrity: sha512-3BEgZ41Qb4oWHz9kZNofToRvNeS4LZYaT9pienR1gWkjhky6t6K1NyeWNBkqSj2llgraUNbgMOCQPY4f7Qp5wA==} /@prisma/engines-version@5.7.0-41.79fb5193cf0a8fdbef536e4b4a159cad677ab1b9: resolution: {integrity: sha512-V6tgRVi62jRwTm0Hglky3Scwjr/AKFBFtS+MdbsBr7UOuiu1TKLPc6xfPiyEN1+bYqjEtjxwGsHgahcJsd1rNg==} dev: false - /@prisma/engines@5.13.0: - resolution: {integrity: sha512-hIFLm4H1boj6CBZx55P4xKby9jgDTeDG0Jj3iXtwaaHmlD5JmiDkZhh8+DYWkTGchu+rRF36AVROLnk0oaqhHw==} + /@prisma/engines@5.15.0: + resolution: {integrity: sha512-hXL5Sn9hh/ZpRKWiyPA5GbvF3laqBHKt6Vo70hYqqOhh5e0ZXDzHcdmxNvOefEFeqxra2DMz2hNbFoPvqrVe1w==} requiresBuild: true dependencies: - '@prisma/debug': 5.13.0 - '@prisma/engines-version': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b - '@prisma/fetch-engine': 5.13.0 - '@prisma/get-platform': 5.13.0 + '@prisma/debug': 5.15.0 + '@prisma/engines-version': 5.15.0-29.12e25d8d06f6ea5a0252864dd9a03b1bb51f3022 + '@prisma/fetch-engine': 5.15.0 + '@prisma/get-platform': 5.15.0 /@prisma/engines@5.7.0: resolution: {integrity: sha512-TkOMgMm60n5YgEKPn9erIvFX2/QuWnl3GBo6yTRyZKk5O5KQertXiNnrYgSLy0SpsKmhovEPQb+D4l0SzyE7XA==} @@ -3908,12 +3905,12 @@ packages: '@prisma/get-platform': 5.7.0 dev: false - /@prisma/fetch-engine@5.13.0: - resolution: {integrity: sha512-Yh4W+t6YKyqgcSEB3odBXt7QyVSm0OQlBSldQF2SNXtmOgMX8D7PF/fvH6E6qBCpjB/yeJLy/FfwfFijoHI6sA==} + /@prisma/fetch-engine@5.15.0: + resolution: {integrity: sha512-z6AY5yyXxc20Klj7wwnfGP0iIUkVKzybqapT02zLYR/nf9ynaeN8bq73WRmi1TkLYn+DJ5Qy+JGu7hBf1pE78A==} dependencies: - '@prisma/debug': 5.13.0 - '@prisma/engines-version': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b - '@prisma/get-platform': 5.13.0 + '@prisma/debug': 5.15.0 + '@prisma/engines-version': 5.15.0-29.12e25d8d06f6ea5a0252864dd9a03b1bb51f3022 + '@prisma/get-platform': 5.15.0 /@prisma/fetch-engine@5.7.0: resolution: {integrity: sha512-zIn/qmO+N/3FYe7/L9o+yZseIU8ivh4NdPKSkQRIHfg2QVTVMnbhGoTcecbxfVubeTp+DjcbjS0H9fCuM4W04w==} @@ -3923,10 +3920,10 @@ packages: '@prisma/get-platform': 5.7.0 dev: false - /@prisma/generator-helper@5.13.0: - resolution: {integrity: sha512-i+53beJ0dxkDrkHdsXxmeMf+eVhyhOIpL0SdBga8vwe0qHPrAIJ/lpuT/Hj0y5awTmq40qiUEmhXwCEuM/Z17w==} + /@prisma/generator-helper@5.15.0: + resolution: {integrity: sha512-7pB3v57GU4Q/iBauGbvQQGenMJSu2ArQboge4Ca6bw0gA7nConfIHP48MdNIYCrBbNPcIVFmrNomyhqCb3IuWQ==} dependencies: - '@prisma/debug': 5.13.0 + '@prisma/debug': 5.15.0 dev: false /@prisma/generator-helper@5.7.0: @@ -3935,10 +3932,10 @@ packages: '@prisma/debug': 5.7.0 dev: false - /@prisma/get-platform@5.13.0: - resolution: {integrity: sha512-B/WrQwYTzwr7qCLifQzYOmQhZcFmIFhR81xC45gweInSUn2hTEbfKUPd2keAog+y5WI5xLAFNJ3wkXplvSVkSw==} + /@prisma/get-platform@5.15.0: + resolution: {integrity: sha512-1GULDkW4+/VQb73vihxCBSc4Chc2x88MA+O40tcZFjmBzG4/fF44PaXFxUqKSFltxU9L9GIMLhh0Gfkk/pUbtg==} dependencies: - '@prisma/debug': 5.13.0 + '@prisma/debug': 5.15.0 /@prisma/get-platform@5.7.0: resolution: {integrity: sha512-ZeV/Op4bZsWXuw5Tg05WwRI8BlKiRFhsixPcAM+5BKYSiUZiMKIi713tfT3drBq8+T0E1arNZgYSA9QYcglWNA==} @@ -3946,16 +3943,16 @@ packages: '@prisma/debug': 5.7.0 dev: false - /@prisma/internals@5.13.0: - resolution: {integrity: sha512-OPMzS+IBPzCLT4s+IfGUbOhGFY51CFbokIFMZuoSeLKWE8UvDlitiXZ3OlVqDPUc0AlH++ysQHzDISHbZD+ZUg==} + /@prisma/internals@5.15.0: + resolution: {integrity: sha512-RTqzD4fTb74jENYPrn3bhD2vl56W84crSx58f7CyyCs2U3hKYIKfZ2kqMZ6psrqsVRCu7PNeCHdhM5kAimCimQ==} dependencies: - '@prisma/debug': 5.13.0 - '@prisma/engines': 5.13.0 - '@prisma/fetch-engine': 5.13.0 - '@prisma/generator-helper': 5.13.0 - '@prisma/get-platform': 5.13.0 - '@prisma/prisma-schema-wasm': 5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b - '@prisma/schema-files-loader': 5.13.0 + '@prisma/debug': 5.15.0 + '@prisma/engines': 5.15.0 + '@prisma/fetch-engine': 5.15.0 + '@prisma/generator-helper': 5.15.0 + '@prisma/get-platform': 5.15.0 + '@prisma/prisma-schema-wasm': 5.15.0-29.12e25d8d06f6ea5a0252864dd9a03b1bb51f3022 + '@prisma/schema-files-loader': 5.15.0 arg: 5.0.2 prompts: 2.4.2 dev: false @@ -3973,17 +3970,18 @@ packages: prompts: 2.4.2 dev: false - /@prisma/prisma-schema-wasm@5.13.0-23.b9a39a7ee606c28e3455d0fd60e78c3ba82b1a2b: - resolution: {integrity: sha512-+IhHvuE1wKlyOpJgwAhGop1oqEt+1eixrCeikBIshRhdX6LwjmtRxVxVMlP5nS1yyughmpfkysIW4jZTa+Zjuw==} + /@prisma/prisma-schema-wasm@5.15.0-29.12e25d8d06f6ea5a0252864dd9a03b1bb51f3022: + resolution: {integrity: sha512-bZYtXnHSP6nkZf20QZm4A/vzz3Psh+u6pMld4t6cdcZlQW0ZOZQ3/WWTOf5Pe+cqS/k4kciEM5urtH2SE01GCg==} dev: false /@prisma/prisma-schema-wasm@5.7.0-41.79fb5193cf0a8fdbef536e4b4a159cad677ab1b9: resolution: {integrity: sha512-w+HdQtux0dJDEn6BG3fgNn+fXErXiekj9n//uHRAgrmZghockJkhnikOmG8aSXjTb1Tu5DrGasBX+rYX6rHT1w==} dev: false - /@prisma/schema-files-loader@5.13.0: - resolution: {integrity: sha512-6sVMoqobkWKsmzb98LfLiIt/aFRucWfkzSUBsqk7sc+h99xjynJt6aKtM2SSkyndFdWpRU0OiCHfQ9UlYUEJIw==} + /@prisma/schema-files-loader@5.15.0: + resolution: {integrity: sha512-ZDIX4Gr5MdGOiik23DSYQ8cOd/Bkat+6yo5TbAF8UlKor9tJsrEVyGRo6DFu1AEvedjSeiwS88jD1dn03sxvyA==} dependencies: + '@prisma/prisma-schema-wasm': 5.15.0-29.12e25d8d06f6ea5a0252864dd9a03b1bb51f3022 fs-extra: 11.1.1 dev: false @@ -5403,7 +5401,7 @@ packages: parse-semver: 1.1.1 read: 1.0.7 semver: 5.7.1 - tmp: 0.2.1 + tmp: 0.2.3 typed-rest-client: 1.8.10 url-join: 4.0.1 xml2js: 0.5.0 @@ -12484,14 +12482,13 @@ packages: hasBin: true dev: true - /prisma@5.13.0: - resolution: {integrity: sha512-kGtcJaElNRAdAGsCNykFSZ7dBKpL14Cbs+VaQ8cECxQlRPDjBlMHNFYeYt0SKovAVy2Y65JXQwB3A5+zIQwnTg==} + /prisma@5.15.0: + resolution: {integrity: sha512-JA81ACQSCi3a7NUOgonOIkdx8PAVkO+HbUOxmd00Yb8DgIIEpr2V9+Qe/j6MLxIgWtE/OtVQ54rVjfYRbZsCfw==} engines: {node: '>=16.13'} hasBin: true requiresBuild: true dependencies: - '@prisma/engines': 5.13.0 - dev: true + '@prisma/engines': 5.15.0 /process-nextick-args@2.0.1: resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==} @@ -13062,6 +13059,11 @@ packages: dependencies: queue-microtask: 1.2.3 + /run-script-os@1.1.6: + resolution: {integrity: sha512-ql6P2LzhBTTDfzKts+Qo4H94VUKpxKDFz6QxxwaUZN0mwvi7L3lpOI7BqPCq7lgDh3XLl0dpeXwfcVIitlrYrw==} + hasBin: true + dev: true + /rxjs@7.8.1: resolution: {integrity: sha512-AA3TVj+0A2iuIoQkWEK/tqFjBq2j+6PO6Y0zJcvzLAFhEFIO3HL0vls9hWLncZbAAbK0mar7oZ4V079I/qPMxg==} dependencies: diff --git a/tests/integration/test-run/package.json b/tests/integration/test-run/package.json index fc262c950..5788cf6d5 100644 --- a/tests/integration/test-run/package.json +++ b/tests/integration/test-run/package.json @@ -10,9 +10,9 @@ "author": "", "license": "ISC", "dependencies": { - "@prisma/client": "^5.13.0", + "@prisma/client": "^5.15.0", "@zenstackhq/runtime": "file:../../../packages/runtime/dist", - "prisma": "^5.13.0", + "prisma": "^5.15.0", "react": "^18.2.0", "swr": "^1.3.0", "typescript": "^4.9.3", diff --git a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts index 6a31540d7..8acc832c6 100644 --- a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts @@ -6,7 +6,7 @@ describe('Polymorphism Test', () => { const schema = POLYMORPHIC_SCHEMA; async function setup() { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); @@ -21,7 +21,7 @@ describe('Polymorphism Test', () => { } it('create hierarchy', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); @@ -100,7 +100,7 @@ describe('Polymorphism Test', () => { name String } `, - { logPrismaQuery: true, enhancements: ['delegate'] } + { enhancements: ['delegate'] } ); const db = enhance(); @@ -109,7 +109,7 @@ describe('Polymorphism Test', () => { }); it('create with nesting', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); // nested create a relation from base @@ -122,7 +122,7 @@ describe('Polymorphism Test', () => { }); it('create many polymorphic model', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); await expect( @@ -140,7 +140,7 @@ describe('Polymorphism Test', () => { }); it('create many polymorphic relation', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const video1 = await db.ratedVideo.create({ @@ -898,7 +898,7 @@ describe('Polymorphism Test', () => { }); it('deleteMany', async () => { - const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); diff --git a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts index d0316595d..c8e5bd432 100644 --- a/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts @@ -89,14 +89,10 @@ describe('Polymorphic Policy Test', () => { for (const schema of [booleanCondition, booleanExpression]) { const { enhanceRaw: enhance, prisma } = await loadSchema(schema); - const fullDb = enhance(prisma, undefined, { kinds: ['delegate'], logPrismaQuery: true }); + const fullDb = enhance(prisma, undefined, { kinds: ['delegate'] }); const user = await fullDb.user.create({ data: { id: 1 } }); - const userDb = enhance( - prisma, - { user: { id: user.id } }, - { kinds: ['delegate', 'policy'], logPrismaQuery: true } - ); + const userDb = enhance(prisma, { user: { id: user.id } }, { kinds: ['delegate', 'policy'] }); // violating Asset create await expect( @@ -189,9 +185,7 @@ describe('Polymorphic Policy Test', () => { } `; - const { enhance } = await loadSchema(schema, { - logPrismaQuery: true, - }); + const { enhance } = await loadSchema(schema); const db = enhance(); const user = await db.user.create({ data: { id: 1 } }); diff --git a/tests/integration/tests/enhancements/with-policy/create-many-and-return.test.ts b/tests/integration/tests/enhancements/with-policy/create-many-and-return.test.ts new file mode 100644 index 000000000..c96f16256 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/create-many-and-return.test.ts @@ -0,0 +1,105 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Test API createManyAndReturn', () => { + it('model-level policies', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + posts Post[] + level Int + + @@allow('read', level > 0) + } + + model Post { + id Int @id @default(autoincrement()) + title String + published Boolean @default(false) + userId Int + user User @relation(fields: [userId], references: [id]) + + @@allow('read', published) + @@allow('create', contains(title, 'hello')) + } + ` + ); + + await prisma.user.createMany({ + data: [ + { id: 1, level: 1 }, + { id: 2, level: 0 }, + ], + }); + + const db = enhance(); + + // create rule violation + await expect( + db.post.createManyAndReturn({ + data: [{ title: 'foo', userId: 1 }], + }) + ).toBeRejectedByPolicy(); + + // success + let r = await db.post.createManyAndReturn({ + data: [{ id: 1, title: 'hello1', userId: 1, published: true }], + }); + expect(r.length).toBe(1); + + // read-back check + await expect( + db.post.createManyAndReturn({ + data: [ + { id: 2, title: 'hello2', userId: 1, published: true }, + { id: 3, title: 'hello3', userId: 1, published: false }, + ], + }) + ).toBeRejectedByPolicy(['result is not allowed to be read back']); + await expect(prisma.post.findMany()).resolves.toHaveLength(3); + + // return relation + await prisma.post.deleteMany(); + r = await db.post.createManyAndReturn({ + include: { user: true }, + data: [{ id: 1, title: 'hello1', userId: 1, published: true }], + }); + expect(r[0]).toMatchObject({ user: { id: 1 } }); + + // relation filtered + await prisma.post.deleteMany(); + await expect( + db.post.createManyAndReturn({ + include: { user: true }, + data: [{ id: 1, title: 'hello1', userId: 2, published: true }], + }) + ).toBeRejectedByPolicy(['result is not allowed to be read back']); + await expect(prisma.post.findMany()).resolves.toHaveLength(1); + }); + + it('field-level policies', async () => { + const { prisma, enhance } = await loadSchema( + ` + model Post { + id Int @id @default(autoincrement()) + title String @allow('read', published) + published Boolean @default(false) + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + const r = await db.post.createManyAndReturn({ + data: [ + { title: 'post1', published: true }, + { title: 'post2', published: false }, + ], + }); + expect(r).toHaveLength(2); + expect(r[0].title).toBe('post1'); + expect(r[1].title).toBeUndefined(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts new file mode 100644 index 000000000..75d694f12 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/cross-model-field-comparison.test.ts @@ -0,0 +1,1000 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Cross-model field comparison', () => { + it('to-one relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int + + @@allow('read', true) + @@allow('create,update,delete', age == profile.age) + @@deny('update', future().age < future().profile.age && age > 0) + } + + model Profile { + id Int @id + age Int + user User? + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.user.deleteMany(); + await prisma.profile.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // createMany + await expect( + db.user.createMany({ data: [{ id: 1, age: 18, profile: { create: { id: 1, age: 20 } } }] }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.createMany({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + // await expect(db.user.findMany()).resolves.toHaveLength(1); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveNull(); + // await expect(db.user.findMany()).resolves.toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 20 }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 15 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 18, profile: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 25 }); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(2); + await reset(); + + // updateMany + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + // non updatable + await expect(db.user.updateMany({ data: { age: 18 } })).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } } }); + // one of the two is updatable + await expect(db.user.updateMany({ data: { age: 30 } })).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 18 }); + await expect(prisma.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ age: 30 }); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.delete({ where: { id: 1 } })).toBeRejectedByPolicy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // deleteMany + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profile: { create: { id: 2, age: 25 } } } }); + // one of the two is deletable + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + }); + + it('nested inside to-one relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile? + age Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@allow('read', true) + @@allow('create,update,delete', user == null || age == user.age) + @@deny('update', future().user != null && future().age < future().user.age && age > 0) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.profile.deleteMany(); + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 }, include: { profile: true } })).resolves.toMatchObject({ + // age: 18, + // profile: expect.objectContaining({ age: 18 }), + // }); + // await expect(db.user.findMany({ include: { profile: true } })).resolves.toEqual( + // expect.arrayContaining([ + // expect.objectContaining({ + // age: 18, + // profile: expect.objectContaining({ age: 18 }), + // }), + // ]) + // ); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // let r = await db.user.findUnique({ where: { id: 1 }, include: { profile: true } }); + // expect(r.profile).toBeUndefined(); + // r = await db.user.findMany({ include: { profile: true } }); + // expect(r[0].profile).toBeUndefined(); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 20 } } } }) + ).toResolveTruthy(); + const r = await prisma.user.findUnique({ where: { id: 1 }, include: { profile: true } }); + expect(r.profile).toMatchObject({ age: 20 }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 18 } } } }) + ).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 15 } } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { update: { age: 20 } } } }) + ).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profile: { + upsert: { + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profile: { + upsert: { + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 18 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profile: { + upsert: { + create: { id: 2, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profile: { + upsert: { + create: { id: 2, age: 18 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { delete: true } } })).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { delete: true } } })).toResolveTruthy(); + await expect(await prisma.profile.findMany()).toHaveLength(0); + await reset(); + + // connect/disconnect + await prisma.user.create({ data: { id: 1, age: 18, profile: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profile: { disconnect: true } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profile: { disconnect: true } } })).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 25 } }); + await expect( + db.user.update({ where: { id: 2 }, data: { profile: { connect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.create({ data: { id: 3, age: 20 } }); + await expect(db.user.update({ where: { id: 3 }, data: { profile: { connect: { id: 1 } } } })).toResolveTruthy(); + await expect(prisma.profile.findFirst()).resolves.toMatchObject({ userId: 3 }); + await reset(); + }); + + it('to-many relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profiles Profile[] + age Int + + @@allow('read', true) + @@allow('create,update,delete', profiles![this.age == age]) + @@deny('update', future().profiles?[this.age < age]) + } + + model Profile { + id Int @id + age Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId Int + + @@allow('all', true) + } + `, + { preserveTsFiles: true } + ); + + const db = enhance(); + + const reset = async () => { + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: [{ id: 1, age: 20 }] } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profiles: { + createMany: { + data: [ + { id: 1, age: 18 }, + { id: 2, age: 20 }, + ], + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: [{ id: 1, age: 20 }] } } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.create({ + data: { + id: 1, + age: 18, + profiles: { + createMany: { + data: [ + { id: 1, age: 18 }, + { id: 2, age: 18 }, + ], + }, + }, + }, + }) + ).toResolveTruthy(); + await expect( + db.user.create({ + data: { id: 2, age: 18 }, + }) + ).toResolveTruthy(); + await reset(); + + // createMany + await expect( + db.user.createMany({ + data: [ + { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } }, + { id: 2, age: 18, profiles: { create: { id: 2, age: 20 } } }, + ], + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.createMany({ + data: [ + { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } }, + { id: 2, age: 19, profiles: { create: { id: 2, age: 19 } } }, + ], + }) + ).resolves.toEqual({ count: 2 }); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + // await expect(db.user.findMany()).resolves.toHaveLength(1); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // await expect(db.user.findUnique({ where: { id: 1 } })).toResolveNull(); + // await expect(db.user.findMany()).resolves.toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 20 }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 18 } })).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 15 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { age: 20 } })).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toBeRejectedByPolicy(); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 18, profiles: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.upsert({ where: { id: 1 }, create: { id: 1, age: 25 }, update: { age: 25 } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 25 }); + await expect( + db.user.upsert({ + where: { id: 2 }, + create: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } }, + update: { age: 25 }, + }) + ).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(2); + await reset(); + + // updateMany + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + // non updatable + await expect(db.user.updateMany({ data: { age: 18 } })).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } } }); + // one of the two is updatable + await expect(db.user.updateMany({ data: { age: 30 } })).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ age: 18 }); + await expect(prisma.user.findUnique({ where: { id: 2 } })).resolves.toMatchObject({ age: 30 }); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect(db.user.delete({ where: { id: 1 } })).toBeRejectedByPolicy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.delete({ where: { id: 1 } })).toResolveTruthy(); + await expect(prisma.user.findMany()).resolves.toHaveLength(0); + await reset(); + + // deleteMany + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 0 }); + await prisma.user.create({ data: { id: 2, age: 25, profiles: { create: { id: 2, age: 25 } } } }); + // one of the two is deletable + await expect(db.user.deleteMany()).resolves.toMatchObject({ count: 1 }); + await expect(prisma.user.findMany()).resolves.toHaveLength(1); + }); + + it('nested inside to-many relation', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profiles Profile[] + age Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? @relation(fields: [userId], references: [id]) + userId Int? @unique + + @@allow('read', true) + @@allow('create,update,delete', user == null || age == user.age) + @@deny('update', future().user != null && future().age < future().user.age && age > 0) + } + ` + ); + + const db = enhance(); + + const reset = async () => { + await prisma.profile.deleteMany(); + await prisma.user.deleteMany(); + }; + + // create + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }) + ).toBeRejectedByPolicy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect( + db.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }) + ).toResolveTruthy(); + await expect(prisma.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await reset(); + + // TODO: cross-model field comparison is not supported for read rules yet + // // read + // await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + // await expect(db.user.findUnique({ where: { id: 1 }, include: { profiles: true } })).resolves.toMatchObject({ + // age: 18, + // profiles: [expect.objectContaining({ age: 18 })], + // }); + // await expect(db.user.findMany({ include: { profiles: true } })).resolves.toEqual( + // expect.arrayContaining([ + // expect.objectContaining({ + // age: 18, + // profiles: [expect.objectContaining({ age: 18 })], + // }), + // ]) + // ); + // await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + // let r = await db.user.findUnique({ where: { id: 1 }, include: { profiles: true } }); + // expect(r.profiles).toHaveLength(0); + // r = await db.user.findMany({ include: { profiles: true } }); + // expect(r[0].profiles).toHaveLength(0); + // await reset(); + + // update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 20 } } } }, + }) + ).toResolveTruthy(); + let r = await prisma.user.findUnique({ where: { id: 1 }, include: { profiles: true } }); + expect(r.profiles[0]).toMatchObject({ age: 20 }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 18 } } } }, + }) + ).toBeRejectedByPolicy(); + await reset(); + + // post update + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 18 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 15 } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 1 }, + data: { profiles: { update: { where: { id: 1 }, data: { age: 20 } } } }, + }) + ).toResolveTruthy(); + await reset(); + + // upsert + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profiles: { + upsert: { + where: { id: 1 }, + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ + where: { id: 1 }, + data: { + profiles: { + upsert: { + where: { id: 1 }, + create: { id: 1, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 18 } }); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profiles: { + upsert: { + where: { id: 2 }, + create: { id: 2, age: 25 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 2 }, + data: { + profiles: { + upsert: { + where: { id: 2 }, + create: { id: 2, age: 18 }, + update: { age: 25 }, + }, + }, + }, + }) + ).toResolveTruthy(); + await reset(); + + // delete + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { delete: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect(db.user.update({ where: { id: 1 }, data: { profiles: { delete: { id: 1 } } } })).toResolveTruthy(); + await expect(await prisma.profile.findMany()).toHaveLength(0); + await reset(); + + // connect/disconnect + await prisma.user.create({ data: { id: 1, age: 18, profiles: { create: { id: 1, age: 20 } } } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { disconnect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + await expect( + db.user.update({ where: { id: 1 }, data: { profiles: { disconnect: { id: 1 } } } }) + ).toResolveTruthy(); + await prisma.user.create({ data: { id: 2, age: 25 } }); + await expect( + db.user.update({ where: { id: 2 }, data: { profiles: { connect: { id: 1 } } } }) + ).toBeRejectedByPolicy(); + await prisma.user.create({ data: { id: 3, age: 20 } }); + await expect( + db.user.update({ where: { id: 3 }, data: { profiles: { connect: { id: 1 } } } }) + ).toResolveTruthy(); + await expect(prisma.profile.findFirst()).resolves.toMatchObject({ userId: 3 }); + await reset(); + }); + + it('field-level simple', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('read', age == profile.age) @allow('update', age > profile.age) + level Int + + @@allow('all', true) + } + + model Profile { + id Int @id + age Int + user User? + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + // read + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + let r = await db.user.findUnique({ where: { id: 1 } }); + expect(r.age).toBeUndefined(); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(r.age).toBeUndefined(); + + // update + await expect(db.user.update({ where: { id: 1 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { level: 2 } })).toResolveTruthy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 21 } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 25 } })).toResolveTruthy(); + }); + + it('field-level read override', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('read', age == profile.age, true) + level Int + } + + model Profile { + id Int @id + age Int + user User? + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + let r = await db.user.findUnique({ where: { id: 1 } }); + expect(r).toBeNull(); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(Object.keys(r).length).toBe(0); + await prisma.user.update({ where: { id: 1 }, data: { age: 20 } }); + r = await db.user.findUnique({ where: { id: 1 }, select: { age: true } }); + expect(r).toMatchObject({ age: 20 }); + }); + + it('field-level update override', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + age Int @allow('update', age > profile.age, true) + level Int + @@allow('read', true) + } + + model Profile { + id Int @id + age Int + user User? + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await prisma.user.create({ data: { id: 1, age: 18, level: 1, profile: { create: { id: 1, age: 20 } } } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 21 } })).toBeRejectedByPolicy(); + await expect(db.user.update({ where: { id: 1 }, data: { level: 2 } })).toBeRejectedByPolicy(); + await prisma.user.update({ where: { id: 1 }, data: { age: 21 } }); + await expect(db.user.update({ where: { id: 1 }, data: { age: 25 } })).toResolveTruthy(); + }); + + it('with auth case 1', async () => { + const { enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + permissions Permission[] + @@allow('all', true) + } + + model Permission { + id Int @id @default(autoincrement()) + user User @relation(fields: [userId], references: [id]) + userId Int + model String + level Int + @@allow('all', true) + } + + model Post { + id Int @id @default(autoincrement()) + title String + permission PostPermission? + + @@allow('read', true) + @@allow("create", auth().permissions?[model == 'Post' && level == this.permission.level]) + } + + model PostPermission { + id Int @id @default(autoincrement()) + post Post @relation(fields: [postId], references: [id]) + postId Int @unique + level Int + @@allow('all', true) + } + ` + ); + + await expect(enhance().post.create({ data: { title: 'P1' } })).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Foo', level: 1 }] }).post.create({ data: { title: 'P1' } }) + ).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Post', level: 1 }] }).post.create({ data: { title: 'P1' } }) + ).toBeRejectedByPolicy(); + await expect( + enhance({ id: 1, permissions: [{ model: 'Post', level: 1 }] }).post.create({ + data: { title: 'P1', permission: { create: { level: 1 } } }, + }) + ).toResolveTruthy(); + }); + + it('with auth case 2', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + teamMembership TeamMembership[] + @@allow('all', true) + } + + model Team { + id Int @id @default(autoincrement()) + permissions Permission[] + assets Asset[] + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + name String + team Team @relation(fields: [teamId], references: [id]) + teamId Int + @@allow('all', auth().teamMembership?[role.permissions?[name == 'ManageTeam' && teamId == this.teamId]]) + @@allow('read', true) + } + + model TeamMembership { + id Int @id @default(autoincrement()) + role TeamRole? + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('all', true) + } + + model TeamRole { + id Int @id @default(autoincrement()) + permissions Permission[] + membership TeamMembership @relation(fields: [membershipId], references: [id]) + membershipId Int @unique + @@allow('all', true) + } + + model Permission { + id Int @id @default(autoincrement()) + name String + team Team @relation(fields: [teamId], references: [id]) + teamId Int + role TeamRole @relation(fields: [roleId], references: [id]) + roleId Int + @@allow('all', true) + } + ` + ); + + const team1 = await prisma.team.create({ data: {} }); + const team2 = await prisma.team.create({ data: {} }); + + const user = await prisma.user.create({ + data: { + teamMembership: { + create: { + role: { + create: { + permissions: { create: [{ name: 'ManageTeam', team: { connect: { id: team1.id } } }] }, + }, + }, + }, + }, + }, + }); + + const asset = await prisma.asset.create({ + data: { name: 'Asset1', team: { connect: { id: team1.id } } }, + }); + + const dbTeam1 = enhance({ + id: user.id, + teamMembership: [{ role: { permissions: [{ name: 'ManageTeam', teamId: team1.id }] } }], + }); + await expect(dbTeam1.asset.update({ where: { id: asset.id }, data: { name: 'Asset2' } })).toResolveTruthy(); + + const dbTeam2 = enhance({ + id: user.id, + teamMembership: [{ role: { permissions: [{ name: 'ManageTeam', teamId: team2.id }] } }], + }); + await expect( + dbTeam2.asset.update({ where: { id: asset.id }, data: { name: 'Asset2' } }) + ).toBeRejectedByPolicy(); + }); + + it('with auth case 3', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + teamMembership TeamMembership[] + @@allow('all', true) + } + + model Team { + id Int @id @default(autoincrement()) + permissions Permission[] + assets Asset[] + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + name String + team Team @relation(fields: [teamId], references: [id]) + teamId Int + @@allow('all', auth().teamMembership?[role.permissions?[name == 'ManageTeam' && team == this.team]]) + @@allow('read', true) + } + + model TeamMembership { + id Int @id @default(autoincrement()) + role TeamRole? + user User @relation(fields: [userId], references: [id]) + userId Int + @@allow('all', true) + } + + model TeamRole { + id Int @id @default(autoincrement()) + permissions Permission[] + membership TeamMembership @relation(fields: [membershipId], references: [id]) + membershipId Int @unique + @@allow('all', true) + } + + model Permission { + id Int @id @default(autoincrement()) + name String + team Team @relation(fields: [teamId], references: [id]) + teamId Int + role TeamRole @relation(fields: [roleId], references: [id]) + roleId Int + @@allow('all', true) + } + ` + ); + + const team1 = await prisma.team.create({ data: {} }); + const team2 = await prisma.team.create({ data: {} }); + + const user = await prisma.user.create({ + data: { + teamMembership: { + create: { + role: { + create: { + permissions: { create: [{ name: 'ManageTeam', team: { connect: { id: team1.id } } }] }, + }, + }, + }, + }, + }, + }); + + const asset = await prisma.asset.create({ + data: { name: 'Asset1', team: { connect: { id: team1.id } } }, + }); + + const dbTeam1 = enhance({ + id: user.id, + teamMembership: [{ role: { permissions: [{ name: 'ManageTeam', team: { id: team1.id } }] } }], + }); + await expect(dbTeam1.asset.update({ where: { id: asset.id }, data: { name: 'Asset2' } })).toResolveTruthy(); + + const dbTeam2 = enhance({ + id: user.id, + teamMembership: [{ role: { permissions: [{ name: 'ManageTeam', teamId: team2.id }] } }], + }); + await expect( + dbTeam2.asset.update({ where: { id: asset.id }, data: { name: 'Asset2' } }) + ).toBeRejectedByPolicy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index de778e8e8..0297116a0 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -915,6 +915,10 @@ describe('Policy: field-level policy', () => { data: { models: { connect: { id: 1 } } }, }) ).toBeRejectedByPolicy(); + await prisma.user.update({ + where: { id: 1 }, + data: { models: { connect: { id: 1 } } }, + }); await expect( db.user.update({ where: { id: 1 }, @@ -1015,6 +1019,10 @@ describe('Policy: field-level policy', () => { data: { model: { connect: { id: 1 } } }, }) ).toBeRejectedByPolicy(); + await prisma.user.update({ + where: { id: 1 }, + data: { model: { connect: { id: 1 } } }, + }); await expect( db.user.update({ where: { id: 1 }, diff --git a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts index 9dd247d65..c910ff4f1 100644 --- a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts +++ b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts @@ -41,8 +41,7 @@ model Post { secret String @default("secret") @allow('read', published == false, true) @@allow('read', published) -}`, - { logPrismaQuery: true } +}` ); await prisma.user.create({ diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index e215a917b..59c968fb5 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -290,8 +290,7 @@ describe('With Policy:nested to-one', () => { @@allow('create', value > 0) @@allow('update', value > 1) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/integration/tests/enhancements/with-policy/post-update.test.ts b/tests/integration/tests/enhancements/with-policy/post-update.test.ts index b101356cd..d43804787 100644 --- a/tests/integration/tests/enhancements/with-policy/post-update.test.ts +++ b/tests/integration/tests/enhancements/with-policy/post-update.test.ts @@ -110,8 +110,7 @@ describe('With Policy: post update', () => { @@allow('create,read', true) @@allow('update', x > 0 && startsWith(future().value, 'hello')) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts index d46c31245..a9a1b49d2 100644 --- a/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts +++ b/tests/integration/tests/enhancements/with-policy/prisma-omit.test.ts @@ -21,7 +21,7 @@ describe('prisma omit', () => { @@allow('all', level > 1) } `, - { previewFeatures: ['omitApi'], logPrismaQuery: true } + { previewFeatures: ['omitApi'] } ); await prisma.user.create({ diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts index 3c725697d..6ee5c2343 100644 --- a/tests/integration/tests/enhancements/with-policy/refactor.test.ts +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -26,7 +26,6 @@ describe('With Policy: refactor tests', () => { { provider: 'postgresql', dbUrl, - logPrismaQuery: true, } ); getDb = enhance; diff --git a/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts b/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts index 1a1c40406..450726b87 100644 --- a/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts +++ b/tests/integration/tests/enhancements/with-policy/relation-one-to-many-filter.test.ts @@ -1,17 +1,6 @@ import { loadSchema } from '@zenstackhq/testtools'; -import path from 'path'; - -describe('With Policy: relation one-to-many filter', () => { - let origDir: string; - - beforeAll(async () => { - origDir = path.resolve('.'); - }); - - afterEach(() => { - process.chdir(origDir); - }); +describe('Relation one-to-many filter', () => { const model = ` model M1 { id String @id @default(uuid()) @@ -456,3 +445,582 @@ describe('With Policy: relation one-to-many filter', () => { ).resolves.toMatchObject({ m2: [{ _count: { m3: 0 } }] }); }); }); + +describe('Relation one-to-many filter with field-level rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + m3 M3[] + + @@allow('read', true) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int @deny('read', deleted) + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String + + @@allow('read', true) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { id: '2-2' }, + }, + }, + }) + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + some: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + let r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(0); + + r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + some: { id: { equals: '3-2' } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + }); + + it('none filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { id: '2-1' }, + }, + }, + }) + ).toResolveFalsy(); + + // include clause + + let r = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + r = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + none: { id: { equals: '3-2' } }, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(1); + }); + + it('every filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + m3: { + create: { + id: '3-1', + value: 1, + }, + }, + }, + { + id: '2-2', + value: 2, + deleted: true, + m3: { + create: { + id: '3-2', + value: 2, + deleted: true, + }, + }, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { id: { contains: '2' } }, + }, + }, + }) + ).toResolveTruthy(); + + // include clause + + const r = await db.m1.findFirst({ + where: { id: '1' }, + include: { + m2: { + where: { + m3: { + every: {}, + }, + }, + }, + }, + }); + expect(r.m2).toHaveLength(2); + + let r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { value: { gt: 1 } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(1); + + r1 = await db.m1.findFirst({ + where: { + id: '1', + }, + include: { + m2: { + where: { + m3: { + every: { id: { contains: '3' } }, + }, + }, + }, + }, + }); + expect(r1.m2).toHaveLength(2); + }); +}); + +describe('Relation one-to-many filter with field-level override rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) @allow('read', true, true) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('some filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: {}, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { value: { gt: 1 } }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + some: { id: '2-2' }, + }, + }, + }) + ).toResolveTruthy(); + }); + + it('none filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: {}, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { value: { gt: 1 } }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + none: { id: '2-1' }, + }, + }, + }) + ).toResolveFalsy(); + }); + + it('every filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: [ + { + id: '2-1', + value: 1, + }, + { + id: '2-2', + value: 2, + deleted: true, + }, + ], + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: {}, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { value: { gt: 1 } }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + every: { id: { contains: '2' } }, + }, + }, + }) + ).toResolveTruthy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts b/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts index d076e18e5..1f8666fd5 100644 --- a/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts +++ b/tests/integration/tests/enhancements/with-policy/relation-one-to-one-filter.test.ts @@ -1,17 +1,6 @@ import { loadSchema } from '@zenstackhq/testtools'; -import path from 'path'; - -describe('With Policy: relation one-to-one filter', () => { - let origDir: string; - - beforeAll(async () => { - origDir = path.resolve('.'); - }); - - afterEach(() => { - process.chdir(origDir); - }); +describe('Relation one-to-one filter', () => { const model = ` model M1 { id String @id @default(uuid()) @@ -184,6 +173,17 @@ describe('With Policy: relation one-to-one filter', () => { }) ).toResolveTruthy(); + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }) + ).toResolveFalsy(); + // m1 with m2 await db.m1.create({ data: { @@ -206,7 +206,18 @@ describe('With Policy: relation one-to-one filter', () => { }, }, }) - ).toResolveFalsy(); + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }) + ).toResolveTruthy(); // m1 with m2 and m3 await db.m1.create({ @@ -239,7 +250,22 @@ describe('With Policy: relation one-to-one filter', () => { }, }, }) - ).toResolveTruthy(); + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + isNot: { + m3: { + isNot: { value: 1 }, + }, + }, + }, + }, + }) + ).toResolveFalsy(); // m1 with null m2 await db.m1.create({ @@ -257,7 +283,7 @@ describe('With Policy: relation one-to-one filter', () => { }, }, }) - ).toResolveFalsy(); + ).toResolveTruthy(); }); it('direct object filter', async () => { @@ -365,3 +391,721 @@ describe('With Policy: relation one-to-one filter', () => { ).toResolveFalsy(); }); }); + +describe('Relation one-to-one filter with field-level rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', true) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) + value Int @allow('read', !deleted) + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', true) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { id: '2' }, + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { id: '3' }, + }, + }, + }, + }) + ).toResolveTruthy(); + }); + + it('isNot filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }) + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { id: '2' }, + }, + }, + }) + ).toResolveFalsy(); + }); + + it('direct object filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + id: '2', + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { id: '3' }, + }, + }, + }) + ).toResolveTruthy(); + }); +}); + +describe('Relation one-to-one filter with field-level override rules', () => { + const model = ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) @allow('read', true, true) + value Int + deleted Boolean @default(false) + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + m3 M3? + + @@allow('read', !deleted) + @@allow('create', true) + } + + model M3 { + id String @id @default(uuid()) @allow('read', true, true) + value Int + deleted Boolean @default(false) + m2 M2 @relation(fields: [m2Id], references:[id]) + m2Id String @unique + + @@allow('read', !deleted) + @@allow('create', true) + } + `; + + it('is filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + is: { value: 1 }, + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { value: 1 }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + is: { id: '2' }, + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { value: 1 }, + }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + is: { + m3: { id: '3' }, + }, + }, + }, + }) + ).toResolveTruthy(); + }); + + it('isNot filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + id: '1', + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 0 }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + isNot: { value: 1 }, + }, + }, + }) + ).toResolveFalsy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 0 }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { value: 1 }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + isNot: { id: '2' }, + }, + }, + }) + ).toResolveFalsy(); + }); + + it('direct object filter', async () => { + const { enhance } = await loadSchema(model); + + const db = enhance(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '1', + m2: { + create: { + id: '1', + value: 1, + m3: { + create: { + value: 1, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '1', + m2: { + value: 1, + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 + await db.m1.create({ + data: { + id: '2', + m2: { + create: { + id: '2', + value: 1, + deleted: true, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + value: 1, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '2', + m2: { + id: '2', + }, + }, + }) + ).toResolveTruthy(); + + // m1 with m2 and m3 + await db.m1.create({ + data: { + id: '3', + m2: { + create: { + id: '3', + value: 1, + m3: { + create: { + id: '3', + value: 1, + deleted: true, + }, + }, + }, + }, + }, + }); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { value: 1 }, + }, + }, + }) + ).toResolveFalsy(); + + await expect( + db.m1.findFirst({ + where: { + id: '3', + m2: { + m3: { id: '3' }, + }, + }, + }) + ).toResolveTruthy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/subscription.test.ts b/tests/integration/tests/enhancements/with-policy/subscription.test.ts index a4dccf807..aa93706d8 100644 --- a/tests/integration/tests/enhancements/with-policy/subscription.test.ts +++ b/tests/integration/tests/enhancements/with-policy/subscription.test.ts @@ -34,7 +34,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); @@ -88,7 +87,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); @@ -143,7 +141,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); @@ -198,7 +195,6 @@ describe.skip('With Policy: subscription test', () => { provider: 'postgresql', dbUrl: DB_URL, pulseApiKey: PULSE_API_KEY, - logPrismaQuery: true, } ); diff --git a/tests/integration/tests/frameworks/nextjs/test-project/package.json b/tests/integration/tests/frameworks/nextjs/test-project/package.json index 2508743a6..1840a40ee 100644 --- a/tests/integration/tests/frameworks/nextjs/test-project/package.json +++ b/tests/integration/tests/frameworks/nextjs/test-project/package.json @@ -9,7 +9,7 @@ "lint": "next lint" }, "dependencies": { - "@prisma/client": "^5.13.0", + "@prisma/client": "^5.15.0", "@types/node": "18.11.18", "@types/react": "18.0.27", "@types/react-dom": "18.0.10", @@ -22,6 +22,6 @@ "zod": "^3.22.4" }, "devDependencies": { - "prisma": "^5.13.0" + "prisma": "^5.15.0" } } diff --git a/tests/integration/tests/frameworks/trpc/test-project/package.json b/tests/integration/tests/frameworks/trpc/test-project/package.json index dba55073e..550f5d0eb 100644 --- a/tests/integration/tests/frameworks/trpc/test-project/package.json +++ b/tests/integration/tests/frameworks/trpc/test-project/package.json @@ -9,7 +9,7 @@ "lint": "next lint" }, "dependencies": { - "@prisma/client": "^5.13.0", + "@prisma/client": "^5.15.0", "@tanstack/react-query": "^4.22.4", "@trpc/client": "^10.34.0", "@trpc/next": "^10.34.0", @@ -26,6 +26,6 @@ "zod": "^3.22.4" }, "devDependencies": { - "prisma": "^5.13.0" + "prisma": "^5.15.0" } } diff --git a/tests/integration/tests/plugins/policy.test.ts b/tests/integration/tests/plugins/policy.test.ts index 5158584f4..3d9e75f98 100644 --- a/tests/integration/tests/plugins/policy.test.ts +++ b/tests/integration/tests/plugins/policy.test.ts @@ -36,18 +36,20 @@ model M { const { policy } = await loadSchema(model); - expect(policy.guard.m.read({ user: undefined })).toEqual(FALSE); - expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(TRUE); - - expect(policy.guard.m.create({ user: undefined })).toEqual(FALSE); - expect(policy.guard.m.create({ user: { id: '1' } })).toEqual(FALSE); - expect(policy.guard.m.create({ user: { id: '1', value: 0 } })).toEqual(FALSE); - expect(policy.guard.m.create({ user: { id: '1', value: 1 } })).toEqual(TRUE); - - expect(policy.guard.m.update({ user: undefined })).toEqual(FALSE); - expect(policy.guard.m.update({ user: { id: '1' } })).toEqual(FALSE); - expect(policy.guard.m.update({ user: { id: '1', value: 0 } })).toEqual(FALSE); - expect(policy.guard.m.update({ user: { id: '1', value: 1 } })).toEqual(TRUE); + const m = policy.policy.m.modelLevel; + + expect((m.read.guard as Function)({ user: undefined })).toEqual(FALSE); + expect((m.read.guard as Function)({ user: { id: '1' } })).toEqual(TRUE); + + expect((m.create.guard as Function)({ user: undefined })).toEqual(FALSE); + expect((m.create.guard as Function)({ user: { id: '1' } })).toEqual(FALSE); + expect((m.create.guard as Function)({ user: { id: '1', value: 0 } })).toEqual(FALSE); + expect((m.create.guard as Function)({ user: { id: '1', value: 1 } })).toEqual(TRUE); + + expect((m.update.guard as Function)({ user: undefined })).toEqual(FALSE); + expect((m.update.guard as Function)({ user: { id: '1' } })).toEqual(FALSE); + expect((m.update.guard as Function)({ user: { id: '1', value: 0 } })).toEqual(FALSE); + expect((m.update.guard as Function)({ user: { id: '1', value: 1 } })).toEqual(TRUE); }); it('no short-circuit', async () => { @@ -66,13 +68,14 @@ model M { const { policy } = await loadSchema(model); - expect(policy.guard.m.read({ user: undefined })).toEqual( + expect((policy.policy.m.modelLevel.read.guard as Function)({ user: undefined })).toEqual( expect.objectContaining({ AND: [{ OR: [] }, { value: { gt: 0 } }] }) ); - expect(policy.guard.m.read({ user: { id: '1' } })).toEqual( + expect((policy.policy.m.modelLevel.read.guard as Function)({ user: { id: '1' } })).toEqual( expect.objectContaining({ AND: [{ AND: [] }, { value: { gt: 0 } }] }) ); }); + it('auth() multiple level member access', async () => { const model = ` model User { @@ -97,12 +100,12 @@ model M { `; const { policy } = await loadSchema(model); - expect(policy.guard.task.read({ user: { cart: { tasks: [{ id: 1 }] } } })).toEqual( - expect.objectContaining({ AND: [{ OR: [] }, { value: { gt: 10 } }] }) - ); + expect( + (policy.policy.task.modelLevel.read.guard as Function)({ user: { cart: { tasks: [{ id: 1 }] } } }) + ).toEqual(expect.objectContaining({ AND: [{ OR: [] }, { value: { gt: 10 } }] })); - expect(policy.guard.task.read({ user: { cart: { tasks: [{ id: 123 }] } } })).toEqual( - expect.objectContaining({ AND: [{ AND: [] }, { value: { gt: 10 } }] }) - ); + expect( + (policy.policy.task.modelLevel.read.guard as Function)({ user: { cart: { tasks: [{ id: 123 }] } } }) + ).toEqual(expect.objectContaining({ AND: [{ AND: [] }, { value: { gt: 10 } }] })); }); }); diff --git a/tests/integration/tests/plugins/zod.test.ts b/tests/integration/tests/plugins/zod.test.ts index 00d40b755..5b896416d 100644 --- a/tests/integration/tests/plugins/zod.test.ts +++ b/tests/integration/tests/plugins/zod.test.ts @@ -61,6 +61,7 @@ describe('Zod plugin tests', () => { authorId Int? published Boolean @default(false) viewCount Int @default(0) + viewMilliseconds BigInt @default(0) } `, { addPrelude: false, pushDb: false } @@ -188,11 +189,13 @@ describe('Zod plugin tests', () => { expect(schemas.PostPrismaCreateSchema.safeParse({ title: 'a' }).success).toBeFalsy(); expect(schemas.PostPrismaCreateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy(); expect(schemas.PostPrismaCreateSchema.safeParse({ viewCount: 1 }).success).toBeTruthy(); + expect(schemas.PostPrismaCreateSchema.safeParse({ viewMilliseconds: 1n }).success).toBeTruthy(); expect(schemas.PostPrismaUpdateSchema.safeParse({ title: 'a' }).success).toBeFalsy(); expect(schemas.PostPrismaUpdateSchema.safeParse({ title: 'abcde' }).success).toBeTruthy(); expect(schemas.PostPrismaUpdateSchema.safeParse({ viewCount: 1 }).success).toBeTruthy(); expect(schemas.PostPrismaUpdateSchema.safeParse({ viewCount: { increment: 1 } }).success).toBeTruthy(); + expect(schemas.PostPrismaUpdateSchema.safeParse({ viewMilliseconds: 1n }).success).toBeTruthy(); }); it('mixed casing', async () => { diff --git a/tests/regression/tests/issue-1014.test.ts b/tests/regression/tests/issue-1014.test.ts index ad862db42..66caa1b11 100644 --- a/tests/regression/tests/issue-1014.test.ts +++ b/tests/regression/tests/issue-1014.test.ts @@ -37,8 +37,7 @@ describe('issue 1014', () => { title String @allow('read', true, true) content String } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/regression/tests/issue-1080.test.ts b/tests/regression/tests/issue-1080.test.ts index 17ce998c2..69408fdf0 100644 --- a/tests/regression/tests/issue-1080.test.ts +++ b/tests/regression/tests/issue-1080.test.ts @@ -19,8 +19,7 @@ describe('issue 1080', () => { @@allow('all', true) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/regression/tests/issue-1241.test.ts b/tests/regression/tests/issue-1241.test.ts index e5d94c9b7..a555fcb8d 100644 --- a/tests/regression/tests/issue-1241.test.ts +++ b/tests/regression/tests/issue-1241.test.ts @@ -38,8 +38,7 @@ describe('issue 1241', () => { @@allow('all', true) } - `, - { logPrismaQuery: true } + ` ); const user = await prisma.user.create({ diff --git a/tests/regression/tests/issue-1271.test.ts b/tests/regression/tests/issue-1271.test.ts index d25cabb3b..9798664cb 100644 --- a/tests/regression/tests/issue-1271.test.ts +++ b/tests/regression/tests/issue-1271.test.ts @@ -39,8 +39,7 @@ describe('issue 1271', () => { @@allow("all", true) } - `, - { logPrismaQuery: true } + ` ); const db = enhance(); diff --git a/tests/regression/tests/issue-1435.test.ts b/tests/regression/tests/issue-1435.test.ts index 0093aff8b..d539b778f 100644 --- a/tests/regression/tests/issue-1435.test.ts +++ b/tests/regression/tests/issue-1435.test.ts @@ -83,7 +83,7 @@ describe('issue 1435', () => { reference String @id } `, - { provider: 'postgresql', dbUrl, logPrismaQuery: true } + { provider: 'postgresql', dbUrl } ); prisma = r.prisma; diff --git a/tests/regression/tests/issue-1451.test.ts b/tests/regression/tests/issue-1451.test.ts index f54a0ca4f..fb105561d 100644 --- a/tests/regression/tests/issue-1451.test.ts +++ b/tests/regression/tests/issue-1451.test.ts @@ -29,8 +29,7 @@ describe('issue 1452', () => { @@id([userId, spaceId]) @@allow('all', true) } - `, - { logPrismaQuery: true } + ` ); await prisma.user.create({ diff --git a/tests/regression/tests/issue-1454.test.ts b/tests/regression/tests/issue-1454.test.ts new file mode 100644 index 000000000..6c42fcf59 --- /dev/null +++ b/tests/regression/tests/issue-1454.test.ts @@ -0,0 +1,117 @@ +import { loadSchema } from '@zenstackhq/testtools'; +describe('issue 1454', () => { + it('regression1', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + sensitiveInformation String + username String + + purchases Purchase[] + + @@allow('read', auth() == this) + } + + model Purchase { + id Int @id @default(autoincrement()) + purchasedAt DateTime @default(now()) + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@allow('read', true) + } + ` + ); + + const db = enhance(); + await prisma.user.create({ + data: { username: 'user1', sensitiveInformation: 'sensitive', purchases: { create: {} } }, + }); + + await expect(db.purchase.findMany({ where: { user: { username: 'user1' } } })).resolves.toHaveLength(0); + await expect(db.purchase.findMany({ where: { user: { is: { username: 'user1' } } } })).resolves.toHaveLength(0); + }); + + it('regression2', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + username String @allow('read', false) + + purchases Purchase[] + + @@allow('read', true) + } + + model Purchase { + id Int @id @default(autoincrement()) + purchasedAt DateTime @default(now()) + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@allow('read', true) + } + `, + { logPrismaQuery: true } + ); + + const db = enhance(); + const user = await prisma.user.create({ + data: { username: 'user1', purchases: { create: {} } }, + }); + + await expect(db.purchase.findMany({ where: { user: { id: user.id } } })).resolves.toHaveLength(1); + await expect(db.purchase.findMany({ where: { user: { username: 'user1' } } })).resolves.toHaveLength(0); + await expect(db.purchase.findMany({ where: { user: { is: { username: 'user1' } } } })).resolves.toHaveLength(0); + }); + + it('regression3', async () => { + const { prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + sensitiveInformation String + username String @allow('read', true, true) + + purchases Purchase[] + + @@allow('read', auth() == this) + } + + model Purchase { + id Int @id @default(autoincrement()) + purchasedAt DateTime @default(now()) + userId Int + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + + @@allow('read', true) + } + `, + { logPrismaQuery: true } + ); + + const db = enhance(); + await prisma.user.create({ + data: { username: 'user1', sensitiveInformation: 'sensitive', purchases: { create: {} } }, + }); + + await expect(db.purchase.findMany({ where: { user: { username: 'user1' } } })).resolves.toHaveLength(1); + await expect(db.purchase.findMany({ where: { user: { is: { username: 'user1' } } } })).resolves.toHaveLength(1); + await expect( + db.purchase.findMany({ where: { user: { sensitiveInformation: 'sensitive' } } }) + ).resolves.toHaveLength(0); + await expect( + db.purchase.findMany({ where: { user: { is: { sensitiveInformation: 'sensitive' } } } }) + ).resolves.toHaveLength(0); + await expect( + db.purchase.findMany({ where: { user: { username: 'user1', sensitiveInformation: 'sensitive' } } }) + ).resolves.toHaveLength(0); + await expect( + db.purchase.findMany({ + where: { OR: [{ user: { username: 'user1' } }, { user: { sensitiveInformation: 'sensitive' } }] }, + }) + ).resolves.toHaveLength(1); + }); +}); diff --git a/tests/regression/tests/issue-1466.test.ts b/tests/regression/tests/issue-1466.test.ts new file mode 100644 index 000000000..3ad17143f --- /dev/null +++ b/tests/regression/tests/issue-1466.test.ts @@ -0,0 +1,236 @@ +import { createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1466', () => { + it('regression1', async () => { + const dbUrl = await createPostgresDb('issue-1466-1'); + let prisma: any; + + try { + const r = await loadSchema( + ` + model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique + } + + model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) + } + + model VideoLongLongLongLongName extends AssetLongLongLongLongName { + duration Int + } + `, + { + provider: 'postgresql', + dbUrl, + enhancements: ['delegate'], + } + ); + + prisma = r.prisma; + const db = r.enhance(); + + const video = await db.VideoLongLongLongLongName.create({ + data: { duration: 100 }, + }); + + const user = await db.UserLongLongLongLongName.create({ + data: { + asset: { connect: { id: video.id } }, + }, + }); + + const userWithAsset = await db.UserLongLongLongLongName.findFirst({ + include: { asset: true }, + }); + + expect(userWithAsset).toMatchObject({ + asset: { assetType: 'VideoLongLongLongLongName', duration: 100 }, + }); + } finally { + if (prisma) { + await prisma.$disconnect(); + } + await dropPostgresDb('issue-1466-1'); + } + }); + + it('regression2', async () => { + const dbUrl = await createPostgresDb('issue-1466-2'); + let prisma: any; + + try { + const r = await loadSchema( + ` + model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int + + @@unique([assetId]) + } + + model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) + } + + model VideoLongLongLongLongName extends AssetLongLongLongLongName { + duration Int + } + `, + { + provider: 'postgresql', + dbUrl, + enhancements: ['delegate'], + } + ); + + prisma = r.prisma; + const db = r.enhance(); + + const video = await db.VideoLongLongLongLongName.create({ + data: { duration: 100 }, + }); + + const user = await db.UserLongLongLongLongName.create({ + data: { + asset: { connect: { id: video.id } }, + }, + }); + + const userWithAsset = await db.UserLongLongLongLongName.findFirst({ + include: { asset: true }, + }); + + expect(userWithAsset).toMatchObject({ + asset: { assetType: 'VideoLongLongLongLongName', duration: 100 }, + }); + } finally { + if (prisma) { + await prisma.$disconnect(); + } + await dropPostgresDb('issue-1466-2'); + } + }); + + it('regression3', async () => { + await loadSchema( + ` + model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique + } + + model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) + } + + model VideoLongLongLongLongName1 extends AssetLongLongLongLongName { + duration Int + } + + model VideoLongLongLongLongName2 extends AssetLongLongLongLongName { + format String + } + `, + { + provider: 'postgresql', + pushDb: false, + } + ); + }); + + it('regression4', async () => { + await loadSchema( + ` + model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique + } + + model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) + } + + model VideoLongLongLongLongName1 extends AssetLongLongLongLongName { + duration Int + } + + model VideoLongLongLongLongName2 extends AssetLongLongLongLongName { + format String + } + `, + { + provider: 'postgresql', + pushDb: false, + } + ); + }); + + it('regression5', async () => { + await loadSchema( + ` + model UserLongLongLongLongName { + id Int @id @default(autoincrement()) + level Int @default(0) + asset AssetLongLongLongLongName @relation(fields: [assetId], references: [id]) + assetId Int @unique(map: 'assetId_unique') + } + + model AssetLongLongLongLongName { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner UserLongLongLongLongName? + assetType String + + @@delegate(assetType) + } + + model VideoLongLongLongLongName1 extends AssetLongLongLongLongName { + duration Int + } + + model VideoLongLongLongLongName2 extends AssetLongLongLongLongName { + format String + } + `, + { + provider: 'postgresql', + pushDb: false, + } + ); + }); +}); diff --git a/tests/regression/tests/issue-1474.test.ts b/tests/regression/tests/issue-1474.test.ts new file mode 100644 index 000000000..9e157d40d --- /dev/null +++ b/tests/regression/tests/issue-1474.test.ts @@ -0,0 +1,27 @@ +import { loadSchema } from '@zenstackhq/testtools'; +describe('issue 1474', () => { + it('regression', async () => { + await loadSchema( + ` + model A { + id Int @id + cs C[] + } + + abstract model B { + a A @relation(fields: [aId], references: [id]) + aId Int + } + + model C extends B { + id Int @id + type String + @@delegate(type) + } + + model D extends C { + } + ` + ); + }); +}); diff --git a/tests/regression/tests/issue-1483.test.ts b/tests/regression/tests/issue-1483.test.ts new file mode 100644 index 000000000..ee947ae2c --- /dev/null +++ b/tests/regression/tests/issue-1483.test.ts @@ -0,0 +1,68 @@ +import { loadSchema } from '@zenstackhq/testtools'; +describe('issue 1483', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` + model User { + @@auth + id String @id + edits Edit[] + @@allow('all', true) + } + + model Entity { + + id String @id @default(cuid()) + name String + edits Edit[] + + type String + @@delegate(type) + + @@allow('all', true) + } + + model Person extends Entity { + } + + model Edit { + id String @id @default(cuid()) + + authorId String? + author User? @relation(fields: [authorId], references: [id], onDelete: Cascade, onUpdate: NoAction) + + entityId String + entity Entity @relation(fields: [entityId], references: [id], onDelete: Cascade, onUpdate: NoAction) + + @@allow('all', true) + } + ` + ); + + const db = enhance(); + await db.edit.deleteMany({}); + await db.person.deleteMany({}); + await db.user.deleteMany({}); + + const person = await db.person.create({ + data: { + name: 'test', + }, + }); + + await db.edit.create({ + data: { + entityId: person.id, + }, + }); + + await expect( + db.edit.findMany({ + include: { + author: true, + entity: true, + }, + }) + ).resolves.toHaveLength(1); + }); +}); diff --git a/tests/regression/tests/issue-1487.test.ts b/tests/regression/tests/issue-1487.test.ts new file mode 100644 index 000000000..6acfcdcfe --- /dev/null +++ b/tests/regression/tests/issue-1487.test.ts @@ -0,0 +1,71 @@ +import { createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools'; +import Decimal from 'decimal.js'; + +describe('issue 1487', () => { + it('regression2', async () => { + const dbUrl = await createPostgresDb('issue-1487'); + let prisma: any; + + try { + const r = await loadSchema( + ` + model LineItem { + id Int @id @default(autoincrement()) + price Decimal + createdAt DateTime @default(now()) + + orderId Int + order Order @relation(fields: [orderId], references: [id]) + } + model Order extends BaseType { + total Decimal + createdAt DateTime @default(now()) + lineItems LineItem[] + } + model BaseType { + id Int @id @default(autoincrement()) + entityType String + + @@delegate(entityType) + } + `, + { + provider: 'postgresql', + dbUrl, + enhancements: ['omit', 'delegate'], + } + ); + + prisma = r.prisma; + const db = r.enhance(); + + const create = await db.Order.create({ + data: { + total: new Decimal(100_100.99), + lineItems: { create: [{ price: 90_000.66 }, { price: 20_100.33 }] }, + }, + }); + + const order = await db.Order.findFirst({ where: { id: create.id }, include: { lineItems: true } }); + expect(Decimal.isDecimal(order.total)).toBe(true); + expect(order.createdAt instanceof Date).toBe(true); + expect(order.total.toString()).toEqual('100100.99'); + order.lineItems.forEach((item: any) => { + expect(Decimal.isDecimal(item.price)).toBe(true); + expect(item.price.toString()).not.toEqual('[object Object]'); + }); + + const lineItems = await db.LineItem.findMany(); + lineItems.forEach((item: any) => { + expect(item.createdAt instanceof Date).toBe(true); + expect(Decimal.isDecimal(item.price)).toBe(true); + expect(item.price.toString()).not.toEqual('[object Object]'); + }); + } finally { + if (prisma) { + await prisma.$disconnect(); + } + await dropPostgresDb('issue-1487'); + } + }); +}); diff --git a/tests/regression/tests/issue-961.test.ts b/tests/regression/tests/issue-961.test.ts index f6dc3a135..1f622059e 100644 --- a/tests/regression/tests/issue-961.test.ts +++ b/tests/regression/tests/issue-961.test.ts @@ -35,7 +35,7 @@ describe('Regression: issue 961', () => { `; it('deleteMany', async () => { - const { prisma, enhance } = await loadSchema(schema, { logPrismaQuery: true }); + const { prisma, enhance } = await loadSchema(schema); const user = await prisma.user.create({ data: { diff --git a/tests/regression/tests/issues.test.ts b/tests/regression/tests/issues.test.ts index 318682aad..1418a309a 100644 --- a/tests/regression/tests/issues.test.ts +++ b/tests/regression/tests/issues.test.ts @@ -531,8 +531,7 @@ model tenant { model Equipment extends BaseEntityWithTenant { a String } -`, - { logPrismaQuery: true } +` ); await prisma.tenant.create({