Skip to content

Commit

Permalink
fix: several issues with using auth() in @default (#1088)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Mar 7, 2024
1 parent 2e81a08 commit 36e515e
Show file tree
Hide file tree
Showing 12 changed files with 271 additions and 69 deletions.
9 changes: 8 additions & 1 deletion packages/runtime/src/cross/model-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ export type FieldInfo = {
isForeignKey?: boolean;

/**
* Mapping from foreign key field names to relation field names
* If the field is a foreign key field, the field name of the corresponding relation field.
* Only available on foreign key fields.
*/
relationField?: string;

/**
* Mapping from foreign key field names to relation field names.
* Only available on relation fields.
*/
foreignKeyMapping?: Record<string, string>;

Expand Down
46 changes: 44 additions & 2 deletions packages/runtime/src/enhancements/default-auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import deepcopy from 'deepcopy';
import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross';
import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields, requireField } from '../cross';
import { DbClientContract } from '../types';
import { EnhancementContext, InternalEnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { isUnsafeMutate } from './utils';

/**
* Gets an enhanced Prisma client that supports `@default(auth())` attribute.
Expand Down Expand Up @@ -68,7 +69,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo);
if (authDefaultValue !== undefined) {
// set field value extracted from `auth()`
data[fieldInfo.name] = authDefaultValue;
this.setAuthDefaultValue(fieldInfo, model, data, authDefaultValue);
}
}
};
Expand All @@ -90,6 +91,47 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
return newArgs;
}

private setAuthDefaultValue(fieldInfo: FieldInfo, model: string, data: any, authDefaultValue: unknown) {
if (fieldInfo.isForeignKey && !isUnsafeMutate(model, data, this.options.modelMeta)) {
// if the field is a fk, and the create payload is not unsafe, we need to translate
// the fk field setting to a `connect` of the corresponding relation field
const relFieldName = fieldInfo.relationField;
if (!relFieldName) {
throw new Error(
`Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found`
);
}
const relationField = requireField(this.options.modelMeta, model, relFieldName);

// construct a `{ connect: { ... } }` payload
let connect = data[relationField.name]?.connect;
if (!connect) {
connect = {};
data[relationField.name] = { connect };
}

// sets the opposite fk field to value `authDefaultValue`
const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo);
if (!oppositeFkFieldName) {
throw new Error(
`Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\``
);
}
connect[oppositeFkFieldName] = authDefaultValue;
} else {
// set default value directly
data[fieldInfo.name] = authDefaultValue;
}
}

private getOppositeFkFieldName(relationField: FieldInfo, fieldInfo: FieldInfo) {
if (!relationField.foreignKeyMapping) {
return undefined;
}
const entry = Object.entries(relationField.foreignKeyMapping).find(([, v]) => v === fieldInfo.name);
return entry?.[0];
}

private getDefaultValueFromAuth(fieldInfo: FieldInfo) {
if (!this.userContext) {
throw new Error(`Evaluating default value of field \`${fieldInfo.name}\` requires a user context`);
Expand Down
21 changes: 2 additions & 19 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import type { EnhancementContext, InternalEnhancementOptions } from '../create-e
import { Logger } from '../logger';
import { PrismaProxyHandler } from '../proxy';
import { QueryUtils } from '../query-utils';
import { formatObject, prismaClientValidationError } from '../utils';
import { formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
import { PolicyUtil } from './policy-utils';
import { createDeferredPromise } from './promise';

Expand Down Expand Up @@ -691,7 +691,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
// operations. E.g.:
// - safe: { data: { user: { connect: { id: 1 }} } }
// - unsafe: { data: { userId: 1 } }
const unsafe = this.isUnsafeMutate(model, args);
const unsafe = isUnsafeMutate(model, args, this.modelMeta);

// handles the connection to upstream entity
const reversedQuery = this.policyUtils.buildReversedQuery(context, true, unsafe);
Expand Down Expand Up @@ -1083,23 +1083,6 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
}

private isUnsafeMutate(model: string, args: any) {
if (!args) {
return false;
}
for (const k of Object.keys(args)) {
const field = resolveField(this.modelMeta, model, k);
if (field && (this.isAutoIncrementIdField(field) || field.isForeignKey)) {
return true;
}
}
return false;
}

private isAutoIncrementIdField(field: FieldInfo) {
return field.isId && field.isAutoIncrement;
}

async updateMany(args: any) {
if (!args) {
throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required');
Expand Down
19 changes: 19 additions & 0 deletions packages/runtime/src/enhancements/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as util from 'util';
import { FieldInfo, ModelMeta, resolveField } from '..';
import type { DbClientContract } from '../types';

/**
Expand All @@ -22,3 +23,21 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo
export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error {
throw new prismaModule.PrismaClientUnknownRequestError(...args);
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function isUnsafeMutate(model: string, args: any, modelMeta: ModelMeta) {
if (!args) {
return false;
}
for (const k of Object.keys(args)) {
const field = resolveField(modelMeta, model, k);
if (field && (isAutoIncrementIdField(field) || field.isForeignKey)) {
return true;
}
}
return false;
}

export function isAutoIncrementIdField(field: FieldInfo) {
return field.isId && field.isAutoIncrement;
}
26 changes: 22 additions & 4 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { name } from '..';
import { execPackage } from '../../../utils/exec-utils';
import { trackPrismaSchemaError } from '../../prisma';
import { PrismaSchemaGenerator } from '../../prisma/schema-generator';
import { isDefaultWithAuth } from '../enhancer-utils';

// information of delegate models and their sub models
type DelegateInfo = [DataModel, DataModel[]][];
Expand All @@ -35,7 +36,7 @@ export async function generate(model: Model, options: PluginOptions, project: Pr
let logicalPrismaClientDir: string | undefined;
let dmmf: DMMF.Document | undefined;

if (hasDelegateModel(model)) {
if (needsLogicalClient(model)) {
// schema contains delegate models, need to generate a logical prisma schema
const result = await generateLogicalPrisma(model, options, outDir);

Expand Down Expand Up @@ -86,13 +87,23 @@ export function enhance<DbClient extends object>(prisma: DbClient, context?: Enh
return { dmmf };
}

function needsLogicalClient(model: Model) {
return hasDelegateModel(model) || hasAuthInDefault(model);
}

function hasDelegateModel(model: Model) {
const dataModels = getDataModels(model);
return dataModels.some(
(dm) => isDelegateModel(dm) && dataModels.some((sub) => sub.superTypes.some((base) => base.ref === dm))
);
}

function hasAuthInDefault(model: Model) {
return getDataModels(model).some((dm) =>
dm.fields.some((f) => f.attributes.some((attr) => isDefaultWithAuth(attr)))
);
}

async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) {
const prismaGenerator = new PrismaSchemaGenerator(model);
const prismaClientOutDir = './.logical-prisma-client';
Expand Down Expand Up @@ -152,12 +163,19 @@ async function processClientTypes(model: Model, prismaClientDir: string) {
const sfNew = project.createSourceFile(path.join(prismaClientDir, 'index-fixed.d.ts'), undefined, {
overwrite: true,
});
transform(sf, sfNew, delegateInfo);
sfNew.formatText();

if (delegateInfo.length > 0) {
// transform types for delegated models
transformDelegate(sf, sfNew, delegateInfo);
sfNew.formatText();
} else {
// just copy
sfNew.replaceWithText(sf.getFullText());
}
await sfNew.save();
}

function transform(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) {
function transformDelegate(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) {
// copy toplevel imports
sfNew.addImportDeclarations(sf.getImportDeclarations().map((n) => n.getStructure()));

Expand Down
20 changes: 20 additions & 0 deletions packages/schema/src/plugins/enhancer/enhancer-utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { isAuthInvocation } from '@zenstackhq/sdk';
import type { DataModelFieldAttribute } from '@zenstackhq/sdk/ast';
import { streamAst } from 'langium';

/**
* Check if the given field attribute is a `@default` with `auth()` invocation
*/
export function isDefaultWithAuth(attr: DataModelFieldAttribute) {
if (attr.decl.ref?.name !== '@default') {
return false;
}

const expr = attr.args[0]?.value;
if (!expr) {
return false;
}

// find `auth()` in default value expression
return streamAst(expr).some(isAuthInvocation);
}
38 changes: 21 additions & 17 deletions packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,27 @@ import { getIdFields } from '../../utils/ast-utils';
import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime';
import {
getAttribute,
getForeignKeyFields,
getLiteral,
getPrismaVersion,
isAuthInvocation,
isDelegateModel,
isIdField,
isRelationshipField,
PluginError,
PluginOptions,
resolved,
ZModelCodeGenerator,
} from '@zenstackhq/sdk';
import fs from 'fs';
import { writeFile } from 'fs/promises';
import { streamAst } from 'langium';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
import semver from 'semver';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '.';
import { getStringLiteral } from '../../language-server/validator/utils';
import { execPackage } from '../../utils/exec-utils';
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
import {
AttributeArgValue,
ModelFieldType,
Expand Down Expand Up @@ -494,10 +495,27 @@ export class PrismaSchemaGenerator {

const type = new ModelFieldType(fieldType, field.type.array, field.type.optional);

if (this.mode === 'logical') {
if (field.attributes.some((attr) => isDefaultWithAuth(attr))) {
// field has `@default` with `auth()`, it should be set optional, and the
// default value setting is handled outside Prisma
type.optional = true;
}

if (isRelationshipField(field)) {
// if foreign key field has `@default` with `auth()`, the relation
// field should be set optional
const foreignKeyFields = getForeignKeyFields(field);
if (foreignKeyFields.some((fkField) => fkField.attributes.some((attr) => isDefaultWithAuth(attr)))) {
type.optional = true;
}
}
}

const attributes = field.attributes
.filter((attr) => this.isPrismaAttribute(attr))
// `@default` with `auth()` is handled outside Prisma
.filter((attr) => !this.isDefaultWithAuth(attr))
.filter((attr) => !isDefaultWithAuth(attr))
.filter(
(attr) =>
// when building physical schema, exclude `@default` for id fields inherited from delegate base
Expand All @@ -524,20 +542,6 @@ export class PrismaSchemaGenerator {
return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom);
}

private isDefaultWithAuth(attr: DataModelFieldAttribute) {
if (attr.decl.ref?.name !== '@default') {
return false;
}

const expr = attr.args[0]?.value;
if (!expr) {
return false;
}

// find `auth()` in default value expression
return streamAst(expr).some(isAuthInvocation);
}

private makeFieldAttribute(attr: DataModelFieldAttribute) {
const attrName = resolved(attr.decl).name;
if (attrName === FIELD_PASSTHROUGH_ATTR) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ describe('Attribute tests', () => {
`);

await loadModel(`
${ prelude }
${prelude}
model A {
id String @id
x String
Expand Down Expand Up @@ -1051,21 +1051,6 @@ describe('Attribute tests', () => {
}
`);

// expect(
// await loadModelWithError(`
// ${prelude}

// model User {
// id String @id
// name String
// }
// model B {
// id String @id
// userData String @default(auth())
// }
// `)
// ).toContain("Value is not assignable to parameter");

expect(
await loadModelWithError(`
${prelude}
Expand Down Expand Up @@ -1185,15 +1170,6 @@ describe('Attribute tests', () => {
});

it('incorrect function expression context', async () => {
// expect(
// await loadModelWithError(`
// ${prelude}
// model M {
// id String @id @default(auth())
// }
// `)
// ).toContain('function "auth" is not allowed in the current context: DefaultValue');

expect(
await loadModelWithError(`
${prelude}
Expand Down
7 changes: 6 additions & 1 deletion packages/sdk/src/model-meta-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import {
isIdField,
resolved,
TypeScriptExpressionTransformer,
getRelationField,
} from '.';

/**
Expand Down Expand Up @@ -247,6 +248,11 @@ function writeFields(
if (isForeignKeyField(f)) {
writer.write(`
isForeignKey: true,`);
const relationField = getRelationField(f);
if (relationField) {
writer.write(`
relationField: '${relationField.name}',`);
}
}

if (fkMapping && Object.keys(fkMapping).length > 0) {
Expand Down Expand Up @@ -408,7 +414,6 @@ function generateForeignKeyMapping(field: DataModelField) {
const fieldNames = fields.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined));
const referenceNames = references.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined));

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const result: Record<string, string> = {};
referenceNames.forEach((name, i) => {
if (name) {
Expand Down
Loading

0 comments on commit 36e515e

Please sign in to comment.