-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove unused api system and support qdrant valid
- Loading branch information
Showing
6 changed files
with
120 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,88 @@ | ||
from fastapi import APIRouter, HTTPException | ||
from pydantic import BaseModel | ||
from typing import List | ||
from fastembed.embedding import DefaultEmbedding | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.http.models import Distance, VectorParams, PointStruct | ||
import numpy as np | ||
import uuid | ||
|
||
router = APIRouter() | ||
|
||
|
||
class EmbeddingRequest(BaseModel): | ||
text: str | ||
class TextInput(BaseModel): | ||
content: str | ||
|
||
|
||
class EmbeddingResponse(BaseModel): | ||
embedding: list[float] | ||
class EmbeddingOutput(BaseModel): | ||
id: str | ||
embedding: List[float] | ||
|
||
|
||
# Load a FastEmbed model | ||
# 初始化FastEmbed模型 | ||
fastembed_model = DefaultEmbedding() | ||
|
||
# 初始化Qdrant客户端 | ||
qdrant_client = QdrantClient(url="http://localhost:6333") | ||
collection_name = 'test_simi' | ||
|
||
@router.post("/embed_text/", response_model=EmbeddingResponse) | ||
async def embed_text(request: EmbeddingRequest) -> EmbeddingResponse: | ||
# 检查集合是否存在,如果不存在则创建 | ||
try: | ||
qdrant_client.get_collection(collection_name) | ||
print(f"Collection {collection_name} already exists") | ||
except HTTPException as e: | ||
# 获取向量维度 | ||
sample_embedding = next(fastembed_model.embed(["Sample text"])) | ||
vector_size = len(sample_embedding) | ||
|
||
# 创建集合 | ||
qdrant_client.create_collection( | ||
collection_name=collection_name, | ||
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), | ||
) | ||
print(f"Created collection: {collection_name}") | ||
|
||
|
||
@router.post("/embed/", response_model=EmbeddingOutput) | ||
async def embed_text(input_data: TextInput): | ||
try: | ||
# 生成嵌入 | ||
embeddings = list(fastembed_model.embed([request.text])) | ||
embeddings = list(fastembed_model.embed([input_data.content])) | ||
if not embeddings: | ||
raise ValueError("Failed to generate embedding") | ||
|
||
# 将numpy数组转换为普通Python列表 | ||
embedding_list = embeddings[0].tolist() | ||
embedding_vector = embeddings[0] | ||
|
||
# 生成唯一ID | ||
point_id = str(uuid.uuid4()) | ||
|
||
# 将嵌入存储到Qdrant | ||
qdrant_client.upsert( | ||
collection_name=collection_name, | ||
points=[ | ||
PointStruct( | ||
id=point_id, | ||
vector=embedding_vector.tolist(), | ||
payload={"text": input_data.content} | ||
) | ||
] | ||
) | ||
|
||
return EmbeddingResponse(embedding=embedding_list) | ||
return EmbeddingOutput(id=point_id, embedding=embedding_vector.tolist()) | ||
except Exception as e: | ||
raise HTTPException(status_code=400, | ||
detail=f"Error processing embedding: {str(e)}") | ||
|
||
|
||
@router.get("/status/") | ||
async def get_collection_status(): | ||
try: | ||
collection_info = qdrant_client.get_collection(collection_name) | ||
return { | ||
"collection_name": collection_name, | ||
"vectors_count": collection_info.vectors_count, | ||
"status": "ready" if collection_info.status == "green" else "not ready" | ||
} | ||
except Exception as e: | ||
raise HTTPException(status_code=400, detail=f"Error generating embedding: {str(e)}") | ||
raise HTTPException(status_code=400, | ||
detail=f"Error getting collection status: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,10 @@ | ||
from fastapi.routing import APIRouter | ||
|
||
from infra_ai_service.api.system.views import router as system_router | ||
from infra_ai_service.api.ai_enhance.text_process import router as text_process_router | ||
from infra_ai_service.api.ai_enhance.embedding import router as embedding_router | ||
from infra_ai_service.api.ai_enhance.vector_search import router as vector_search_router | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(system_router, prefix="/system", tags=["system"]) | ||
api_router.include_router(text_process_router, prefix="/text", tags=["Text Processing"]) | ||
api_router.include_router(embedding_router, prefix="/embed", tags=["Embedding"]) | ||
api_router.include_router(vector_search_router, prefix="/search", tags=["Vector Search"]) | ||
api_router.include_router(text_process_router, prefix="/text", tags=["text"]) | ||
api_router.include_router(embedding_router, prefix="/embedding", tags=["embedding"]) | ||
api_router.include_router(vector_search_router, prefix="/search", tags=["search"]) |
Empty file.
This file was deleted.
Oops, something went wrong.