Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make sure both fk and relation fields are optional in create input types #1862

Merged
merged 4 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions packages/runtime/src/enhancements/node/delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
let curr = args;
let base = this.getBaseModel(model);
let sub = this.getModelInfo(model);
const hasDelegateBase = !!base;

while (base) {
const baseRelationName = this.makeAuxRelationName(base);
Expand Down Expand Up @@ -615,6 +616,55 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
sub = base;
base = this.getBaseModel(base.name);
}

if (hasDelegateBase) {
// A delegate base model creation is added, this can be incompatible if
// the user-provided payload assigns foreign keys directly, because Prisma
// doesn't permit mixed "checked" and "unchecked" fields in a payload.
//
// {
// delegate_aux_base: { ... },
// [fkField]: value // <- this is not compatible
// }
//
// We need to convert foreign key assignments to `connect`.
this.fkAssignmentToConnect(model, args);
}
}

// convert foreign key assignments to `connect` payload
// e.g.: { authorId: value } -> { author: { connect: { id: value } } }
private fkAssignmentToConnect(model: string, args: any) {
const keysToDelete: string[] = [];
for (const [key, value] of Object.entries(args)) {
if (value === undefined) {
continue;
}

const fieldInfo = this.queryUtils.getModelField(model, key);
if (
!fieldInfo?.inheritedFrom && // fields from delegate base are handled outside
fieldInfo?.isForeignKey
) {
const relationInfo = this.queryUtils.getRelationForForeignKey(model, key);
if (relationInfo) {
// turn { [fk]: value } into { [relation]: { connect: { [id]: value } } }
const relationName = relationInfo.relation.name;
if (!args[relationName]) {
args[relationName] = {};
}
if (!args[relationName].connect) {
args[relationName].connect = {};
}
if (!(relationInfo.idField in args[relationName].connect)) {
args[relationName].connect[relationInfo.idField] = value;
keysToDelete.push(key);
}
}
}
}

keysToDelete.forEach((key) => delete args[key]);
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
}

// inject field data that belongs to base type into proper nesting structure
Expand Down
21 changes: 21 additions & 0 deletions packages/runtime/src/enhancements/node/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,25 @@ export class QueryUtils {

return model;
}

/**
* Gets relation info for a foreign key field.
*/
getRelationForForeignKey(model: string, fkField: string) {
const modelInfo = getModelInfo(this.options.modelMeta, model);
if (!modelInfo) {
return undefined;
}

for (const field of Object.values(modelInfo.fields)) {
if (field.foreignKeyMapping) {
const entry = Object.entries(field.foreignKeyMapping).find(([, v]) => v === fkField);
if (entry) {
return { relation: field, idField: entry[0], fkField: entry[1] };
}
}
}

return undefined;
}
}
209 changes: 127 additions & 82 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
getDataModelAndTypeDefs,
getDataModels,
getLiteral,
getRelationField,
isDelegateModel,
isDiscriminatorField,
normalizedRelative,
Expand Down Expand Up @@ -55,12 +56,23 @@ type DelegateInfo = [DataModel, DataModel[]][];
const LOGICAL_CLIENT_GENERATION_PATH = './.logical-prisma-client';

export class EnhancerGenerator {
// regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type
// names for models that use `auth()` in `@default` attribute
private readonly modelsWithAuthInDefaultCreateInputPattern: RegExp;

constructor(
private readonly model: Model,
private readonly options: PluginOptions,
private readonly project: Project,
private readonly outDir: string
) {}
) {
const modelsWithAuthInDefault = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => f.attributes.some(isDefaultWithAuth))
);
this.modelsWithAuthInDefaultCreateInputPattern = new RegExp(
`^(${modelsWithAuthInDefault.map((m) => m.name).join('|')})(Unchecked)?Create.*?Input$`
);
}
ymc9 marked this conversation as resolved.
Show resolved Hide resolved

