diff --git a/packages/runtime/src/enhancements/node/delegate.ts b/packages/runtime/src/enhancements/node/delegate.ts index 78523b837..80fd09f17 100644 --- a/packages/runtime/src/enhancements/node/delegate.ts +++ b/packages/runtime/src/enhancements/node/delegate.ts @@ -587,6 +587,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { let curr = args; let base = this.getBaseModel(model); let sub = this.getModelInfo(model); + const hasDelegateBase = !!base; while (base) { const baseRelationName = this.makeAuxRelationName(base); @@ -615,6 +616,55 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { sub = base; base = this.getBaseModel(base.name); } + + if (hasDelegateBase) { + // A delegate base model creation is added, this can be incompatible if + // the user-provided payload assigns foreign keys directly, because Prisma + // doesn't permit mixed "checked" and "unchecked" fields in a payload. + // + // { + // delegate_aux_base: { ... }, + // [fkField]: value // <- this is not compatible + // } + // + // We need to convert foreign key assignments to `connect`. + this.fkAssignmentToConnect(model, args); + } + } + + // convert foreign key assignments to `connect` payload + // e.g.: { authorId: value } -> { author: { connect: { id: value } } } + private fkAssignmentToConnect(model: string, args: any) { + const keysToDelete: string[] = []; + for (const [key, value] of Object.entries(args)) { + if (value === undefined) { + continue; + } + + const fieldInfo = this.queryUtils.getModelField(model, key); + if ( + !fieldInfo?.inheritedFrom && // fields from delegate base are handled outside + fieldInfo?.isForeignKey + ) { + const relationInfo = this.queryUtils.getRelationForForeignKey(model, key); + if (relationInfo) { + // turn { [fk]: value } into { [relation]: { connect: { [id]: value } } } + const relationName = relationInfo.relation.name; + if (!args[relationName]) { + args[relationName] = {}; + } + if (!args[relationName].connect) { + args[relationName].connect = {}; + } + if (!(relationInfo.idField in args[relationName].connect)) { + args[relationName].connect[relationInfo.idField] = value; + keysToDelete.push(key); + } + } + } + } + + keysToDelete.forEach((key) => delete args[key]); } // inject field data that belongs to base type into proper nesting structure diff --git a/packages/runtime/src/enhancements/node/query-utils.ts b/packages/runtime/src/enhancements/node/query-utils.ts index 0effd4557..c09fe1f95 100644 --- a/packages/runtime/src/enhancements/node/query-utils.ts +++ b/packages/runtime/src/enhancements/node/query-utils.ts @@ -232,4 +232,25 @@ export class QueryUtils { return model; } + + /** + * Gets relation info for a foreign key field. + */ + getRelationForForeignKey(model: string, fkField: string) { + const modelInfo = getModelInfo(this.options.modelMeta, model); + if (!modelInfo) { + return undefined; + } + + for (const field of Object.values(modelInfo.fields)) { + if (field.foreignKeyMapping) { + const entry = Object.entries(field.foreignKeyMapping).find(([, v]) => v === fkField); + if (entry) { + return { relation: field, idField: entry[0], fkField: entry[1] }; + } + } + } + + return undefined; + } } diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 7236cac85..ce4e167c1 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -7,6 +7,7 @@ import { getDataModelAndTypeDefs, getDataModels, getLiteral, + getRelationField, isDelegateModel, isDiscriminatorField, normalizedRelative, @@ -55,12 +56,23 @@ type DelegateInfo = [DataModel, DataModel[]][]; const LOGICAL_CLIENT_GENERATION_PATH = './.logical-prisma-client'; export class EnhancerGenerator { + // regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type + // names for models that use `auth()` in `@default` attribute + private readonly modelsWithAuthInDefaultCreateInputPattern: RegExp; + constructor( private readonly model: Model, private readonly options: PluginOptions, private readonly project: Project, private readonly outDir: string - ) {} + ) { + const modelsWithAuthInDefault = this.model.declarations.filter( + (d): d is DataModel => isDataModel(d) && d.fields.some((f) => f.attributes.some(isDefaultWithAuth)) + ); + this.modelsWithAuthInDefaultCreateInputPattern = new RegExp( + `^(${modelsWithAuthInDefault.map((m) => m.name).join('|')})(Unchecked)?Create.*?Input$` + ); + } async generate(): Promise<{ dmmf: DMMF.Document | undefined; newPrismaClientDtsPath: string | undefined }> { let dmmf: DMMF.Document | undefined; @@ -69,7 +81,7 @@ export class EnhancerGenerator { let prismaTypesFixed = false; let resultPrismaImport = prismaImport; - if (this.needsLogicalClient || this.needsPrismaClientTypeFixes) { + if (this.needsLogicalClient) { prismaTypesFixed = true; resultPrismaImport = `${LOGICAL_CLIENT_GENERATION_PATH}/index-fixed`; const result = await this.generateLogicalPrisma(); @@ -230,11 +242,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara } private get needsLogicalClient() { - return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model); - } - - private get needsPrismaClientTypeFixes() { - return this.hasTypeDef(this.model); + return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model) || this.hasTypeDef(this.model); } private hasDelegateModel(model: Model) { @@ -449,11 +457,13 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara const auxFields = this.findAuxDecls(variable); if (auxFields.length > 0) { structure.declarations.forEach((variable) => { - let source = variable.type?.toString(); - auxFields.forEach((f) => { - source = source?.replace(f.getText(), ''); - }); - variable.type = source; + if (variable.type) { + let source = variable.type.toString(); + auxFields.forEach((f) => { + source = this.removeFromSource(source, f.getText()); + }); + variable.type = source; + } }); } @@ -498,6 +508,9 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara // fix delegate payload union type source = this.fixDelegatePayloadType(typeAlias, delegateInfo, source); + // fix fk and relation fields related to using `auth()` in `@default` + source = this.fixDefaultAuthType(typeAlias, source); + // fix json field type source = this.fixJsonFieldType(typeAlias, source); @@ -505,65 +518,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara return structure; } - private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) { - const modelsWithTypeField = this.model.declarations.filter( - (d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref)) - ); - const typeName = typeAlias.getName(); - - const getTypedJsonFields = (model: DataModel) => { - return model.fields.filter((f) => isTypeDef(f.type.reference?.ref)); - }; - - const replacePrismaJson = (source: string, field: DataModelField) => { - return source.replace( - new RegExp(`(${field.name}\\??\\s*):[^\\n]+`), - `$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${ - field.type.optional ? ' | null' : '' - }` - ); - }; - - // fix "$[Model]Payload" type - const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName); - if (payloadModelMatch) { - const scalars = typeAlias - .getDescendantsOfKind(SyntaxKind.PropertySignature) - .find((p) => p.getName() === 'scalars'); - if (!scalars) { - return source; - } - - const fieldsToFix = getTypedJsonFields(payloadModelMatch); - for (const field of fieldsToFix) { - source = replacePrismaJson(source, field); - } - } - - // fix input/output types, "[Model]CreateInput", etc. - const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name)); - if (inputOutputModelMatch) { - const relevantTypePatterns = [ - 'GroupByOutputType', - '(Unchecked)?Create(\\S+?)?Input', - '(Unchecked)?Update(\\S+?)?Input', - 'CreateManyInput', - '(Unchecked)?UpdateMany(Mutation)?Input', - ]; - const typeRegex = modelsWithTypeField.map( - (m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`) - ); - if (typeRegex.some((r) => r.test(typeName))) { - const fieldsToFix = getTypedJsonFields(inputOutputModelMatch); - for (const field of fieldsToFix) { - source = replacePrismaJson(source, field); - } - } - } - - return source; - } - private fixDelegatePayloadType(typeAlias: TypeAliasDeclaration, delegateInfo: DelegateInfo, source: string) { // change the type of `$Payload` type of delegate model to a union of concrete types const typeName = typeAlias.getName(); @@ -595,7 +549,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara .getDescendantsOfKind(SyntaxKind.PropertySignature) .filter((p) => ['create', 'createMany', 'connectOrCreate', 'upsert'].includes(p.getName())); toRemove.forEach((r) => { - source = source.replace(r.getText(), ''); + this.removeFromSource(source, r.getText()); }); } return source; @@ -633,7 +587,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara if (isDiscriminatorField(field)) { const fieldDef = this.findNamedProperty(typeAlias, field.name); if (fieldDef) { - source = source.replace(fieldDef.getText(), ''); + source = this.removeFromSource(source, fieldDef.getText()); } } } @@ -646,7 +600,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara const auxDecls = this.findAuxDecls(typeAlias); if (auxDecls.length > 0) { auxDecls.forEach((d) => { - source = source.replace(d.getText(), ''); + source = this.removeFromSource(source, d.getText()); }); } return source; @@ -677,7 +631,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara const fieldDef = this.findNamedProperty(typeAlias, relationFieldName); if (fieldDef) { // remove relation field of delegate type, e.g., `asset` - source = source.replace(fieldDef.getText(), ''); + source = this.removeFromSource(source, fieldDef.getText()); } // remove fk fields related to the delegate type relation, e.g., `assetId` @@ -709,13 +663,103 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara fkFields.forEach((fkField) => { const fieldDef = this.findNamedProperty(typeAlias, fkField); if (fieldDef) { - source = source.replace(fieldDef.getText(), ''); + source = this.removeFromSource(source, fieldDef.getText()); } }); return source; } + private fixDefaultAuthType(typeAlias: TypeAliasDeclaration, source: string) { + const match = typeAlias.getName().match(this.modelsWithAuthInDefaultCreateInputPattern); + if (!match) { + return source; + } + + const modelName = match[1]; + const dataModel = this.model.declarations.find((d): d is DataModel => isDataModel(d) && d.name === modelName); + if (dataModel) { + for (const fkField of dataModel.fields.filter((f) => f.attributes.some(isDefaultWithAuth))) { + // change fk field to optional since it has a default + source = source.replace(new RegExp(`^(\\s*${fkField.name}\\s*):`, 'm'), `$1?:`); + + const relationField = getRelationField(fkField); + if (relationField) { + // change relation field to optional since its fk has a default + source = source.replace(new RegExp(`^(\\s*${relationField.name}\\s*):`, 'm'), `$1?:`); + } + } + } + return source; + } + + private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) { + const modelsWithTypeField = this.model.declarations.filter( + (d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref)) + ); + const typeName = typeAlias.getName(); + + const getTypedJsonFields = (model: DataModel) => { + return model.fields.filter((f) => isTypeDef(f.type.reference?.ref)); + }; + + const replacePrismaJson = (source: string, field: DataModelField) => { + return source.replace( + new RegExp(`(${field.name}\\??\\s*):[^\\n]+`), + `$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${ + field.type.optional ? ' | null' : '' + }` + ); + }; + + // fix "$[Model]Payload" type + const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName); + if (payloadModelMatch) { + const scalars = typeAlias + .getDescendantsOfKind(SyntaxKind.PropertySignature) + .find((p) => p.getName() === 'scalars'); + if (!scalars) { + return source; + } + + const fieldsToFix = getTypedJsonFields(payloadModelMatch); + for (const field of fieldsToFix) { + source = replacePrismaJson(source, field); + } + } + + // fix input/output types, "[Model]CreateInput", etc. + const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name)); + if (inputOutputModelMatch) { + const relevantTypePatterns = [ + 'GroupByOutputType', + '(Unchecked)?Create(\\S+?)?Input', + '(Unchecked)?Update(\\S+?)?Input', + 'CreateManyInput', + '(Unchecked)?UpdateMany(Mutation)?Input', + ]; + const typeRegex = modelsWithTypeField.map( + (m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`) + ); + if (typeRegex.some((r) => r.test(typeName))) { + const fieldsToFix = getTypedJsonFields(inputOutputModelMatch); + for (const field of fieldsToFix) { + source = replacePrismaJson(source, field); + } + } + } + + return source; + } + + private async generateExtraTypes(sf: SourceFile) { + for (const decl of this.model.declarations) { + if (isTypeDef(decl)) { + generateTypeDefType(sf, decl); + } + } + } + private findNamedProperty(typeAlias: TypeAliasDeclaration, name: string) { return typeAlias.getFirstDescendant((d) => d.isKind(SyntaxKind.PropertySignature) && d.getName() === name); } @@ -745,11 +789,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara return this.options.generatePermissionChecker === true; } - private async generateExtraTypes(sf: SourceFile) { - for (const decl of this.model.declarations) { - if (isTypeDef(decl)) { - generateTypeDefType(sf, decl); - } - } + private removeFromSource(source: string, text: string) { + source = source.replace(text, ''); + return this.trimEmptyLines(source); + } + + private trimEmptyLines(source: string): string { + return source.replace(/^\s*[\r\n]/gm, ''); } } diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index b3d382795..d893d729f 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -31,7 +31,7 @@ import { StringLiteral, } from '@zenstackhq/language/ast'; import { getPrismaVersion } from '@zenstackhq/sdk/prisma'; -import { match, P } from 'ts-pattern'; +import { match } from 'ts-pattern'; import { getIdFields } from '../../utils/ast-utils'; import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; @@ -838,14 +838,6 @@ export class PrismaSchemaGenerator { const docs = [...field.comments, ...this.getCustomAttributesAsComments(field)]; const result = model.addField(field.name, type, attributes, docs, addToFront); - if (this.mode === 'logical') { - if (field.attributes.some((attr) => isDefaultWithAuth(attr))) { - // field has `@default` with `auth()`, turn it into a dummy default value, and the - // real default value setting is handled outside Prisma - this.setDummyDefault(result, field); - } - } - return result; } @@ -856,23 +848,6 @@ export class PrismaSchemaGenerator { } } - private setDummyDefault(result: ModelField, field: DataModelField) { - const dummyDefaultValue = match(field.type.type) - .with('String', () => new AttributeArgValue('String', '')) - .with(P.union('Int', 'BigInt', 'Float', 'Decimal'), () => new AttributeArgValue('Number', '0')) - .with('Boolean', () => new AttributeArgValue('Boolean', 'false')) - .with('DateTime', () => new AttributeArgValue('FunctionCall', new PrismaFunctionCall('now'))) - .with('Json', () => new AttributeArgValue('String', '{}')) - .with('Bytes', () => new AttributeArgValue('String', '')) - .otherwise(() => { - throw new PluginError(name, `Unsupported field type with default value: ${field.type.type}`); - }); - - result.attributes.push( - new PrismaFieldAttribute('@default', [new PrismaAttributeArg(undefined, dummyDefaultValue)]) - ); - } - private isInheritedFromDelegate(field: DataModelField) { return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom); } diff --git a/packages/schema/src/plugins/zod/transformer.ts b/packages/schema/src/plugins/zod/transformer.ts index a61eec3a2..db0c2a7bb 100644 --- a/packages/schema/src/plugins/zod/transformer.ts +++ b/packages/schema/src/plugins/zod/transformer.ts @@ -1,6 +1,12 @@ /* eslint-disable @typescript-eslint/ban-ts-comment */ -import { indentString, isDiscriminatorField, type PluginOptions } from '@zenstackhq/sdk'; -import { DataModel, Enum, isDataModel, isEnum, isTypeDef, type Model } from '@zenstackhq/sdk/ast'; +import { + getForeignKeyFields, + hasAttribute, + indentString, + isDiscriminatorField, + type PluginOptions, +} from '@zenstackhq/sdk'; +import { DataModel, DataModelField, Enum, isDataModel, isEnum, isTypeDef, type Model } from '@zenstackhq/sdk/ast'; import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers'; import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma'; import path from 'path'; @@ -241,7 +247,8 @@ export default class Transformer { this.addSchemaImport(inputType.type); } - result.push(this.generatePrismaStringLine(field, inputType, lines.length)); + const contextField = contextDataModel?.fields.find((f) => f.name === field.name); + result.push(this.generatePrismaStringLine(field, inputType, lines.length, contextField)); } } @@ -315,7 +322,12 @@ export default class Transformer { this.schemaImports.add(upperCaseFirst(name)); } - generatePrismaStringLine(field: PrismaDMMF.SchemaArg, inputType: PrismaDMMF.InputTypeRef, inputsLength: number) { + generatePrismaStringLine( + field: PrismaDMMF.SchemaArg, + inputType: PrismaDMMF.InputTypeRef, + inputsLength: number, + contextField: DataModelField | undefined + ) { const isEnum = inputType.location === 'enumTypes'; const { isModelQueryType, modelName, queryName } = this.checkIsModelQueryType(inputType.type as string); @@ -330,11 +342,36 @@ export default class Transformer { const arr = inputType.isList ? '.array()' : ''; - const opt = !field.isRequired ? '.optional()' : ''; + const optional = + !field.isRequired || + // also check if the zmodel field infers the field as optional + (contextField && this.isFieldOptional(contextField)); return inputsLength === 1 - ? ` ${field.name}: z.lazy(() => ${schema})${arr}${opt}` - : `z.lazy(() => ${schema})${arr}${opt}`; + ? ` ${field.name}: z.lazy(() => ${schema})${arr}${optional ? '.optional()' : ''}` + : `z.lazy(() => ${schema})${arr}${optional ? '.optional()' : ''}`; + } + + private isFieldOptional(dmField: DataModelField) { + if (hasAttribute(dmField, '@default')) { + // it's possible that ZModel field has a default but it's transformed away + // when generating Prisma schema, e.g.: `@default(auth().id)` + return true; + } + + if (isDataModel(dmField.type.reference?.ref)) { + // if field is a relation, we need to check if the corresponding fk field has a default + // { + // authorId Int @default(auth().id) + // author User @relation(...) // <- author should be optional + // } + const fkFields = getForeignKeyFields(dmField); + if (fkFields.every((fkField) => hasAttribute(fkField, '@default'))) { + return true; + } + } + + return false; } generateFieldValidators(zodStringWithMainType: string, field: PrismaDMMF.SchemaArg) { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 8313d662c..5b34a0e0c 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -381,6 +381,13 @@ export function getRelationField(fkField: DataModelField) { }); } +/** + * Gets the foreign key fields of the given relation field. + */ +export function getForeignKeyFields(relationField: DataModelField) { + return getRelationKeyPairs(relationField).map((pair) => pair.foreignKey); +} + export function resolvePath(_path: string, options: Pick) { if (path.isAbsolute(_path)) { return _path; diff --git a/tests/regression/tests/issue-1843.test.ts b/tests/regression/tests/issue-1843.test.ts new file mode 100644 index 000000000..518262857 --- /dev/null +++ b/tests/regression/tests/issue-1843.test.ts @@ -0,0 +1,108 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1843', () => { + it('regression', async () => { + const { zodSchemas, enhance, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + email String @unique @email @length(6, 32) + password String @password @omit + contents Content[] + postsCoauthored PostWithCoauthor[] + + @@allow('all', true) + } + + abstract model Owner { + owner User @relation(fields: [ownerId], references: [id]) + ownerId String @default(auth().id) + } + + abstract model BaseContent extends Owner { + published Boolean @default(false) + + @@index([published]) + } + + model Content extends BaseContent { + id String @id @default(cuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + + contentType String + @@allow('all', true) + + @@delegate(contentType) + } + + model PostWithCoauthor extends Content { + title String + + coauthor User @relation(fields: [coauthorId], references: [id]) + coauthorId String + + @@allow('all', true) + } + + model Post extends Content { + title String + + @@allow('all', true) + } + `, + { + compile: true, + extraSourceFiles: [ + { + name: 'main.ts', + content: ` + import { PrismaClient } from '@prisma/client'; + import { enhance } from '.zenstack/enhance'; + + async function main() { + const enhanced = enhance(new PrismaClient()); + await enhanced.postWithCoauthor.create({ + data: { + title: "new post", + coauthor: { + connect: { + id: "1" + } + }, + } + }); + + await enhanced.postWithCoauthor.create({ + data: { + title: "new post", + coauthorId: "1" + } + }); + } + `, + }, + ], + } + ); + + const user = await prisma.user.create({ data: { email: 'abc', password: '123' } }); + const db = enhance({ id: user.id }); + + // connect + await expect( + db.postWithCoauthor.create({ data: { title: 'new post', coauthor: { connect: { id: user.id } } } }) + ).toResolveTruthy(); + + // fk setting + await expect( + db.postWithCoauthor.create({ data: { title: 'new post', coauthorId: user.id } }) + ).toResolveTruthy(); + + // zod validation + zodSchemas.models.PostWithCoauthorCreateSchema.parse({ + title: 'new post', + coauthorId: '1', + }); + }); +});