From 37030f41f753d243898cd23ff7e5f881aacc0173 Mon Sep 17 00:00:00 2001 From: Wei Ouyang Date: Wed, 27 Nov 2024 07:21:51 -0800 Subject: [PATCH] Use redis as vector search (#719) * Use redis as vector search * Fix publish_to * Fix syntax --- .github/workflows/test.yml | 6 - docs/artifact-manager.md | 119 +++- hypha/VERSION | 2 +- hypha/artifact.py | 137 ++-- hypha/core/store.py | 35 -- hypha/core/workspace.py | 168 ++--- hypha/server.py | 7 - hypha/templates/apps/web-python.index.html | 4 +- .../hypha-core-app/hypha-app-webpython.js | 2 +- hypha/vectors.py | 583 +++++++++++++----- requirements.txt | 1 - setup.py | 1 - tests/__init__.py | 4 - tests/conftest.py | 65 +- tests/test_artifact.py | 17 +- tests/test_service_search.py | 12 +- tests/test_vectors.py | 102 ++- 17 files changed, 733 insertions(+), 532 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 215524fd..37b0bbaf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,12 +31,6 @@ jobs: ports: - 6338:6379 options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5 - qdrant: - image: qdrant/qdrant:latest - ports: - - 6333:6333 - - 6334:6334 - options: --health-cmd "bash -c ':> /dev/tcp/127.0.0.1/6333' || exit 1" --health-interval 10s --health-timeout 5s --health-retries 5 steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/docs/artifact-manager.md b/docs/artifact-manager.md index 857b039c..19fe9682 100644 --- a/docs/artifact-manager.md +++ b/docs/artifact-manager.md @@ -235,7 +235,7 @@ print("Valid dataset committed.") ## API References -### `create(parent_id: str, alias: str, type: str, manifest: dict, permissions: dict=None, config: dict=None, version: str = None, comment: str = None, overwrite: bool = False, publish_to: str = None) -> None` +### `create(parent_id: str, alias: str, type: str, manifest: dict, permissions: dict=None, config: dict=None, version: str = None, comment: str = None, overwrite: bool = False) -> None` Creates a new artifact or collection with the specified manifest. The artifact is staged until committed. For collections, the `collection` field should be an empty list. @@ -251,13 +251,14 @@ Creates a new artifact or collection with the specified manifest. The artifact i - **Id Parts**: You can also use id parts stored in the parent collection's config['id_parts'] to generate an id. For example, if the parent collection has `{"animals": ["dog", "cat", ...], "colors": ["red", "blue", ...]}`, you can use `"{colors}-{animals}"` to generate an id like `red-dog`. - `workspace`: Optional. The workspace id where the artifact will be created. If not set, it will be created in the default workspace. If specified, it should match the workspace in the alias and also the parent_id. - `parent_id`: The id of the parent collection where the artifact will be created. If the artifact is a top-level collection, leave this field empty or set to None. -- `type`: The type of the artifact. Supported values are `collection`, `generic` and any other custom type. By default, it's set to `generic` which contains fields tailored for displaying the artifact as cards on a webpage. +- `type`: The type of the artifact. Supported values are `collection`, `vector-collection`, `generic` and any other custom type. By default, it's set to `generic` which contains fields tailored for displaying the artifact as cards on a webpage. - `manifest`: The manifest of the new artifact. Ensure the manifest follows the required schema if applicable (e.g., for collections). - `config`: Optional. A dictionary containing additional configuration options for the artifact (shared for both staged and committed). For collections, the config can contain the following special fields: - `collection_schema`: Optional. A JSON schema that defines the structure of child artifacts in the collection. This schema is used to validate child artifacts when they are created or edited. If a child artifact does not conform to the schema, the creation or edit operation will fail. - `id_parts`: Optional. A dictionary of id name parts to be used in generating the id for child artifacts. For example: `{"animals": ["dog", "cat", ...], "colors": ["red", "blue", ...]}`. This can be used for creating child artifacts with auto-generated ids based on the id parts. For example, when calling `create`, you can specify the alias as `my-pet-{colors}-{animals}`, and the id will be generated based on the id parts, e.g., `my-pet-red-dog`. - `permissions`: Optional. A dictionary containing user permissions. For example `{"*": "r+"}` gives read and create access to everyone, `{"@": "rw+"}` allows all authenticated users to read/write/create, and `{"user_id_1": "r+"}` grants read and create permissions to a specific user. You can also set permissions for specific operations, such as `{"user_id_1": ["read", "create"]}`. See detailed explanation about permissions below. - `list_fields`: Optional. A list of fields to be collected when calling ``list`` function. By default, it collects all fields in the artifacts. If you want to collect only specific fields, you can set this field to a list of field names, e.g. `["manifest", "download_count"]`. + - `publish_to`: Optional. A string specifying the target platform to publish the artifact. Supported values are `zenodo` and `sandbox_zenodo`. If set, the artifact will be published to the specified platform. The artifact must have a valid Zenodo metadata schema to be published. - `version`: Optional. The version of the artifact to create. By default, it set to None or `"new"`, it will generate a version `v0`. If you want to create a staged version, you can set it to `"stage"`. - `comment`: Optional. A comment to describe the changes made to the artifact. - `secrets`: Optional. A dictionary containing secrets to be stored with the artifact. Secrets are encrypted and can only be accessed by the artifact owner or users with appropriate permissions. The following keys can be used: @@ -271,7 +272,7 @@ Creates a new artifact or collection with the specified manifest. The artifact i - `S3_PREFIX`: The prefix of the S3 storage for the artifact. Default: `""`. - `S3_PUBLIC_ENDPOINT_URL`: The public endpoint URL of the S3 storage for the artifact. If the S3 server is not public, you can set this to the public endpoint URL. Default: `None`. - `overwrite`: Optional. A boolean flag to overwrite the existing artifact with the same alias. Default is `False`. -- `publish_to`: Optional. A string specifying the target platform to publish the artifact. Supported values are `zenodo` and `sandbox_zenodo`. If set, the artifact will be published to the specified platform. The artifact must have a valid Zenodo metadata schema to be published. + **Note 1: If you set `version="stage"`, you must call `commit()` to finalize the artifact.** @@ -438,6 +439,118 @@ await artifact_manager.delete(artifact_id="other_workspace/example-dataset", del --- +### `add_vectors(artifact_id: str, vectors: list, embedding_models: Optional[Dict[str, str]] = None, context: dict = None) -> None` + +Adds vectors to a vector collection artifact. + +**Parameters:** + +- `artifact_id`: The ID of the artifact to which vectors will be added. This must be a vector-collection artifact. +- `vectors`: A list of vectors to add to the collection. +- `embedding_models`: (Optional) A dictionary specifying embedding models to be used. If not provided, the default models from the artifact's configuration will be used. +- `context`: A dictionary containing user and session context information. + +**Returns:** None. + +**Example:** + +```python +await artifact_manager.add_vectors(artifact_id="example-id", vectors=[{"id": 1, "vector": [0.1, 0.2]}]) +``` + +--- + +### `search_vectors(artifact_id: str, query: Optional[Dict[str, Any]] = None, embedding_models: Optional[str] = None, filters: Optional[dict[str, Any]] = None, limit: Optional[int] = 5, offset: Optional[int] = 0, return_fields: Optional[List[str]] = None, order_by: Optional[str] = None, pagination: Optional[bool] = False, context: dict = None) -> list` + +Searches vectors in a vector collection artifact based on a query. + +**Parameters:** + +- `artifact_id`: The ID of the artifact to search within. This must be a vector-collection artifact. +- `query`: (Optional) A dictionary representing the query vector or conditions. +- `embedding_models`: (Optional) Specifies which embedding model to use. Defaults to the artifact's configuration. +- `filters`: (Optional) Filters for refining the search results. +- `limit`: (Optional) Maximum number of results to return. Defaults to 5. +- `offset`: (Optional) Number of results to skip. Defaults to 0. +- `return_fields`: (Optional) A list of fields to include in the results. +- `order_by`: (Optional) Field to order the results by. +- `pagination`: (Optional) Whether to include pagination metadata. Defaults to `False`. +- `context`: A dictionary containing user and session context information. + +**Returns:** A list of search results. + +**Example:** + +```python +results = await artifact_manager.search_vectors(artifact_id="example-id", query={"vector": [0.1, 0.2]}, limit=10) +``` + +--- + +### `remove_vectors(artifact_id: str, ids: list, context: dict = None) -> None` + +Removes vectors from a vector collection artifact. + +**Parameters:** + +- `artifact_id`: The ID of the artifact from which vectors will be removed. This must be a vector-collection artifact. +- `ids`: A list of vector IDs to remove. +- `context`: A dictionary containing user and session context information. + +**Returns:** None. + +**Example:** + +```python +await artifact_manager.remove_vectors(artifact_id="example-id", ids=[1, 2, 3]) +``` + +--- + +### `get_vector(artifact_id: str, id: int, context: dict = None) -> dict` + +Fetches a specific vector by its ID from a vector collection artifact. + +**Parameters:** + +- `artifact_id`: The ID of the artifact to fetch the vector from. This must be a vector-collection artifact. +- `id`: The ID of the vector to fetch. +- `context`: A dictionary containing user and session context information. + +**Returns:** A dictionary containing the vector data. + +**Example:** + +```python +vector = await artifact_manager.get_vector(artifact_id="example-id", id=123) +``` + +--- + +### `list_vectors(artifact_id: str, offset: int = 0, limit: int = 10, return_fields: List[str] = None, order_by: str = None, pagination: bool = False, context: dict = None) -> list` + +Lists vectors in a vector collection artifact. + +**Parameters:** + +- `artifact_id`: The ID of the artifact to list vectors from. This must be a vector-collection artifact. +- `offset`: (Optional) Number of results to skip. Defaults to 0. +- `limit`: (Optional) Maximum number of results to return. Defaults to 10. +- `return_fields`: (Optional) A list of fields to include in the results. +- `order_by`: (Optional) Field to order the results by. +- `pagination`: (Optional) Whether to include pagination metadata. Defaults to `False`. +- `context`: A dictionary containing user and session context information. + +**Returns:** A list of vectors. + +**Example:** + +```python +vectors = await artifact_manager.list_vectors(artifact_id="example-id", limit=20) +``` + +--- + ### `put_file(artifact_id: str, file_path: str, download_weight: int = 0) -> str` Generates a pre-signed URL to upload a file to the artifact in S3. The URL can be used with an HTTP `PUT` request to upload the file. The file is staged until the artifact is committed. diff --git a/hypha/VERSION b/hypha/VERSION index 160f643a..a40dc949 100644 --- a/hypha/VERSION +++ b/hypha/VERSION @@ -1,3 +1,3 @@ { - "version": "0.20.40.post1" + "version": "0.20.40.post2" } diff --git a/hypha/artifact.py b/hypha/artifact.py index 149e69e1..36d7c078 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -5,6 +5,7 @@ import random import re import json +import math from io import BytesIO import zipfile from sqlalchemy import ( @@ -53,9 +54,10 @@ ) from hypha.vectors import VectorSearchEngine from hypha_rpc.utils import ObjectProxy +import numpy as np from jsonschema import validate from sqlmodel import SQLModel, Field, Relationship, UniqueConstraint -from typing import Optional, Union, List, Any +from typing import Optional, Union, List, Any, Dict # Logger setup logging.basicConfig(stream=sys.stdout) @@ -63,6 +65,21 @@ logger.setLevel(logging.INFO) +def make_json_safe(data): + if isinstance(data, dict): + return {k: make_json_safe(v) for k, v in data.items()} + elif isinstance(data, list): + return [make_json_safe(v) for v in data] + elif data == float("inf"): + return "Infinity" + elif data == float("-inf"): + return "-Infinity" + elif isinstance(data, float) and math.isnan(data): + return "NaN" + else: + return data + + # SQLModel model for storing artifacts class ArtifactModel(SQLModel, table=True): # `table=True` makes it a table model __tablename__ = "artifacts" @@ -175,7 +192,9 @@ def __init__( self.store = store self._cache = store.get_redis_cache() self._openai_client = self.store.get_openai_client() - self._vector_engine = VectorSearchEngine(store) + self._vector_engine = VectorSearchEngine( + store.get_redis(), store.get_cache_dir() + ) router = APIRouter() self._artifacts_dir = artifacts_dir @@ -190,12 +209,13 @@ async def get_artifact( ): """Get artifact metadata, manifest, and config (excluding secrets).""" try: - return await self.read( + artifact = await self.read( artifact_id=f"{workspace}/{artifact_alias}", version=version, silent=silent, context={"user": user_info.model_dump(), "ws": workspace}, ) + return artifact except KeyError: raise HTTPException(status_code=404, detail="Artifact not found") except PermissionError: @@ -255,6 +275,7 @@ async def list_children( context={"user": user_info.model_dump(), "ws": workspace}, ) await self._cache.set(cache_key, results, ttl=60) + return results except KeyError: raise HTTPException(status_code=404, detail="Parent artifact not found") @@ -898,7 +919,6 @@ def _expand_permission(self, permission): "create", "commit", "add_vectors", - "add_documents", ], "lf": ["list", "list_files"], "lf+": ["list", "list_files", "create", "commit", "put_file"], @@ -921,7 +941,6 @@ def _expand_permission(self, permission): "create", "commit", "add_vectors", - "add_documents", ], "rw": [ "read", @@ -935,7 +954,6 @@ def _expand_permission(self, permission): "commit", "put_file", "add_vectors", - "add_documents", "remove_file", "remove_vectors", ], @@ -951,7 +969,6 @@ def _expand_permission(self, permission): "commit", "put_file", "add_vectors", - "add_documents", "remove_file", "remove_vectors", "create", @@ -968,7 +985,6 @@ def _expand_permission(self, permission): "commit", "put_file", "add_vectors", - "add_documents", "remove_file", "remove_vectors", "create", @@ -1025,7 +1041,6 @@ async def _get_artifact_with_permission( "edit": UserPermission.read_write, "commit": UserPermission.read_write, "add_vectors": UserPermission.read_write, - "add_documents": UserPermission.read_write, "put_file": UserPermission.read_write, "remove_vectors": UserPermission.read_write, "remove_file": UserPermission.read_write, @@ -1362,7 +1377,6 @@ async def create( type="generic", config: dict = None, secrets: dict = None, - publish_to=None, version: str = None, comment: str = None, overwrite: bool = False, @@ -1383,9 +1397,11 @@ async def create( if isinstance(manifest, ObjectProxy): manifest = ObjectProxy.toDict(manifest) + manifest = manifest and make_json_safe(manifest) + config = config and make_json_safe(config) + if alias: alias = alias.strip() - assert "^" not in alias, "Alias cannot contain the '^' character." if "/" in alias: ws, alias = alias.split("/") if workspace and ws != workspace: @@ -1437,6 +1453,7 @@ async def create( "timestamp": str(int(time.time())), "user_id": user_info.id, } + publish_to = config.get("publish_to") if publish_to: zenodo_client = self._get_zenodo_client( parent_artifact, publish_to=publish_to @@ -1447,7 +1464,6 @@ async def create( deposition_info["conceptrecid"] ) config["zenodo"] = deposition_info - config["publish_to"] = publish_to if publish_to not in ["zenodo", "sandbox_zenodo"]: assert ( @@ -1540,8 +1556,8 @@ async def create( session.add(new_artifact) if new_artifact.type == "vector-collection": await self._vector_engine.create_collection( - f"{new_artifact.workspace}^{new_artifact.alias}", - vectors_config=config.get("vectors_config", {}), + f"{new_artifact.workspace}/{new_artifact.alias}", + config.get("vector_fields", []), ) await session.commit() await self._save_version_to_s3( @@ -1621,6 +1637,8 @@ async def edit( user_info = UserInfo.model_validate(context["user"]) artifact_id = self._validate_artifact_id(artifact_id, context) session = await self._get_session() + manifest = manifest and make_json_safe(manifest) + config = config and make_json_safe(config) try: async with session.begin(): artifact, parent_artifact = await self._get_artifact_with_permission( @@ -1747,7 +1765,7 @@ async def read( artifact_data["config"][ "vector_count" ] = await self._vector_engine.count( - f"{artifact.workspace}^{artifact.alias}" + f"{artifact.workspace}/{artifact.alias}" ) if not silent: await session.commit() @@ -1897,7 +1915,7 @@ async def delete( if artifact.type == "vector-collection": await self._vector_engine.delete_collection( - f"{artifact.workspace}^{artifact.alias}" + f"{artifact.workspace}/{artifact.alias}" ) s3_config = self._get_s3_config(artifact, parent_artifact) @@ -1953,6 +1971,7 @@ async def add_vectors( self, artifact_id: str, vectors: list, + embedding_models: Optional[Dict[str, str]] = None, context: dict = None, ): """ @@ -1971,7 +1990,10 @@ async def add_vectors( assert artifact.manifest, "Artifact must be committed before upserting." await self._vector_engine.add_vectors( - f"{artifact.workspace}^{artifact.alias}", vectors + f"{artifact.workspace}/{artifact.alias}", + vectors, + embedding_models=embedding_models + or artifact.config.get("embedding_models"), ) logger.info(f"Added vectors to artifact with ID: {artifact_id}") except Exception as e: @@ -1979,46 +2001,17 @@ async def add_vectors( finally: await session.close() - async def add_documents( - self, - artifact_id: str, - documents: str, # `id`, `text` and other fields - context: dict = None, - ): - """ - Add documents to the artifact. - """ - user_info = UserInfo.model_validate(context["user"]) - session = await self._get_session() - try: - async with session.begin(): - artifact, _ = await self._get_artifact_with_permission( - user_info, artifact_id, "add_documents", session - ) - assert ( - artifact.type == "vector-collection" - ), "Artifact must be a vector collection." - embedding_model = artifact.config.get("embedding_model") - await self._vector_engine.add_documents( - f"{artifact.workspace}^{artifact.alias}", documents, embedding_model - ) - logger.info(f"Added documents to artifact with ID: {artifact_id}") - except Exception as e: - raise e - finally: - await session.close() - async def search_vectors( self, artifact_id: str, - query_text: str = None, - query_vector: Any = None, - query_filter: dict = None, - offset: int = 0, - limit: int = 10, - with_payload: bool = True, - with_vectors: bool = False, - pagination: bool = False, + query: Optional[Dict[str, Any]] = None, + embedding_models: Optional[str] = None, + filters: Optional[dict[str, Any]] = None, + limit: Optional[int] = 5, + offset: Optional[int] = 0, + return_fields: Optional[List[str]] = None, + order_by: Optional[str] = None, + pagination: Optional[bool] = False, context: dict = None, ): user_info = UserInfo.model_validate(context["user"]) @@ -2032,17 +2025,18 @@ async def search_vectors( artifact.type == "vector-collection" ), "Artifact must be a vector collection." - embedding_model = artifact.config.get("embedding_model") + embedding_models = embedding_models or artifact.config.get( + "embedding_models" + ) return await self._vector_engine.search_vectors( - f"{artifact.workspace}^{artifact.alias}", - embedding_model=embedding_model, - query_text=query_text, - query_vector=query_vector, - query_filter=query_filter, - offset=offset, + f"{artifact.workspace}/{artifact.alias}", + query=query, + embedding_models=embedding_models, + filters=filters, limit=limit, - with_payload=with_payload, - with_vectors=with_vectors, + offset=offset, + return_fields=return_fields, + order_by=order_by, pagination=pagination, ) except Exception as e: @@ -2067,7 +2061,7 @@ async def remove_vectors( artifact.type == "vector-collection" ), "Artifact must be a vector collection." await self._vector_engine.remove_vectors( - f"{artifact.workspace}^{artifact.alias}", ids + f"{artifact.workspace}/{artifact.alias}", ids ) logger.info(f"Removed vectors from artifact with ID: {artifact_id}") except Exception as e: @@ -2092,7 +2086,7 @@ async def get_vector( artifact.type == "vector-collection" ), "Artifact must be a vector collection." return await self._vector_engine.get_vector( - f"{artifact.workspace}^{artifact.alias}", id + f"{artifact.workspace}/{artifact.alias}", id ) except Exception as e: raise e @@ -2102,12 +2096,11 @@ async def get_vector( async def list_vectors( self, artifact_id: str, - query_filter: dict = None, offset: int = 0, limit: int = 10, + return_fields: List[str] = None, order_by: str = None, - with_payload: bool = True, - with_vectors: bool = False, + pagination: bool = False, context: dict = None, ): user_info = UserInfo.model_validate(context["user"]) @@ -2121,13 +2114,12 @@ async def list_vectors( artifact.type == "vector-collection" ), "Artifact must be a vector collection." return await self._vector_engine.list_vectors( - f"{artifact.workspace}^{artifact.alias}", - query_filter=query_filter, + f"{artifact.workspace}/{artifact.alias}", offset=offset, limit=limit, + return_fields=return_fields, order_by=order_by, - with_payload=with_payload, - with_vectors=with_vectors, + pagination=pagination, ) except Exception as e: @@ -2724,7 +2716,6 @@ def get_artifact_service(self): "list": self.list_children, "list_files": self.list_files, "add_vectors": self.add_vectors, - "add_documents": self.add_documents, "search_vectors": self.search_vectors, "remove_vectors": self.remove_vectors, "get_vector": self.get_vector, diff --git a/hypha/core/store.py b/hypha/core/store.py index d9c15018..540e2c3b 100644 --- a/hypha/core/store.py +++ b/hypha/core/store.py @@ -96,7 +96,6 @@ def __init__( local_base_url=None, redis_uri=None, database_uri=None, - vectordb_uri=None, ollama_host=None, openai_config=None, cache_dir=None, @@ -136,13 +135,6 @@ def __init__( } logger.info("Server info: %s", self._server_info) - self._vectordb_uri = vectordb_uri - if self._vectordb_uri is not None: - from qdrant_client import AsyncQdrantClient - - self._vectordb_client = AsyncQdrantClient(self._vectordb_uri) - else: - self._vectordb_client = None self._database_uri = database_uri if self._database_uri is None: @@ -205,9 +197,6 @@ def get_redis(self): def get_sql_engine(self): return self._sql_engine - def get_vectordb_client(self): - return self._vectordb_client - def get_openai_client(self): return self._openai_client @@ -553,33 +542,9 @@ async def _register_root_services(self): "list_servers": self.list_servers, "kickout_client": self.kickout_client, "list_workspaces": self.list_all_workspaces, - "list_vector_collections": self.list_vector_collections, - "delete_vector_collection": self.delete_vector_collection, } ) - @schema_method - async def list_vector_collections(self): - """List all vector collections.""" - if self._vectordb_client is None: - raise Exception("Vector database is not configured") - # get_collections - collections = await self._vectordb_client.get_collections() - return collections - - @schema_method - async def delete_vector_collection( - self, - collection_name: str = Field( - ..., description="The name of the vector collection to delete." - ), - ): - """Delete a vector collection.""" - if self._vectordb_client is None: - raise Exception("Vector database is not configured") - # delete_collection - await self._vectordb_client.delete_collection(collection_name) - @schema_method async def list_servers(self): """List all servers.""" diff --git a/hypha/core/workspace.py b/hypha/core/workspace.py index 3e66dfcb..bf9766a5 100644 --- a/hypha/core/workspace.py +++ b/hypha/core/workspace.py @@ -34,6 +34,7 @@ UserPermission, ServiceTypeInfo, ) +from hypha.vectors import VectorSearchEngine from hypha.core.auth import generate_presigned_token, create_scope, valid_token from hypha.utils import EventBus, random_id @@ -181,47 +182,44 @@ async def setup( logger.info("Database tables created successfully.") self._embedding_model = None - self._search_fields = None if self._enable_service_search: from fastembed import TextEmbedding self._embedding_model = TextEmbedding( model_name="BAAI/bge-small-en-v1.5", cache_dir=self._cache_dir ) - - from redis.commands.search.field import VectorField, TextField, TagField - from redis.commands.search.indexDefinition import IndexDefinition, IndexType - - # Define vector field for RedisSearch (assuming cosine similarity) - # Manually define Redis fields for each ServiceInfo attribute - self._search_fields = [ - TagField(name="id"), # id as tag - TextField(name="name"), # name as text - TagField(name="type"), # type as tag (enum-like) - TextField(name="description"), # description as text - TextField(name="docs"), # docs as text - TagField(name="app_id"), # app_id as tag - TextField( - name="service_schema" - ), # service_schema as text (you can store a serialized JSON or string representation) - VectorField( - "service_embedding", - "FLAT", - {"TYPE": "FLOAT32", "DIM": 384, "DISTANCE_METRIC": "COSINE"}, - ), - TagField(name="visibility"), # visibility as tag - TagField(name="require_context"), # require_context as tag - TagField(name="workspace"), # workspace as tag - TagField(name="flags", separator=","), # flags as tag - TagField(name="singleton"), # singleton as tag - TextField(name="created_by"), # created_by as text - ] - # Create the index with vector field and additional fields for metadata (e.g., title) - await self._redis.ft("service_info_index").create_index( - fields=self._search_fields, - definition=IndexDefinition( - prefix=["services:"], index_type=IndexType.HASH - ), + self._vector_search = VectorSearchEngine( + self._redis, + prefix=None, + cache_dir=self._cache_dir, + ) + await self._vector_search.create_collection( + collection_name="services", + vector_fields=[ + {"type": "TAG", "name": "id"}, + {"type": "TEXT", "name": "name"}, + {"type": "TAG", "name": "type"}, + {"type": "TEXT", "name": "description"}, + {"type": "TEXT", "name": "docs"}, + {"type": "TAG", "name": "app_id"}, + {"type": "TEXT", "name": "service_schema"}, + { + "type": "VECTOR", + "name": "service_embedding", + "algorithm": "FLAT", + "attributes": { + "TYPE": "FLOAT32", + "DIM": 384, + "DISTANCE_METRIC": "COSINE", + }, + }, + {"type": "TAG", "name": "visibility"}, + {"type": "TAG", "name": "require_context"}, + {"type": "TAG", "name": "workspace"}, + {"type": "TAG", "name": "flags", "separator": ","}, + {"type": "TAG", "name": "singleton"}, + {"type": "TEXT", "name": "created_by"}, + ], ) self._initialized = True return rpc @@ -841,12 +839,9 @@ def _convert_filters_to_hybrid_query(self, filters: dict) -> str: @schema_method async def search_services( self, - text_query: Optional[str] = Field( - None, description="Text query for semantic search." - ), - vector_query: Optional[Any] = Field( + query: Optional[Union[str, Any]] = Field( None, - description="Precomputed embedding vector for vector search in numpy format.", + description="Text query or precomputed embedding vector for vector search in numpy format.", ), filters: Optional[Dict[str, Any]] = Field( None, description="Filter dictionary for hybrid search." @@ -855,7 +850,9 @@ async def search_services( 5, description="Maximum number of results to return." ), offset: Optional[int] = Field(0, description="Offset for pagination."), - fields: Optional[List[str]] = Field(None, description="Fields to return."), + return_fields: Optional[List[str]] = Field( + None, description="Fields to return." + ), order_by: Optional[str] = Field( None, description="Order by field, default is score if embedding or text_query is provided.", @@ -870,91 +867,40 @@ async def search_services( """ if not self._enable_service_search: raise RuntimeError("Service search is not enabled.") - from redis.commands.search.query import Query current_workspace = context["ws"] # Generate embedding if text_query is provided - if text_query and not vector_query: + if isinstance(query, str): loop = asyncio.get_event_loop() embeddings = list( - await loop.run_in_executor( - None, self._embedding_model.embed, [text_query] - ) + await loop.run_in_executor(None, self._embedding_model.embed, [query]) ) vector_query = embeddings[0] - - auth_filter = f"@visibility:{{public}} | @workspace:{{{sanitize_search_value(current_workspace)}}}" - # If service_embedding is provided, prepare KNN search query - if vector_query is not None: - query_vector = vector_query.astype("float32").tobytes() - query_params = {"vector": query_vector} - knn_query = f"[KNN {limit} @service_embedding $vector AS score]" - # Combine filters into the query string - if filters: - filter_query = self._convert_filters_to_hybrid_query(filters) - query_string = f"(({filter_query}) ({auth_filter}))=>{knn_query}" - else: - query_string = f"({auth_filter})=>{knn_query}" else: - query_params = {} - if filters: - filter_query = self._convert_filters_to_hybrid_query(filters) - query_string = f"({filter_query}) ({auth_filter})" - else: - query_string = auth_filter - - all_fields = [field.name for field in self._search_fields] + ["score"] - if fields is None: - # exclude embedding - fields = [field for field in all_fields if field != "service_embedding"] - else: - for field in fields: - if field not in all_fields: - raise ValueError(f"Invalid field: {field}") - if order_by is None: - order_by = "score" if vector_query is not None else "id" - else: - if order_by not in all_fields: - raise ValueError(f"Invalid order_by field: {order_by}") - - # Build the RedisSearch query - query = ( - Query(query_string) - .return_fields(*fields) - .sort_by(order_by, asc=True) - .paging(offset, limit) - .dialect(2) - ) + vector_query = query - # Perform the search using the RedisSearch index - results = await self._redis.ft("service_info_index").search( - query, query_params=query_params + auth_filter = f"@visibility:{{public}} | @workspace:{{{sanitize_search_value(current_workspace)}}}" + results = await self._vector_search.search_vectors( + "services", + query={"service_embedding": vector_query}, + filters=filters, + extra_filter=auth_filter, + limit=limit, + offset=offset, + return_fields=return_fields, + order_by=order_by, + pagination=True, ) - # Handle pagination - if pagination: - count_query = Query(query_string).paging(0, 0).dialect(2) - count_results = await self._redis.ft("service_info_index").search( - count_query, query_params=query_params - ) - total_count = count_results.total - else: - total_count = None - # Convert results to dictionaries and return - services = [ - ServiceInfo.from_redis_dict(vars(doc), in_bytes=False) - for doc in results.docs + results["items"] = [ + ServiceInfo.from_redis_dict(doc, in_bytes=False).model_dump() + for doc in results["items"] ] if pagination: - return { - "items": [service.model_dump() for service in services], - "total": total_count, - "offset": offset, - "limit": limit, - } + return results else: - return [service.model_dump() for service in services] + return results["items"] @schema_method async def list_services( diff --git a/hypha/server.py b/hypha/server.py index 25d202f4..45463694 100644 --- a/hypha/server.py +++ b/hypha/server.py @@ -221,7 +221,6 @@ async def lifespan(app: FastAPI): local_base_url=local_base_url, redis_uri=args.redis_uri, database_uri=args.database_uri, - vectordb_uri=args.vectordb_uri, ollama_host=args.ollama_host, cache_dir=args.cache_dir, openai_config={ @@ -399,12 +398,6 @@ def get_argparser(add_help=True): default=None, help="set OpenAI API key", ) - parser.add_argument( - "--vectordb-uri", - type=str, - default=None, - help="set URI for the vector database", - ) parser.add_argument( "--database-uri", type=str, diff --git a/hypha/templates/apps/web-python.index.html b/hypha/templates/apps/web-python.index.html index dbd41c6d..570e56d8 100644 --- a/hypha/templates/apps/web-python.index.html +++ b/hypha/templates/apps/web-python.index.html @@ -121,9 +121,9 @@ async function setupPyodide() { if(self.pyodide) return; - importScripts('https://cdn.jsdelivr.net/pyodide/v0.26.1/full/pyodide.js'); + importScripts('https://cdn.jsdelivr.net/pyodide/v0.26.4/full/pyodide.js'); self.pyodide = await loadPyodide({ - indexURL : 'https://cdn.jsdelivr.net/pyodide/v0.26.1/full/', + indexURL : 'https://cdn.jsdelivr.net/pyodide/v0.26.4/full/', stdout: (text) => { self.postMessage({"type": "stdout", "content": text}) }, diff --git a/hypha/templates/hypha-core-app/hypha-app-webpython.js b/hypha/templates/hypha-core-app/hypha-app-webpython.js index 041dc0fb..8c17945e 100644 --- a/hypha/templates/hypha-core-app/hypha-app-webpython.js +++ b/hypha/templates/hypha-core-app/hypha-app-webpython.js @@ -1,4 +1,4 @@ -importScripts("https://cdn.jsdelivr.net/pyodide/v0.26.1/full/pyodide.js"); +importScripts("https://cdn.jsdelivr.net/pyodide/v0.26.4/full/pyodide.js"); const startupScript = ` import sys diff --git a/hypha/vectors.py b/hypha/vectors.py index e8ffa836..a0f43ec4 100644 --- a/hypha/vectors.py +++ b/hypha/vectors.py @@ -1,213 +1,480 @@ import asyncio -from qdrant_client.models import PointStruct -from qdrant_client.models import Filter -from qdrant_client.models import Distance, VectorParams -import uuid import numpy as np -from typing import Any +import uuid import logging import sys +import re +from typing import Any, List, Dict, Optional +from fakeredis import aioredis +from redis.commands.search.field import ( + TagField, + TextField, + NumericField, + GeoField, + VectorField, +) +from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.query import Query +from redis.commands.search.query import Query # Logger setup logging.basicConfig(stream=sys.stdout) logger = logging.getLogger("vectors") logger.setLevel(logging.INFO) +FIELD_TYPE_MAPPING = { + "TAG": TagField, + "TEXT": TextField, + "NUMERIC": NumericField, + "GEO": GeoField, + "VECTOR": VectorField, +} + + +def escape_redis_syntax(value: str) -> str: + """Escape Redis special characters in a query string, except '*'.""" + # Escape all special characters except '*' + return re.sub(r"([{}|@$\\\-\[\]\(\)\!&~:\"])", r"\\\1", value) + + +def sanitize_search_value(value: str) -> str: + """Sanitize a value to prevent injection attacks, allowing '*' for wildcard support.""" + # Allow alphanumeric characters, spaces, underscores, hyphens, dots, slashes, and '*' + value = re.sub( + r"[^a-zA-Z0-9 _\-./*]", "", value + ) # Remove unwanted characters except '*' + return escape_redis_syntax(value.strip()) + + +def parse_attributes(attributes): + """ + Parse the attributes list into a structured dictionary. + """ + parsed_attributes = [] + for attr in attributes: + attr_dict = {} + for i in range(0, len(attr), 2): + key = attr[i].decode("utf-8") if isinstance(attr[i], bytes) else attr[i] + value = attr[i + 1] + if isinstance(value, bytes): + value = value.decode("utf-8") + attr_dict[key] = value + parsed_attributes.append(attr_dict) + return {attr["identifier"]: attr for attr in parsed_attributes} + class VectorSearchEngine: - def __init__(self, store): - self.store = store - self._vectordb_client = self.store.get_vectordb_client() - self._cache_dir = self.store.get_cache_dir() - - def _ensure_client(self): - assert ( - self._vectordb_client - ), "The server is not configured to use a VectorDB client." - - async def create_collection(self, collection_name: str, vectors_config: dict): - self._ensure_client() - await self._vectordb_client.create_collection( - collection_name=collection_name, - vectors_config=VectorParams( - size=vectors_config.get("size", 128), - distance=Distance(vectors_config.get("distance", "Cosine")), + def __init__(self, redis, prefix: str = "vs", cache_dir=None): + self._redis: aioredis.FakeRedis = redis + self._index_name_prefix = prefix + self._cache_dir = cache_dir + + def _get_index_name(self, collection_name: str) -> str: + return ( + f"{self._index_name_prefix}:{collection_name}" + if self._index_name_prefix + else collection_name + ) + + async def _get_fields(self, collection_name: str): + index_name = self._get_index_name(collection_name) + info = await self._redis.ft(index_name).info() + fields = parse_attributes(info["attributes"]) + return fields + + async def create_collection( + self, + collection_name: str, + vector_fields: List[Dict[str, Any]], + ): + """ + Creates a RedisSearch index collection. + + Args: + collection_name (str): Name of the collection. + fields (List[Dict[str, Any]]): A list of dictionaries defining fields. + Example: + [ + {"type": "TAG", "name": "id"}, + {"type": "VECTOR", "name": "vector", "algorithm": "FLAT", + "attributes": {"TYPE": "FLOAT32", "DIM": 384, "DISTANCE_METRIC": "COSINE"}} + ] + """ + index_name = self._get_index_name(collection_name) + assert vector_fields, "At least one field must be provided." + + redis_fields = [] + for field in vector_fields: + field_type = field.get("type") + field_name = field.get("name") + if not field_type or not field_name: + raise ValueError( + f"Invalid field definition: {field}. Each field must have 'name' and 'type'." + ) + + field_class = FIELD_TYPE_MAPPING.get(field_type) + if not field_class: + raise ValueError( + f"Unsupported field type: {field_type}. Supported types: {list(FIELD_TYPE_MAPPING.keys())}" + ) + + if field_type == "VECTOR": + algorithm = field.get("algorithm", "FLAT") + attributes = field.get("attributes", {}) + redis_fields.append(field_class(field_name, algorithm, attributes)) + else: + redis_fields.append(field_class(name=field_name)) + + # Create the index + await self._redis.ft(index_name).create_index( + fields=redis_fields, + definition=IndexDefinition( + prefix=[f"{index_name}:"], index_type=IndexType.HASH ), ) - logger.info(f"Collection {collection_name} created.") + fields_info = [f"{field.name} ({type(field)})" for field in redis_fields] + logger.info(f"Collection {collection_name} created with fields: {fields_info}.") async def count(self, collection_name: str): - self._ensure_client() - return ( - await self._vectordb_client.count(collection_name=collection_name) - ).count + index_name = self._get_index_name(collection_name) + count_query = Query("*").paging(0, 0).dialect(2) + results = await self._redis.ft(index_name).search(count_query) + return results.total async def delete_collection(self, collection_name: str): - self._ensure_client() - await self._vectordb_client.delete_collection(collection_name=collection_name) + index_name = self._get_index_name(collection_name) + await self._redis.ft(index_name).dropindex() logger.info(f"Collection {collection_name} deleted.") - async def _embed_texts(self, embedding_model, texts: list): - self._ensure_client() - assert ( - embedding_model - ), "Embedding model must be provided, e.g. 'fastembed:BAAI/bge-small-en-v1.5', 'openai:text-embedding-3-small' for openai embeddings." - if embedding_model.startswith("fastembed"): + async def list_collections(self): + # use redis.ft to list all collections + collections = await self._redis.execute_command("FT._LIST") + return [c.decode("utf-8") for c in collections] + + async def _embed_texts( + self, + texts: List[str], + embedding_model: str = "fastembed:BAAI/bge-small-en-v1.5", + ) -> List[np.ndarray]: + if not embedding_model: + raise ValueError("Embedding model is not configured.") + if embedding_model.startswith("fastembed:"): from fastembed import TextEmbedding - assert ":" in embedding_model, "Embedding model must be provided." - model_name = embedding_model.split(":")[-1] + model_name = embedding_model.split(":")[1] embedding_model = TextEmbedding( model_name=model_name, cache_dir=self._cache_dir ) - loop = asyncio.get_event_loop() - embeddings = list( - await loop.run_in_executor(None, embedding_model.embed, texts) - ) - elif embedding_model.startswith("openai"): - assert ( - self._openai_client - ), "The server is not configured to use an OpenAI client." - assert ":" in embedding_model, "Embedding model must be provided." - embedding_model = embedding_model.split(":")[-1] - result = await self._openai_client.embeddings.create( - input=texts, model=embedding_model - ) - embeddings = [data.embedding for data in result.data] - else: - raise ValueError( - f"Unsupported embedding model: {embedding_model}, supported models: 'fastembed:*', 'openai:*'" - ) + loop = asyncio.get_event_loop() + embeddings = list( + await loop.run_in_executor(None, embedding_model.embed, texts) + ) return embeddings - async def add_vectors(self, collection_name: str, vectors: list): - self._ensure_client() - assert isinstance(vectors, list), "Vectors must be a list of dictionaries." - assert all( - isinstance(v, dict) for v in vectors - ), "Vectors must be a list of dictionaries." - - _points = [] - for p in vectors: - p["id"] = p.get("id") or str(uuid.uuid4()) - _points.append(PointStruct(**p)) - await self._vectordb_client.upsert( - collection_name=collection_name, - points=_points, - ) + async def add_vectors( + self, + collection_name: str, + vectors: List[Dict[str, Any]], + embedding_models: Optional[Dict[str, str]] = None, + ): + index_name = self._get_index_name(collection_name) + fields = await self._get_fields(collection_name) + ids = [] + for vector in vectors: + if "_id" in vector: + _id = vector["_id"] + del vector["_id"] + else: + _id = str(uuid.uuid4()) + ids.append(_id) + # convert numpy arrays to bytes + for key, value in vector.items(): + if key in fields: + if fields[key]["type"] == "VECTOR": + if isinstance(value, np.ndarray) or isinstance(value, list): + vector[key] = np.array(value).astype("float32").tobytes() + else: + assert ( + key in embedding_models + ), f"Embedding model not provided for field {key}." + embeddings = await self._embed_texts( + [vector[key]], embedding_models[key] + ) + vector[key] = ( + np.array(embeddings[0]).astype("float32").tobytes() + ) + + await self._redis.hset( + f"{index_name}:{_id}", + mapping=vector, + ) + logger.info(f"Added {len(vectors)} vectors to collection {collection_name}.") + return ids + + def _format_docs(self, docs, index_name): + formated = [] + for doc in docs: + d = vars(doc) + if "id" in d: + d["id"] = d["id"].replace(f"{index_name}:", "") + formated.append(d) + return formated - async def add_documents(self, collection_name, documents, embedding_model): - self._ensure_client() - texts = [doc["text"] for doc in documents] - embeddings = await self._embed_texts(embedding_model, texts) - points = [ - PointStruct( - id=doc.get("id") or str(uuid.uuid4()), - vector=embedding, - payload=doc, + def _convert_filters_to_hybrid_query(self, filters: dict, fields: dict) -> str: + """ + Convert a filter dictionary to a Redis hybrid query string. + + Args: + filters (dict): Dictionary of filters, e.g., {"type": "my-type", "year": [2011, 2012]}. + + Returns: + str: Redis hybrid query string, e.g., "(@type:{my-type} @year:[2011 2012])". + """ + + conditions = [] + + for field_name, value in filters.items(): + # Find the field type in the schema + field_type = fields.get(field_name) + field_type = ( + FIELD_TYPE_MAPPING.get(field_type["type"]) if field_type else None ) - for embedding, doc in zip(embeddings, documents) - ] - await self._vectordb_client.upsert( - collection_name=collection_name, - points=points, - ) - logger.info( - f"Added {len(documents)} documents to collection {collection_name}." - ) + + if not field_type: + raise ValueError( + f"Unknown field '{field_name}' in filters, available fields: {list(fields.keys())}" + ) + + # Sanitize the field name + sanitized_field_name = sanitize_search_value(field_name) + + if field_type == TagField: + # Use `{value}` for TagField + if not isinstance(value, str): + raise ValueError( + f"TagField '{field_name}' requires a string value." + ) + sanitized_value = sanitize_search_value(value) + conditions.append(f"@{sanitized_field_name}:{{{sanitized_value}}}") + + elif field_type == NumericField: + # Use `[min max]` for NumericField + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError( + f"NumericField '{field_name}' requires a list or tuple with two elements." + ) + min_val, max_val = value + conditions.append(f"@{sanitized_field_name}:[{min_val} {max_val}]") + + elif field_type == TextField: + # Use `"value"` for TextField + if not isinstance(value, str): + raise ValueError( + f"TextField '{field_name}' requires a string value." + ) + if "*" in value: + assert value.endswith("*"), "Wildcard '*' must be at the end." + sanitized_value = sanitize_search_value(value) + conditions.append(f"@{sanitized_field_name}:{sanitized_value}") + else: + sanitized_value = escape_redis_syntax(value) + conditions.append(f'@{sanitized_field_name}:"{sanitized_value}"') + + else: + raise ValueError(f"Unsupported field type for '{field_name}'.") + + return " ".join(conditions) async def search_vectors( self, collection_name: str, - embedding_model: str, - query_text: str, - query_vector: Any, - query_filter: dict = None, - offset: int = 0, - limit: int = 10, - with_payload: bool = True, - with_vectors: bool = False, - pagination: bool = False, + query: Optional[Dict[str, Any]] = None, + embedding_models: Optional[Dict[str, str]] = None, + filters: Optional[Dict[str, Any]] = None, + extra_filter: Optional[str] = None, + limit: Optional[int] = 5, + offset: Optional[int] = 0, + return_fields: Optional[List[str]] = None, + order_by: Optional[str] = None, + pagination: Optional[bool] = False, ): - self._ensure_client() - # if it's a numpy array, convert it to a list - if isinstance(query_vector, np.ndarray): - query_vector = query_vector.tolist() - from qdrant_client.models import Filter + """ + Search vectors in a collection. + """ + index_name = self._get_index_name(collection_name) + collection_fields = await self._get_fields(collection_name) + all_fields = [field for field in collection_fields.keys()] + ["score"] + # Generate embedding if text_query is provided + if query: + assert ( + len(query) == 1 + ), "Only one of text_query or vector_query can be provided." + use_vector = list(query.keys())[0] + assert ( + collection_fields[use_vector]["type"] == "VECTOR" + ), f"Field {use_vector} is not a vector field." - if query_filter: - query_filter = Filter.model_validate(query_filter) + vector_query = query[use_vector] + if isinstance(vector_query, np.ndarray): + vector_query = vector_query.astype("float32").tobytes() + elif isinstance(vector_query, list): + vector_query = np.array(vector_query).astype("float32").tobytes() + elif vector_query is not None: + assert ( + embedding_models and use_vector in embedding_models + ), f"Embedding model not provided for field {use_vector}." + embeddings = await self._embed_texts( + [vector_query], + embedding_models[use_vector], + ) + vector_query = np.array(embeddings[0]).astype("float32").tobytes() - if query_text: - assert ( - not query_vector - ), "Either query_text or query_vector must be provided." - embeddings = await self._embed_texts(embedding_model, [query_text]) - query_vector = embeddings[0] - - search_results = await self._vectordb_client.search( - collection_name=collection_name, - query_vector=query_vector, - query_filter=query_filter, - limit=limit, - offset=offset, - with_payload=with_payload, - with_vectors=with_vectors, + # If service_embedding is provided, prepare KNN search query + if vector_query is not None: + query_params = {"vector": vector_query} + knn_query = f"[KNN {limit} @{use_vector} $vector AS score]" + # Combine filters into the query string + if filters: + filter_query = self._convert_filters_to_hybrid_query( + filters, collection_fields + ) + query_string = ( + f"(({filter_query}) ({extra_filter}))=>{knn_query}" + if extra_filter + else f"({filter_query})=>{knn_query}" + ) + else: + query_string = ( + f"({extra_filter})=>{knn_query}" + if extra_filter + else f"*=>{knn_query}" + ) + else: + query_params = {} + if filters: + filter_query = self._convert_filters_to_hybrid_query( + filters, collection_fields + ) + query_string = ( + f"({filter_query}) ({extra_filter})" + if extra_filter + else f"{filter_query}" + ) + else: + query_string = extra_filter if extra_filter else "*" + + if return_fields is None: + # exclude embedding + return_fields = [ + field + for field in all_fields + if field == "score" or collection_fields[field]["type"] != "VECTOR" + ] + else: + for field in return_fields: + if field not in all_fields: + raise ValueError(f"Invalid field: {field}") + if order_by is None: + order_by = "score" if vector_query is not None else all_fields[0] + else: + if order_by not in all_fields: + raise ValueError(f"Invalid order_by field: {order_by}") + + # Build the RedisSearch query + query = ( + Query(query_string) + .return_fields(*return_fields) + .sort_by(order_by, asc=True) + .paging(offset, limit) + .dialect(2) ) + + # Perform the search using the RedisSearch index + results = await self._redis.ft(index_name).search( + query, query_params=query_params + ) + + # Handle pagination + if pagination: + count_query = Query(query_string).paging(0, 0).dialect(2) + count_results = await self._redis.ft(index_name).search( + count_query, query_params=query_params + ) + total_count = count_results.total + else: + total_count = None + + docs = self._format_docs(results.docs, index_name) + if pagination: - count = await self._vectordb_client.count(collection_name=collection_name) return { - "total": count.count, - "items": search_results, + "items": docs, + "total": total_count, "offset": offset, "limit": limit, } - logger.info(f"Performed semantic search in collection {collection_name}.") - return search_results - - async def remove_vectors(self, collection_name: str, ids: list): - self._ensure_client() - await self._vectordb_client.delete( - collection_name=collection_name, - points_selector=ids, - ) + else: + return docs + + async def remove_vectors(self, collection_name: str, ids: List[str]): + index_name = self._get_index_name(collection_name) + for id in ids: + await self._redis.delete(f"{index_name}:{id}") logger.info(f"Removed {len(ids)} vectors from collection {collection_name}.") async def get_vector(self, collection_name: str, id: str): - self._ensure_client() - points = await self._vectordb_client.retrieve( - collection_name=collection_name, - ids=[id], - with_payload=True, - with_vectors=True, - ) - if not points: + index_name = self._get_index_name(collection_name) + vector = await self._redis.hgetall(f"{index_name}:{id}") + if not vector: raise ValueError(f"Vector {id} not found in collection {collection_name}.") - logger.info(f"Retrieved vector {id} from collection {collection_name}.") - return points[0] + + return {key.decode("utf-8"): value for key, value in vector.items()} async def list_vectors( self, collection_name: str, - query_filter: dict = None, offset: int = 0, limit: int = 10, - order_by: str = None, - with_payload: bool = True, - with_vectors: bool = False, + return_fields: Optional[List[str]] = None, + order_by: Optional[str] = None, + pagination: Optional[bool] = False, ): - self._ensure_client() - if query_filter: - query_filter = Filter.model_validate(query_filter) - points, _ = await self._vectordb_client.scroll( - collection_name=collection_name, - scroll_filter=query_filter, - limit=limit, - offset=offset, - order_by=order_by, - with_payload=with_payload, - with_vectors=with_vectors, + index_name = self._get_index_name(collection_name) + # read the index settings + fields = await self._get_fields(collection_name) + return_fields = return_fields or [ + field["identifier"] + for field in fields.values() + if field["type"] != "VECTOR" + ] + if order_by is None: + order_by = list(fields.keys())[0] + else: + if order_by not in fields.keys(): + raise ValueError(f"Invalid order_by field: {order_by}") + + query_string = "*" + + query = ( + Query(query_string) + .return_fields(*return_fields) + .paging(offset, limit) + .dialect(2) + .sort_by(order_by, asc=True) ) - logger.info(f"Listed vectors in collection {collection_name}.") - return points + + # Perform the search using the RedisSearch index + results = await self._redis.ft(index_name).search(query) + docs = self._format_docs(results.docs, index_name) + + if pagination: + return { + "items": docs, + "duration": results.duration, + "total": results.total, + "offset": offset, + "limit": limit, + } + else: + return docs diff --git a/requirements.txt b/requirements.txt index 294ae15b..bc5fd643 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,7 +34,6 @@ asyncpg>=0.30.0 sqlmodel==0.0.22 alembic==1.14.0 hrid==0.2.4 -qdrant-client==1.12.1 ollama==0.3.3 fastembed==0.4.2 asgiproxy==0.1.1 diff --git a/setup.py b/setup.py index 40ded274..c91a77e5 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,6 @@ "db": [ "psycopg2-binary>=2.9.10", "asyncpg>=0.30.0", - "qdrant-client>=1.12.1", "fastembed>=0.4.2", ], }, diff --git a/tests/__init__.py b/tests/__init__.py index edb0fa75..af49fa69 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -28,10 +28,6 @@ POSTGRES_URI = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@localhost:{POSTGRES_PORT}/{POSTGRES_DB}" -QDRANT_PORT = 6333 -QDRANT_URL = "http://127.0.0.1:6333" - - def find_item(items, key, value): """Find an item with key or attributes in an object list.""" filtered = [ diff --git a/tests/conftest.py b/tests/conftest.py index 7d24b65b..445cf87c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,6 @@ from hypha.core.auth import generate_presigned_token, create_scope from hypha.minio import setup_minio_executables from redis import Redis -from qdrant_client import QdrantClient from . import ( MINIO_PORT, @@ -39,8 +38,6 @@ POSTGRES_PASSWORD, POSTGRES_DB, POSTGRES_URI, - QDRANT_PORT, - QDRANT_URL, ) JWT_SECRET = str(uuid.uuid4()) @@ -262,67 +259,8 @@ def redis_server(): time.sleep(1) -@pytest_asyncio.fixture(name="qdrant_server", scope="session") -def qdrant_server(): - """Start a Qdrant server as test fixture and tear down after test.""" - try: - # Check if Qdrant is already running - client = QdrantClient(QDRANT_URL) - client.info() - print(f"Qdrant is running, using vectordb at {QDRANT_URL}") - yield QDRANT_URL - except Exception: - # Pull the Qdrant image - print("Pulling Qdrant Docker image...") - subprocess.run( - ["docker", "pull", "qdrant/qdrant:v1.12.4-unprivileged"], check=True - ) - # Start Qdrant using Docker - print("Qdrant is not running, starting a Qdrant server using Docker...") - dirpath = tempfile.mkdtemp() - subprocess.Popen( - [ - "docker", - "run", - "-d", - "--name", - "qdrant", - "-p", - f"{QDRANT_PORT}:{QDRANT_PORT}", - "-p", - "6334:6334", - "-v", - f"{dirpath}:/qdrant/storage:z", - "qdrant/qdrant:v1.12.4-unprivileged", - ] - ) - - # Wait for Qdrant to be ready - timeout = 10 - while timeout > 0: - try: - client = QdrantClient(QDRANT_URL) - client.info() - print(f"Qdrant is running at {QDRANT_URL}") - break - except Exception: - pass - timeout -= 0.1 - time.sleep(0.1) - - if timeout <= 0: - raise RuntimeError("Failed to start Qdrant server.") - - yield QDRANT_URL - - # Stop and remove the Docker container after the test - subprocess.Popen(["docker", "stop", "qdrant"]) - subprocess.Popen(["docker", "rm", "qdrant"]) - time.sleep(1) - - @pytest_asyncio.fixture(name="fastapi_server", scope="session") -def fastapi_server_fixture(minio_server, postgres_server, qdrant_server): +def fastapi_server_fixture(minio_server, postgres_server): """Start server as test fixture and tear down after test.""" with subprocess.Popen( [ @@ -342,7 +280,6 @@ def fastapi_server_fixture(minio_server, postgres_server, qdrant_server): "--enable-s3-proxy", f"--workspace-bucket=my-workspaces", "--s3-admin-type=minio", - f"--vectordb-uri={qdrant_server}", "--cache-dir=./bin/cache", f"--triton-servers=http://127.0.0.1:{TRITON_PORT}", "--static-mounts=/tests:./tests", diff --git a/tests/test_artifact.py b/tests/test_artifact.py index f211c01d..1e5476b7 100644 --- a/tests/test_artifact.py +++ b/tests/test_artifact.py @@ -6,6 +6,7 @@ from io import BytesIO from zipfile import ZipFile import httpx +import yaml from . import SERVER_URL, SERVER_URL_SQLITE, find_item @@ -737,8 +738,10 @@ async def test_publish_artifact(minio_server, fastapi_server, test_user_token): alias="{zenodo_conceptrecid}", parent_id=collection.id, manifest=dataset_manifest, + config={ + "publish_to": "sandbox_zenodo" + }, version="stage", - publish_to="sandbox_zenodo", ) assert ( @@ -895,11 +898,13 @@ async def test_http_artifact_endpoint(minio_server, fastapi_server, test_user_to config={"permissions": {"*": "r", "@": "r+"}}, ) - # Create an artifact within the collection - dataset_manifest = { - "name": "Test Dataset", - "description": "A test dataset for HTTP endpoint", - } + # Create a string in yaml with infinite float + yaml_str = """ + name: Test Dataset + description: A test dataset for HTTP endpoint + inf_float: [-.inf, .inf] + """ + dataset_manifest = yaml.safe_load(yaml_str) dataset = await artifact_manager.create( type="dataset", parent_id=collection.id, diff --git a/tests/test_service_search.py b/tests/test_service_search.py index f42f70ae..b166fa56 100644 --- a/tests/test_service_search.py +++ b/tests/test_service_search.py @@ -63,14 +63,14 @@ async def test_service_search(fastapi_server_redis_1, test_user_token): # Test semantic search using `text_query` text_query = "NLP" - services = await api.search_services(text_query=text_query, limit=3) + services = await api.search_services(query=text_query, limit=3) assert isinstance(services, list) assert len(services) <= 3 # The top hit should be the service with "natural language processing" in the `docs` field assert "natural language processing" in services[0]["docs"] assert services[0]["score"] < services[1]["score"] - results = await api.search_services(text_query=text_query, limit=3, pagination=True) + results = await api.search_services(query=text_query, limit=3, pagination=True) assert results["total"] >= 1 embedding = np.ones(384).astype(np.float32) @@ -88,7 +88,7 @@ async def test_service_search(fastapi_server_redis_1, test_user_token): ) # Test vector query with the exact embedding - services = await api.search_services(vector_query=embedding, limit=3) + services = await api.search_services(query=embedding, limit=3) assert isinstance(services, list) assert len(services) <= 3 assert "service-88" in services[0]["id"] @@ -103,9 +103,7 @@ async def test_service_search(fastapi_server_redis_1, test_user_token): # Test hybrid search (text query + filters) filters = {"type": "my-type"} text_query = "genomics workflows" - services = await api.search_services( - text_query=text_query, filters=filters, limit=3 - ) + services = await api.search_services(query=text_query, filters=filters, limit=3) assert isinstance(services, list) assert all(service["type"] == "my-type" for service in services) # The top hit should be the service with "genomics" in the `docs` field @@ -114,7 +112,7 @@ async def test_service_search(fastapi_server_redis_1, test_user_token): # Test hybrid search (embedding + filters) filters = {"type": "my-type"} services = await api.search_services( - vector_query=np.random.rand(384), filters=filters, limit=3 + query=np.random.rand(384), filters=filters, limit=3 ) assert isinstance(services, list) assert all(service["type"] == "my-type" for service in services) diff --git a/tests/test_vectors.py b/tests/test_vectors.py index 3593235b..a1c1fee0 100644 --- a/tests/test_vectors.py +++ b/tests/test_vectors.py @@ -5,14 +5,14 @@ from hypha_rpc import connect_to_server -from . import SERVER_URL, find_item +from . import SERVER_URL_REDIS_1 # All test coroutines will be treated as marked. pytestmark = pytest.mark.asyncio async def test_artifact_vector_collection( - minio_server, fastapi_server, test_user_token + minio_server, fastapi_server_redis_1, test_user_token ): """Test vector-related functions within a vector-collection artifact.""" @@ -20,7 +20,7 @@ async def test_artifact_vector_collection( api = await connect_to_server( { "name": "test deploy client", - "server_url": SERVER_URL, + "server_url": SERVER_URL_REDIS_1, "token": test_user_token, } ) @@ -32,11 +32,24 @@ async def test_artifact_vector_collection( "description": "A test vector collection", } vector_collection_config = { - "vectors_config": { - "size": 384, - "distance": "Cosine", + "vector_fields": [ + { + "type": "VECTOR", + "name": "vector", + "algorithm": "FLAT", + "attributes": { + "TYPE": "FLOAT32", + "DIM": 384, + "DISTANCE_METRIC": "COSINE", + }, + }, + {"type": "TEXT", "name": "text"}, + {"type": "TAG", "name": "label"}, + {"type": "NUMERIC", "name": "rand_number"}, + ], + "embedding_models": { + "vector": "fastembed:BAAI/bge-small-en-v1.5", }, - "embedding_model": "fastembed:BAAI/bge-small-en-v1.5", } vector_collection = await artifact_manager.create( type="vector-collection", @@ -47,27 +60,21 @@ async def test_artifact_vector_collection( vectors = [ { "vector": [random.random() for _ in range(384)], - "payload": { - "text": "This is a test document.", - "label": "doc1", - "rand_number": random.randint(0, 10), - }, + "text": "This is a test document.", + "label": "doc1", + "rand_number": random.randint(0, 10), }, { "vector": np.random.rand(384), - "payload": { - "text": "Another document.", - "label": "doc2", - "rand_number": random.randint(0, 10), - }, + "text": "Another document.", + "label": "doc2", + "rand_number": random.randint(0, 10), }, { "vector": np.random.rand(384), - "payload": { - "text": "Yet another document.", - "label": "doc3", - "rand_number": random.randint(0, 10), - }, + "text": "Yet another document.", + "label": "doc3", + "rand_number": random.randint(0, 10), }, ] await artifact_manager.add_vectors( @@ -82,57 +89,48 @@ async def test_artifact_vector_collection( query_vector = [random.random() for _ in range(384)] search_results = await artifact_manager.search_vectors( artifact_id=vector_collection.id, - query_vector=query_vector, + query={"vector": query_vector}, limit=2, ) assert len(search_results) <= 2 results = await artifact_manager.search_vectors( artifact_id=vector_collection.id, - query_vector=query_vector, - limit=2, + query={"vector": query_vector}, + limit=3, pagination=True, ) assert results["total"] == 3 - query_filter = { - "should": None, - "min_should": None, - "must": [ - { - "key": "rand_number", - "match": None, - "range": {"lt": None, "gt": None, "gte": 3.0, "lte": None}, - "geo_bounding_box": None, - "geo_radius": None, - "geo_polygon": None, - "values_count": None, - } - ], - "must_not": None, - } + search_results = await artifact_manager.search_vectors( + artifact_id=vector_collection.id, + filters={"rand_number": [-2, -1]}, + query={"vector": np.random.rand(384)}, + limit=2, + ) + assert len(search_results) == 0 search_results = await artifact_manager.search_vectors( artifact_id=vector_collection.id, - query_filter=query_filter, - query_vector=np.random.rand(384), + filters={"rand_number": [0, 10]}, + query={"vector": np.random.rand(384)}, limit=2, ) - assert len(search_results) <= 2 + assert len(search_results) > 0 # Search for vectors by text - documents = [ - {"text": "This is a test document.", "label": "doc1"}, - {"text": "Another test document.", "label": "doc2"}, + vectors = [ + {"vector": "This is a test document.", "label": "doc1"}, + {"vector": "Another test document.", "label": "doc2"}, ] - await artifact_manager.add_documents( + await artifact_manager.add_vectors( artifact_id=vector_collection.id, - documents=documents, + vectors=vectors, ) - text_query = "test document" + text_search_results = await artifact_manager.search_vectors( artifact_id=vector_collection.id, - query_text=text_query, + query={"vector": "test document"}, limit=2, ) assert len(text_search_results) <= 2 @@ -142,7 +140,7 @@ async def test_artifact_vector_collection( artifact_id=vector_collection.id, id=text_search_results[0]["id"], ) - assert retrieved_vector.id == text_search_results[0]["id"] + assert "label" in retrieved_vector # List vectors in the collection vector_list = await artifact_manager.list_vectors(