async generate(): Promise<{ dmmf: DMMF.Document | undefined; newPrismaClientDtsPath: string | undefined }> {
let dmmf: DMMF.Document | undefined;
Expand All @@ -69,7 +81,7 @@ export class EnhancerGenerator {
let prismaTypesFixed = false;
let resultPrismaImport = prismaImport;

if (this.needsLogicalClient || this.needsPrismaClientTypeFixes) {
if (this.needsLogicalClient) {
prismaTypesFixed = true;
resultPrismaImport = `${LOGICAL_CLIENT_GENERATION_PATH}/index-fixed`;
const result = await this.generateLogicalPrisma();
Expand Down Expand Up @@ -230,11 +242,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
}

private get needsLogicalClient() {
return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model);
}

private get needsPrismaClientTypeFixes() {
return this.hasTypeDef(this.model);
return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model) || this.hasTypeDef(this.model);
}

private hasDelegateModel(model: Model) {
Expand Down Expand Up @@ -449,11 +457,13 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const auxFields = this.findAuxDecls(variable);
if (auxFields.length > 0) {
structure.declarations.forEach((variable) => {
let source = variable.type?.toString();
auxFields.forEach((f) => {
source = source?.replace(f.getText(), '');
});
variable.type = source;
if (variable.type) {
let source = variable.type.toString();
auxFields.forEach((f) => {
source = this.removeFromSource(source, f.getText());
});
variable.type = source;
}
});
}

Expand Down Expand Up @@ -498,72 +508,16 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
// fix delegate payload union type
source = this.fixDelegatePayloadType(typeAlias, delegateInfo, source);

// fix fk and relation fields related to using `auth()` in `@default`
source = this.fixDefaultAuthType(typeAlias, source);

// fix json field type
source = this.fixJsonFieldType(typeAlias, source);

structure.type = source;
return structure;
}

private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) {
const modelsWithTypeField = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref))
);
const typeName = typeAlias.getName();

const getTypedJsonFields = (model: DataModel) => {
return model.fields.filter((f) => isTypeDef(f.type.reference?.ref));
};

const replacePrismaJson = (source: string, field: DataModelField) => {
return source.replace(
new RegExp(`(${field.name}\\??\\s*):[^\\n]+`),
`$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${
field.type.optional ? ' | null' : ''
}`
);
};

