Skip to content

Commit

Permalink
Add vikingdb as new vector provider
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoice committed Oct 12, 2024
1 parent d97d3ff commit 06f4405
Show file tree
Hide file tree
Showing 15 changed files with 634 additions and 3 deletions.
11 changes: 10 additions & 1 deletion api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ SUPABASE_URL=your-server-url
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*

# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector
# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, vikingdb
VECTOR_STORE=weaviate

# Weaviate configuration
Expand Down Expand Up @@ -220,6 +220,15 @@ BAIDU_VECTOR_DB_DATABASE=dify
BAIDU_VECTOR_DB_SHARD=1
BAIDU_VECTOR_DB_REPLICAS=3

# ViKingDB configuration
VIKINGDB_ACCESS_KEY=your-ak
VIKINGDB_SECRET_KEY=your-sk
VIKINGDB_REGION=cn-shanghai
VIKINGDB_HOST=api-vikingdb.xxx.volces.com
VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30

# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
Expand Down
2 changes: 2 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from configs.middleware.vdb.relyt_config import RelytConfig
from configs.middleware.vdb.tencent_vector_config import TencentVectorDBConfig
from configs.middleware.vdb.tidb_vector_config import TiDBVectorConfig
from configs.middleware.vdb.vikingdb_config import VikingDBConfig
from configs.middleware.vdb.weaviate_config import WeaviateConfig


Expand Down Expand Up @@ -243,5 +244,6 @@ class MiddlewareConfig(
WeaviateConfig,
ElasticsearchConfig,
InternalTestConfig,
VikingDBConfig,
):
pass
37 changes: 37 additions & 0 deletions api/configs/middleware/vdb/vikingdb_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional

from pydantic import BaseModel, Field


class VikingDBConfig(BaseModel):
"""
Configuration for connecting to Volcengine VikingDB.
Refer to the following documentation for details on obtaining credentials:
https://www.volcengine.com/docs/6291/65568
"""

