diff --git a/chainfury/components/qdrant/__init__.py b/chainfury/components/qdrant/__init__.py deleted file mode 100644 index dd29d1f..0000000 --- a/chainfury/components/qdrant/__init__.py +++ /dev/null @@ -1,353 +0,0 @@ -# Copyright © 2023- Frello Technology Private Limited - -from uuid import uuid4 -from functools import lru_cache -from typing import List, Dict, Tuple, Optional, Union - -try: - from qdrant_client import models, QdrantClient - - QDRANT_CLIENT_INSTALLED = True -except ImportError: - QDRANT_CLIENT_INSTALLED = False - -from chainfury import Secret, memory_registry, logger -from chainfury.components.const import Env, ComponentMissingError - -# https://qdrant.tech/documentation/concepts/filtering -# Must : "must" : AND -# Should : "should" : OR -# Must Not: "must_not" : NOT -# Match: = -# Match Any: IN -# Match Except: NOT IN - - -@lru_cache(maxsize=1) -def _get_qdrant_client( - qdrant_url: Secret = Secret(), qdrant_api_key: Secret = Secret() -): - """Create a qdrant client and cache it - - Args: - qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. - qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. - - Returns: - qdrant_client.QdrantClient: qdrant client - """ - qdrant_url = Secret(Env.QDRANT_API_URL(qdrant_url.value)).value # type: ignore - qdrant_api_key = Secret(Env.QDRANT_API_KEY(qdrant_api_key.value)).value # type: ignore - if not qdrant_url: - raise Exception( - "Qdrant URL is not set. Please pass `qdrant_url` or env var `QDRANT_API_URL=`" - ) - if not qdrant_api_key: - raise Exception( - "Qdrant API KEY is not set. Please pass `qdrant_api_key` or env var `QDRANT_API_KEY=`" - ) - logger.info("Creating Qdrant client") - return QdrantClient(url=qdrant_url, api_key=qdrant_api_key) # type: ignore - - -def qdrant_write( - embeddings: List[List[float]], - collection_name: str, - qdrant_url: Secret = Secret(""), - qdrant_api_key: Secret = Secret(""), - extra_payload: List[Dict[str, str]] = [], - wait: bool = True, - create_if_not_present: bool = True, - distance: str = "cosine", -) -> Tuple[str, Optional[Exception]]: - """ - Write to the Qdrant DB using the Qdrant client. In order to use this, access via the `memory_registry`: - - Example: - >>> from chainfury import memory_registry - >>> mem = memory_registry.get_write("qdrant") - >>> sentence = "C.P. Cavafy is widely considered the most distinguished Greek poet of the 20th century." - >>> out, err = mem( - { - "items": [sentence], - "extra_payload": [ - {"data": sentence}, - ], - "collection_name": "my_test_collection", - "embedding_model": "openai-embedding", - "create_if_not_present": True, - } - ) - >>> if err: - print("TRACE:", out) - else: - print(out) - - Args: - embeddings (List[List[float]]): list of embeddings - collection_name (str): collection name - qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. - qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. - extra_payload (List[Dict[str, str]], optional): extra payload. Defaults to []. - wait (bool, optional): wait for the response. Defaults to True. - create_if_not_present (bool, optional): create collection if not present. Defaults to True. - distance (str, optional): distance metric. Defaults to "cosine". - - Returns: - Tuple[str, Optional[Exception]]: status and error - """ - # client check - if not QDRANT_CLIENT_INSTALLED: - raise ComponentMissingError( - "Qdrant client is not installed. Please install it with `pip install qdrant-client`" - ) - - # checks - if not (len(embeddings) and len(embeddings[0]) and type(embeddings[0][0]) == float): - raise Exception("Embeddings should be a list of lists of floats") - if extra_payload and len(extra_payload) != len(embeddings): - raise Exception("Length of extra_payload should be equal to embeddings") - - client: QdrantClient = _get_qdrant_client(qdrant_url, qdrant_api_key) # type: ignore - - # next we create points and upsert them into the DB - points = [] - for i, embedding in enumerate(embeddings): - payload = {} - if extra_payload: - payload = extra_payload[i] - points.append(models.PointStruct(id=str(uuid4()), payload=payload, vector=embedding)) # type: ignore - batch = models.Batch( # type: ignore - ids=[point.id for point in points], - vectors=[point.vector for point in points], - payloads=[point.payload for point in points], - ) - - def _insert(): - try: - result = client.upsert( - collection_name=collection_name, - points=batch, - wait=wait, - ) - except Exception as e: - return e.content, e # type: ignore - return result.status.lower(), None - - status, err = _insert() - if err and err.status_code == 404 and create_if_not_present: # type: ignore - collection = client.recreate_collection( - collection_name=collection_name, - vectors_config=models.VectorParams( # type: ignore - size=len(embeddings[0]), - distance=getattr(models.Distance, distance.upper()), # type: ignore - ), - ) - logger.info(f"Created collection {collection}") - status, err = _insert() - return status, err - - -memory_registry.register_write( - component_name="qdrant", - fn=qdrant_write, - outputs={"status": 0}, - vector_key="embeddings", - description="Write to the Qdrant DB using the Qdrant client", -) - - -def qdrant_read( - embeddings: List[List[float]], - collection_name: str, - cutoff_score: float = 0.0, - top: int = 5, - limit: int = 0, - offset: int = 0, - filters: Dict[str, Dict[str, str]] = {}, - qdrant_url: Secret = Secret(""), - qdrant_api_key: Secret = Secret(""), - qdrant_search_hnsw_ef: int = 0, - qdrant_search_exact: bool = False, - batch_search: bool = False, -) -> Tuple[Dict[str, List[Dict[str, Union[float, int]]]], Optional[Exception]]: - """ - Read from the Qdrant DB using the Qdrant client. In order to use this access via the `memory_registry`: - - Example: - >>> from chainfury import memory_registry - >>> mem = memory_registry.get_read("qdrant") - >>> sentence = "Who was the Cafavy?" - >>> out, err = mem( - { - "items": [sentence], - "collection_name": "my_test_collection", - "embedding_model": "openai-embedding" - } - ) - >>> if err: - print("TRACE:", out) - else: - print(out) - - Note: - `batch_search` is not implemented yet. There's some issues from the `qdrant_client` library. - - Args: - embeddings (List[List[float]]): list of embeddings - collection_name (str): collection name - cutoff_score (float, optional): cutoff score. Defaults to 0.0. - limit (int, optional): limit. Defaults to 3. - offset (int, optional): offset. Defaults to 0. - qdrant_url (Secret, optional): qdrant url or set env var `QDRANT_API_URL`. - qdrant_api_key (Secret, optional): qdrant api key or set env var `QDRANT_API_KEY`. - qdrant_search_hnsw_ef (int, optional): qdrant search beam size, the larger the beam size the more accurate the search, - if not set uses default value. - qdrant_search_exact (bool, optional): qdrant search exact. Defaults to False. - batch_search (bool, optional): batch search. Defaults to False. - - Returns: - Tuple[List[Dict[str, Union[float, int]]], Optional[Exception]]: list of results and error - """ - # client check - if not QDRANT_CLIENT_INSTALLED: - raise ComponentMissingError( - "Qdrant client is not installed. Please install it with `pip install qdrant-client`" - ) - - # checks - if not (len(embeddings) and len(embeddings[0]) and type(embeddings[0][0]) == float): - raise Exception("Embeddings should be a list of lists of floats") - if batch_search: - raise NotImplementedError("Batch search is not implemented yet") - if not batch_search and len(embeddings) > 1: - raise Exception( - "Batch search is not enabled, but multiple embeddings are passed" - ) - if not top and not limit: - raise Exception("Either top or limit should be set") - - client: QdrantClient = _get_qdrant_client(qdrant_url, qdrant_api_key) # type: ignore - - search_params = models.SearchParams() # type: ignore - if qdrant_search_hnsw_ef: - search_params.hnsw_ef = qdrant_search_hnsw_ef - if qdrant_search_exact: - search_params.exact = qdrant_search_exact - - if batch_search: - # this is not implemented, this fails when we try to pass a list of vectors - search_queries = [ - models.SearchRequest( - vector=x, limit=limit, offset=offset, params=search_params - ) - for x in embeddings - ] - out = client.search_batch( - collection_name=collection_name, - requests=search_queries, - ) - res = [[_x.dict(skip_defaults=False) for _x in x] for x in out] # type: ignore - - query_filter = None - if filters: - query_filter = models.Filter(**filters) # type: ignore - - out = client.search( - collection_name=collection_name, - query_vector=embeddings[0], - query_filter=query_filter, - limit=max(limit, top), - offset=offset, - search_params=search_params, - ) - out = [x for x in out if x.score > cutoff_score] - res = [_x.dict(skip_defaults=False) for _x in out] # type: ignore - return {"data": res}, None - - -memory_registry.register_read( - component_name="qdrant", - fn=qdrant_read, - outputs={"items": 0}, - vector_key="embeddings", - description="Function to read from the Qdrant DB using the Qdrant client", -) - - -# helper functions - - -def recreate_collection(collection_name: str, embedding_dim: int) -> bool: - """ - Deletes and recreates a collection - - Note: - This will delete all the data in the collection, use with caution - - Args: - collection_name (str): collection name - embedding_dim (int): embedding dimension - - Returns: - bool: success - """ - client: QdrantClient = _get_qdrant_client() # type: ignore - return client.recreate_collection( - collection_name=collection_name, - vectors_config=models.VectorParams( # type: ignore - size=embedding_dim, - distance=models.Distance.COSINE, # type: ignore - ), - optimizers_config=models.OptimizersConfigDiff( # type: ignore - indexing_threshold=0, - ), - ) - - -def enable_indexing(collection_name: str, indexing_threshold: int = 20000) -> bool: - """ - Enable indexing for a collection, use this in conjunction with `disable_indexing`. Read more - `here `_. - - Example: - >>> from chainfury.components.qdrant import enable_indexing, disable_indexing, qdrant_write - >>> disable_indexing("my_collection") - >>> qdrant_write([[1, 2, 3] for _ in range(100)], "my_collection") - >>> enable_indexing("my_collection") - - - Args: - collection_name (str): collection name - indexing_threshold (int, optional): indexing threshold. Defaults to 20000. - - Returns: - bool: success - """ - client: QdrantClient = _get_qdrant_client() # type: ignore - return client.update_collection( - collection_name=collection_name, - optimizer_config=models.OptimizersConfigDiff( # type: ignore - indexing_threshold=indexing_threshold, - ), - ) - - -def disable_indexing(collection_name: str): - """ - Disable indexing for a collection, use this in conjunction with `enable_indexing`. Read more - `here `_. - - Args: - collection_name (str): collection name - - Returns: - bool: success - """ - client: QdrantClient = _get_qdrant_client() # type: ignore - return client.update_collection( - collection_name=collection_name, - optimizer_config=models.OptimizersConfigDiff( # type: ignore - indexing_threshold=0, - ), - ) diff --git a/pyproject.toml b/pyproject.toml index 05bedbe..2e44725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,15 +18,11 @@ requests = "^2.31.0" python-dotenv = "1.0.0" urllib3 = ">=1.26.18" tabulate = "0.9.0" -"cryptography" = ">=41.0.6" -stability-sdk = { version = "0.8.3", optional = true } -qdrant-client = { version = "1.5.4", optional = true } +cryptography = ">=41.0.6" boto3 = { version = "1.29.6", optional = true } [tool.poetry.extras] -all = ["stability-sdk", "qdrant-client", "boto3"] -stability = ["stability-sdk"] -qdrant = ["qdrant-client"] +all = ["boto3"] [tool.poetry.group.dev.dependencies] sphinx = "7.2.5" diff --git a/server/chainfury_server/api/chains.py b/server/chainfury_server/api/chains.py index 16ac631..7575ab1 100644 --- a/server/chainfury_server/api/chains.py +++ b/server/chainfury_server/api/chains.py @@ -16,8 +16,6 @@ def create_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], chatbot_data: T.ApiCreateChainRequest, db: Session = Depends(DB.fastapi_db_session), @@ -27,8 +25,7 @@ def create_chain( # validate chatbot if not chatbot_data.name: - resp.status_code = 400 - return T.ApiResponse(message="Name not specified") + raise HTTPException(status_code=400, detail="Name not specified") if chatbot_data.dag: for n in chatbot_data.dag.nodes: if len(n.id) > Env.CFS_MAXLEN_CF_NODE(): @@ -55,8 +52,6 @@ def create_chain( def get_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], id: str, tag_id: str = "", @@ -75,16 +70,13 @@ def get_chain( filters.append(DB.ChatBot.tag_id == tag_id) chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") # return return chatbot.to_ApiChain() def update_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], id: str, chatbot_data: T.ApiChain, @@ -96,14 +88,14 @@ def update_chain( # validate chatbot update if not len(chatbot_data.update_keys): - resp.status_code = 400 - return T.ApiResponse(message="No keys to update") + raise HTTPException(status_code=400, detail="No keys to update") unq_keys = set(chatbot_data.update_keys) valid_keys = {"name", "description", "dag"} if not unq_keys.issubset(valid_keys): - resp.status_code = 400 - return T.ApiResponse(message=f"Invalid keys {unq_keys.difference(valid_keys)}") + raise HTTPException( + status_code=400, detail=f"Invalid keys {unq_keys.difference(valid_keys)}" + ) # DB Call filters = [ @@ -115,8 +107,7 @@ def update_chain( filters.append(DB.ChatBot.tag_id == tag_id) chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") for field in unq_keys: if field == "name": @@ -133,8 +124,6 @@ def update_chain( def delete_chain( - req: Request, - resp: Response, token: Annotated[str, Header()], id: str, tag_id: str = "", @@ -153,8 +142,7 @@ def delete_chain( filters.append(DB.ChatBot.tag_id == tag_id) chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") chatbot.deleted_at = datetime.now() db.commit() @@ -163,8 +151,6 @@ def delete_chain( def list_chains( - req: Request, - resp: Response, token: Annotated[str, Header()], skip: int = 0, limit: int = 10, @@ -190,8 +176,6 @@ def list_chains( def run_chain( - req: Request, - resp: Response, id: str, token: Annotated[str, Header()], prompt: T.ApiPromptBody, @@ -234,8 +218,7 @@ def run_chain( ] chatbot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore if not chatbot: - resp.status_code = 404 - return T.ApiResponse(message="ChatBot not found") + raise HTTPException(status_code=404, detail="Chain not found") # call the engine engine = FuryEngine() @@ -292,8 +275,6 @@ def _get_streaming_response(result): def get_chain_metrics( - req: Request, - resp: Response, id: str, token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), diff --git a/server/chainfury_server/api/user.py b/server/chainfury_server/api/user.py index 633124c..77aacc8 100644 --- a/server/chainfury_server/api/user.py +++ b/server/chainfury_server/api/user.py @@ -85,8 +85,8 @@ def change_password( def create_secret( - token: Annotated[str, Header()], inputs: T.ApiToken, + token: Annotated[str, Header()], db: Session = Depends(DB.fastapi_db_session), ) -> T.ApiResponse: # validate user