Skip to content

Commit

Permalink
fix: compatibility with Prisma's "omit" feature (#1432)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored May 11, 2024
1 parent 23a9bbb commit 296ca25
Show file tree
Hide file tree
Showing 19 changed files with 196 additions and 62 deletions.
19 changes: 8 additions & 11 deletions packages/plugins/openapi/src/rpc-generator.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
// Inspired by: https://github.com/omar-dulaimi/prisma-trpc-generator

import { analyzePolicies, PluginError, requireOption, resolvePath } from '@zenstackhq/sdk';
import { PluginError, analyzePolicies, requireOption, resolvePath, supportCreateMany } from '@zenstackhq/sdk';
import { DataModel, isDataModel } from '@zenstackhq/sdk/ast';
import {
AggregateOperationSupport,
addMissingInputObjectTypesForAggregate,
addMissingInputObjectTypesForInclude,
addMissingInputObjectTypesForModelArgs,
addMissingInputObjectTypesForSelect,
AggregateOperationSupport,
resolveAggregateOperationSupport,
} from '@zenstackhq/sdk/dmmf-helpers';
import type { DMMF } from '@zenstackhq/sdk/prisma';
import { type DMMF } from '@zenstackhq/sdk/prisma';
import * as fs from 'fs';
import { lowerCaseFirst } from 'lower-case-first';
import type { OpenAPIV3_1 as OAPI } from 'openapi-types';
import * as path from 'path';
import invariant from 'tiny-invariant';
import { match, P } from 'ts-pattern';
import { P, match } from 'ts-pattern';
import { upperCaseFirst } from 'upper-case-first';
import YAML from 'yaml';
import { name } from '.';
Expand Down Expand Up @@ -166,7 +166,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase {
});
}

