Skip to content

Commit

Permalink
Use qdrant server for testing, add house keeping tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Nov 21, 2024
1 parent f5a7203 commit 6ffb36c
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 135 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,16 @@ jobs:
redis:
image: redis/redis-stack:7.2.0-v13
ports:
- 6333:6379
- 6338:6379
options: --health-cmd "redis-cli ping" --health-interval 10s --health-timeout 5s --health-retries 5
qdrant:
image: qdrant/qdrant:latest
ports:
- 6333:6333
- 6334:6334
options: --health-cmd "bash -c ':> /dev/tcp/127.0.0.1/6333' || exit 1" --health-interval 10s --health-timeout 5s --health-retries 5
volumes:
- ./qdrant_storage:/qdrant/storage:z
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion hypha/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.20.39.post13"
"version": "0.20.39.post14"
}
41 changes: 23 additions & 18 deletions hypha/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ async def create(

if alias:
alias = alias.strip()
assert "^" not in alias, "Alias cannot contain the '^' character."
if "/" in alias:
ws, alias = alias.split("/")
if workspace and ws != workspace:
Expand Down Expand Up @@ -1068,7 +1069,7 @@ async def create(

vectors_config = config.get("vectors_config", {})
await self._vectordb_client.create_collection(
collection_name=f"{new_artifact.workspace}/{new_artifact.alias}",
collection_name=f"{new_artifact.workspace}^{new_artifact.alias}",
vectors_config=VectorParams(
size=vectors_config.get("size", 128),
distance=Distance(vectors_config.get("distance", "Cosine")),
Expand Down Expand Up @@ -1277,7 +1278,7 @@ async def read(
artifact_data["config"] = artifact_data.get("config", {})
artifact_data["config"]["vector_count"] = (
await self._vectordb_client.count(
collection_name=f"{artifact.workspace}/{artifact.alias}"
collection_name=f"{artifact.workspace}^{artifact.alias}"
)
).count

Expand Down Expand Up @@ -1433,7 +1434,7 @@ async def delete(
self._vectordb_client
), "The server is not configured to use a VectorDB client."
await self._vectordb_client.delete_collection(
collection_name=f"{artifact.workspace}/{artifact.alias}"
collection_name=f"{artifact.workspace}^{artifact.alias}"
)

s3_config = self._get_s3_config(artifact, parent_artifact)
Expand Down Expand Up @@ -1522,7 +1523,7 @@ async def add_vectors(
p["id"] = p.get("id") or str(uuid.uuid4())
_points.append(PointStruct(**p))
await self._vectordb_client.upsert(
collection_name=f"{artifact.workspace}/{artifact.alias}",
collection_name=f"{artifact.workspace}^{artifact.alias}",
points=_points,
)
# TODO: Update file_count
Expand All @@ -1536,29 +1537,33 @@ async def _embed_texts(self, config, texts):
embedding_model = config.get("embedding_model") # "text-embedding-3-small"
assert (
embedding_model
), "Embedding model must be provided, e.g. 'fastembed', 'text-embedding-3-small' for openai or 'all-minilm' for ollama."
), "Embedding model must be provided, e.g. 'fastembed:BAAI/bge-small-en-v1.5', 'openai:text-embedding-3-small' for openai embeddings."
if embedding_model.startswith("fastembed"):
from fastembed import TextEmbedding

if ":" in embedding_model:
model_name = embedding_model.split(":")[-1]
else:
model_name = "BAAI/bge-small-en-v1.5"
assert ":" in embedding_model, "Embedding model must be provided."
model_name = embedding_model.split(":")[-1]
embedding_model = TextEmbedding(
model_name=model_name, cache_dir=self._cache_dir
)
loop = asyncio.get_event_loop()
embeddings = list(
await loop.run_in_executor(None, embedding_model.embed, texts)
)
else:
elif embedding_model.startswith("openai"):
assert (
self._openai_client
), "The server is not configured to use an OpenAI client."
assert ":" in embedding_model, "Embedding model must be provided."
embedding_model = embedding_model.split(":")[-1]
result = await self._openai_client.embeddings.create(
input=texts, model=embedding_model
)
embeddings = [data.embedding for data in result.data]
else:
raise ValueError(
f"Unsupported embedding model: {embedding_model}, supported models: 'fastembed:*', 'openai:*'"
)
return embeddings

async def add_documents(
Expand Down Expand Up @@ -1593,7 +1598,7 @@ async def add_documents(
for embedding, doc in zip(embeddings, documents)
]
await self._vectordb_client.upsert(
collection_name=f"{artifact.workspace}/{artifact.alias}",
collection_name=f"{artifact.workspace}^{artifact.alias}",
points=points,
)
logger.info(f"Upserted documents to artifact with ID: {artifact_id}")
Expand Down Expand Up @@ -1632,7 +1637,7 @@ async def search_by_vector(
if query_filter:
query_filter = Filter.model_validate(query_filter)
search_results = await self._vectordb_client.search(
collection_name=f"{artifact.workspace}/{artifact.alias}",
collection_name=f"{artifact.workspace}^{artifact.alias}",
query_vector=query_vector,
query_filter=query_filter,
limit=limit,
Expand All @@ -1642,7 +1647,7 @@ async def search_by_vector(
)
if pagination:
count = await self._vectordb_client.count(
collection_name=f"{artifact.workspace}/{artifact.alias}"
collection_name=f"{artifact.workspace}^{artifact.alias}"
)
return {
"total": count.count,
Expand Down Expand Up @@ -1684,7 +1689,7 @@ async def search_by_text(
if query_filter:
query_filter = Filter.model_validate(query_filter)
search_results = await self._vectordb_client.search(
collection_name=f"{artifact.workspace}/{artifact.alias}",
collection_name=f"{artifact.workspace}^{artifact.alias}",
query_vector=query_vector,
query_filter=query_filter,
limit=limit,
Expand All @@ -1694,7 +1699,7 @@ async def search_by_text(
)
if pagination:
count = await self._vectordb_client.count(
collection_name=f"{artifact.workspace}/{artifact.alias}"
collection_name=f"{artifact.workspace}^{artifact.alias}"
)
return {
"total": count.count,
Expand Down Expand Up @@ -1728,7 +1733,7 @@ async def remove_vectors(
self._vectordb_client
), "The server is not configured to use a VectorDB client."
await self._vectordb_client.delete(
collection_name=f"{artifact.workspace}/{artifact.alias}",
collection_name=f"{artifact.workspace}^{artifact.alias}",
points_selector=ids,
)
logger.info(f"Removed vectors from artifact with ID: {artifact_id}")
Expand Down Expand Up @@ -1757,7 +1762,7 @@ async def get_vector(
self._vectordb_client
), "The server is not configured to use a VectorDB client."
points = await self._vectordb_client.retrieve(
collection_name=f"{artifact.workspace}/{artifact.alias}",
collection_name=f"{artifact.workspace}^{artifact.alias}",
ids=[id],
with_payload=True,
with_vectors=True,
Expand Down Expand Up @@ -1797,7 +1802,7 @@ async def list_vectors(
if query_filter:
query_filter = Filter.model_validate(query_filter)
points, _ = await self._vectordb_client.scroll(
collection_name=f"{artifact.workspace}/{artifact.alias}",
collection_name=f"{artifact.workspace}^{artifact.alias}",
scroll_filter=query_filter,
limit=limit,
offset=offset,
Expand Down
47 changes: 47 additions & 0 deletions hypha/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hypha_rpc import RPC
from hypha_rpc.utils.schema import schema_method
from starlette.routing import Mount
from pydantic.fields import Field

from hypha import __version__
from hypha.core import (
Expand Down Expand Up @@ -278,6 +279,26 @@ async def _run_startup_functions(self, startup_functions):
# Stop the entire event loop if an error occurs
asyncio.get_running_loop().stop()

async def housekeeping(self):
"""Perform housekeeping tasks."""
# Perform housekeeping tasks
# Start the housekeeping task after 2 minutes
logger.info("Starting housekeeping task in 2 minutes...")
await asyncio.sleep(120)
while True:
try:
logger.info("Running housekeeping task...")
async with self.get_workspace_interface(
self._root_user, "ws-user-root", client_id="housekeeping"
) as api:
# admin = await api.get_service("admin-utils")
workspaces = await api.list_workspaces()
for workspace in workspaces:
await api.cleanup(workspace.id)
await asyncio.sleep(3600)
except Exception as e:
logger.exception(f"Error in housekeeping: {e}")

async def upgrade(self):
"""Upgrade the store."""
current_version = await self._redis.get("hypha_version")
Expand Down Expand Up @@ -503,6 +524,8 @@ async def init(self, reset_redis, startup_functions=None):
logger.info("Server initialized with server id: %s", self._server_id)
logger.info("Currently connected hypha servers: %s", servers)

asyncio.create_task(self.housekeeping())

async def _register_root_services(self):
"""Register root services."""
self._root_workspace_interface = await self.get_workspace_interface(
Expand All @@ -522,9 +545,33 @@ async def _register_root_services(self):
"list_servers": self.list_servers,
"kickout_client": self.kickout_client,
"list_workspaces": self.list_all_workspaces,
"list_vector_collections": self.list_vector_collections,
"delete_vector_collection": self.delete_vector_collection,
}
)

@schema_method
async def list_vector_collections(self):
"""List all vector collections."""
if self._vectordb_client is None:
raise Exception("Vector database is not configured")
# get_collections
collections = await self._vectordb_client.get_collections()
return collections

@schema_method
async def delete_vector_collection(
self,
collection_name: str = Field(
..., description="The name of the vector collection to delete."
),
):
"""Delete a vector collection."""
if self._vectordb_client is None:
raise Exception("Vector database is not configured")
# delete_collection
await self._vectordb_client.delete_collection(collection_name)

@schema_method
async def list_servers(self):
"""List all servers."""
Expand Down
6 changes: 5 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MINIO_SERVER_URL_PUBLIC = f"http://localhost:{MINIO_PORT}"
MINIO_ROOT_USER = "minio"
MINIO_ROOT_PASSWORD = str(uuid.uuid4())
REDIS_PORT = 6333
REDIS_PORT = 6338

POSTGRES_PORT = 5432
POSTGRES_USER = "postgres"
Expand All @@ -28,6 +28,10 @@
POSTGRES_URI = f"postgresql+asyncpg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@localhost:{POSTGRES_PORT}/{POSTGRES_DB}"


QDRANT_PORT = 6333
QDRANT_URL = "http://127.0.0.1:6333"


def find_item(items, key, value):
"""Find an item with key or attributes in an object list."""
filtered = [
Expand Down
Loading

0 comments on commit 6ffb36c

Please sign in to comment.