-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
557 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
|
||
class VikingDBConfig(BaseModel): | ||
""" | ||
Configuration for connecting to Volcengine VikingDB. | ||
Refer to the following documentation for details on obtaining credentials: | ||
https://www.volcengine.com/docs/6291/65568 | ||
""" | ||
|
||
VIKINGDB_ACCESS_KEY: Optional[str] = Field( | ||
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication." | ||
) | ||
VIKINGDB_SECRET_KEY: Optional[str] = Field( | ||
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication." | ||
) | ||
VIKINGDB_REGION: Optional[str] = Field( | ||
default="cn-shanghai", | ||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').", | ||
) | ||
VIKINGDB_HOST: Optional[str] = Field( | ||
default="api-vikingdb.mlp.cn-shanghai.volces.com", | ||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \ | ||
'api-vikingdb.mlp.cn-shanghai.volces.com')", | ||
) | ||
VIKINGDB_SCHEME: Optional[str] = Field( | ||
default="http", | ||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').", | ||
) | ||
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field( | ||
default=30, description="The connection timeout of the Volcengine VikingDB service." | ||
) | ||
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field( | ||
default=30, description="The socket timeout of the Volcengine VikingDB service." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
239 changes: 239 additions & 0 deletions
239
api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
import json | ||
from typing import Any | ||
|
||
from pydantic import BaseModel | ||
from volcengine.viking_db import ( | ||
Data, | ||
DistanceType, | ||
Field, | ||
FieldType, | ||
IndexType, | ||
QuantType, | ||
VectorIndexParams, | ||
VikingDBService, | ||
) | ||
|
||
from configs import dify_config | ||
from core.rag.datasource.entity.embedding import Embeddings | ||
from core.rag.datasource.vdb.field import Field as vdb_Field | ||
from core.rag.datasource.vdb.vector_base import BaseVector | ||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | ||
from core.rag.datasource.vdb.vector_type import VectorType | ||
from core.rag.models.document import Document | ||
from extensions.ext_redis import redis_client | ||
from models.dataset import Dataset | ||
|
||
|
||
class VikingDBConfig(BaseModel): | ||
access_key: str | ||
secret_key: str | ||
host: str | ||
region: str | ||
scheme: str | ||
connection_timeout: int | ||
socket_timeout: int | ||
index_type: str = IndexType.HNSW | ||
distance: str = DistanceType.L2 | ||
quant: str = QuantType.Float | ||
|
||
|
||
class VikingDBVector(BaseVector): | ||
def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig): | ||
super().__init__(collection_name) | ||
self._group_id = group_id | ||
self._client_config = config | ||
self._index_name = f"{self._collection_name}_idx" | ||
self._client = VikingDBService( | ||
host=config.host, | ||
region=config.region, | ||
scheme=config.scheme, | ||
connection_timeout=config.connection_timeout, | ||
socket_timeout=config.socket_timeout, | ||
ak=config.access_key, | ||
sk=config.secret_key, | ||
) | ||
|
||
def _has_collection(self) -> bool: | ||
try: | ||
self._client.get_collection(self._collection_name) | ||
except Exception: | ||
return False | ||
return True | ||
|
||
def _has_index(self) -> bool: | ||
try: | ||
self._client.get_index(self._collection_name, self._index_name) | ||
except Exception: | ||
return False | ||
return True | ||
|
||
def _create_collection(self, dimension: int): | ||
lock_name = f"vector_indexing_lock_{self._collection_name}" | ||
with redis_client.lock(lock_name, timeout=20): | ||
collection_exist_cache_key = f"vector_indexing_{self._collection_name}" | ||
if redis_client.get(collection_exist_cache_key): | ||
return | ||
|
||
if not self._has_collection(): | ||
fields = [ | ||
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True), | ||
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String), | ||
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String), | ||
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text), | ||
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension), | ||
] | ||
|
||
self._client.create_collection( | ||
collection_name=self._collection_name, | ||
fields=fields, | ||
description="Collection For Dify", | ||
) | ||
|
||
if not self._has_index(): | ||
vector_index = VectorIndexParams( | ||
distance=self._client_config.distance, | ||
index_type=self._client_config.index_type, | ||
quant=self._client_config.quant, | ||
) | ||
|
||
self._client.create_index( | ||
collection_name=self._collection_name, | ||
index_name=self._index_name, | ||
vector_index=vector_index, | ||
partition_by=vdb_Field.GROUP_KEY.value, | ||
description="Index For Dify", | ||
) | ||
redis_client.set(collection_exist_cache_key, 1, ex=3600) | ||
|
||
def get_type(self) -> str: | ||
return VectorType.VIKINGDB | ||
|
||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||
dimension = len(embeddings[0]) | ||
self._create_collection(dimension) | ||
self.add_texts(texts, embeddings, **kwargs) | ||
|
||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||
page_contents = [doc.page_content for doc in documents] | ||
metadatas = [doc.metadata for doc in documents] | ||
docs = [] | ||
|
||
for i, page_content in enumerate(page_contents): | ||
metadata = {} | ||
if metadatas is not None: | ||
for key, val in metadatas[i].items(): | ||
metadata[key] = val | ||
doc = Data( | ||
{ | ||
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"], | ||
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None, | ||
vdb_Field.CONTENT_KEY.value: page_content, | ||
vdb_Field.METADATA_KEY.value: json.dumps(metadata), | ||
vdb_Field.GROUP_KEY.value: self._group_id, | ||
} | ||
) | ||
docs.append(doc) | ||
|
||
self._client.get_collection(self._collection_name).upsert_data(docs) | ||
|
||
def text_exists(self, id: str) -> bool: | ||
docs = self._client.get_collection(self._collection_name).fetch_data(id) | ||
not_exists_str = "data does not exist" | ||
if docs is not None and not_exists_str not in docs.fields.get("message", ""): | ||
return True | ||
return False | ||
|
||
def delete_by_ids(self, ids: list[str]) -> None: | ||
self._client.get_collection(self._collection_name).delete_data(ids) | ||
|
||
def get_ids_by_metadata_field(self, key: str, value: str): | ||
# Note: Metadata field value is an dict, but vikingdb field | ||
# not support json type | ||
results = self._client.get_index(self._collection_name, self._index_name).search( | ||
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]}, | ||
# max value is 5000 | ||
limit=5000, | ||
) | ||
|
||
if not results: | ||
return [] | ||
|
||
ids = [] | ||
for result in results: | ||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value) | ||
if metadata is not None: | ||
metadata = json.loads(metadata) | ||
if metadata.get(key) == value: | ||
ids.append(result.id) | ||
return ids | ||
|
||
def delete_by_metadata_field(self, key: str, value: str) -> None: | ||
ids = self.get_ids_by_metadata_field(key, value) | ||
self.delete_by_ids(ids) | ||
|
||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||
results = self._client.get_index(self._collection_name, self._index_name).search_by_vector( | ||
query_vector, limit=kwargs.get("top_k", 50) | ||
) | ||
score_threshold = float(kwargs.get("score_threshold") or 0.0) | ||
return self._get_search_res(results, score_threshold) | ||
|
||
def _get_search_res(self, results, score_threshold): | ||
if len(results) == 0: | ||
return [] | ||
|
||
docs = [] | ||
for result in results: | ||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value) | ||
if metadata is not None: | ||
metadata = json.loads(metadata) | ||
if result.score > score_threshold: | ||
metadata["score"] = result.score | ||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata) | ||
docs.append(doc) | ||
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True) | ||
return docs | ||
|
||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||
return [] | ||
|
||
def delete(self) -> None: | ||
if self._has_index(): | ||
self._client.drop_index(self._collection_name, self._index_name) | ||
if self._has_collection(): | ||
self._client.drop_collection(self._collection_name) | ||
|
||
|
||
class VikingDBVectorFactory(AbstractVectorFactory): | ||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector: | ||
if dataset.index_struct_dict: | ||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | ||
collection_name = class_prefix.lower() | ||
else: | ||
dataset_id = dataset.id | ||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | ||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name)) | ||
|
||
if dify_config.VIKINGDB_ACCESS_KEY is None: | ||
raise ValueError("VIKINGDB_ACCESS_KEY should not be None") | ||
if dify_config.VIKINGDB_SECRET_KEY is None: | ||
raise ValueError("VIKINGDB_SECRET_KEY should not be None") | ||
if dify_config.VIKINGDB_HOST is None: | ||
raise ValueError("VIKINGDB_HOST should not be None") | ||
if dify_config.VIKINGDB_REGION is None: | ||
raise ValueError("VIKINGDB_REGION should not be None") | ||
if dify_config.VIKINGDB_SCHEME is None: | ||
raise ValueError("VIKINGDB_SCHEME should not be None") | ||
return VikingDBVector( | ||
collection_name=collection_name, | ||
group_id=dataset.id, | ||
config=VikingDBConfig( | ||
access_key=dify_config.VIKINGDB_ACCESS_KEY, | ||
secret_key=dify_config.VIKINGDB_SECRET_KEY, | ||
host=dify_config.VIKINGDB_HOST, | ||
region=dify_config.VIKINGDB_REGION, | ||
scheme=dify_config.VIKINGDB_SCHEME, | ||
connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT, | ||
socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT, | ||
), | ||
) |
Oops, something went wrong.