if (ops['createMany']) {
if (ops['createMany'] && supportCreateMany(zmodel.$container)) {
definitions.push({
method: 'post',
operation: 'createMany',
Expand Down Expand Up @@ -704,7 +704,7 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase {
private generateEnumComponent(_enum: DMMF.SchemaEnum): OAPI.SchemaObject {
const schema: OAPI.SchemaObject = {
type: 'string',
enum: _enum.values,
enum: _enum.values as string[],
};
return schema;
}
Expand Down Expand Up @@ -793,17 +793,14 @@ export class RPCOpenAPIGenerator extends OpenAPIGeneratorBase {
return result;
}

private setInputRequired(fields: { name: string; isRequired: boolean }[], result: OAPI.NonArraySchemaObject) {
private setInputRequired(fields: readonly DMMF.SchemaArg[], result: OAPI.NonArraySchemaObject) {
const required = fields.filter((f) => f.isRequired).map((f) => f.name);
if (required.length > 0) {
result.required = required;
}
}

private setOutputRequired(
fields: { name: string; isNullable?: boolean; outputType: DMMF.OutputTypeRef }[],
result: OAPI.NonArraySchemaObject
) {
private setOutputRequired(fields: readonly DMMF.SchemaField[], result: OAPI.NonArraySchemaObject) {
const required = fields.filter((f) => f.isNullable !== true).map((f) => f.name);
if (required.length > 0) {
result.required = required;
Expand Down
3 changes: 2 additions & 1 deletion packages/plugins/swr/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
requireOption,
resolvePath,
saveProject,
supportCreateMany,
} from '@zenstackhq/sdk';
import { DataModel, DataModelFieldType, Model, isEnum } from '@zenstackhq/sdk/ast';
import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma';
Expand Down Expand Up @@ -85,7 +86,7 @@ function generateModelHooks(
}

// createMany
if (mapping.createMany) {
if (mapping.createMany && supportCreateMany(model.$container)) {
const argsType = `Prisma.${model.name}CreateManyArgs`;
mutationFuncs.push(generateMutation(sf, model, 'POST', 'createMany', argsType, true));
}
Expand Down
3 changes: 2 additions & 1 deletion packages/plugins/tanstack-query/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
requireOption,
resolvePath,
saveProject,
supportCreateMany,
} from '@zenstackhq/sdk';
import { DataModel, DataModelFieldType, Model, isEnum } from '@zenstackhq/sdk/ast';
import { getPrismaClientImportSpec, type DMMF } from '@zenstackhq/sdk/prisma';
Expand Down Expand Up @@ -348,7 +349,7 @@ function generateModelHooks(
}

// createMany
if (mapping.createMany) {
if (mapping.createMany && supportCreateMany(model.$container)) {
generateMutationHook(target, sf, model.name, 'createMany', 'post', false, 'Prisma.BatchPayload');
}

Expand Down
15 changes: 11 additions & 4 deletions packages/plugins/trpc/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
requireOption,
resolvePath,
saveProject,
supportCreateMany,
type PluginOptions,
} from '@zenstackhq/sdk';
import { Model } from '@zenstackhq/sdk/ast';
Expand Down Expand Up @@ -79,11 +80,11 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.

function createAppRouter(
outDir: string,
modelOperations: DMMF.ModelMapping[],
modelOperations: readonly DMMF.ModelMapping[],
hiddenModels: string[],
generateModelActions: string[] | undefined,
generateClientHelpers: string[] | undefined,
_zmodel: Model,
zmodel: Model,
zodSchemasImport: string,
options: PluginOptions
) {
Expand Down Expand Up @@ -171,7 +172,8 @@ function createAppRouter(
generateModelActions,
generateClientHelpers,
zodSchemasImport,
options
options,
zmodel
);

appRouter.addImportDeclaration({
Expand Down Expand Up @@ -241,7 +243,8 @@ function generateModelCreateRouter(
generateModelActions: string[] | undefined,
generateClientHelpers: string[] | undefined,
zodSchemasImport: string,
options: PluginOptions
options: PluginOptions,
zmodel: Model
) {
const modelRouter = project.createSourceFile(path.resolve(outputDir, 'routers', `${model}.router.ts`), undefined, {
overwrite: true,
Expand Down Expand Up @@ -298,6 +301,10 @@ function generateModelCreateRouter(
inputType &&
(!generateModelActions || generateModelActions.includes(generateOpName))
) {
if (generateOpName === 'createMany' && !supportCreateMany(zmodel)) {
continue;
}

generateProcedure(funcWriter, generateOpName, upperCaseFirst(inputType), model, baseOpType);

if (routerTypingStructure) {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/trpc/src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ export const getProcedureTypeByOpName = (opName: string) => {
return procType;
};

export function resolveModelsComments(models: DMMF.Model[], hiddenModels: string[]) {
export function resolveModelsComments(models: readonly DMMF.Model[], hiddenModels: string[]) {
const modelAttributeRegex = /(@@Gen\.)+([A-z])+(\()+(.+)+(\))+/;
const attributeNameRegex = /(?:\.)+([A-Za-z])+(?:\()+/;
const attributeArgsRegex = /(?:\()+([A-Za-z])+:+(.+)+(?:\))+/;
Expand Down
11 changes: 11 additions & 0 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,11 @@ export class PolicyUtil extends QueryUtils {
}

private doInjectReadCheckSelect(model: string, args: any, input: any) {
// omit should be ignored to avoid interfering with field selection
if (args.omit) {
delete args.omit;
}

if (!input?.select) {
return;
}
Expand Down Expand Up @@ -1178,6 +1183,12 @@ export class PolicyUtil extends QueryUtils {
continue;
}

if (queryArgs?.omit?.[field] === true) {
// respect `{ omit: { [field]: true } }`
delete entityData[field];
continue;
}

if (hasFieldLevelPolicy) {
// 1. remove fields selected for checking field-level policies but not selected by the original query args
// 2. evaluate field-level policies and remove fields that are not readable
Expand Down
2 changes: 1 addition & 1 deletion packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export class EnhancerGenerator {
private readonly outDir: string
) {}

async generate() {
async generate(): Promise<{ dmmf: DMMF.Document | undefined }> {
let logicalPrismaClientDir: string | undefined;
let dmmf: DMMF.Document | undefined;

Expand Down
7 changes: 5 additions & 2 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export class ZodSchemaGenerator {
project: this.project,
inputObjectTypes,
});
await transformer.generateInputSchemas(this.options);
await transformer.generateInputSchemas(this.options, this.model);
this.sourceFiles.push(...transformer.sourceFiles);
}

Expand Down Expand Up @@ -189,7 +189,10 @@ export class ZodSchemaGenerator {
);
}

private async generateEnumSchemas(prismaSchemaEnum: DMMF.SchemaEnum[], modelSchemaEnum: DMMF.SchemaEnum[]) {
private async generateEnumSchemas(
prismaSchemaEnum: readonly DMMF.SchemaEnum[],
modelSchemaEnum: readonly DMMF.SchemaEnum[]
) {
const enumTypes = [...prismaSchemaEnum, ...modelSchemaEnum];
const enumNames = enumTypes.map((enumItem) => upperCaseFirst(enumItem.name));
Transformer.enumNames = enumNames ?? [];
Expand Down
13 changes: 7 additions & 6 deletions packages/schema/src/plugins/zod/transformer.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { indentString, type PluginOptions } from '@zenstackhq/sdk';
import { indentString, supportCreateMany, type PluginOptions } from '@zenstackhq/sdk';
import type { Model } from '@zenstackhq/sdk/ast';
import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers';
import { type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
import path from 'path';
Expand All @@ -11,12 +12,12 @@ import { AggregateOperationSupport, TransformerParams } from './types';
export default class Transformer {
name: string;
originalName: string;
fields: PrismaDMMF.SchemaArg[];
fields: readonly PrismaDMMF.SchemaArg[];
schemaImports = new Set<string>();
models: PrismaDMMF.Model[];
models: readonly PrismaDMMF.Model[];
modelOperations: PrismaDMMF.ModelMapping[];
aggregateOperationSupport: AggregateOperationSupport;
enumTypes: PrismaDMMF.SchemaEnum[];
enumTypes: readonly PrismaDMMF.SchemaEnum[];

static enumNames: string[] = [];
static rawOpsMap: { [name: string]: string } = {};
Expand Down Expand Up @@ -389,7 +390,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
return wrapped;
}

async generateInputSchemas(options: PluginOptions) {
async generateInputSchemas(options: PluginOptions, zmodel: Model) {
const globalExports: string[] = [];

// whether Prisma's Unchecked* series of input types should be generated
Expand Down Expand Up @@ -489,7 +490,7 @@ export const ${this.name}ObjectSchema: SchemaType = ${schema} as SchemaType;`;
operations.push(['create', origModelName]);
}

if (createMany) {
if (createMany && supportCreateMany(zmodel)) {
imports.push(
`import { ${modelName}CreateManyInputObjectSchema } from '../objects/${modelName}CreateManyInput.schema'`
);
Expand Down
6 changes: 3 additions & 3 deletions packages/schema/src/plugins/zod/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import type { DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
import { Project } from 'ts-morph';

export type TransformerParams = {
enumTypes?: PrismaDMMF.SchemaEnum[];
fields?: PrismaDMMF.SchemaArg[];
enumTypes?: readonly PrismaDMMF.SchemaEnum[];
fields?: readonly PrismaDMMF.SchemaArg[];
name?: string;
models?: PrismaDMMF.Model[];
models?: readonly PrismaDMMF.Model[];
modelOperations?: PrismaDMMF.ModelMapping[];
aggregateOperationSupport?: AggregateOperationSupport;
isDefaultPrismaClientOutput?: boolean;
Expand Down
4 changes: 2 additions & 2 deletions packages/sdk/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"author": "",
"license": "MIT",
"dependencies": {
"@prisma/generator-helper": "5.7.0",
"@prisma/internals": "5.7.0",
"@prisma/generator-helper": "^5.13.0",
"@prisma/internals": "^5.13.0",
"@zenstackhq/language": "workspace:*",
"@zenstackhq/runtime": "workspace:*",
"langium": "1.3.1",
Expand Down
7 changes: 5 additions & 2 deletions packages/sdk/src/dmmf-helpers/include-helpers.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import type { DMMF } from '../prisma';
import { checkIsModelRelationField, checkModelHasManyModelRelation, checkModelHasModelRelation } from './model-helpers';

export function addMissingInputObjectTypesForInclude(inputObjectTypes: DMMF.InputType[], models: DMMF.Model[]) {
export function addMissingInputObjectTypesForInclude(
inputObjectTypes: DMMF.InputType[],
models: readonly DMMF.Model[]
) {
// generate input object types necessary to support ModelInclude with relation support
const generatedIncludeInputObjectTypes = generateModelIncludeInputObjectTypes(models);

for (const includeInputObjectType of generatedIncludeInputObjectTypes) {
inputObjectTypes.push(includeInputObjectType);
}
}
function generateModelIncludeInputObjectTypes(models: DMMF.Model[]) {
function generateModelIncludeInputObjectTypes(models: readonly DMMF.Model[]) {
const modelIncludeInputObjectTypes: DMMF.InputType[] = [];
for (const model of models) {
const { name: modelName, fields: modelFields } = model;
Expand Down
2 changes: 1 addition & 1 deletion packages/sdk/src/dmmf-helpers/model-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ export function checkIsManyModelRelationField(modelField: DMMF.Field) {
return checkIsModelRelationField(modelField) && modelField.isList;
}

export function findModelByName(models: DMMF.Model[], modelName: string) {
export function findModelByName(models: readonly DMMF.Model[], modelName: string) {
return models.find(({ name }) => name === modelName);
}
7 changes: 5 additions & 2 deletions packages/sdk/src/dmmf-helpers/modelArgs-helpers.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import type { DMMF } from '../prisma';
import { checkModelHasModelRelation } from './model-helpers';

export function addMissingInputObjectTypesForModelArgs(inputObjectTypes: DMMF.InputType[], models: DMMF.Model[]) {
export function addMissingInputObjectTypesForModelArgs(
inputObjectTypes: DMMF.InputType[],
models: readonly DMMF.Model[]
) {
const modelArgsInputObjectTypes = generateModelArgsInputObjectTypes(models);

for (const modelArgsInputObjectType of modelArgsInputObjectTypes) {
inputObjectTypes.push(modelArgsInputObjectType);
}
}
function generateModelArgsInputObjectTypes(models: DMMF.Model[]) {
function generateModelArgsInputObjectTypes(models: readonly DMMF.Model[]) {
const modelArgsInputObjectTypes: DMMF.InputType[] = [];
for (const model of models) {
const { name: modelName } = model;
Expand Down
21 changes: 10 additions & 11 deletions packages/sdk/src/dmmf-helpers/select-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { checkIsModelRelationField, checkModelHasManyModelRelation } from './mod
export function addMissingInputObjectTypesForSelect(
inputObjectTypes: DMMF.InputType[],
outputObjectTypes: DMMF.OutputType[],
models: DMMF.Model[]
models: readonly DMMF.Model[]
) {
// generate input object types necessary to support ModelSelect._count
const modelCountOutputTypes = getModelCountOutputTypes(outputObjectTypes);
Expand Down Expand Up @@ -89,34 +89,33 @@ function generateModelCountOutputTypeArgsInputObjectTypes(modelCountOutputTypes:
return modelCountOutputTypeArgsInputObjectTypes;
}

function generateModelSelectInputObjectTypes(models: DMMF.Model[]) {
function generateModelSelectInputObjectTypes(models: readonly DMMF.Model[]) {
const modelSelectInputObjectTypes: DMMF.InputType[] = [];
for (const model of models) {
const { name: modelName, fields: modelFields } = model;
const fields: DMMF.SchemaArg[] = [];

for (const modelField of modelFields) {
const { name: modelFieldName, isList, type } = modelField;
const inputTypes = [{ isList: false, type: 'Boolean', location: 'scalar' }];

const isRelationField = checkIsModelRelationField(modelField);

const field: DMMF.SchemaArg = {
name: modelFieldName,
isRequired: false,
isNullable: false,
inputTypes: [{ isList: false, type: 'Boolean', location: 'scalar' }],
};

if (isRelationField) {
const schemaArgInputType: DMMF.InputTypeRef = {
isList: false,
type: isList ? `${type}FindManyArgs` : `${type}Args`,
location: 'inputObjectTypes',
namespace: 'prisma',
};
field.inputTypes.push(schemaArgInputType);
inputTypes.push(schemaArgInputType);
}

const field: DMMF.SchemaArg = {
name: modelFieldName,
isRequired: false,
isNullable: false,
inputTypes: inputTypes as DMMF.InputTypeRef[],
};
fields.push(field);
}

Expand Down
Loading

0 comments on commit 296ca25

Please sign in to comment.