Skip to content

Commit

Permalink
fix: make sure both fk and relation fields are optional in create inp…
Browse files Browse the repository at this point in the history
…ut types (#1862)
  • Loading branch information
ymc9 authored Nov 15, 2024
1 parent ad07053 commit 285b258
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 115 deletions.
50 changes: 50 additions & 0 deletions packages/runtime/src/enhancements/node/delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
let curr = args;
let base = this.getBaseModel(model);
let sub = this.getModelInfo(model);
const hasDelegateBase = !!base;

while (base) {
const baseRelationName = this.makeAuxRelationName(base);
Expand Down Expand Up @@ -615,6 +616,55 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
sub = base;
base = this.getBaseModel(base.name);
}

if (hasDelegateBase) {
// A delegate base model creation is added, this can be incompatible if
// the user-provided payload assigns foreign keys directly, because Prisma
// doesn't permit mixed "checked" and "unchecked" fields in a payload.
//
// {
// delegate_aux_base: { ... },
// [fkField]: value // <- this is not compatible
// }
//
// We need to convert foreign key assignments to `connect`.
this.fkAssignmentToConnect(model, args);
}
}

// convert foreign key assignments to `connect` payload
// e.g.: { authorId: value } -> { author: { connect: { id: value } } }
private fkAssignmentToConnect(model: string, args: any) {
const keysToDelete: string[] = [];
for (const [key, value] of Object.entries(args)) {
if (value === undefined) {
continue;
}

const fieldInfo = this.queryUtils.getModelField(model, key);
if (
!fieldInfo?.inheritedFrom && // fields from delegate base are handled outside
fieldInfo?.isForeignKey
) {
const relationInfo = this.queryUtils.getRelationForForeignKey(model, key);
if (relationInfo) {
// turn { [fk]: value } into { [relation]: { connect: { [id]: value } } }
const relationName = relationInfo.relation.name;
if (!args[relationName]) {
args[relationName] = {};
}
if (!args[relationName].connect) {
args[relationName].connect = {};
}
if (!(relationInfo.idField in args[relationName].connect)) {
args[relationName].connect[relationInfo.idField] = value;
keysToDelete.push(key);
}
}
}
}

keysToDelete.forEach((key) => delete args[key]);
}

// inject field data that belongs to base type into proper nesting structure
Expand Down
21 changes: 21 additions & 0 deletions packages/runtime/src/enhancements/node/query-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -232,4 +232,25 @@ export class QueryUtils {

return model;
}

/**
* Gets relation info for a foreign key field.
*/
getRelationForForeignKey(model: string, fkField: string) {
const modelInfo = getModelInfo(this.options.modelMeta, model);
if (!modelInfo) {
return undefined;
}

for (const field of Object.values(modelInfo.fields)) {
if (field.foreignKeyMapping) {
const entry = Object.entries(field.foreignKeyMapping).find(([, v]) => v === fkField);
if (entry) {
return { relation: field, idField: entry[0], fkField: entry[1] };
}
}
}

return undefined;
}
}
209 changes: 127 additions & 82 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
getDataModelAndTypeDefs,
getDataModels,
getLiteral,
getRelationField,
isDelegateModel,
isDiscriminatorField,
normalizedRelative,
Expand Down Expand Up @@ -55,12 +56,23 @@ type DelegateInfo = [DataModel, DataModel[]][];
const LOGICAL_CLIENT_GENERATION_PATH = './.logical-prisma-client';

export class EnhancerGenerator {
// regex for matching "ModelCreateXXXInput" and "ModelUncheckedCreateXXXInput" type
// names for models that use `auth()` in `@default` attribute
private readonly modelsWithAuthInDefaultCreateInputPattern: RegExp;

constructor(
private readonly model: Model,
private readonly options: PluginOptions,
private readonly project: Project,
private readonly outDir: string
) {}
) {
const modelsWithAuthInDefault = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => f.attributes.some(isDefaultWithAuth))
);
this.modelsWithAuthInDefaultCreateInputPattern = new RegExp(
`^(${modelsWithAuthInDefault.map((m) => m.name).join('|')})(Unchecked)?Create.*?Input$`
);
}