VIKINGDB_ACCESS_KEY: Optional[str] = Field(
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
)
VIKINGDB_SECRET_KEY: Optional[str] = Field(
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
)
VIKINGDB_REGION: Optional[str] = Field(
default="cn-shanghai",
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
)
VIKINGDB_HOST: Optional[str] = Field(
default="api-vikingdb.mlp.cn-shanghai.volces.com",
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
'api-vikingdb.mlp.cn-shanghai.volces.com')",
)
VIKINGDB_SCHEME: Optional[str] = Field(
default="http",
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
)
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
default=30, description="The connection timeout of the Volcengine VikingDB service."
)
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
default=30, description="The socket timeout of the Volcengine VikingDB service."
)
2 changes: 2 additions & 0 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def get(self):
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down Expand Up @@ -655,6 +656,7 @@ def get(self, vector_type):
| VectorType.TENCENT
| VectorType.PGVECTO_RS
| VectorType.BAIDU
| VectorType.VIKINGDB
):
return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
case (
Expand Down
4 changes: 4 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory

return BaiduVectorFactory
case VectorType.VIKINGDB:
from core.rag.datasource.vdb.vikingdb.vikingdb_vector import VikingDBVectorFactory

return VikingDBVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")

Expand Down
1 change: 1 addition & 0 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ class VectorType(str, Enum):
ORACLE = "oracle"
ELASTICSEARCH = "elasticsearch"
BAIDU = "baidu"
VIKINGDB = "vikingdb"
Empty file.
239 changes: 239 additions & 0 deletions api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import json
from typing import Any

from pydantic import BaseModel
from volcengine.viking_db import (
Data,
DistanceType,
Field,
FieldType,
IndexType,
QuantType,
VectorIndexParams,
VikingDBService,
)

from configs import dify_config
from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.field import Field as vdb_Field
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset


class VikingDBConfig(BaseModel):
access_key: str
secret_key: str
host: str
region: str
scheme: str
connection_timeout: int
socket_timeout: int
index_type: str = IndexType.HNSW
distance: str = DistanceType.L2
quant: str = QuantType.Float


class VikingDBVector(BaseVector):
def __init__(self, collection_name: str, group_id: str, config: VikingDBConfig):
super().__init__(collection_name)
self._group_id = group_id
self._client_config = config
self._index_name = f"{self._collection_name}_idx"
self._client = VikingDBService(
host=config.host,
region=config.region,
scheme=config.scheme,
connection_timeout=config.connection_timeout,
socket_timeout=config.socket_timeout,
ak=config.access_key,
sk=config.secret_key,
)

def _has_collection(self) -> bool:
try:
self._client.get_collection(self._collection_name)
except Exception:
return False
return True

def _has_index(self) -> bool:
try:
self._client.get_index(self._collection_name, self._index_name)
except Exception:
return False
return True

def _create_collection(self, dimension: int):
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return

if not self._has_collection():
fields = [
Field(field_name=vdb_Field.PRIMARY_KEY.value, field_type=FieldType.String, is_primary_key=True),
Field(field_name=vdb_Field.METADATA_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.GROUP_KEY.value, field_type=FieldType.String),
Field(field_name=vdb_Field.CONTENT_KEY.value, field_type=FieldType.Text),
Field(field_name=vdb_Field.VECTOR.value, field_type=FieldType.Vector, dim=dimension),
]

self._client.create_collection(
collection_name=self._collection_name,
fields=fields,
description="Collection For Dify",
)

if not self._has_index():
vector_index = VectorIndexParams(
distance=self._client_config.distance,
index_type=self._client_config.index_type,
quant=self._client_config.quant,
)

self._client.create_index(
collection_name=self._collection_name,
index_name=self._index_name,
vector_index=vector_index,
partition_by=vdb_Field.GROUP_KEY.value,
description="Index For Dify",
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)

def get_type(self) -> str:
return VectorType.VIKINGDB

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
self.add_texts(texts, embeddings, **kwargs)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
page_contents = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
docs = []

for i, page_content in enumerate(page_contents):
metadata = {}
if metadatas is not None:
for key, val in metadatas[i].items():
metadata[key] = val
doc = Data(
{
vdb_Field.PRIMARY_KEY.value: metadatas[i]["doc_id"],
vdb_Field.VECTOR.value: embeddings[i] if embeddings else None,
vdb_Field.CONTENT_KEY.value: page_content,
vdb_Field.METADATA_KEY.value: json.dumps(metadata),
vdb_Field.GROUP_KEY.value: self._group_id,
}
)
docs.append(doc)

self._client.get_collection(self._collection_name).upsert_data(docs)

def text_exists(self, id: str) -> bool:
docs = self._client.get_collection(self._collection_name).fetch_data(id)
not_exists_str = "data does not exist"
if docs is not None and not_exists_str not in docs.fields.get("message", ""):
return True
return False

def delete_by_ids(self, ids: list[str]) -> None:
self._client.get_collection(self._collection_name).delete_data(ids)

def get_ids_by_metadata_field(self, key: str, value: str):
# Note: Metadata field value is an dict, but vikingdb field
# not support json type
results = self._client.get_index(self._collection_name, self._index_name).search(
filter={"op": "must", "field": vdb_Field.GROUP_KEY.value, "conds": [self._group_id]},
# max value is 5000
limit=5000,
)

if not results:
return []

ids = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
if metadata is not None:
metadata = json.loads(metadata)
if metadata.get(key) == value:
ids.append(result.id)
return ids

def delete_by_metadata_field(self, key: str, value: str) -> None:
ids = self.get_ids_by_metadata_field(key, value)
self.delete_by_ids(ids)

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
query_vector, limit=kwargs.get("top_k", 50)
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)

def _get_search_res(self, results, score_threshold):
if len(results) == 0:
return []

docs = []
for result in results:
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
if metadata is not None:
metadata = json.loads(metadata)
if result.score > score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return []

def delete(self) -> None:
if self._has_index():
self._client.drop_index(self._collection_name, self._index_name)
if self._has_collection():
self._client.drop_collection(self._collection_name)


class VikingDBVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VikingDBVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix.lower()
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VIKINGDB, collection_name))

if dify_config.VIKINGDB_ACCESS_KEY is None:
raise ValueError("VIKINGDB_ACCESS_KEY should not be None")
if dify_config.VIKINGDB_SECRET_KEY is None:
raise ValueError("VIKINGDB_SECRET_KEY should not be None")
if dify_config.VIKINGDB_HOST is None:
raise ValueError("VIKINGDB_HOST should not be None")
if dify_config.VIKINGDB_REGION is None:
raise ValueError("VIKINGDB_REGION should not be None")
if dify_config.VIKINGDB_SCHEME is None:
raise ValueError("VIKINGDB_SCHEME should not be None")
return VikingDBVector(
collection_name=collection_name,
group_id=dataset.id,
config=VikingDBConfig(
access_key=dify_config.VIKINGDB_ACCESS_KEY,
secret_key=dify_config.VIKINGDB_SECRET_KEY,
host=dify_config.VIKINGDB_HOST,
region=dify_config.VIKINGDB_REGION,
scheme=dify_config.VIKINGDB_SCHEME,
connection_timeout=dify_config.VIKINGDB_CONNECTION_TIMEOUT,
socket_timeout=dify_config.VIKINGDB_SOCKET_TIMEOUT,
),
)
Loading

0 comments on commit 06f4405

Please sign in to comment.