From 1ebc973e21daa93f3564cf751e9acc3bb7b7e7c4 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Wed, 17 Jul 2024 15:39:31 +0700 Subject: [PATCH] feat: Support MetadataFilters for Milvus and SimpleVectorStore (#1033) Co-authored-by: Marcus Schiesser --- .changeset/famous-poets-hammer.md | 6 + examples/metadata-filter/milvus.ts | 40 +++ examples/metadata-filter/simple.ts | 61 +++- .../storage/vectorStore/MilvusVectorStore.ts | 64 +++- .../src/storage/vectorStore/PGVectorStore.ts | 5 +- .../vectorStore/PineconeVectorStore.ts | 10 +- .../storage/vectorStore/SimpleVectorStore.ts | 113 ++++-- .../src/storage/vectorStore/types.ts | 36 +- .../src/storage/vectorStore/utils.ts | 23 ++ .../tests/mocks/TestableMilvusVectorStore.ts | 24 ++ .../vectorStores/MilvusVectorStore.test.ts | 333 ++++++++++++++++++ .../vectorStores/SimpleVectorStore.test.ts | 274 ++++++++++++-- 12 files changed, 921 insertions(+), 68 deletions(-) create mode 100644 .changeset/famous-poets-hammer.md create mode 100644 examples/metadata-filter/milvus.ts create mode 100644 packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts create mode 100644 packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts diff --git a/.changeset/famous-poets-hammer.md b/.changeset/famous-poets-hammer.md new file mode 100644 index 0000000000..7d97257dce --- /dev/null +++ b/.changeset/famous-poets-hammer.md @@ -0,0 +1,6 @@ +--- +"llamaindex": patch +"@llamaindex/llamaindex-test": patch +--- + +Add support for Metadata filters diff --git a/examples/metadata-filter/milvus.ts b/examples/metadata-filter/milvus.ts new file mode 100644 index 0000000000..9415bca57f --- /dev/null +++ b/examples/metadata-filter/milvus.ts @@ -0,0 +1,40 @@ +import { MilvusVectorStore, VectorStoreIndex } from "llamaindex"; + +const collectionName = "movie_reviews"; + +async function main() { + try { + const milvus = new MilvusVectorStore({ collection: collectionName }); + const index = await VectorStoreIndex.fromVectorStore(milvus); + const retriever = index.asRetriever({ similarityTopK: 20 }); + + console.log("\n=====\nQuerying the index with filters"); + const queryEngineWithFilters = index.asQueryEngine({ + retriever, + preFilters: { + filters: [ + { + key: "document_id", + value: "./data/movie_reviews.csv_37", + operator: "==", + }, + { + key: "document_id", + value: "./data/movie_reviews.csv_37", + operator: "!=", + }, + ], + condition: "or", + }, + }); + const resultAfterFilter = await queryEngineWithFilters.query({ + query: "Get all movie titles.", + }); + console.log(`Query from ${resultAfterFilter.sourceNodes?.length} nodes`); + console.log(resultAfterFilter.response); + } catch (e) { + console.error(e); + } +} + +void main(); diff --git a/examples/metadata-filter/simple.ts b/examples/metadata-filter/simple.ts index 34226768e3..245e48c3df 100644 --- a/examples/metadata-filter/simple.ts +++ b/examples/metadata-filter/simple.ts @@ -66,29 +66,78 @@ async function main() { console.log("No filter response:", noFilterResponse.toString()); console.log( - "\n=============\nQuerying index with dogId 2. The output always should be red.", + "\n=============\nQuerying index with dogId 2 and private false. The output always should be red.", ); - const queryEngineDogId2 = index.asQueryEngine({ + const queryEngineEQ = index.asQueryEngine({ preFilters: { filters: [ { key: "private", value: "false", - filterType: "ExactMatch", + operator: "==", }, { key: "dogId", value: "3", - filterType: "ExactMatch", + operator: "==", }, ], }, similarityTopK: 3, }); - const response = await queryEngineDogId2.query({ + const responseEQ = await queryEngineEQ.query({ query: "What is the color of the dog?", }); - console.log("Filter with dogId 2 response:", response.toString()); + console.log("Filter with dogId 2 response:", responseEQ.toString()); + + console.log( + "\n=============\nQuerying index with dogId IN (1, 3). The output should be brown and red.", + ); + const queryEngineIN = index.asQueryEngine({ + preFilters: { + filters: [ + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + }, + similarityTopK: 3, + }); + const responseIN = await queryEngineIN.query({ + query: "What is the color of the dog?", + }); + console.log("Filter with dogId IN (1, 3) response:", responseIN.toString()); + + console.log( + "\n=============\nQuerying index with dogId IN (1, 3). The output should be any.", + ); + const queryEngineOR = index.asQueryEngine({ + preFilters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + condition: "or", + }, + similarityTopK: 3, + }); + const responseOR = await queryEngineOR.query({ + query: "What is the color of the dog?", + }); + console.log( + "Filter with dogId with OR operator response:", + responseOR.toString(), + ); } void main(); diff --git a/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts index 42b7df5687..15c3fe6e9f 100644 --- a/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/MilvusVectorStore.ts @@ -11,11 +11,66 @@ import { import { VectorStoreBase, type IEmbedModel, + type MetadataFilters, type VectorStoreNoEmbedModel, type VectorStoreQuery, type VectorStoreQueryResult, } from "./types.js"; -import { metadataDictToNode, nodeToMetadata } from "./utils.js"; +import { + metadataDictToNode, + nodeToMetadata, + parseArrayValue, + parseNumberValue, + parsePrimitiveValue, +} from "./utils.js"; + +function parseScalarFilters(scalarFilters: MetadataFilters): string { + const condition = scalarFilters.condition ?? "and"; + const filters: string[] = []; + + for (const filter of scalarFilters.filters) { + switch (filter.operator) { + case "==": + case "!=": { + filters.push( + `metadata["${filter.key}"] ${filter.operator} "${parsePrimitiveValue(filter.value)}"`, + ); + break; + } + case "in": { + const filterValue = parseArrayValue(filter.value) + .map((v) => `"${v}"`) + .join(", "); + filters.push( + `metadata["${filter.key}"] ${filter.operator} [${filterValue}]`, + ); + break; + } + case "nin": { + // Milvus does not support `nin` operator, so we need to manually check every value + // Expected: not metadata["key"] != "value1" and not metadata["key"] != "value2" + const filterStr = parseArrayValue(filter.value) + .map((v) => `metadata["${filter.key}"] != "${v}"`) + .join(" && "); + filters.push(filterStr); + break; + } + case "<": + case "<=": + case ">": + case ">=": { + filters.push( + `metadata["${filter.key}"] ${filter.operator} ${parseNumberValue(filter.value)}`, + ); + break; + } + default: + throw new Error(`Operator ${filter.operator} is not supported.`); + } + } + + return filters.join(` ${condition} `); +} export class MilvusVectorStore extends VectorStoreBase @@ -183,6 +238,12 @@ export class MilvusVectorStore }); } + public toMilvusFilter(filters?: MetadataFilters): string | undefined { + if (!filters) return undefined; + // TODO: Milvus also support standard filters, we can add it later + return parseScalarFilters(filters); + } + public async query( query: VectorStoreQuery, _options?: any, @@ -193,6 +254,7 @@ export class MilvusVectorStore collection_name: this.collectionName, limit: query.similarityTopK, vector: query.queryEmbedding, + filter: this.toMilvusFilter(query.filters), }); const nodes: BaseNode[] = []; diff --git a/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts index e54c64aa29..468f5966d0 100644 --- a/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/PGVectorStore.ts @@ -272,7 +272,10 @@ export class PGVectorStore query.filters?.filters.forEach((filter, index) => { const paramIndex = params.length + 1; whereClauses.push(`metadata->>'${filter.key}' = $${paramIndex}`); - params.push(filter.value); + // TODO: support filter with other operators + if (!Array.isArray(filter.value)) { + params.push(filter.value); + } }); const where = diff --git a/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts index 81d1efa6cc..50a2a6241d 100644 --- a/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/PineconeVectorStore.ts @@ -1,7 +1,7 @@ import { VectorStoreBase, - type ExactMatchFilter, type IEmbedModel, + type MetadataFilter, type MetadataFilters, type VectorStoreNoEmbedModel, type VectorStoreQuery, @@ -199,8 +199,12 @@ export class PineconeVectorStore } toPineconeFilter(stdFilters?: MetadataFilters) { - return stdFilters?.filters?.reduce((carry: any, item: ExactMatchFilter) => { - carry[item.key] = item.value; + return stdFilters?.filters?.reduce((carry: any, item: MetadataFilter) => { + // Use MetadataFilter with EQ operator to replace ExactMatchFilter + // TODO: support filter with other operators + if (item.operator === "==") { + carry[item.key] = item.value; + } return carry; }, {}); } diff --git a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts index 359c93e3ac..6d0847f3e4 100644 --- a/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts +++ b/packages/llamaindex/src/storage/vectorStore/SimpleVectorStore.ts @@ -8,14 +8,22 @@ import { import { exists } from "../FileSystem.js"; import { DEFAULT_PERSIST_DIR } from "../constants.js"; import { + FilterOperator, VectorStoreBase, VectorStoreQueryMode, type IEmbedModel, + type MetadataFilter, + type MetadataFilters, type VectorStoreNoEmbedModel, type VectorStoreQuery, type VectorStoreQueryResult, } from "./types.js"; -import { nodeToMetadata } from "./utils.js"; +import { + nodeToMetadata, + parseArrayValue, + parseNumberValue, + parsePrimitiveValue, +} from "./utils.js"; const LEARNER_MODES = new Set([ VectorStoreQueryMode.SVM, @@ -25,10 +33,85 @@ const LEARNER_MODES = new Set([ const MMR_MODE = VectorStoreQueryMode.MMR; +type MetadataValue = Record; + +// Mapping of filter operators to metadata filter functions +const OPERATOR_TO_FILTER: { + [key in FilterOperator]: ( + { key, value }: MetadataFilter, + metadata: MetadataValue, + ) => boolean; +} = { + [FilterOperator.EQ]: ({ key, value }, metadata) => { + return parsePrimitiveValue(metadata[key]) === parsePrimitiveValue(value); + }, + [FilterOperator.NE]: ({ key, value }, metadata) => { + return parsePrimitiveValue(metadata[key]) !== parsePrimitiveValue(value); + }, + [FilterOperator.IN]: ({ key, value }, metadata) => { + return parseArrayValue(value).includes(parsePrimitiveValue(metadata[key])); + }, + [FilterOperator.NIN]: ({ key, value }, metadata) => { + return !parseArrayValue(value).includes(parsePrimitiveValue(metadata[key])); + }, + [FilterOperator.ANY]: ({ key, value }, metadata) => { + return parseArrayValue(value).some((v) => + parseArrayValue(metadata[key]).includes(v), + ); + }, + [FilterOperator.ALL]: ({ key, value }, metadata) => { + return parseArrayValue(value).every((v) => + parseArrayValue(metadata[key]).includes(v), + ); + }, + [FilterOperator.TEXT_MATCH]: ({ key, value }, metadata) => { + return parsePrimitiveValue(metadata[key]).includes( + parsePrimitiveValue(value), + ); + }, + [FilterOperator.CONTAINS]: ({ key, value }, metadata) => { + return parseArrayValue(metadata[key]).includes(parsePrimitiveValue(value)); + }, + [FilterOperator.GT]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) > parseNumberValue(value); + }, + [FilterOperator.LT]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) < parseNumberValue(value); + }, + [FilterOperator.GTE]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) >= parseNumberValue(value); + }, + [FilterOperator.LTE]: ({ key, value }, metadata) => { + return parseNumberValue(metadata[key]) <= parseNumberValue(value); + }, +}; + +// Build a filter function based on the metadata and the preFilters +const buildFilterFn = ( + metadata: MetadataValue | undefined, + preFilters: MetadataFilters | undefined, +) => { + if (!preFilters) return true; + if (!metadata) return false; + + const { filters, condition } = preFilters; + const queryCondition = condition || "and"; // default to and + + const itemFilterFn = (filter: MetadataFilter) => { + const metadataLookupFn = OPERATOR_TO_FILTER[filter.operator]; + if (!metadataLookupFn) + throw new Error(`Unsupported operator: ${filter.operator}`); + return metadataLookupFn(filter, metadata); + }; + + if (queryCondition === "and") return filters.every(itemFilterFn); + return filters.some(itemFilterFn); +}; + class SimpleVectorStoreData { embeddingDict: Record = {}; textIdToRefDocId: Record = {}; - metadataDict: Record> = {}; + metadataDict: Record = {}; } export class SimpleVectorStore @@ -103,31 +186,9 @@ export class SimpleVectorStore embeddings: number[][]; }> { const items = Object.entries(this.data.embeddingDict); - - const metadataLookup = { - ExactMatch: ( - metadata: Record, - key: string, - value: string | number, - ) => { - return String(metadata[key]) === value.toString(); // compare as string - }, - }; - const queryFilterFn = (nodeId: string) => { - if (!query.filters) return true; - const filters = query.filters.filters; - for (const filter of filters) { - const { key, value, filterType } = filter; - const metadataLookupFn = metadataLookup[filterType]; - const metadata = this.data.metadataDict[nodeId]; - const isMatch = - metadataLookupFn && - metadata && - metadataLookupFn(metadata, key, value); - if (!isMatch) return false; // TODO: handle condition OR AND - } - return true; + const metadata = this.data.metadataDict[nodeId]; + return buildFilterFn(metadata, query.filters); }; const nodeFilterFn = (nodeId: string) => { diff --git a/packages/llamaindex/src/storage/vectorStore/types.ts b/packages/llamaindex/src/storage/vectorStore/types.ts index 1862631f92..88ffbe9677 100644 --- a/packages/llamaindex/src/storage/vectorStore/types.ts +++ b/packages/llamaindex/src/storage/vectorStore/types.ts @@ -20,19 +20,51 @@ export enum VectorStoreQueryMode { MMR = "mmr", } +/** + * @deprecated Use MetadataFilter with operator EQ instead + */ export interface ExactMatchFilter { filterType: "ExactMatch"; key: string; value: string | number; } +export enum FilterOperator { + EQ = "==", // default operator (string, number) + IN = "in", // In array (string or number) + GT = ">", // greater than (number) + LT = "<", // less than (number) + NE = "!=", // not equal to (string, number) + GTE = ">=", // greater than or equal to (number) + LTE = "<=", // less than or equal to (number) + NIN = "nin", // Not in array (string or number) + ANY = "any", // Contains any (array of strings) + ALL = "all", // Contains all (array of strings) + TEXT_MATCH = "text_match", // full text match (allows you to search for a specific substring, token or phrase within the text field) + CONTAINS = "contains", // metadata array contains value (string or number) +} + +export enum FilterCondition { + AND = "and", + OR = "or", +} + +export type MetadataFilterValue = string | number | string[] | number[]; + +export interface MetadataFilter { + key: string; + value: MetadataFilterValue; + operator: `${FilterOperator}`; // ==, any, all,... +} + export interface MetadataFilters { - filters: ExactMatchFilter[]; + filters: Array; + condition?: `${FilterCondition}`; // and, or } export interface VectorStoreQuerySpec { query: string; - filters: ExactMatchFilter[]; + filters: MetadataFilter[]; topK?: number; } diff --git a/packages/llamaindex/src/storage/vectorStore/utils.ts b/packages/llamaindex/src/storage/vectorStore/utils.ts index 1a00dee274..8bdc394744 100644 --- a/packages/llamaindex/src/storage/vectorStore/utils.ts +++ b/packages/llamaindex/src/storage/vectorStore/utils.ts @@ -1,5 +1,6 @@ import type { BaseNode, Metadata } from "@llamaindex/core/schema"; import { ObjectType, jsonToNode } from "@llamaindex/core/schema"; +import type { MetadataFilterValue } from "./types.js"; const DEFAULT_TEXT_KEY = "text"; @@ -77,3 +78,25 @@ export function metadataDictToNode( return jsonToNode(nodeObj, ObjectType.TEXT); } } + +export const parseNumberValue = (value: MetadataFilterValue): number => { + if (typeof value !== "number") throw new Error("Value must be a number"); + return value; +}; + +export const parsePrimitiveValue = (value: MetadataFilterValue): string => { + if (typeof value !== "number" && typeof value !== "string") { + throw new Error("Value must be a string or number"); + } + return value.toString(); +}; + +export const parseArrayValue = (value: MetadataFilterValue): string[] => { + const isPrimitiveArray = + Array.isArray(value) && + value.every((v) => typeof v === "string" || typeof v === "number"); + if (!isPrimitiveArray) { + throw new Error("Value must be an array of strings or numbers"); + } + return value.map(String); +}; diff --git a/packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts b/packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts new file mode 100644 index 0000000000..f5665c7af9 --- /dev/null +++ b/packages/llamaindex/tests/mocks/TestableMilvusVectorStore.ts @@ -0,0 +1,24 @@ +import type { BaseNode } from "@llamaindex/core/schema"; +import type { MilvusClient } from "@zilliz/milvus2-sdk-node"; +import { MilvusVectorStore } from "llamaindex"; +import { type Mocked } from "vitest"; + +export class TestableMilvusVectorStore extends MilvusVectorStore { + public nodes: BaseNode[] = []; + + private fakeTimeout = (ms: number) => { + return new Promise((resolve) => setTimeout(resolve, ms)); + }; + + public async add(nodes: BaseNode[]): Promise { + this.nodes.push(...nodes); + await this.fakeTimeout(100); + return nodes.map((node) => node.id_); + } + + constructor() { + super({ + milvusClient: {} as Mocked, + }); + } +} diff --git a/packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts b/packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts new file mode 100644 index 0000000000..7c2c5e50a9 --- /dev/null +++ b/packages/llamaindex/tests/vectorStores/MilvusVectorStore.test.ts @@ -0,0 +1,333 @@ +import type { BaseNode } from "@llamaindex/core/schema"; +import { TextNode } from "@llamaindex/core/schema"; +import { + MilvusVectorStore, + VectorStoreQueryMode, + type MetadataFilters, +} from "llamaindex"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { TestableMilvusVectorStore } from "../mocks/TestableMilvusVectorStore.js"; + +type FilterTestCase = { + title: string; + filters?: MetadataFilters; + expected: number; + expectedFilterStr: string | undefined; + mockResultIds: string[]; +}; + +describe("MilvusVectorStore", () => { + let store: MilvusVectorStore; + let nodes: BaseNode[]; + + beforeEach(() => { + store = new TestableMilvusVectorStore(); + nodes = [ + new TextNode({ + id_: "1", + embedding: [0.1, 0.2], + text: "The dog is brown", + metadata: { + name: "Anakin", + dogId: "1", + private: "true", + weight: 1.2, + type: ["husky", "puppy"], + }, + }), + new TextNode({ + id_: "2", + embedding: [0.1, 0.2], + text: "The dog is yellow", + metadata: { + name: "Luke", + dogId: "2", + private: "false", + weight: 2.3, + type: ["puppy"], + }, + }), + new TextNode({ + id_: "3", + embedding: [0.1, 0.2], + text: "The dog is red", + metadata: { + name: "Leia", + dogId: "3", + private: "false", + weight: 3.4, + type: ["husky"], + }, + }), + ]; + }); + + describe("[MilvusVectorStore] manage nodes", () => { + it("able to add nodes to store", async () => { + const ids = await store.add(nodes); + expect(ids).length(3); + }); + }); + + describe("[MilvusVectorStore] filter nodes with supported operators", () => { + const testcases: FilterTestCase[] = [ + { + title: "No filter", + expected: 3, + mockResultIds: ["1", "2", "3"], + expectedFilterStr: undefined, + }, + { + title: "Filter EQ", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + ], + }, + expected: 2, + mockResultIds: ["2", "3"], + expectedFilterStr: 'metadata["private"] == "false"', + }, + { + title: "Filter NE", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "!=", + }, + ], + }, + expected: 1, + mockResultIds: ["1"], + expectedFilterStr: 'metadata["private"] != "false"', + }, + { + title: "Filter GT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">", + }, + ], + }, + expected: 1, + mockResultIds: ["3"], + expectedFilterStr: 'metadata["weight"] > 2.3', + }, + { + title: "Filter GTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">=", + }, + ], + }, + expected: 2, + mockResultIds: ["2", "3"], + expectedFilterStr: 'metadata["weight"] >= 2.3', + }, + { + title: "Filter LT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<", + }, + ], + }, + expected: 1, + mockResultIds: ["1"], + expectedFilterStr: 'metadata["weight"] < 2.3', + }, + { + title: "Filter LTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<=", + }, + ], + }, + expected: 2, + mockResultIds: ["1", "2"], + expectedFilterStr: 'metadata["weight"] <= 2.3', + }, + { + title: "Filter IN", + filters: { + filters: [ + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + }, + expected: 2, + mockResultIds: ["1", "3"], + expectedFilterStr: 'metadata["dogId"] in ["1", "3"]', + }, + { + title: "Filter NIN", + filters: { + filters: [ + { + key: "name", + value: ["Anakin", "Leia"], + operator: "nin", + }, + ], + }, + expected: 1, + mockResultIds: ["2"], + expectedFilterStr: + 'metadata["name"] != "Anakin" && metadata["name"] != "Leia"', + }, + { + title: "Filter OR", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + condition: "or", + }, + expected: 3, + mockResultIds: ["1", "2", "3"], + expectedFilterStr: + 'metadata["private"] == "false" or metadata["dogId"] in ["1", "3"]', + }, + { + title: "Filter AND", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: "10", + operator: "==", + }, + ], + condition: "and", + }, + expected: 0, + mockResultIds: [], + expectedFilterStr: + 'metadata["private"] == "false" and metadata["dogId"] == "10"', + }, + ]; + + testcases.forEach((tc) => { + it(`[${tc.title}] should return ${tc.expected} nodes`, async () => { + expect(store.toMilvusFilter(tc.filters)).toBe(tc.expectedFilterStr); + + vi.spyOn(store, "query").mockResolvedValue({ + ids: tc.mockResultIds, + similarities: [0.1, 0.2, 0.3], + }); + + await store.add(nodes); + const result = await store.query({ + queryEmbedding: [0.1, 0.2], + similarityTopK: 3, + mode: VectorStoreQueryMode.DEFAULT, + filters: tc.filters, + }); + expect(result.ids).length(tc.expected); + }); + }); + }); + + describe("[MilvusVectorStore] filter nodes with unsupported operators", () => { + const testcases: Array< + Omit + > = [ + { + title: "Filter ANY", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "any", + }, + ], + }, + expected: 3, + }, + { + title: "Filter ALL", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "all", + }, + ], + }, + expected: 1, + }, + { + title: "Filter CONTAINS", + filters: { + filters: [ + { + key: "type", + value: "puppy", + operator: "contains", + }, + ], + }, + expected: 2, + }, + { + title: "Filter TEXT_MATCH", + filters: { + filters: [ + { + key: "name", + value: "Luk", + operator: "text_match", + }, + ], + }, + expected: 1, + }, + ]; + + testcases.forEach((tc) => { + it(`[Unsupported Operator] [${tc.title}] should throw error`, async () => { + const errorMsg = `Operator ${tc.filters?.filters[0].operator} is not supported.`; + expect(() => store.toMilvusFilter(tc.filters)).toThrow(errorMsg); + }); + }); + }); +}); diff --git a/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts b/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts index 67d91071aa..11423f0382 100644 --- a/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts +++ b/packages/llamaindex/tests/vectorStores/SimpleVectorStore.test.ts @@ -4,11 +4,19 @@ import { SimpleVectorStore, TextNode, VectorStoreQueryMode, + type Metadata, + type MetadataFilters, } from "llamaindex"; import { beforeEach, describe, expect, it, vi } from "vitest"; vi.mock("@qdrant/js-client-rest"); +type FilterTestCase = { + title: string; + filters?: MetadataFilters; + expected: number; +}; + describe("SimpleVectorStore", () => { let nodes: BaseNode[]; let store: SimpleVectorStore; @@ -19,19 +27,37 @@ describe("SimpleVectorStore", () => { id_: "1", embedding: [0.1, 0.2], text: "The dog is brown", - metadata: { dogId: "1", private: true }, + metadata: { + name: "Anakin", + dogId: "1", + private: "true", + weight: 1.2, + type: ["husky", "puppy"], + }, }), new TextNode({ id_: "2", - embedding: [0.2, 0.3], + embedding: [0.1, 0.2], text: "The dog is yellow", - metadata: { dogId: "2", private: false }, + metadata: { + name: "Luke", + dogId: "2", + private: "false", + weight: 2.3, + type: ["puppy"], + }, }), new TextNode({ id_: "3", - embedding: [0.3, 0.1], + embedding: [0.1, 0.2], text: "The dog is red", - metadata: { dogId: "3", private: false }, + metadata: { + name: "Leia", + dogId: "3", + private: "false", + weight: 3.4, + type: ["husky"], + }, }), ]; store = new SimpleVectorStore({ @@ -39,47 +65,237 @@ describe("SimpleVectorStore", () => { data: { embeddingDict: {}, textIdToRefDocId: {}, - metadataDict: { - // Mocking the metadataDict - "1": { dogId: "1", private: true }, - "2": { dogId: "2", private: false }, - "3": { dogId: "3", private: false }, - }, + metadataDict: nodes.reduce( + (acc, node) => { + acc[node.id_] = node.metadata; + return acc; + }, + {} as Record, + ), }, }); }); - describe("[SimpleVectorStore]", () => { + describe("[SimpleVectorStore] manage nodes", () => { it("able to add nodes to store", async () => { const ids = await store.add(nodes); expect(ids).length(3); }); - it("able to query nodes without filter", async () => { - await store.add(nodes); - const result = await store.query({ - queryEmbedding: [0.1, 0.2], - similarityTopK: 3, - mode: VectorStoreQueryMode.DEFAULT, - }); - expect(result.similarities).length(3); - }); - it("able to query nodes with filter", async () => { - await store.add(nodes); - const result = await store.query({ - queryEmbedding: [0.1, 0.2], - similarityTopK: 3, - mode: VectorStoreQueryMode.DEFAULT, + }); + + describe("[SimpleVectorStore] query nodes", () => { + const testcases: FilterTestCase[] = [ + { + title: "No filter", + expected: 3, + }, + { + title: "Filter EQ", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + ], + }, + expected: 2, + }, + { + title: "Filter NE", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "!=", + }, + ], + }, + expected: 1, + }, + { + title: "Filter GT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">", + }, + ], + }, + expected: 1, + }, + { + title: "Filter GTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: ">=", + }, + ], + }, + expected: 2, + }, + { + title: "Filter LT", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<", + }, + ], + }, + expected: 1, + }, + { + title: "Filter LTE", + filters: { + filters: [ + { + key: "weight", + value: 2.3, + operator: "<=", + }, + ], + }, + expected: 2, + }, + { + title: "Filter IN", + filters: { + filters: [ + { + key: "dogId", + value: ["1", "3"], + operator: "in", + }, + ], + }, + expected: 2, + }, + { + title: "Filter NIN", + filters: { + filters: [ + { + key: "name", + value: ["Anakin", "Leia"], + operator: "nin", + }, + ], + }, + expected: 1, + }, + { + title: "Filter ANY", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "any", + }, + ], + }, + expected: 3, + }, + { + title: "Filter ALL", + filters: { + filters: [ + { + key: "type", + value: ["husky", "puppy"], + operator: "all", + }, + ], + }, + expected: 1, + }, + { + title: "Filter CONTAINS", + filters: { + filters: [ + { + key: "type", + value: "puppy", + operator: "contains", + }, + ], + }, + expected: 2, + }, + { + title: "Filter TEXT_MATCH", + filters: { + filters: [ + { + key: "name", + value: "Luk", + operator: "text_match", + }, + ], + }, + expected: 1, + }, + { + title: "Filter OR", filters: { filters: [ { key: "private", value: "false", - filterType: "ExactMatch", + operator: "==", + }, + { + key: "dogId", + value: ["1", "3"], + operator: "in", }, ], + condition: "or", }, + expected: 3, + }, + { + title: "Filter AND", + filters: { + filters: [ + { + key: "private", + value: "false", + operator: "==", + }, + { + key: "dogId", + value: "10", + operator: "==", + }, + ], + condition: "and", + }, + expected: 0, + }, + ]; + + testcases.forEach((tc) => { + it(`[${tc.title}] should return ${tc.expected} nodes`, async () => { + await store.add(nodes); + const result = await store.query({ + queryEmbedding: [0.1, 0.2], + similarityTopK: 3, + mode: VectorStoreQueryMode.DEFAULT, + filters: tc.filters, + }); + expect(result.ids).length(tc.expected); }); - expect(result.similarities).length(2); }); }); });