From 36e515e485c580657b9edbfc52014f3542abfb96 Mon Sep 17 00:00:00 2001 From: Yiming Date: Wed, 6 Mar 2024 21:09:41 -0800 Subject: [PATCH] fix: several issues with using `auth()` in `@default` (#1088) --- packages/runtime/src/cross/model-meta.ts | 9 +- .../runtime/src/enhancements/default-auth.ts | 46 +++++++++- .../src/enhancements/policy/handler.ts | 21 +---- packages/runtime/src/enhancements/utils.ts | 19 +++++ .../src/plugins/enhancer/enhance/index.ts | 26 +++++- .../src/plugins/enhancer/enhancer-utils.ts | 20 +++++ .../src/plugins/prisma/schema-generator.ts | 38 +++++---- .../validation/attribute-validation.test.ts | 26 +----- packages/sdk/src/model-meta-generator.ts | 7 +- packages/sdk/src/utils.ts | 39 +++++++++ packages/testtools/src/schema.ts | 6 ++ .../enhancements/with-policy/auth.test.ts | 83 +++++++++++++++++++ 12 files changed, 271 insertions(+), 69 deletions(-) create mode 100644 packages/schema/src/plugins/enhancer/enhancer-utils.ts diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 9f767af0e..efa4d1a03 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -75,7 +75,14 @@ export type FieldInfo = { isForeignKey?: boolean; /** - * Mapping from foreign key field names to relation field names + * If the field is a foreign key field, the field name of the corresponding relation field. + * Only available on foreign key fields. + */ + relationField?: string; + + /** + * Mapping from foreign key field names to relation field names. + * Only available on relation fields. */ foreignKeyMapping?: Record; diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index bbbd35861..78294f28b 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -2,10 +2,11 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import deepcopy from 'deepcopy'; -import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross'; +import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields, requireField } from '../cross'; import { DbClientContract } from '../types'; import { EnhancementContext, InternalEnhancementOptions } from './create-enhancement'; import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; +import { isUnsafeMutate } from './utils'; /** * Gets an enhanced Prisma client that supports `@default(auth())` attribute. @@ -68,7 +69,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo); if (authDefaultValue !== undefined) { // set field value extracted from `auth()` - data[fieldInfo.name] = authDefaultValue; + this.setAuthDefaultValue(fieldInfo, model, data, authDefaultValue); } } }; @@ -90,6 +91,47 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler { return newArgs; } + private setAuthDefaultValue(fieldInfo: FieldInfo, model: string, data: any, authDefaultValue: unknown) { + if (fieldInfo.isForeignKey && !isUnsafeMutate(model, data, this.options.modelMeta)) { + // if the field is a fk, and the create payload is not unsafe, we need to translate + // the fk field setting to a `connect` of the corresponding relation field + const relFieldName = fieldInfo.relationField; + if (!relFieldName) { + throw new Error( + `Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found` + ); + } + const relationField = requireField(this.options.modelMeta, model, relFieldName); + + // construct a `{ connect: { ... } }` payload + let connect = data[relationField.name]?.connect; + if (!connect) { + connect = {}; + data[relationField.name] = { connect }; + } + + // sets the opposite fk field to value `authDefaultValue` + const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo); + if (!oppositeFkFieldName) { + throw new Error( + `Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\`` + ); + } + connect[oppositeFkFieldName] = authDefaultValue; + } else { + // set default value directly + data[fieldInfo.name] = authDefaultValue; + } + } + + private getOppositeFkFieldName(relationField: FieldInfo, fieldInfo: FieldInfo) { + if (!relationField.foreignKeyMapping) { + return undefined; + } + const entry = Object.entries(relationField.foreignKeyMapping).find(([, v]) => v === fieldInfo.name); + return entry?.[0]; + } + private getDefaultValueFromAuth(fieldInfo: FieldInfo) { if (!this.userContext) { throw new Error(`Evaluating default value of field \`${fieldInfo.name}\` requires a user context`); diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index f31e145f9..f2bc4ad07 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -21,7 +21,7 @@ import type { EnhancementContext, InternalEnhancementOptions } from '../create-e import { Logger } from '../logger'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; -import { formatObject, prismaClientValidationError } from '../utils'; +import { formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; import { PolicyUtil } from './policy-utils'; import { createDeferredPromise } from './promise'; @@ -691,7 +691,7 @@ export class PolicyProxyHandler implements Pr // operations. E.g.: // - safe: { data: { user: { connect: { id: 1 }} } } // - unsafe: { data: { userId: 1 } } - const unsafe = this.isUnsafeMutate(model, args); + const unsafe = isUnsafeMutate(model, args, this.modelMeta); // handles the connection to upstream entity const reversedQuery = this.policyUtils.buildReversedQuery(context, true, unsafe); @@ -1083,23 +1083,6 @@ export class PolicyProxyHandler implements Pr } } - private isUnsafeMutate(model: string, args: any) { - if (!args) { - return false; - } - for (const k of Object.keys(args)) { - const field = resolveField(this.modelMeta, model, k); - if (field && (this.isAutoIncrementIdField(field) || field.isForeignKey)) { - return true; - } - } - return false; - } - - private isAutoIncrementIdField(field: FieldInfo) { - return field.isId && field.isAutoIncrement; - } - async updateMany(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index ba2f9a2d8..9bc7ce0bc 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -1,4 +1,5 @@ import * as util from 'util'; +import { FieldInfo, ModelMeta, resolveField } from '..'; import type { DbClientContract } from '../types'; /** @@ -22,3 +23,21 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error { throw new prismaModule.PrismaClientUnknownRequestError(...args); } + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function isUnsafeMutate(model: string, args: any, modelMeta: ModelMeta) { + if (!args) { + return false; + } + for (const k of Object.keys(args)) { + const field = resolveField(modelMeta, model, k); + if (field && (isAutoIncrementIdField(field) || field.isForeignKey)) { + return true; + } + } + return false; +} + +export function isAutoIncrementIdField(field: FieldInfo) { + return field.isId && field.isAutoIncrement; +} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 9488b24f7..63845ba1c 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -27,6 +27,7 @@ import { name } from '..'; import { execPackage } from '../../../utils/exec-utils'; import { trackPrismaSchemaError } from '../../prisma'; import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; +import { isDefaultWithAuth } from '../enhancer-utils'; // information of delegate models and their sub models type DelegateInfo = [DataModel, DataModel[]][]; @@ -35,7 +36,7 @@ export async function generate(model: Model, options: PluginOptions, project: Pr let logicalPrismaClientDir: string | undefined; let dmmf: DMMF.Document | undefined; - if (hasDelegateModel(model)) { + if (needsLogicalClient(model)) { // schema contains delegate models, need to generate a logical prisma schema const result = await generateLogicalPrisma(model, options, outDir); @@ -86,6 +87,10 @@ export function enhance(prisma: DbClient, context?: Enh return { dmmf }; } +function needsLogicalClient(model: Model) { + return hasDelegateModel(model) || hasAuthInDefault(model); +} + function hasDelegateModel(model: Model) { const dataModels = getDataModels(model); return dataModels.some( @@ -93,6 +98,12 @@ function hasDelegateModel(model: Model) { ); } +function hasAuthInDefault(model: Model) { + return getDataModels(model).some((dm) => + dm.fields.some((f) => f.attributes.some((attr) => isDefaultWithAuth(attr))) + ); +} + async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) { const prismaGenerator = new PrismaSchemaGenerator(model); const prismaClientOutDir = './.logical-prisma-client'; @@ -152,12 +163,19 @@ async function processClientTypes(model: Model, prismaClientDir: string) { const sfNew = project.createSourceFile(path.join(prismaClientDir, 'index-fixed.d.ts'), undefined, { overwrite: true, }); - transform(sf, sfNew, delegateInfo); - sfNew.formatText(); + + if (delegateInfo.length > 0) { + // transform types for delegated models + transformDelegate(sf, sfNew, delegateInfo); + sfNew.formatText(); + } else { + // just copy + sfNew.replaceWithText(sf.getFullText()); + } await sfNew.save(); } -function transform(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) { +function transformDelegate(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) { // copy toplevel imports sfNew.addImportDeclarations(sf.getImportDeclarations().map((n) => n.getStructure())); diff --git a/packages/schema/src/plugins/enhancer/enhancer-utils.ts b/packages/schema/src/plugins/enhancer/enhancer-utils.ts new file mode 100644 index 000000000..9bb429ca5 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/enhancer-utils.ts @@ -0,0 +1,20 @@ +import { isAuthInvocation } from '@zenstackhq/sdk'; +import type { DataModelFieldAttribute } from '@zenstackhq/sdk/ast'; +import { streamAst } from 'langium'; + +/** + * Check if the given field attribute is a `@default` with `auth()` invocation + */ +export function isDefaultWithAuth(attr: DataModelFieldAttribute) { + if (attr.decl.ref?.name !== '@default') { + return false; + } + + const expr = attr.args[0]?.value; + if (!expr) { + return false; + } + + // find `auth()` in default value expression + return streamAst(expr).some(isAuthInvocation); +} diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 2519c3cd3..bc63d535a 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -34,11 +34,12 @@ import { getIdFields } from '../../utils/ast-utils'; import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; import { getAttribute, + getForeignKeyFields, getLiteral, getPrismaVersion, - isAuthInvocation, isDelegateModel, isIdField, + isRelationshipField, PluginError, PluginOptions, resolved, @@ -46,7 +47,6 @@ import { } from '@zenstackhq/sdk'; import fs from 'fs'; import { writeFile } from 'fs/promises'; -import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import semver from 'semver'; @@ -54,6 +54,7 @@ import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; import { execPackage } from '../../utils/exec-utils'; +import { isDefaultWithAuth } from '../enhancer/enhancer-utils'; import { AttributeArgValue, ModelFieldType, @@ -494,10 +495,27 @@ export class PrismaSchemaGenerator { const type = new ModelFieldType(fieldType, field.type.array, field.type.optional); + if (this.mode === 'logical') { + if (field.attributes.some((attr) => isDefaultWithAuth(attr))) { + // field has `@default` with `auth()`, it should be set optional, and the + // default value setting is handled outside Prisma + type.optional = true; + } + + if (isRelationshipField(field)) { + // if foreign key field has `@default` with `auth()`, the relation + // field should be set optional + const foreignKeyFields = getForeignKeyFields(field); + if (foreignKeyFields.some((fkField) => fkField.attributes.some((attr) => isDefaultWithAuth(attr)))) { + type.optional = true; + } + } + } + const attributes = field.attributes .filter((attr) => this.isPrismaAttribute(attr)) // `@default` with `auth()` is handled outside Prisma - .filter((attr) => !this.isDefaultWithAuth(attr)) + .filter((attr) => !isDefaultWithAuth(attr)) .filter( (attr) => // when building physical schema, exclude `@default` for id fields inherited from delegate base @@ -524,20 +542,6 @@ export class PrismaSchemaGenerator { return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom); } - private isDefaultWithAuth(attr: DataModelFieldAttribute) { - if (attr.decl.ref?.name !== '@default') { - return false; - } - - const expr = attr.args[0]?.value; - if (!expr) { - return false; - } - - // find `auth()` in default value expression - return streamAst(expr).some(isAuthInvocation); - } - private makeFieldAttribute(attr: DataModelFieldAttribute) { const attrName = resolved(attr.decl).name; if (attrName === FIELD_PASSTHROUGH_ATTR) { diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 611f8dc60..c6d0db13b 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -227,7 +227,7 @@ describe('Attribute tests', () => { `); await loadModel(` - ${ prelude } + ${prelude} model A { id String @id x String @@ -1051,21 +1051,6 @@ describe('Attribute tests', () => { } `); - // expect( - // await loadModelWithError(` - // ${prelude} - - // model User { - // id String @id - // name String - // } - // model B { - // id String @id - // userData String @default(auth()) - // } - // `) - // ).toContain("Value is not assignable to parameter"); - expect( await loadModelWithError(` ${prelude} @@ -1185,15 +1170,6 @@ describe('Attribute tests', () => { }); it('incorrect function expression context', async () => { - // expect( - // await loadModelWithError(` - // ${prelude} - // model M { - // id String @id @default(auth()) - // } - // `) - // ).toContain('function "auth" is not allowed in the current context: DefaultValue'); - expect( await loadModelWithError(` ${prelude} diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 8adf42c4c..3dc0f3f1e 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -32,6 +32,7 @@ import { isIdField, resolved, TypeScriptExpressionTransformer, + getRelationField, } from '.'; /** @@ -247,6 +248,11 @@ function writeFields( if (isForeignKeyField(f)) { writer.write(` isForeignKey: true,`); + const relationField = getRelationField(f); + if (relationField) { + writer.write(` + relationField: '${relationField.name}',`); + } } if (fkMapping && Object.keys(fkMapping).length > 0) { @@ -408,7 +414,6 @@ function generateForeignKeyMapping(field: DataModelField) { const fieldNames = fields.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined)); const referenceNames = references.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined)); - // eslint-disable-next-line @typescript-eslint/no-explicit-any const result: Record = {}; referenceNames.forEach((name, i) => { if (name) { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index f73d2e12c..641446a02 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -281,6 +281,45 @@ export function isForeignKeyField(field: DataModelField) { }); } +/** + * Gets the foreign key fields of the given relation field. + */ +export function getForeignKeyFields(relationField: DataModelField) { + if (!isRelationshipField(relationField)) { + return []; + } + + const relAttr = relationField.attributes.find((attr) => attr.decl.ref?.name === '@relation'); + if (relAttr) { + // find "fields" arg + const fieldsArg = getAttributeArg(relAttr, 'fields'); + if (fieldsArg && isArrayExpr(fieldsArg)) { + return fieldsArg.items + .filter((item): item is ReferenceExpr => isReferenceExpr(item)) + .map((item) => item.target.ref as DataModelField); + } + } + + return []; +} + +/** + * Gets the relation field of the given foreign key field. + */ +export function getRelationField(fkField: DataModelField) { + const model = fkField.$container as DataModel; + return model.fields.find((f) => { + const relAttr = f.attributes.find((attr) => attr.decl.ref?.name === '@relation'); + if (relAttr) { + const fieldsArg = getAttributeArg(relAttr, 'fields'); + if (fieldsArg && isArrayExpr(fieldsArg)) { + return fieldsArg.items.some((item) => isReferenceExpr(item) && item.target.ref === fkField); + } + } + return false; + }); +} + export function resolvePath(_path: string, options: Pick) { if (path.isAbsolute(_path)) { return _path; diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index df72c107c..392b4af4f 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -121,6 +121,7 @@ export type SchemaLoadOptions = { getPrismaOnly?: boolean; enhancements?: EnhancementKind[]; enhanceOptions?: Partial; + extraSourceFiles?: { name: string; content: string }[]; }; const defaultOptions: SchemaLoadOptions = { @@ -246,6 +247,11 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { if (opt.compile) { console.log('Compiling...'); + + opt.extraSourceFiles?.forEach(({ name, content }) => { + fs.writeFileSync(path.join(projectRoot, name), content); + }); + run('npx tsc --init'); // add generated '.zenstack/zod' folder to typescript's search path, diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index e1fff4f73..e2655b36a 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -534,4 +534,87 @@ describe('With Policy: auth() test', () => { ); await expect(db.post.findMany({})).toResolveTruthy(); }); + + it('Default auth() field optionality', async () => { + await loadSchema( + ` + model User { + id String @id + posts Post[] + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { PrismaClient } from '@prisma/client'; + import { enhance } from '.zenstack/enhance'; + + const prisma = new PrismaClient(); + const db = enhance(prisma, { user: { id: 'user1' } }); + + // "author" and "authorId" are optional + db.post.create({ data: { title: 'abc' } }); +`, + }, + ], + } + ); + }); + + it('Default auth() safe unsafe mix', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + posts Post[] + + @@allow('all', true) + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + + stats Stats @relation(fields: [statsId], references: [id]) + statsId String @unique + + @@allow('all', true) + } + + model Stats { + id String @id @default(uuid()) + viewCount Int @default(0) + post Post? + + @@allow('all', true) + + } + ` + ); + + const db = enhance({ id: 'userId-1' }); + await db.user.create({ data: { id: 'userId-1' } }); + + // safe + await db.stats.create({ data: { id: 'stats-1', viewCount: 10 } }); + await expect(db.post.create({ data: { title: 'title', statsId: 'stats-1' } })).toResolveTruthy(); + + // unsafe + await db.stats.create({ data: { id: 'stats-2', viewCount: 10 } }); + await expect( + db.post.create({ data: { title: 'title', stats: { connect: { id: 'stats-2' } } } }) + ).toResolveTruthy(); + }); });