From 769cdb0957fb8428c299c24cb78ac3cc478a647a Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Thu, 14 Mar 2024 20:22:12 -0700 Subject: [PATCH] feat: generate strong typing for the `user` context of `enhance` API --- .../src/enhancements/create-enhancement.ts | 4 +- .../runtime/src/enhancements/policy/index.ts | 4 +- .../enhancer/enhance/auth-type-generator.ts | 141 +++++++++++++++ .../src/plugins/enhancer/enhance/index.ts | 20 ++- packages/testtools/src/schema.ts | 12 +- .../enhancements/with-policy/auth.test.ts | 170 +++++++++++++++++- 6 files changed, 339 insertions(+), 12 deletions(-) create mode 100644 packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index be2fc4579..9ab27448c 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -94,8 +94,8 @@ export type InternalEnhancementOptions = EnhancementOptions & { /** * Context for creating enhanced `PrismaClient` */ -export type EnhancementContext = { - user?: AuthUser; +export type EnhancementContext = { + user?: User; }; let hasPassword: boolean | undefined = undefined; diff --git a/packages/runtime/src/enhancements/policy/index.ts b/packages/runtime/src/enhancements/policy/index.ts index e197e18c1..c76812a51 100644 --- a/packages/runtime/src/enhancements/policy/index.ts +++ b/packages/runtime/src/enhancements/policy/index.ts @@ -4,6 +4,7 @@ import { getIdFields } from '../../cross'; import { DbClientContract } from '../../types'; import { hasAllFields } from '../../validation'; import type { EnhancementContext, InternalEnhancementOptions } from '../create-enhancement'; +import { Logger } from '../logger'; import { makeProxy } from '../proxy'; import { PolicyProxyHandler } from './handler'; @@ -44,7 +45,8 @@ export function withPolicy( if (authSelector) { Object.keys(authSelector).forEach((f) => { if (!(f in userContext)) { - console.warn(`User context does not have field "${f}" used in policy rules`); + const logger = new Logger(prisma); + logger.warn(`User context does not have field "${f}" used in policy rules`); } }); } diff --git a/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts new file mode 100644 index 000000000..d8e53c173 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/enhance/auth-type-generator.ts @@ -0,0 +1,141 @@ +import { getIdFields, isAuthInvocation, isDataModelFieldReference } from '@zenstackhq/sdk'; +import { + DataModel, + DataModelField, + Expression, + isDataModel, + isMemberAccessExpr, + type Model, +} from '@zenstackhq/sdk/ast'; +import { streamAst, type AstNode } from 'langium'; +import { isCollectionPredicate } from '../../../utils/ast-utils'; + +/** + * Generate types for typing the `user` context object passed to the `enhance` call, based + * on the fields (potentially deeply) access through `auth()`. + */ +export function generateAuthType(model: Model, authModel: DataModel) { + const types = new Map< + string, + { + // scalar fields to directly pick from Prisma-generated type + pickFields: string[]; + + // relation fields to include + addFields: { name: string; type: string }[]; + } + >(); + + types.set(authModel.name, { pickFields: getIdFields(authModel).map((f) => f.name), addFields: [] }); + + const ensureType = (model: string) => { + if (!types.has(model)) { + types.set(model, { pickFields: [], addFields: [] }); + } + }; + + const addPickField = (model: string, field: string) => { + let fields = types.get(model); + if (!fields) { + fields = { pickFields: [], addFields: [] }; + types.set(model, fields); + } + if (!fields.pickFields.includes(field)) { + fields.pickFields.push(field); + } + }; + + const addAddField = (model: string, name: string, type: string, array: boolean) => { + let fields = types.get(model); + if (!fields) { + fields = { pickFields: [], addFields: [] }; + types.set(model, fields); + } + if (!fields.addFields.find((f) => f.name === name)) { + fields.addFields.push({ name, type: array ? `${type}[]` : type }); + } + }; + + // get all policy expressions involving `auth()` + const authInvolvedExprs = streamAst(model).filter(isAuthAccess); + + // traverse the expressions and collect types and fields involved + authInvolvedExprs.forEach((expr) => { + streamAst(expr).forEach((node) => { + if (isMemberAccessExpr(node)) { + const exprType = node.operand.$resolvedType?.decl; + if (isDataModel(exprType)) { + const memberDecl = node.member.ref; + if (isDataModel(memberDecl?.type.reference?.ref)) { + // member is a relation + const fieldType = memberDecl.type.reference.ref.name; + ensureType(fieldType); + addAddField(exprType.name, memberDecl.name, fieldType, memberDecl.type.array); + } else { + // member is a scalar + addPickField(exprType.name, node.member.$refText); + } + } + } + + if (isDataModelFieldReference(node)) { + // this can happen inside collection predicates + const fieldDecl = node.target.ref as DataModelField; + const fieldType = fieldDecl.type.reference?.ref; + if (isDataModel(fieldType)) { + // field is a relation + ensureType(fieldType.name); + addAddField(fieldDecl.$container.name, node.target.$refText, fieldType.name, fieldDecl.type.array); + } else { + // field is a scalar + addPickField(fieldDecl.$container.name, node.target.$refText); + } + } + }); + }); + + // generate: + // ` + // namespace auth { + // export type User = WithRequired, 'id'> & { profile: Profile; }; + // export type Profile = WithRequired, 'age'>; + // } + // ` + + return `namespace auth { + type WithRequired = T & { [P in K]-?: T[P] }; +${Array.from(types.entries()) + .map(([model, fields]) => { + let result = `Partial<_P.${model}>`; + + if (fields.pickFields.length > 0) { + result = `WithRequired<${result}, ${fields.pickFields.map((f) => `'${f}'`).join('|')}>`; + } + + if (fields.addFields.length > 0) { + result = `${result} & { ${fields.addFields.map(({ name, type }) => `${name}: ${type}`).join('; ')} }`; + } + + return ` export type ${model} = ${result};`; + }) + .join('\n')} +}`; +} + +function isAuthAccess(node: AstNode): node is Expression { + if (isAuthInvocation(node)) { + return true; + } + + if (isMemberAccessExpr(node) && isAuthAccess(node.operand)) { + return true; + } + + if (isCollectionPredicate(node)) { + if (isAuthAccess(node.left)) { + return true; + } + } + + return false; +} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index a379e5ad0..b18ed59b2 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -2,6 +2,7 @@ import type { DMMF } from '@prisma/generator-helper'; import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; import { getAttribute, + getAuthModel, getDataModels, getDMMF, getPrismaClientImportSpec, @@ -28,6 +29,7 @@ import { execPackage } from '../../../utils/exec-utils'; import { trackPrismaSchemaError } from '../../prisma'; import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; import { isDefaultWithAuth } from '../enhancer-utils'; +import { generateAuthType } from './auth-type-generator'; // information of delegate models and their sub models type DelegateInfo = [DataModel, DataModel[]][]; @@ -62,16 +64,26 @@ export async function generate(model: Model, options: PluginOptions, project: Pr await prismaDts.save(); } + const authModel = getAuthModel(model.declarations.filter(isDataModel)); + const authTypes = authModel ? generateAuthType(model, authModel) : ''; + const authTypeParam = authModel ? `auth.${authModel.name}` : 'AuthUser'; + const enhanceTs = project.createSourceFile( path.join(outDir, 'enhance.ts'), - `import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas } from '@zenstackhq/runtime'; + `import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas, type AuthUser } from '@zenstackhq/runtime'; import modelMeta from './model-meta'; import policy from './policy'; -${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} import { Prisma } from '${getPrismaClientImportSpec(outDir, options)}'; -${withLogicalClient ? `import { type PrismaClient } from '${logicalPrismaClientDir}/index-fixed';` : ``} +${ + withLogicalClient + ? `import type * as _P, { type PrismaClient } from '${logicalPrismaClientDir}/index-fixed';` + : `import type * as _P from '@prisma/client';` +} +${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} + +${authTypes} -export function enhance(prisma: DbClient, context?: EnhancementContext, options?: EnhancementOptions)${ +export function enhance(prisma: DbClient, context?: EnhancementContext<${authTypeParam}>, options?: EnhancementOptions)${ withLogicalClient ? ': PrismaClient' : '' } { return createEnhancement(prisma, { diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index ecdad8336..b744aa578 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -251,13 +251,17 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey })); } + opt.extraSourceFiles?.forEach(({ name, content }) => { + fs.writeFileSync(path.join(projectRoot, name), content); + }); + + if (opt.extraSourceFiles && opt.extraSourceFiles.length > 0 && !opt.compile) { + console.warn('`extraSourceFiles` is true but `compile` is false.'); + } + if (opt.compile) { console.log('Compiling...'); - opt.extraSourceFiles?.forEach(({ name, content }) => { - fs.writeFileSync(path.join(projectRoot, name), content); - }); - run('npx tsc --init'); // add generated '.zenstack/zod' folder to typescript's search path, diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index e2655b36a..9079da045 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -1,7 +1,7 @@ import { loadSchema } from '@zenstackhq/testtools'; import path from 'path'; -describe('With Policy: auth() test', () => { +describe('auth() runtime test', () => { let origDir: string; beforeAll(async () => { @@ -618,3 +618,171 @@ describe('With Policy: auth() test', () => { ).toResolveTruthy(); }); }); + +describe('auth() compile-time test', () => { + it('default enhanced typing', async () => { + await loadSchema( + ` + model User { + id1 Int + id2 Int + age Int + + @@id([id1, id2]) + @@allow('all', true) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { enhance } from ".zenstack/enhance"; + import { PrismaClient } from '@prisma/client'; + enhance(new PrismaClient(), { user: { id1: 1, id2: 2 } }); + `, + }, + ], + } + ); + }); + + it('custom auth model', async () => { + await loadSchema( + ` + model Foo { + id Int @id + age Int + + @@auth + @@allow('all', true) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { enhance } from ".zenstack/enhance"; + import { PrismaClient } from '@prisma/client'; + enhance(new PrismaClient(), { user: { id: 1 } }); + `, + }, + ], + } + ); + }); + + it('auth() selection', async () => { + await loadSchema( + ` + model User { + id Int @id + age Int + email String + + @@allow('all', auth().age > 0) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { enhance } from ".zenstack/enhance"; + import { PrismaClient } from '@prisma/client'; + enhance(new PrismaClient(), { user: { id: 1, age: 10 } }); + `, + }, + ], + } + ); + }); + + it('auth() to-one relation selection', async () => { + await loadSchema( + ` + model User { + id Int @id + email String + profile Profile? + + @@allow('all', auth().profile.age > 0 && auth().profile.job.level > 0) + } + + model Profile { + id Int @id + job Job? + age Int + user User @relation(fields: [userId], references: [id]) + userId Int @unique + } + + model Job { + id Int @id + level Int + profile Profile @relation(fields: [profileId], references: [id]) + profileId Int @unique + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { enhance } from ".zenstack/enhance"; + import { PrismaClient } from '@prisma/client'; + enhance(new PrismaClient(), { user: { id: 1, profile: { age: 1, job: { level: 10 } } } }); + `, + }, + ], + } + ); + }); + + it('auth() to-many relation selection', async () => { + await loadSchema( + ` + model User { + id Int @id + email String + posts Post[] + + @@allow('all', auth().posts?[viewCount > 0] && auth().posts?[comments?[level > 0]]) + } + + model Post { + id Int @id + viewCount Int + comments Comment[] + user User @relation(fields: [userId], references: [id]) + userId Int + } + + model Comment { + id Int @id + level Int + post Post @relation(fields: [postId], references: [id]) + postId Int + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { enhance } from ".zenstack/enhance"; + import { PrismaClient } from '@prisma/client'; + enhance(new PrismaClient(), { user: { id: 1, posts: [ { viewCount: 1, comments: [ { level: 1 } ] } ] } }); + `, + }, + ], + } + ); + }); +});