Skip to content

Commit

Permalink
fix: incorrect prisma query executed when count using a where filter …
Browse files Browse the repository at this point in the history
…involving a polymorphic base field (#1586)
  • Loading branch information
ymc9 authored Jul 14, 2024
1 parent a11ab8c commit 3140d9b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 49 deletions.
64 changes: 15 additions & 49 deletions packages/runtime/src/enhancements/delegate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ import {
FieldInfo,
ModelInfo,
NestedWriteVisitor,
clone,
enumerate,
getIdFields,
getModelInfo,
isDelegateModel,
resolveField,
} from '../cross';
import { clone } from '../cross';
import type { CrudContract, DbClientContract } from '../types';
import type { InternalEnhancementOptions } from './create-enhancement';
import { Logger } from './logger';
Expand Down Expand Up @@ -79,7 +79,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

if (args.orderBy) {
// `orderBy` may contain fields from base types
args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy);
this.injectWhereHierarchy(this.model, args.orderBy);
}

if (this.options.logPrismaQuery) {
Expand All @@ -95,7 +95,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}

private injectWhereHierarchy(model: string, where: any) {
if (!where || typeof where !== 'object') {
if (!where || !isPlainObject(where)) {
return;
}

Expand All @@ -108,44 +108,9 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {

const fieldInfo = resolveField(this.options.modelMeta, model, field);
if (!fieldInfo?.inheritedFrom) {
return;
}

let base = this.getBaseModel(model);
let target = where;

while (base) {
const baseRelationName = this.makeAuxRelationName(base);

// prepare base layer where
let thisLayer: any;
if (target[baseRelationName]) {
thisLayer = target[baseRelationName];
} else {
thisLayer = target[baseRelationName] = {};
}

if (base.name === fieldInfo.inheritedFrom) {
thisLayer[field] = value;
delete where[field];
break;
} else {
target = thisLayer;
base = this.getBaseModel(base.name);
if (fieldInfo?.isDataModel) {
this.injectWhereHierarchy(fieldInfo.type, value);
}
}
});
}

private buildWhereHierarchy(model: string, where: any) {
if (!where) {
return undefined;
}

where = clone(where);
Object.entries(where).forEach(([field, value]) => {
const fieldInfo = resolveField(this.options.modelMeta, model, field);
if (!fieldInfo?.inheritedFrom) {
return;
}

Expand All @@ -164,6 +129,9 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}

if (base.name === fieldInfo.inheritedFrom) {
if (fieldInfo.isDataModel) {
this.injectWhereHierarchy(base.name, value);
}
thisLayer[field] = value;
delete where[field];
break;
Expand All @@ -173,8 +141,6 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
}
}
});

return where;
}

private injectSelectIncludeHierarchy(model: string, args: any) {
Expand All @@ -189,7 +155,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
if (fieldInfo && value !== undefined) {
if (value?.orderBy) {
// `orderBy` may contain fields from base types
value.orderBy = this.buildWhereHierarchy(fieldInfo.type, value.orderBy);
this.injectWhereHierarchy(fieldInfo.type, value.orderBy);
}

if (this.injectBaseFieldSelect(model, field, value, args, kind)) {
Expand Down Expand Up @@ -921,15 +887,15 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

if (args.cursor) {
args.cursor = this.buildWhereHierarchy(this.model, args.cursor);
this.injectWhereHierarchy(this.model, args.cursor);
}

if (args.orderBy) {
args.orderBy = this.buildWhereHierarchy(this.model, args.orderBy);
this.injectWhereHierarchy(this.model, args.orderBy);
}

if (args.where) {
args.where = this.buildWhereHierarchy(this.model, args.where);
this.injectWhereHierarchy(this.model, args.where);
}

if (this.options.logPrismaQuery) {
Expand All @@ -949,11 +915,11 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

if (args?.cursor) {
args.cursor = this.buildWhereHierarchy(this.model, args.cursor);
this.injectWhereHierarchy(this.model, args.cursor);
}

if (args?.where) {
args.where = this.buildWhereHierarchy(this.model, args.where);
this.injectWhereHierarchy(this.model, args.where);
}

if (this.options.logPrismaQuery) {
Expand Down Expand Up @@ -989,7 +955,7 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler {
args = clone(args);

if (args.where) {
args.where = this.buildWhereHierarchy(this.model, args.where);
this.injectWhereHierarchy(this.model, args.where);
}

if (this.options.logPrismaQuery) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,70 @@ describe('Polymorphism Test', () => {
});
});

it('read with compound filter', async () => {
const { enhance } = await loadSchema(
`
model Base {
id Int @id @default(autoincrement())
type String
viewCount Int
@@delegate(type)
}
model Foo extends Base {
name String
}
`,
{ enhancements: ['delegate'] }
);

const db = enhance();
await db.foo.create({ data: { name: 'foo1', viewCount: 0 } });
await db.foo.create({ data: { name: 'foo2', viewCount: 1 } });

await expect(db.foo.findMany({ where: { viewCount: { gt: 0 } } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { AND: { viewCount: { gt: 0 } } } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { AND: [{ viewCount: { gt: 0 } }] } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { OR: [{ viewCount: { gt: 0 } }] } })).resolves.toHaveLength(1);
await expect(db.foo.findMany({ where: { NOT: { viewCount: { lte: 0 } } } })).resolves.toHaveLength(1);
});

it('read with nested filter', async () => {
const { enhance } = await loadSchema(
`
model Base {
id Int @id @default(autoincrement())
type String
viewCount Int
@@delegate(type)
}
model Foo extends Base {
name String
bar Bar?
}
model Bar extends Base {
foo Foo @relation(fields: [fooId], references: [id])
fooId Int @unique
}
`,
{ enhancements: ['delegate'] }
);

const db = enhance();

await db.bar.create({
data: { foo: { create: { name: 'foo', viewCount: 2 } }, viewCount: 1 },
});

await expect(
db.bar.findMany({
where: { viewCount: { gt: 0 }, foo: { viewCount: { gt: 1 } } },
})
).resolves.toHaveLength(1);
});

it('order by base fields', async () => {
const { db, user } = await setup();

Expand Down Expand Up @@ -1013,6 +1077,18 @@ describe('Polymorphism Test', () => {
});
expect(count).toMatchObject({ _all: 1, rating: 1 });

count = await db.ratedVideo.count({
select: { _all: true, rating: true },
where: { AND: { viewCount: { gt: 0 }, rating: { gt: 10 } } },
});
expect(count).toMatchObject({ _all: 1, rating: 1 });

count = await db.ratedVideo.count({
select: { _all: true, rating: true },
where: { AND: [{ viewCount: { gt: 0 }, rating: { gt: 10 } }] },
});
expect(count).toMatchObject({ _all: 1, rating: 1 });

expect(() => db.ratedVideo.count({ select: { rating: true, viewCount: true } })).toThrow(
'count with fields from base type is not supported yet'
);
Expand Down
30 changes: 30 additions & 0 deletions tests/regression/tests/issue-1585.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { loadSchema } from '@zenstackhq/testtools';
describe('issue 1585', () => {
it('regression', async () => {
const { enhance } = await loadSchema(
`
model Asset {
id Int @id @default(autoincrement())
type String
views Int
@@allow('all', true)
@@delegate(type)
}
model Post extends Asset {
title String
}
`
);

const db = enhance();
await db.post.create({ data: { title: 'Post1', views: 0 } });
await db.post.create({ data: { title: 'Post2', views: 1 } });
await expect(
db.post.count({
where: { views: { gt: 0 } },
})
).resolves.toBe(1);
});
});

0 comments on commit 3140d9b

Please sign in to comment.