async generate(): Promise<{ dmmf: DMMF.Document | undefined; newPrismaClientDtsPath: string | undefined }> {
let dmmf: DMMF.Document | undefined;
Expand All @@ -69,7 +81,7 @@ export class EnhancerGenerator {
let prismaTypesFixed = false;
let resultPrismaImport = prismaImport;

if (this.needsLogicalClient || this.needsPrismaClientTypeFixes) {
if (this.needsLogicalClient) {
prismaTypesFixed = true;
resultPrismaImport = `${LOGICAL_CLIENT_GENERATION_PATH}/index-fixed`;
const result = await this.generateLogicalPrisma();
Expand Down Expand Up @@ -230,11 +242,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
}

private get needsLogicalClient() {
return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model);
}

private get needsPrismaClientTypeFixes() {
return this.hasTypeDef(this.model);
return this.hasDelegateModel(this.model) || this.hasAuthInDefault(this.model) || this.hasTypeDef(this.model);
}

private hasDelegateModel(model: Model) {
Expand Down Expand Up @@ -449,11 +457,13 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const auxFields = this.findAuxDecls(variable);
if (auxFields.length > 0) {
structure.declarations.forEach((variable) => {
let source = variable.type?.toString();
auxFields.forEach((f) => {
source = source?.replace(f.getText(), '');
});
variable.type = source;
if (variable.type) {
let source = variable.type.toString();
auxFields.forEach((f) => {
source = this.removeFromSource(source, f.getText());
});
variable.type = source;
}
});
}

Expand Down Expand Up @@ -498,72 +508,16 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
// fix delegate payload union type
source = this.fixDelegatePayloadType(typeAlias, delegateInfo, source);

// fix fk and relation fields related to using `auth()` in `@default`
source = this.fixDefaultAuthType(typeAlias, source);

// fix json field type
source = this.fixJsonFieldType(typeAlias, source);

structure.type = source;
return structure;
}

private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) {
const modelsWithTypeField = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref))
);
const typeName = typeAlias.getName();

const getTypedJsonFields = (model: DataModel) => {
return model.fields.filter((f) => isTypeDef(f.type.reference?.ref));
};

const replacePrismaJson = (source: string, field: DataModelField) => {
return source.replace(
new RegExp(`(${field.name}\\??\\s*):[^\\n]+`),
`$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${
field.type.optional ? ' | null' : ''
}`
);
};

