diff --git a/packages/plugins/openapi/tests/openapi-rpc.test.ts b/packages/plugins/openapi/tests/openapi-rpc.test.ts index 845116ac3..930341ffb 100644 --- a/packages/plugins/openapi/tests/openapi-rpc.test.ts +++ b/packages/plugins/openapi/tests/openapi-rpc.test.ts @@ -457,6 +457,36 @@ model post_Item { await OpenAPIParser.validate(output); }); + + it('auth() in @default()', async () => { + const { projectDir } = await loadSchema(` +plugin openapi { + provider = '${normalizePath(path.resolve(__dirname, '../dist'))}' + output = '$projectRoot/openapi.yaml' + flavor = 'rpc' +} + +model User { + id Int @id + posts Post[] +} + +model Post { + id Int @id + title String + author User @relation(fields: [authorId], references: [id]) + authorId Int @default(auth().id) +} + `); + + const output = path.join(projectDir, 'openapi.yaml'); + console.log('OpenAPI specification generated:', output); + + await OpenAPIParser.validate(output); + const parsed = YAML.parse(fs.readFileSync(output, 'utf-8')); + expect(parsed.components.schemas.PostCreateInput.required).not.toContain('author'); + expect(parsed.components.schemas.PostCreateManyInput.required).not.toContain('authorId'); + }); }); function buildOptions(model: Model, modelFile: string, output: string) { diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts index 404c16cd8..51f100557 100644 --- a/packages/schema/src/plugins/enhancer/enhance/index.ts +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -1,3 +1,4 @@ +import { ReadonlyDeep } from '@prisma/generator-helper'; import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; import { PluginError, @@ -6,8 +7,10 @@ import { getAuthDecl, getDataModelAndTypeDefs, getDataModels, + getForeignKeyFields, getLiteral, getRelationField, + hasAttribute, isDelegateModel, isDiscriminatorField, normalizedRelative, @@ -311,7 +314,8 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara // make a bunch of typing fixes to the generated prisma client await this.processClientTypes(path.join(this.outDir, LOGICAL_CLIENT_GENERATION_PATH)); - const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) }); + // get the dmmf of the logical prisma schema + const dmmf = await this.getLogicalDMMF(logicalPrismaFile); try { // clean up temp schema @@ -329,6 +333,56 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara }; } + private async getLogicalDMMF(logicalPrismaFile: string) { + const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) }); + + // make necessary fixes + + // fields that use `auth()` in `@default` are not handled by Prisma so in the DMMF + // they may be incorrectly represented as required, we need to fix that for input types + // also, if a FK field is of such case, its corresponding relation field should be optional + const createInputPattern = new RegExp(`^(.+?)(Unchecked)?Create.*Input$`); + for (const inputType of dmmf.schema.inputObjectTypes.prisma) { + const match = inputType.name.match(createInputPattern); + const modelName = match?.[1]; + if (modelName) { + const dataModel = this.model.declarations.find( + (d): d is DataModel => isDataModel(d) && d.name === modelName + ); + if (dataModel) { + for (const field of inputType.fields) { + if (field.isRequired && this.shouldBeOptional(field, dataModel)) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (field as any).isRequired = false; + } + } + } + } + } + return dmmf; + } + + private shouldBeOptional(field: ReadonlyDeep, dataModel: DataModel) { + const dmField = dataModel.fields.find((f) => f.name === field.name); + if (!dmField) { + return false; + } + + if (hasAttribute(dmField, '@default')) { + return true; + } + + if (isDataModel(dmField.type.reference?.ref)) { + // if FK field should be optional, the relation field should too + const fkFields = getForeignKeyFields(dmField); + if (fkFields.length > 0 && fkFields.every((f) => hasAttribute(f, '@default'))) { + return true; + } + } + + return false; + } + private getPrismaClientGeneratorName(model: Model) { for (const generator of model.declarations.filter(isGeneratorDecl)) { if (