From c4fe45aba87b543b05de8df7c89ea7b2381a3e1e Mon Sep 17 00:00:00 2001 From: loopsaaage Date: Tue, 10 Sep 2024 15:12:22 +0800 Subject: [PATCH] Remove unused api system and support qdrant valid --- infra_ai_service/api/ai_enhance/embedding.py | 79 ++++++++++++++++--- .../api/ai_enhance/text_process.py | 17 ++-- .../api/ai_enhance/vector_search.py | 48 +++++++++-- infra_ai_service/api/router.py | 8 +- infra_ai_service/api/system/__init__.py | 0 infra_ai_service/api/system/views.py | 12 --- 6 files changed, 120 insertions(+), 44 deletions(-) delete mode 100644 infra_ai_service/api/system/__init__.py delete mode 100644 infra_ai_service/api/system/views.py diff --git a/infra_ai_service/api/ai_enhance/embedding.py b/infra_ai_service/api/ai_enhance/embedding.py index 217f71b..769f399 100644 --- a/infra_ai_service/api/ai_enhance/embedding.py +++ b/infra_ai_service/api/ai_enhance/embedding.py @@ -1,33 +1,88 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel +from typing import List from fastembed.embedding import DefaultEmbedding +from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams, PointStruct +import numpy as np +import uuid router = APIRouter() -class EmbeddingRequest(BaseModel): - text: str +class TextInput(BaseModel): + content: str -class EmbeddingResponse(BaseModel): - embedding: list[float] +class EmbeddingOutput(BaseModel): + id: str + embedding: List[float] -# Load a FastEmbed model +# 初始化FastEmbed模型 fastembed_model = DefaultEmbedding() +# 初始化Qdrant客户端 +qdrant_client = QdrantClient(url="http://localhost:6333") +collection_name = 'test_simi' -@router.post("/embed_text/", response_model=EmbeddingResponse) -async def embed_text(request: EmbeddingRequest) -> EmbeddingResponse: +# 检查集合是否存在,如果不存在则创建 +try: + qdrant_client.get_collection(collection_name) + print(f"Collection {collection_name} already exists") +except HTTPException as e: + # 获取向量维度 + sample_embedding = next(fastembed_model.embed(["Sample text"])) + vector_size = len(sample_embedding) + + # 创建集合 + qdrant_client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + print(f"Created collection: {collection_name}") + + +@router.post("/embed/", response_model=EmbeddingOutput) +async def embed_text(input_data: TextInput): try: # 生成嵌入 - embeddings = list(fastembed_model.embed([request.text])) + embeddings = list(fastembed_model.embed([input_data.content])) if not embeddings: raise ValueError("Failed to generate embedding") - # 将numpy数组转换为普通Python列表 - embedding_list = embeddings[0].tolist() + embedding_vector = embeddings[0] + + # 生成唯一ID + point_id = str(uuid.uuid4()) + + # 将嵌入存储到Qdrant + qdrant_client.upsert( + collection_name=collection_name, + points=[ + PointStruct( + id=point_id, + vector=embedding_vector.tolist(), + payload={"text": input_data.content} + ) + ] + ) - return EmbeddingResponse(embedding=embedding_list) + return EmbeddingOutput(id=point_id, embedding=embedding_vector.tolist()) + except Exception as e: + raise HTTPException(status_code=400, + detail=f"Error processing embedding: {str(e)}") + + +@router.get("/status/") +async def get_collection_status(): + try: + collection_info = qdrant_client.get_collection(collection_name) + return { + "collection_name": collection_name, + "vectors_count": collection_info.vectors_count, + "status": "ready" if collection_info.status == "green" else "not ready" + } except Exception as e: - raise HTTPException(status_code=400, detail=f"Error generating embedding: {str(e)}") + raise HTTPException(status_code=400, + detail=f"Error getting collection status: {str(e)}") diff --git a/infra_ai_service/api/ai_enhance/text_process.py b/infra_ai_service/api/ai_enhance/text_process.py index ad146e8..d2082cb 100644 --- a/infra_ai_service/api/ai_enhance/text_process.py +++ b/infra_ai_service/api/ai_enhance/text_process.py @@ -5,23 +5,22 @@ router = APIRouter() -class TextRequest(BaseModel): +class TextInput(BaseModel): content: str -class TextResponse(BaseModel): - processed_content: str +class TextOutput(BaseModel): + modified_content: str def clean_text(text: str) -> str: - cleaned_text = re.sub(r'[{}[\]()@.#\\_\':\/-]', '', text) - return cleaned_text + return re.sub(r'[{}[\]()@.#\\_\':\/-]', '', text) -@router.post("/process_text/", response_model=TextResponse) -async def process_text(request: TextRequest) -> TextResponse: +@router.post("/process/", response_model=TextOutput) +async def process_text(input_data: TextInput): try: - processed_text = clean_text(request.content) - return TextResponse(processed_content=processed_text) + modified_text = clean_text(input_data.content) + return TextOutput(modified_content=modified_text) except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing text: {str(e)}") diff --git a/infra_ai_service/api/ai_enhance/vector_search.py b/infra_ai_service/api/ai_enhance/vector_search.py index e14a171..3017243 100644 --- a/infra_ai_service/api/ai_enhance/vector_search.py +++ b/infra_ai_service/api/ai_enhance/vector_search.py @@ -1,12 +1,19 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel +from typing import List from fastembed.embedding import DefaultEmbedding from qdrant_client import QdrantClient +from qdrant_client.http.models import Distance, VectorParams +import logging router = APIRouter() +# 设置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) -class SearchRequest(BaseModel): + +class SearchInput(BaseModel): query_text: str top_n: int = 5 score_threshold: float = 0.7 @@ -17,15 +24,39 @@ class SearchResult(BaseModel): score: float +class SearchOutput(BaseModel): + results: List[SearchResult] + + # 初始化FastEmbed模型和Qdrant客户端 fastembed_model = DefaultEmbedding() qdrant_client = QdrantClient(url="http://localhost:6333") collection_name = 'test_simi' +# 检查集合是否存在,如果不存在则创建 +try: + qdrant_client.get_collection(collection_name) + print(f"Collection {collection_name} already exists") +except HTTPException as e: + # 获取向量维度 + sample_embedding = next(fastembed_model.embed(["Sample text"])) + vector_size = len(sample_embedding) -@router.post("/search_vectors/", response_model=SearchResult) -async def search_vectors(input_data: SearchRequest) -> SearchResult: + # 创建集合 + qdrant_client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), + ) + print(f"Created collection: {collection_name}") + + +@router.post("/query/", response_model=SearchOutput) +async def vector_search(input_data: SearchInput): try: + # 检查集合是否存在 + if not qdrant_client.get_collection(collection_name): + raise ValueError(f"Collection {collection_name} does not exist") + # 生成查询文本的嵌入 query_vector = list(fastembed_model.embed([input_data.query_text])) if not query_vector: @@ -41,10 +72,15 @@ async def search_vectors(input_data: SearchRequest) -> SearchResult: # 转换搜索结果为输出格式 results = [ - SearchResult(id=str(result.id), score=result.score) + SearchResult( + id=str(result.id), + score=result.score, + text=result.payload.get('text', 'No text available') + ) for result in search_results ] - return SearchResult(results=results) + return SearchOutput(results=results) except Exception as e: - raise HTTPException(status_code=400, detail=f"Error performing vector search: {str(e)}") + logger.error(f"Error in vector search: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Error performing vector search: {str(e)}") diff --git a/infra_ai_service/api/router.py b/infra_ai_service/api/router.py index 8777264..537e1a0 100644 --- a/infra_ai_service/api/router.py +++ b/infra_ai_service/api/router.py @@ -1,12 +1,10 @@ from fastapi.routing import APIRouter -from infra_ai_service.api.system.views import router as system_router from infra_ai_service.api.ai_enhance.text_process import router as text_process_router from infra_ai_service.api.ai_enhance.embedding import router as embedding_router from infra_ai_service.api.ai_enhance.vector_search import router as vector_search_router api_router = APIRouter() -api_router.include_router(system_router, prefix="/system", tags=["system"]) -api_router.include_router(text_process_router, prefix="/text", tags=["Text Processing"]) -api_router.include_router(embedding_router, prefix="/embed", tags=["Embedding"]) -api_router.include_router(vector_search_router, prefix="/search", tags=["Vector Search"]) +api_router.include_router(text_process_router, prefix="/text", tags=["text"]) +api_router.include_router(embedding_router, prefix="/embedding", tags=["embedding"]) +api_router.include_router(vector_search_router, prefix="/search", tags=["search"]) diff --git a/infra_ai_service/api/system/__init__.py b/infra_ai_service/api/system/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/infra_ai_service/api/system/views.py b/infra_ai_service/api/system/views.py deleted file mode 100644 index a2b55dc..0000000 --- a/infra_ai_service/api/system/views.py +++ /dev/null @@ -1,12 +0,0 @@ -from fastapi import APIRouter - -router = APIRouter() - - -@router.get("/health/") -async def health() -> None: - """ - Checks the health of a project. - - It returns 200 if the project is healthy. - """