// fix "$[Model]Payload" type
const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName);
if (payloadModelMatch) {
const scalars = typeAlias
.getDescendantsOfKind(SyntaxKind.PropertySignature)
.find((p) => p.getName() === 'scalars');
if (!scalars) {
return source;
}

const fieldsToFix = getTypedJsonFields(payloadModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}

// fix input/output types, "[Model]CreateInput", etc.
const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name));
if (inputOutputModelMatch) {
const relevantTypePatterns = [
'GroupByOutputType',
'(Unchecked)?Create(\\S+?)?Input',
'(Unchecked)?Update(\\S+?)?Input',
'CreateManyInput',
'(Unchecked)?UpdateMany(Mutation)?Input',
];
const typeRegex = modelsWithTypeField.map(
(m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`)
);
if (typeRegex.some((r) => r.test(typeName))) {
const fieldsToFix = getTypedJsonFields(inputOutputModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}
}

return source;
}

private fixDelegatePayloadType(typeAlias: TypeAliasDeclaration, delegateInfo: DelegateInfo, source: string) {
// change the type of `$<DelegateModel>Payload` type of delegate model to a union of concrete types
const typeName = typeAlias.getName();
Expand Down Expand Up @@ -595,7 +549,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
.getDescendantsOfKind(SyntaxKind.PropertySignature)
.filter((p) => ['create', 'createMany', 'connectOrCreate', 'upsert'].includes(p.getName()));
toRemove.forEach((r) => {
source = source.replace(r.getText(), '');
this.removeFromSource(source, r.getText());
});
}
return source;
Expand Down Expand Up @@ -633,7 +587,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
if (isDiscriminatorField(field)) {
const fieldDef = this.findNamedProperty(typeAlias, field.name);
if (fieldDef) {
source = source.replace(fieldDef.getText(), '');
source = this.removeFromSource(source, fieldDef.getText());
}
}
}
Expand All @@ -646,7 +600,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const auxDecls = this.findAuxDecls(typeAlias);
if (auxDecls.length > 0) {
auxDecls.forEach((d) => {
source = source.replace(d.getText(), '');
source = this.removeFromSource(source, d.getText());
});
}
return source;
Expand Down Expand Up @@ -677,7 +631,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const fieldDef = this.findNamedProperty(typeAlias, relationFieldName);
if (fieldDef) {
// remove relation field of delegate type, e.g., `asset`
source = source.replace(fieldDef.getText(), '');
source = this.removeFromSource(source, fieldDef.getText());
}

// remove fk fields related to the delegate type relation, e.g., `assetId`
Expand Down Expand Up @@ -709,13 +663,103 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
fkFields.forEach((fkField) => {
const fieldDef = this.findNamedProperty(typeAlias, fkField);
if (fieldDef) {
source = source.replace(fieldDef.getText(), '');
source = this.removeFromSource(source, fieldDef.getText());
}
});

return source;
}

private fixDefaultAuthType(typeAlias: TypeAliasDeclaration, source: string) {
const match = typeAlias.getName().match(this.modelsWithAuthInDefaultCreateInputPattern);
if (!match) {
return source;
}

const modelName = match[1];
const dataModel = this.model.declarations.find((d): d is DataModel => isDataModel(d) && d.name === modelName);
if (dataModel) {
for (const fkField of dataModel.fields.filter((f) => f.attributes.some(isDefaultWithAuth))) {
// change fk field to optional since it has a default
source = source.replace(new RegExp(`^(\\s*${fkField.name}\\s*):`, 'm'), `$1?:`);

const relationField = getRelationField(fkField);
if (relationField) {
// change relation field to optional since its fk has a default
source = source.replace(new RegExp(`^(\\s*${relationField.name}\\s*):`, 'm'), `$1?:`);
}
ymc9 marked this conversation as resolved.
Show resolved Hide resolved
}
}
return source;
}

private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) {
const modelsWithTypeField = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref))
);
const typeName = typeAlias.getName();

const getTypedJsonFields = (model: DataModel) => {
return model.fields.filter((f) => isTypeDef(f.type.reference?.ref));
};

const replacePrismaJson = (source: string, field: DataModelField) => {
return source.replace(
new RegExp(`(${field.name}\\??\\s*):[^\\n]+`),
`$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${
field.type.optional ? ' | null' : ''
}`
);
};
ymc9 marked this conversation as resolved.
Show resolved Hide resolved

// fix "$[Model]Payload" type
const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName);
if (payloadModelMatch) {
const scalars = typeAlias
.getDescendantsOfKind(SyntaxKind.PropertySignature)
.find((p) => p.getName() === 'scalars');
if (!scalars) {
return source;
}

const fieldsToFix = getTypedJsonFields(payloadModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}

// fix input/output types, "[Model]CreateInput", etc.
const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name));
if (inputOutputModelMatch) {
const relevantTypePatterns = [
'GroupByOutputType',
'(Unchecked)?Create(\\S+?)?Input',
'(Unchecked)?Update(\\S+?)?Input',
'CreateManyInput',
'(Unchecked)?UpdateMany(Mutation)?Input',
];
const typeRegex = modelsWithTypeField.map(
(m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`)
);
if (typeRegex.some((r) => r.test(typeName))) {
const fieldsToFix = getTypedJsonFields(inputOutputModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}
}

return source;
}

private async generateExtraTypes(sf: SourceFile) {
for (const decl of this.model.declarations) {
if (isTypeDef(decl)) {
generateTypeDefType(sf, decl);
}
}
}

private findNamedProperty(typeAlias: TypeAliasDeclaration, name: string) {
return typeAlias.getFirstDescendant((d) => d.isKind(SyntaxKind.PropertySignature) && d.getName() === name);
}
Expand Down Expand Up @@ -745,11 +789,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
return this.options.generatePermissionChecker === true;
}

private async generateExtraTypes(sf: SourceFile) {
for (const decl of this.model.declarations) {
if (isTypeDef(decl)) {
generateTypeDefType(sf, decl);
}
}
private removeFromSource(source: string, text: string) {
source = source.replace(text, '');
return this.trimEmptyLines(source);
}

private trimEmptyLines(source: string): string {
return source.replace(/^\s*[\r\n]/gm, '');
}
}
Loading
Loading