// fix "$[Model]Payload" type
const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName);
if (payloadModelMatch) {
const scalars = typeAlias
.getDescendantsOfKind(SyntaxKind.PropertySignature)
.find((p) => p.getName() === 'scalars');
if (!scalars) {
return source;
}

const fieldsToFix = getTypedJsonFields(payloadModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}

// fix input/output types, "[Model]CreateInput", etc.
const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name));
if (inputOutputModelMatch) {
const relevantTypePatterns = [
'GroupByOutputType',
'(Unchecked)?Create(\\S+?)?Input',
'(Unchecked)?Update(\\S+?)?Input',
'CreateManyInput',
'(Unchecked)?UpdateMany(Mutation)?Input',
];
const typeRegex = modelsWithTypeField.map(
(m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`)
);
if (typeRegex.some((r) => r.test(typeName))) {
const fieldsToFix = getTypedJsonFields(inputOutputModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}
}

return source;
}

private fixDelegatePayloadType(typeAlias: TypeAliasDeclaration, delegateInfo: DelegateInfo, source: string) {
// change the type of `$<DelegateModel>Payload` type of delegate model to a union of concrete types
const typeName = typeAlias.getName();
Expand Down Expand Up @@ -595,7 +549,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
.getDescendantsOfKind(SyntaxKind.PropertySignature)
.filter((p) => ['create', 'createMany', 'connectOrCreate', 'upsert'].includes(p.getName()));
toRemove.forEach((r) => {
source = source.replace(r.getText(), '');
this.removeFromSource(source, r.getText());
});
}
return source;
Expand Down Expand Up @@ -633,7 +587,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
if (isDiscriminatorField(field)) {
const fieldDef = this.findNamedProperty(typeAlias, field.name);
if (fieldDef) {
source = source.replace(fieldDef.getText(), '');
source = this.removeFromSource(source, fieldDef.getText());
}
}
}
Expand All @@ -646,7 +600,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const auxDecls = this.findAuxDecls(typeAlias);
if (auxDecls.length > 0) {
auxDecls.forEach((d) => {
source = source.replace(d.getText(), '');
source = this.removeFromSource(source, d.getText());
});
}
return source;
Expand Down Expand Up @@ -677,7 +631,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const fieldDef = this.findNamedProperty(typeAlias, relationFieldName);
if (fieldDef) {
// remove relation field of delegate type, e.g., `asset`
source = source.replace(fieldDef.getText(), '');
source = this.removeFromSource(source, fieldDef.getText());
}

// remove fk fields related to the delegate type relation, e.g., `assetId`
Expand Down Expand Up @@ -709,13 +663,103 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
fkFields.forEach((fkField) => {
const fieldDef = this.findNamedProperty(typeAlias, fkField);
if (fieldDef) {
source = source.replace(fieldDef.getText(), '');
source = this.removeFromSource(source, fieldDef.getText());
}
});

return source;
}

private fixDefaultAuthType(typeAlias: TypeAliasDeclaration, source: string) {
const match = typeAlias.getName().match(this.modelsWithAuthInDefaultCreateInputPattern);
if (!match) {
return source;
}

const modelName = match[1];
const dataModel = this.model.declarations.find((d): d is DataModel => isDataModel(d) && d.name === modelName);
if (dataModel) {
for (const fkField of dataModel.fields.filter((f) => f.attributes.some(isDefaultWithAuth))) {
// change fk field to optional since it has a default
source = source.replace(new RegExp(`^(\\s*${fkField.name}\\s*):`, 'm'), `$1?:`);

const relationField = getRelationField(fkField);
if (relationField) {
// change relation field to optional since its fk has a default
source = source.replace(new RegExp(`^(\\s*${relationField.name}\\s*):`, 'm'), `$1?:`);
}
}
}
return source;
}

private fixJsonFieldType(typeAlias: TypeAliasDeclaration, source: string) {
const modelsWithTypeField = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.fields.some((f) => isTypeDef(f.type.reference?.ref))
);
const typeName = typeAlias.getName();

const getTypedJsonFields = (model: DataModel) => {
return model.fields.filter((f) => isTypeDef(f.type.reference?.ref));
};

const replacePrismaJson = (source: string, field: DataModelField) => {
return source.replace(
new RegExp(`(${field.name}\\??\\s*):[^\\n]+`),
`$1: ${field.type.reference!.$refText}${field.type.array ? '[]' : ''}${
field.type.optional ? ' | null' : ''
}`
);
};

// fix "$[Model]Payload" type
const payloadModelMatch = modelsWithTypeField.find((m) => `$${m.name}Payload` === typeName);
if (payloadModelMatch) {
const scalars = typeAlias
.getDescendantsOfKind(SyntaxKind.PropertySignature)
.find((p) => p.getName() === 'scalars');
if (!scalars) {
return source;
}

const fieldsToFix = getTypedJsonFields(payloadModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}

// fix input/output types, "[Model]CreateInput", etc.
const inputOutputModelMatch = modelsWithTypeField.find((m) => typeName.startsWith(m.name));
if (inputOutputModelMatch) {
const relevantTypePatterns = [
'GroupByOutputType',
'(Unchecked)?Create(\\S+?)?Input',
'(Unchecked)?Update(\\S+?)?Input',
'CreateManyInput',
'(Unchecked)?UpdateMany(Mutation)?Input',
];
const typeRegex = modelsWithTypeField.map(
(m) => new RegExp(`^(${m.name})(${relevantTypePatterns.join('|')})$`)
);
if (typeRegex.some((r) => r.test(typeName))) {
const fieldsToFix = getTypedJsonFields(inputOutputModelMatch);
for (const field of fieldsToFix) {
source = replacePrismaJson(source, field);
}
}
}

return source;
}

private async generateExtraTypes(sf: SourceFile) {
for (const decl of this.model.declarations) {
if (isTypeDef(decl)) {
generateTypeDefType(sf, decl);
}
}
}

private findNamedProperty(typeAlias: TypeAliasDeclaration, name: string) {
return typeAlias.getFirstDescendant((d) => d.isKind(SyntaxKind.PropertySignature) && d.getName() === name);
}
Expand Down Expand Up @@ -745,11 +789,12 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
return this.options.generatePermissionChecker === true;
}

private async generateExtraTypes(sf: SourceFile) {
for (const decl of this.model.declarations) {
if (isTypeDef(decl)) {
generateTypeDefType(sf, decl);
}
}
private removeFromSource(source: string, text: string) {
source = source.replace(text, '');
return this.trimEmptyLines(source);
}

private trimEmptyLines(source: string): string {
return source.replace(/^\s*[\r\n]/gm, '');
}
}
Loading

0 comments on commit 285b258

Please sign in to comment.