From 3140d9bee91171665a8f1f69b8939a38643f9cb1 Mon Sep 17 00:00:00 2001 From: Yiming Date: Sat, 13 Jul 2024 23:33:43 -0700 Subject: [PATCH] fix: incorrect prisma query executed when count using a where filter involving a polymorphic base field (#1586) --- packages/runtime/src/enhancements/delegate.ts | 64 ++++------------ .../with-delegate/enhanced-client.test.ts | 76 +++++++++++++++++++ tests/regression/tests/issue-1585.test.ts | 30 ++++++++ 3 files changed, 121 insertions(+), 49 deletions(-) create mode 100644 tests/regression/tests/issue-1585.test.ts diff --git a/packages/runtime/src/enhancements/delegate.ts b/packages/runtime/src/enhancements/delegate.ts index fd8ad633a..d3dd5b83c 100644 --- a/packages/runtime/src/enhancements/delegate.ts +++ b/packages/runtime/src/enhancements/delegate.ts @@ -7,13 +7,13 @@ import { FieldInfo, ModelInfo, NestedWriteVisitor, + clone, enumerate, getIdFields, getModelInfo, isDelegateModel, resolveField, } from '../cross'; -import { clone } from '../cross'; import type { CrudContract, DbClientContract } from '../types'; import type { InternalEnhancementOptions } from './create-enhancement'; import { Logger } from './logger'; @@ -79,7 +79,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { if (args.orderBy) { // `orderBy` may contain fields from base types - args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy); + this.injectWhereHierarchy(this.model, args.orderBy); } if (this.options.logPrismaQuery) { @@ -95,7 +95,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { } private injectWhereHierarchy(model: string, where: any) { - if (!where || typeof where !== 'object') { + if (!where || !isPlainObject(where)) { return; } @@ -108,44 +108,9 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { const fieldInfo = resolveField(this.options.modelMeta, model, field); if (!fieldInfo?.inheritedFrom) { - return; - } - - let base = this.getBaseModel(model); - let target = where; - - while (base) { - const baseRelationName = this.makeAuxRelationName(base); - - // prepare base layer where - let thisLayer: any; - if (target[baseRelationName]) { - thisLayer = target[baseRelationName]; - } else { - thisLayer = target[baseRelationName] = {}; - } - - if (base.name === fieldInfo.inheritedFrom) { - thisLayer[field] = value; - delete where[field]; - break; - } else { - target = thisLayer; - base = this.getBaseModel(base.name); + if (fieldInfo?.isDataModel) { + this.injectWhereHierarchy(fieldInfo.type, value); } - } - }); - } - - private buildWhereHierarchy(model: string, where: any) { - if (!where) { - return undefined; - } - - where = clone(where); - Object.entries(where).forEach(([field, value]) => { - const fieldInfo = resolveField(this.options.modelMeta, model, field); - if (!fieldInfo?.inheritedFrom) { return; } @@ -164,6 +129,9 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { } if (base.name === fieldInfo.inheritedFrom) { + if (fieldInfo.isDataModel) { + this.injectWhereHierarchy(base.name, value); + } thisLayer[field] = value; delete where[field]; break; @@ -173,8 +141,6 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { } } }); - - return where; } private injectSelectIncludeHierarchy(model: string, args: any) { @@ -189,7 +155,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { if (fieldInfo && value !== undefined) { if (value?.orderBy) { // `orderBy` may contain fields from base types - value.orderBy = this.buildWhereHierarchy(fieldInfo.type, value.orderBy); + this.injectWhereHierarchy(fieldInfo.type, value.orderBy); } if (this.injectBaseFieldSelect(model, field, value, args, kind)) { @@ -921,15 +887,15 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { args = clone(args); if (args.cursor) { - args.cursor = this.buildWhereHierarchy(this.model, args.cursor); + this.injectWhereHierarchy(this.model, args.cursor); } if (args.orderBy) { - args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy); + this.injectWhereHierarchy(this.model, args.orderBy); } if (args.where) { - args.where = this.buildWhereHierarchy(this.model, args.where); + this.injectWhereHierarchy(this.model, args.where); } if (this.options.logPrismaQuery) { @@ -949,11 +915,11 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { args = clone(args); if (args?.cursor) { - args.cursor = this.buildWhereHierarchy(this.model, args.cursor); + this.injectWhereHierarchy(this.model, args.cursor); } if (args?.where) { - args.where = this.buildWhereHierarchy(this.model, args.where); + this.injectWhereHierarchy(this.model, args.where); } if (this.options.logPrismaQuery) { @@ -989,7 +955,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { args = clone(args); if (args.where) { - args.where = this.buildWhereHierarchy(this.model, args.where); + this.injectWhereHierarchy(this.model, args.where); } if (this.options.logPrismaQuery) { diff --git a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts index ea9b8efca..1a8f996df 100644 --- a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts @@ -284,6 +284,70 @@ describe('Polymorphism Test', () => { }); }); + it('read with compound filter', async () => { + const { enhance } = await loadSchema( + ` + model Base { + id Int @id @default(autoincrement()) + type String + viewCount Int + @@delegate(type) + } + + model Foo extends Base { + name String + } + `, + { enhancements: ['delegate'] } + ); + + const db = enhance(); + await db.foo.create({ data: { name: 'foo1', viewCount: 0 } }); + await db.foo.create({ data: { name: 'foo2', viewCount: 1 } }); + + await expect(db.foo.findMany({ where: { viewCount: { gt: 0 } } })).resolves.toHaveLength(1); + await expect(db.foo.findMany({ where: { AND: { viewCount: { gt: 0 } } } })).resolves.toHaveLength(1); + await expect(db.foo.findMany({ where: { AND: [{ viewCount: { gt: 0 } }] } })).resolves.toHaveLength(1); + await expect(db.foo.findMany({ where: { OR: [{ viewCount: { gt: 0 } }] } })).resolves.toHaveLength(1); + await expect(db.foo.findMany({ where: { NOT: { viewCount: { lte: 0 } } } })).resolves.toHaveLength(1); + }); + + it('read with nested filter', async () => { + const { enhance } = await loadSchema( + ` + model Base { + id Int @id @default(autoincrement()) + type String + viewCount Int + @@delegate(type) + } + + model Foo extends Base { + name String + bar Bar? + } + + model Bar extends Base { + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int @unique + } + `, + { enhancements: ['delegate'] } + ); + + const db = enhance(); + + await db.bar.create({ + data: { foo: { create: { name: 'foo', viewCount: 2 } }, viewCount: 1 }, + }); + + await expect( + db.bar.findMany({ + where: { viewCount: { gt: 0 }, foo: { viewCount: { gt: 1 } } }, + }) + ).resolves.toHaveLength(1); + }); + it('order by base fields', async () => { const { db, user } = await setup(); @@ -1013,6 +1077,18 @@ describe('Polymorphism Test', () => { }); expect(count).toMatchObject({ _all: 1, rating: 1 }); + count = await db.ratedVideo.count({ + select: { _all: true, rating: true }, + where: { AND: { viewCount: { gt: 0 }, rating: { gt: 10 } } }, + }); + expect(count).toMatchObject({ _all: 1, rating: 1 }); + + count = await db.ratedVideo.count({ + select: { _all: true, rating: true }, + where: { AND: [{ viewCount: { gt: 0 }, rating: { gt: 10 } }] }, + }); + expect(count).toMatchObject({ _all: 1, rating: 1 }); + expect(() => db.ratedVideo.count({ select: { rating: true, viewCount: true } })).toThrow( 'count with fields from base type is not supported yet' ); diff --git a/tests/regression/tests/issue-1585.test.ts b/tests/regression/tests/issue-1585.test.ts new file mode 100644 index 000000000..49ec333d4 --- /dev/null +++ b/tests/regression/tests/issue-1585.test.ts @@ -0,0 +1,30 @@ +import { loadSchema } from '@zenstackhq/testtools'; +describe('issue 1585', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` + model Asset { + id Int @id @default(autoincrement()) + type String + views Int + + @@allow('all', true) + @@delegate(type) + } + + model Post extends Asset { + title String + } + ` + ); + + const db = enhance(); + await db.post.create({ data: { title: 'Post1', views: 0 } }); + await db.post.create({ data: { title: 'Post2', views: 1 } }); + await expect( + db.post.count({ + where: { views: { gt: 0 } }, + }) + ).resolves.toBe(1); + }); +});