From 2ec6ffe478984539209725b754631565515da8bd Mon Sep 17 00:00:00 2001 From: Shili Cao Date: Sat, 12 Oct 2024 23:24:17 +0800 Subject: [PATCH] feat:support baidu vector db (#9185) --- api/.env.example | 9 + api/commands.py | 8 + .../middleware/vdb/baidu_vector_config.py | 45 +++ api/controllers/console/datasets/datasets.py | 2 + api/core/rag/datasource/vdb/baidu/__init__.py | 0 .../rag/datasource/vdb/baidu/baidu_vector.py | 272 ++++++++++++++++++ api/core/rag/datasource/vdb/vector_factory.py | 4 + api/core/rag/datasource/vdb/vector_type.py | 1 + api/poetry.lock | 47 ++- api/pyproject.toml | 1 + .../vdb/__mock/baiduvectordb.py | 154 ++++++++++ .../integration_tests/vdb/baidu/__init__.py | 0 .../integration_tests/vdb/baidu/test_baidu.py | 36 +++ docker/.env.example | 9 + docker/docker-compose.yaml | 7 + 15 files changed, 582 insertions(+), 13 deletions(-) create mode 100644 api/configs/middleware/vdb/baidu_vector_config.py create mode 100644 api/core/rag/datasource/vdb/baidu/__init__.py create mode 100644 api/core/rag/datasource/vdb/baidu/baidu_vector.py create mode 100644 api/tests/integration_tests/vdb/__mock/baiduvectordb.py create mode 100644 api/tests/integration_tests/vdb/baidu/__init__.py create mode 100644 api/tests/integration_tests/vdb/baidu/test_baidu.py diff --git a/api/.env.example b/api/.env.example index 7b5e4950c82dd..3f88fb3cdf5b0 100644 --- a/api/.env.example +++ b/api/.env.example @@ -208,6 +208,15 @@ OPENSEARCH_USER=admin OPENSEARCH_PASSWORD=admin OPENSEARCH_SECURE=true +# Baidu configuration +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 diff --git a/api/commands.py b/api/commands.py index 7ef4aed7f7766..dbcd8a744d3a4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -347,6 +347,14 @@ def migrate_knowledge_vector_database(): index_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == VectorType.BAIDU: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": VectorType.BAIDU, + "vector_store": {"class_prefix": collection_name}, + } + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py new file mode 100644 index 0000000000000..44742c2e2f434 --- /dev/null +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -0,0 +1,45 @@ +from typing import Optional + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class BaiduVectorDBConfig(BaseSettings): + """ + Configuration settings for Baidu Vector Database + """ + + BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( + description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", + default=None, + ) + + BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field( + description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)", + default=30000, + ) + + BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( + description="Account for authenticating with the Baidu Vector Database", + default=None, + ) + + BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( + description="API key for authenticating with the Baidu Vector Database service", + default=None, + ) + + BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( + description="Name of the specific Baidu Vector Database to connect to", + default=None, + ) + + BAIDU_VECTOR_DB_SHARD: PositiveInt = Field( + description="Number of shards for the Baidu Vector Database (default is 1)", + default=1, + ) + + BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field( + description="Number of replicas for the Baidu Vector Database (default is 3)", + default=3, + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 9561fd8b70e4b..102089bf071ac 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -617,6 +617,7 @@ def get(self): | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS + | VectorType.BAIDU ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( @@ -653,6 +654,7 @@ def get(self, vector_type): | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS + | VectorType.BAIDU ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( diff --git a/api/core/rag/datasource/vdb/baidu/__init__.py b/api/core/rag/datasource/vdb/baidu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py new file mode 100644 index 0000000000000..543cfa67b3540 --- /dev/null +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -0,0 +1,272 @@ +import json +import time +import uuid +from typing import Any + +from pydantic import BaseModel, model_validator +from pymochow import MochowClient +from pymochow.auth.bce_credentials import BceCredentials +from pymochow.configuration import Configuration +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row + +from configs import dify_config +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + + +class BaiduConfig(BaseModel): + endpoint: str + connection_timeout_in_mills: int = 30 * 1000 + account: str + api_key: str + database: str + index_type: str = "HNSW" + metric_type: str = "L2" + shard: int = 1 + replicas: int = 3 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["endpoint"]: + raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") + if not values["account"]: + raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required") + if not values["api_key"]: + raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required") + if not values["database"]: + raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required") + return values + + +class BaiduVector(BaseVector): + field_id: str = "id" + field_vector: str = "vector" + field_text: str = "text" + field_metadata: str = "metadata" + field_app_id: str = "app_id" + field_annotation_id: str = "annotation_id" + index_vector: str = "vector_idx" + + def __init__(self, collection_name: str, config: BaiduConfig): + super().__init__(collection_name) + self._client_config = config + self._client = self._init_client(config) + self._db = self._init_database() + + def get_type(self) -> str: + return VectorType.BAIDU + + def to_index_struct(self) -> dict: + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self._create_table(len(embeddings[0])) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + total_count = len(documents) + batch_size = 1000 + + # upsert texts and embeddings batch by batch + table = self._db.table(self._collection_name) + for start in range(0, total_count, batch_size): + end = min(start + batch_size, total_count) + rows = [] + for i in range(start, end, 1): + row = Row( + id=metadatas[i].get("doc_id", str(uuid.uuid4())), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadatas[i]), + app_id=metadatas[i].get("app_id", ""), + annotation_id=metadatas[i].get("annotation_id", ""), + ) + rows.append(row) + table.upsert(rows=rows) + + # rebuild vector index after upsert finished + table.rebuild_index(self.index_vector) + while True: + time.sleep(1) + index = table.describe_index(self.index_vector) + if index.state == IndexState.NORMAL: + break + + def text_exists(self, id: str) -> bool: + res = self._db.table(self._collection_name).query(primary_key={self.field_id: id}) + if res and res.code == 0: + return True + return False + + def delete_by_ids(self, ids: list[str]) -> None: + quoted_ids = [f"'{id}'" for id in ids] + self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") + + def delete_by_metadata_field(self, key: str, value: str) -> None: + self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + anns = AnnSearch( + vector_field=self.field_vector, + vector_floats=query_vector, + params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), + ) + res = self._db.table(self._collection_name).search( + anns=anns, + projections=[self.field_id, self.field_text, self.field_metadata], + retrieve_vector=True, + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(res, score_threshold) + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # baidu vector database doesn't support bm25 search on current version + return [] + + def _get_search_res(self, res, score_threshold): + docs = [] + for row in res.rows: + row_data = row.get("row", {}) + meta = row_data.get(self.field_metadata) + if meta is not None: + meta = json.loads(meta) + score = row.get("score", 0.0) + if score > score_threshold: + meta["score"] = score + doc = Document(page_content=row_data.get(self.field_text), metadata=meta) + docs.append(doc) + + return docs + + def delete(self) -> None: + self._db.drop_table(table_name=self._collection_name) + + def _init_client(self, config) -> MochowClient: + config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) + client = MochowClient(config) + return client + + def _init_database(self): + exists = False + for db in self._client.list_databases(): + if db.database_name == self._client_config.database: + exists = True + break + # Create database if not existed + if exists: + return self._client.database(self._client_config.database) + else: + return self._client.create_database(database_name=self._client_config.database) + + def _table_existed(self) -> bool: + tables = self._db.list_table() + return any(table.table_name == self._collection_name for table in tables) + + def _create_table(self, dimension: int) -> None: + # Try to grab distributed lock and create table + lock_name = "vector_indexing_lock_{}".format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(table_exist_cache_key): + return + + if self._table_existed(): + return + + self.delete() + + # check IndexType and MetricType + index_type = None + for k, v in IndexType.__members__.items(): + if k == self._client_config.index_type: + index_type = v + if index_type is None: + raise ValueError("unsupported index_type") + metric_type = None + for k, v in MetricType.__members__.items(): + if k == self._client_config.metric_type: + metric_type = v + if metric_type is None: + raise ValueError("unsupported metric_type") + + # Construct field schema + fields = [] + fields.append( + Field( + self.field_id, + FieldType.STRING, + primary_key=True, + partition_key=True, + auto_increment=False, + not_null=True, + ) + ) + fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True)) + fields.append(Field(self.field_app_id, FieldType.STRING)) + fields.append(Field(self.field_annotation_id, FieldType.STRING)) + fields.append(Field(self.field_text, FieldType.TEXT, not_null=True)) + fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) + + # Construct vector index params + indexes = [] + indexes.append( + VectorIndex( + index_name="vector_idx", + index_type=index_type, + field="vector", + metric_type=metric_type, + params=HNSWParams(m=16, efconstruction=200), + ) + ) + + # Create table + self._db.create_table( + table_name=self._collection_name, + replication=self._client_config.replicas, + partition=Partition(partition_num=self._client_config.shard), + schema=Schema(fields=fields, indexes=indexes), + description="Table for Dify", + ) + + redis_client.set(table_exist_cache_key, 1, ex=3600) + + # Wait for table created + while True: + time.sleep(1) + table = self._db.describe_table(self._collection_name) + if table.state == TableState.NORMAL: + break + + +class BaiduVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name)) + + return BaiduVector( + collection_name=collection_name, + config=BaiduConfig( + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, + database=dify_config.BAIDU_VECTOR_DB_DATABASE, + shard=dify_config.BAIDU_VECTOR_DB_SHARD, + replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, + ), + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 943b23870cc5c..1f4a4d44a23ee 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -103,6 +103,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory return AnalyticdbVectorFactory + case VectorType.BAIDU: + from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory + + return BaiduVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index ba04ea879d9b4..996ff48615c90 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -16,3 +16,4 @@ class VectorType(str, Enum): TENCENT = "tencent" ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" + BAIDU = "baidu" diff --git a/api/poetry.lock b/api/poetry.lock index 52e0f3031151d..6565db27ad572 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -732,7 +732,7 @@ name = "bce-python-sdk" version = "0.9.23" description = "BCE SDK for python" optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4" +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,<4,>=2.7" files = [ {file = "bce_python_sdk-0.9.23-py3-none-any.whl", hash = "sha256:8debe21a040e00060f6044877d594765ed7b18bc765c6bf16b878bca864140a3"}, {file = "bce_python_sdk-0.9.23.tar.gz", hash = "sha256:19739fed5cd0725356fc5ffa2acbdd8fb23f2a81edb91db21a03174551d0cf41"}, @@ -847,7 +847,7 @@ name = "botocore" version = "1.35.38" description = "Low-level, data-driven core of boto 3." optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ {file = "botocore-1.35.38-py3-none-any.whl", hash = "sha256:2eb17d32fa2d3bb5d475132a83564d28e3acc2161534f24b75a54418a1d51359"}, {file = "botocore-1.35.38.tar.gz", hash = "sha256:55d9305c44e5ba29476df456120fa4fb919f03f066afa82f2ae400485e7465f4"}, @@ -1068,7 +1068,7 @@ name = "build" version = "1.2.2.post1" description = "A simple, correct Python build frontend" optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ {file = "build-1.2.2.post1-py3-none-any.whl", hash = "sha256:1d61c0887fa860c01971625baae8bdd338e517b836a2f70dd1f7aa3a6b2fc5b5"}, {file = "build-1.2.2.post1.tar.gz", hash = "sha256:b36993e92ca9375a219c99e606a122ff365a760a2d4bba0caa09bd5278b608b7"}, @@ -3385,7 +3385,7 @@ name = "gotrue" version = "2.9.2" description = "Python Client Library for Supabase Auth" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "gotrue-2.9.2-py3-none-any.whl", hash = "sha256:fcd5279e8f1cc630f3ac35af5485fe39f8030b23906776920d2c32a4e308cff4"}, {file = "gotrue-2.9.2.tar.gz", hash = "sha256:57b3245e916c5efbf19a21b1181011a903c1276bb1df2d847558f2f24f29abb2"}, @@ -4415,7 +4415,7 @@ name = "langfuse" version = "2.51.5" description = "A client library for accessing langfuse" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ {file = "langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb"}, {file = "langfuse-2.51.5.tar.gz", hash = "sha256:55bc37b5c5d3ae133c1a95db09117cfb3117add110ba02ebbf2ce45ac4395c5b"}, @@ -4440,7 +4440,7 @@ name = "langsmith" version = "0.1.134" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ {file = "langsmith-0.1.134-py3-none-any.whl", hash = "sha256:ada98ad80ef38807725f32441a472da3dd28394010877751f48f458d3289da04"}, {file = "langsmith-0.1.134.tar.gz", hash = "sha256:23abee3b508875a0e63c602afafffc02442a19cfd88f9daae05b3e9054fd6b61"}, @@ -6429,7 +6429,7 @@ name = "postgrest" version = "0.17.1" description = "PostgREST client for Python. This library provides an ORM interface to PostgREST." optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "postgrest-0.17.1-py3-none-any.whl", hash = "sha256:ec1d00dc8532fe5ffb342cfc7c4e610a1e0e2272eb14f78f9b2b61094f9be510"}, {file = "postgrest-0.17.1.tar.gz", hash = "sha256:e31d9977dbb80dc5f9fdd4d444014686606692dc4ddb9adc85639e56c6d54c92"}, @@ -7047,6 +7047,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] model = ["milvus-model (>=0.1.0)"] +[[package]] +name = "pymochow" +version = "1.3.1" +description = "Python SDK for mochow" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"}, + {file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"}, +] + +[package.dependencies] +future = "*" +orjson = "*" +requests = "*" + [[package]] name = "pymysql" version = "1.1.1" @@ -7746,7 +7762,7 @@ name = "realtime" version = "2.0.2" description = "" optional = false -python-versions = ">=3.9,<4.0" +python-versions = "<4.0,>=3.9" files = [ {file = "realtime-2.0.2-py3-none-any.whl", hash = "sha256:2634c915bc38807f2013f21e8bcc4d2f79870dfd81460ddb9393883d0489928a"}, {file = "realtime-2.0.2.tar.gz", hash = "sha256:519da9325b3b8102139d51785013d592f6b2403d81fa21d838a0b0234723ed7d"}, @@ -8173,7 +8189,7 @@ name = "s3transfer" version = "0.10.3" description = "An Amazon S3 Transfer Manager" optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" files = [ {file = "s3transfer-0.10.3-py3-none-any.whl", hash = "sha256:263ed587a5803c6c708d3ce44dc4dfedaab4c1a32e8329bab818933d79ddcf5d"}, {file = "s3transfer-0.10.3.tar.gz", hash = "sha256:4f50ed74ab84d474ce614475e0b8d5047ff080810aac5d01ea25231cfc944b0c"}, @@ -8417,6 +8433,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -8836,7 +8857,7 @@ name = "storage3" version = "0.8.1" description = "Supabase Storage client for Python." optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "storage3-0.8.1-py3-none-any.whl", hash = "sha256:0b21205f43eaf0d1dd33bde6c6d0612f88524b7865f017d2ae9827e3f63d9cdc"}, {file = "storage3-0.8.1.tar.gz", hash = "sha256:ea60b68b2221b3868ccc1a7f1294d57d0d9c51642cdc639d8115fe5d0adc8892"}, @@ -8882,7 +8903,7 @@ name = "supabase" version = "2.8.1" description = "Supabase client for Python." optional = false -python-versions = ">=3.9,<4.0" +python-versions = "<4.0,>=3.9" files = [ {file = "supabase-2.8.1-py3-none-any.whl", hash = "sha256:dfa8bef89b54129093521d5bba2136ff765baf67cd76d8ad0aa4984d61a7815c"}, {file = "supabase-2.8.1.tar.gz", hash = "sha256:711c70e6acd9e2ff48ca0dc0b1bb70c01c25378cc5189ec9f5ed9655b30bc41d"}, @@ -8902,7 +8923,7 @@ name = "supafunc" version = "0.6.1" description = "Library for Supabase Functions" optional = false -python-versions = ">=3.8,<4.0" +python-versions = "<4.0,>=3.8" files = [ {file = "supafunc-0.6.1-py3-none-any.whl", hash = "sha256:01aeeeb4bf429977664454a32c86418345140faf6d2e6eb0636d52e4547c5fbb"}, {file = "supafunc-0.6.1.tar.gz", hash = "sha256:3c8761e3999336ccdb7550498a395fd08afc8469382f55ea56f7f640e5a909aa"}, @@ -10615,4 +10636,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "cc10ee218369eb5576d1e5ac8aeeb72e8927bbcb8bd1ac1594167c45aa9d9a21" +content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774" diff --git a/api/pyproject.toml b/api/pyproject.toml index 277d1690c7270..594517771b34f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -242,6 +242,7 @@ oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" pymilvus = "~2.4.4" +pymochow = "1.3.1" qdrant-client = "1.7.3" tcvectordb = "1.3.2" tidb-vector = "0.0.9" diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py new file mode 100644 index 0000000000000..a8eaf42b7de1d --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -0,0 +1,154 @@ +import os + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from pymochow import MochowClient +from pymochow.model.database import Database +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState +from pymochow.model.schema import HNSWParams, VectorIndex +from pymochow.model.table import Table +from requests.adapters import HTTPAdapter + + +class MockBaiduVectorDBClass: + def mock_vector_db_client( + self, + config=None, + adapter: HTTPAdapter = None, + ): + self._conn = None + self._config = None + + def list_databases(self, config=None) -> list[Database]: + return [ + Database( + conn=self._conn, + database_name="dify", + config=self._config, + ) + ] + + def create_database(self, database_name: str, config=None) -> Database: + return Database(conn=self._conn, database_name=database_name, config=config) + + def list_table(self, config=None) -> list[Table]: + return [] + + def drop_table(self, table_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def create_table( + self, + table_name: str, + replication: int, + partition: int, + schema, + enable_dynamic_field=False, + description: str = "", + config=None, + ) -> Table: + return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config) + + def describe_table(self, table_name: str, config=None) -> Table: + return Table( + self, + table_name, + 3, + 1, + None, + enable_dynamic_field=False, + description="table for dify", + config=config, + state=TableState.NORMAL, + ) + + def upsert(self, rows, config=None): + return {"code": 0, "msg": "operation success", "affectedCount": 1} + + def rebuild_index(self, index_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def describe_index(self, index_name: str, config=None): + return VectorIndex( + index_name=index_name, + index_type=IndexType.HNSW, + field="vector", + metric_type=MetricType.L2, + params=HNSWParams(m=16, efconstruction=200), + auto_build=False, + state=IndexState.NORMAL, + ) + + def query( + self, + primary_key, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": {"doc_id": "doc_id_001"}, + }, + "code": 0, + "msg": "Success", + } + + def delete(self, primary_key=None, partition_key=None, filter=None, config=None): + return {"code": 0, "msg": "Success"} + + def search( + self, + anns, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return { + "rows": [ + { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": {"doc_id": "doc_id_001"}, + }, + "distance": 0.1, + "score": 0.5, + } + ], + "code": 0, + "msg": "Success", + } + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client) + monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases) + monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database) + monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table) + monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table) + monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table) + monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index) + monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index) + monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete) + monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/baidu/__init__.py b/api/tests/integration_tests/vdb/baidu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/tests/integration_tests/vdb/baidu/test_baidu.py new file mode 100644 index 0000000000000..01a7f8853ac36 --- /dev/null +++ b/api/tests/integration_tests/vdb/baidu/test_baidu.py @@ -0,0 +1,36 @@ +from unittest.mock import MagicMock + +from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector +from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + +mock_client = MagicMock() +mock_client.list_databases.return_value = [{"name": "test"}] + + +class BaiduVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = BaiduVector( + "dify", + BaiduConfig( + endpoint="http://127.0.0.1:5287", + account="root", + api_key="dify", + database="dify", + shard=1, + replicas=3, + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock): + BaiduVectorTest().run_all_tests() diff --git a/docker/.env.example b/docker/.env.example index 87d7709a1830a..c4eae46cb0215 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -462,6 +462,15 @@ ELASTICSEARCH_PORT=9200 ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=elastic +# baidu vector configurations, only available when VECTOR_STORE is `baidu` +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + # ------------------------------ # Knowledge Configuration # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 62d798a695967..c046c17ef8f2b 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -165,6 +165,13 @@ x-shared-env: &shared-api-worker-env TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify} TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1} TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} + BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} + BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} + BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} + BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify} + BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1} + BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify}