Skip to content

Commit

Permalink
fix: make sure the logical DMMF respects auth() in @default
Browse files Browse the repository at this point in the history
fixes #1893
  • Loading branch information
ymc9 committed Dec 4, 2024
1 parent d5c30f9 commit 4574154
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
30 changes: 30 additions & 0 deletions packages/plugins/openapi/tests/openapi-rpc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,36 @@ model post_Item {

await OpenAPIParser.validate(output);
});

it('auth() in @default()', async () => {
const { projectDir } = await loadSchema(`
plugin openapi {
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
output = '$projectRoot/openapi.yaml'
flavor = 'rpc'
}
model User {
id Int @id
posts Post[]
}
model Post {
id Int @id
title String
author User @relation(fields: [authorId], references: [id])
authorId Int @default(auth().id)
}
`);

const output = path.join(projectDir, 'openapi.yaml');
console.log('OpenAPI specification generated:', output);

await OpenAPIParser.validate(output);
const parsed = YAML.parse(fs.readFileSync(output, 'utf-8'));
expect(parsed.components.schemas.PostCreateInput.required).not.toContain('author');
expect(parsed.components.schemas.PostCreateManyInput.required).not.toContain('authorId');
});
});

function buildOptions(model: Model, modelFile: string, output: string) {
Expand Down
56 changes: 55 additions & 1 deletion packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { ReadonlyDeep } from '@prisma/generator-helper';
import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime';
import {
PluginError,
Expand All @@ -6,8 +7,10 @@ import {
getAuthDecl,
getDataModelAndTypeDefs,
getDataModels,
getForeignKeyFields,
getLiteral,
getRelationField,
hasAttribute,
isDelegateModel,
isDiscriminatorField,
normalizedRelative,
Expand Down Expand Up @@ -311,7 +314,8 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
// make a bunch of typing fixes to the generated prisma client
await this.processClientTypes(path.join(this.outDir, LOGICAL_CLIENT_GENERATION_PATH));

const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) });
// get the dmmf of the logical prisma schema
const dmmf = await this.getLogicalDMMF(logicalPrismaFile);

try {
// clean up temp schema
Expand All @@ -329,6 +333,56 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
};
}

private async getLogicalDMMF(logicalPrismaFile: string) {
const dmmf = await getDMMF({ datamodel: fs.readFileSync(logicalPrismaFile, { encoding: 'utf-8' }) });

// make necessary fixes

// fields that use `auth()` in `@default` are not handled by Prisma so in the DMMF
// they may be incorrectly represented as required, we need to fix that for input types
// also, if a FK field is of such case, its corresponding relation field should be optional
const createInputPattern = new RegExp(`^(.+?)(Unchecked)?Create.*Input$`);
for (const inputType of dmmf.schema.inputObjectTypes.prisma) {
const match = inputType.name.match(createInputPattern);
const modelName = match?.[1];
if (modelName) {
const dataModel = this.model.declarations.find(
(d): d is DataModel => isDataModel(d) && d.name === modelName
);
if (dataModel) {
for (const field of inputType.fields) {
if (field.isRequired && this.shouldBeOptional(field, dataModel)) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(field as any).isRequired = false;
}
}
}
}
}
return dmmf;
}

private shouldBeOptional(field: ReadonlyDeep<DMMF.SchemaArg>, dataModel: DataModel) {
const dmField = dataModel.fields.find((f) => f.name === field.name);
if (!dmField) {
return false;
}

if (hasAttribute(dmField, '@default')) {
return true;
}

if (isDataModel(dmField.type.reference?.ref)) {
// if FK field should be optional, the relation field should too
const fkFields = getForeignKeyFields(dmField);
if (fkFields.length > 0 && fkFields.every((f) => hasAttribute(f, '@default'))) {
return true;
}
}

return false;
}

private getPrismaClientGeneratorName(model: Model) {
for (const generator of model.declarations.filter(isGeneratorDecl)) {
if (
Expand Down

0 comments on commit 4574154

Please sign in to comment.