Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:support baidu vector db #9185

Merged
merged 6 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true

# Baidu configuration
BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000
BAIDU_VECTOR_DB_ACCOUNT=root
BAIDU_VECTOR_DB_API_KEY=dify
BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3

# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
Expand Down
8 changes: 8 additions & 0 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,14 @@ def migrate_knowledge_vector_database():
index_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}}
dataset.index_struct = json.dumps(index_struct_dict)
elif vector_type == VectorType.BAIDU:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
index_struct_dict = {
"type": VectorType.BAIDU,
"vector_store": {"class_prefix": collection_name},
}
dataset.index_struct = json.dumps(index_struct_dict)
else:
raise ValueError(f"Vector store {vector_type} is not supported.")

Expand Down
45 changes: 45 additions & 0 deletions api/configs/middleware/vdb/baidu_vector_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional

from pydantic import Field, NonNegativeInt, PositiveInt
from pydantic_settings import BaseSettings


class BaiduVectorDBConfig(BaseSettings):
"""
Configuration settings for Baidu Vector Database
"""

BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field(
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')",
default=None,
)

BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field(
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)",
default=30,
WayneCao marked this conversation as resolved.
Show resolved Hide resolved
)

BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field(
description="Account for authenticating with the Baidu Vector Database",
default=None,
)

BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field(
description="API key for authenticating with the Baidu Vector Database service",
default=None,
)

BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field(
description="Name of the specific Baidu Vector Database to connect to",
default=None,
)

BAIDU_VECTOR_DB_SHARD: PositiveInt = Field(
description="Number of shards for the Baidu Vector Database (default is 1)",
default=1,
)

BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field(
description="Number of replicas for the Baidu Vector Database (default is 3)",
default=3,
)
2 changes: 2 additions & 0 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def get(self):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down Expand Up @@ -653,6 +654,7 @@ def get(self, vector_type):
| VectorType.CHROMA
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down
Empty file.
272 changes: 272 additions & 0 deletions api/core/rag/datasource/vdb/baidu/baidu_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
import json
import time
import uuid
from typing import Any

from pydantic import BaseModel, model_validator
from pymochow import MochowClient
from pymochow.auth.bce_credentials import BceCredentials
from pymochow.configuration import Configuration
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row

from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset


class BaiduConfig(BaseModel):
endpoint: str
connection_timeout_in_mills: int = 30 * 1000
account: str
api_key: str
database: str
index_type: str = "HNSW"
metric_type: str = "L2"
shard: int = 1
replicas: int = 3

@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["endpoint"]:
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required")
if not values["account"]:
raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required")
if not values["api_key"]:
raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required")
if not values["database"]:
raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required")
return values


class BaiduVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
field_text: str = "text"
field_metadata: str = "metadata"
field_app_id: str = "app_id"
field_annotation_id: str = "annotation_id"
index_vector: str = "vector_idx"

def __init__(self, collection_name: str, config: BaiduConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._db = self._init_database()

def get_type(self) -> str:
return VectorType.BAIDU

def to_index_struct(self) -> dict:
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}}

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
self._create_table(len(embeddings[0]))
self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
total_count = len(documents)
batch_size = 1000

# upsert texts and embeddings batch by batch
table = self._db.table(self._collection_name)
for start in range(0, total_count, batch_size):
end = min(start + batch_size, total_count)
rows = []
for i in range(start, end, 1):
row = Row(
id=metadatas[i].get("doc_id", str(uuid.uuid4())),
vector=embeddings[i],
text=texts[i],
metadata=json.dumps(metadatas[i]),
app_id=metadatas[i].get("app_id", ""),
annotation_id=metadatas[i].get("annotation_id", ""),
)
rows.append(row)
table.upsert(rows=rows)

# rebuild vector index after upsert finished
table.rebuild_index(self.index_vector)
while True:
time.sleep(1)
index = table.describe_index(self.index_vector)
if index.state == IndexState.NORMAL:
break

def text_exists(self, id: str) -> bool:
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id})
if res and res.code == 0:
return True
return False

def delete_by_ids(self, ids: list[str]) -> None:
quoted_ids = [f"'{id}'" for id in ids]
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})")

def delete_by_metadata_field(self, key: str, value: str) -> None:
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'")

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
anns = AnnSearch(
vector_field=self.field_vector,
vector_floats=query_vector,
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
)
res = self._db.table(self._collection_name).search(
anns=anns,
projections=[self.field_id, self.field_text, self.field_metadata],
retrieve_vector=True,
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(res, score_threshold)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# baidu vector database doesn't support bm25 search on current version
return []

def _get_search_res(self, res, score_threshold):
docs = []
for row in res.rows:
row_data = row.get("row", {})
meta = row_data.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)

return docs

def delete(self) -> None:
self._db.drop_table(table_name=self._collection_name)

def _init_client(self, config) -> MochowClient:
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint)
client = MochowClient(config)
return client

def _init_database(self):
exists = False
for db in self._client.list_databases():
if db.database_name == self._client_config.database:
exists = True
break
# Create database if not existed
if exists:
return self._client.database(self._client_config.database)
else:
return self._client.create_database(database_name=self._client_config.database)

def _table_existed(self) -> bool:
tables = self._db.list_table()
return any(table.table_name == self._collection_name for table in tables)

def _create_table(self, dimension: int) -> None:
# Try to grab distributed lock and create table
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(table_exist_cache_key):
return

if self._table_existed():
return

self.delete()

# check IndexType and MetricType
index_type = None
for k, v in IndexType.__members__.items():
if k == self._client_config.index_type:
index_type = v
if index_type is None:
raise ValueError("unsupported index_type")
metric_type = None
for k, v in MetricType.__members__.items():
if k == self._client_config.metric_type:
metric_type = v
if metric_type is None:
raise ValueError("unsupported metric_type")

# Construct field schema
fields = []
fields.append(
Field(
self.field_id,
FieldType.STRING,
primary_key=True,
partition_key=True,
auto_increment=False,
not_null=True,
)
)
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True))
fields.append(Field(self.field_app_id, FieldType.STRING))
fields.append(Field(self.field_annotation_id, FieldType.STRING))
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True))
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension))

# Construct vector index params
indexes = []
indexes.append(
VectorIndex(
index_name="vector_idx",
index_type=index_type,
field="vector",
metric_type=metric_type,
params=HNSWParams(m=16, efconstruction=200),
)
)

# Create table
self._db.create_table(
table_name=self._collection_name,
replication=self._client_config.replicas,
partition=Partition(partition_num=self._client_config.shard),
schema=Schema(fields=fields, indexes=indexes),
description="Table for Dify",
)

redis_client.set(table_exist_cache_key, 1, ex=3600)

# Wait for table created
while True:
time.sleep(1)
table = self._db.describe_table(self._collection_name)
if table.state == TableState.NORMAL:
break


class BaiduVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name))

return BaiduVector(
collection_name=collection_name,
config=BaiduConfig(
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT,
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS,
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT,
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY,
database=dify_config.BAIDU_VECTOR_DB_DATABASE,
shard=dify_config.BAIDU_VECTOR_DB_SHARD,
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS,
),
)
4 changes: 4 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory

return AnalyticdbVectorFactory
case VectorType.BAIDU:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory

return BaiduVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

Expand Down
1 change: 1 addition & 0 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ class VectorType(str, Enum):
TENCENT = "tencent"
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
BAIDU = "baidu"
Loading