Skip to content

Commit

Permalink
Remove unused api system and support qdrant valid
Browse files Browse the repository at this point in the history
  • Loading branch information
MaskerPRC committed Sep 10, 2024
1 parent 228e8a9 commit c4fe45a
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 44 deletions.
79 changes: 67 additions & 12 deletions infra_ai_service/api/ai_enhance/embedding.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,88 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import List
from fastembed.embedding import DefaultEmbedding
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, PointStruct
import numpy as np
import uuid

router = APIRouter()


class EmbeddingRequest(BaseModel):
text: str
class TextInput(BaseModel):
content: str


class EmbeddingResponse(BaseModel):
embedding: list[float]
class EmbeddingOutput(BaseModel):
id: str
embedding: List[float]


# Load a FastEmbed model
# 初始化FastEmbed模型
fastembed_model = DefaultEmbedding()

# 初始化Qdrant客户端
qdrant_client = QdrantClient(url="http://localhost:6333")
collection_name = 'test_simi'

@router.post("/embed_text/", response_model=EmbeddingResponse)
async def embed_text(request: EmbeddingRequest) -> EmbeddingResponse:
# 检查集合是否存在,如果不存在则创建
try:
qdrant_client.get_collection(collection_name)
print(f"Collection {collection_name} already exists")
except HTTPException as e:
# 获取向量维度
sample_embedding = next(fastembed_model.embed(["Sample text"]))
vector_size = len(sample_embedding)

# 创建集合
qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
print(f"Created collection: {collection_name}")


@router.post("/embed/", response_model=EmbeddingOutput)
async def embed_text(input_data: TextInput):
try:
# 生成嵌入
embeddings = list(fastembed_model.embed([request.text]))
embeddings = list(fastembed_model.embed([input_data.content]))
if not embeddings:
raise ValueError("Failed to generate embedding")

# 将numpy数组转换为普通Python列表
embedding_list = embeddings[0].tolist()
embedding_vector = embeddings[0]

# 生成唯一ID
point_id = str(uuid.uuid4())

# 将嵌入存储到Qdrant
qdrant_client.upsert(
collection_name=collection_name,
points=[
PointStruct(
id=point_id,
vector=embedding_vector.tolist(),
payload={"text": input_data.content}
)
]
)

return EmbeddingResponse(embedding=embedding_list)
return EmbeddingOutput(id=point_id, embedding=embedding_vector.tolist())
except Exception as e:
raise HTTPException(status_code=400,
detail=f"Error processing embedding: {str(e)}")


@router.get("/status/")
async def get_collection_status():
try:
collection_info = qdrant_client.get_collection(collection_name)
return {
"collection_name": collection_name,
"vectors_count": collection_info.vectors_count,
"status": "ready" if collection_info.status == "green" else "not ready"
}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error generating embedding: {str(e)}")
raise HTTPException(status_code=400,
detail=f"Error getting collection status: {str(e)}")
17 changes: 8 additions & 9 deletions infra_ai_service/api/ai_enhance/text_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,22 @@
router = APIRouter()


class TextRequest(BaseModel):
class TextInput(BaseModel):
content: str


class TextResponse(BaseModel):
processed_content: str
class TextOutput(BaseModel):
modified_content: str


def clean_text(text: str) -> str:
cleaned_text = re.sub(r'[{}[\]()@.#\\_\':\/-]', '', text)
return cleaned_text
return re.sub(r'[{}[\]()@.#\\_\':\/-]', '', text)


@router.post("/process_text/", response_model=TextResponse)
async def process_text(request: TextRequest) -> TextResponse:
@router.post("/process/", response_model=TextOutput)
async def process_text(input_data: TextInput):
try:
processed_text = clean_text(request.content)
return TextResponse(processed_content=processed_text)
modified_text = clean_text(input_data.content)
return TextOutput(modified_content=modified_text)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing text: {str(e)}")
48 changes: 42 additions & 6 deletions infra_ai_service/api/ai_enhance/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import List
from fastembed.embedding import DefaultEmbedding
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
import logging

router = APIRouter()

# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class SearchRequest(BaseModel):

class SearchInput(BaseModel):
query_text: str
top_n: int = 5
score_threshold: float = 0.7
Expand All @@ -17,15 +24,39 @@ class SearchResult(BaseModel):
score: float


class SearchOutput(BaseModel):
results: List[SearchResult]


# 初始化FastEmbed模型和Qdrant客户端
fastembed_model = DefaultEmbedding()
qdrant_client = QdrantClient(url="http://localhost:6333")
collection_name = 'test_simi'

# 检查集合是否存在,如果不存在则创建
try:
qdrant_client.get_collection(collection_name)
print(f"Collection {collection_name} already exists")
except HTTPException as e:
# 获取向量维度
sample_embedding = next(fastembed_model.embed(["Sample text"]))
vector_size = len(sample_embedding)

@router.post("/search_vectors/", response_model=SearchResult)
async def search_vectors(input_data: SearchRequest) -> SearchResult:
# 创建集合
qdrant_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
print(f"Created collection: {collection_name}")


@router.post("/query/", response_model=SearchOutput)
async def vector_search(input_data: SearchInput):
try:
# 检查集合是否存在
if not qdrant_client.get_collection(collection_name):
raise ValueError(f"Collection {collection_name} does not exist")

# 生成查询文本的嵌入
query_vector = list(fastembed_model.embed([input_data.query_text]))
if not query_vector:
Expand All @@ -41,10 +72,15 @@ async def search_vectors(input_data: SearchRequest) -> SearchResult:

# 转换搜索结果为输出格式
results = [
SearchResult(id=str(result.id), score=result.score)
SearchResult(
id=str(result.id),
score=result.score,
text=result.payload.get('text', 'No text available')
)
for result in search_results
]

return SearchResult(results=results)
return SearchOutput(results=results)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error performing vector search: {str(e)}")
logger.error(f"Error in vector search: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Error performing vector search: {str(e)}")
8 changes: 3 additions & 5 deletions infra_ai_service/api/router.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from fastapi.routing import APIRouter

from infra_ai_service.api.system.views import router as system_router
from infra_ai_service.api.ai_enhance.text_process import router as text_process_router
from infra_ai_service.api.ai_enhance.embedding import router as embedding_router
from infra_ai_service.api.ai_enhance.vector_search import router as vector_search_router

api_router = APIRouter()
api_router.include_router(system_router, prefix="/system", tags=["system"])
api_router.include_router(text_process_router, prefix="/text", tags=["Text Processing"])
api_router.include_router(embedding_router, prefix="/embed", tags=["Embedding"])
api_router.include_router(vector_search_router, prefix="/search", tags=["Vector Search"])
api_router.include_router(text_process_router, prefix="/text", tags=["text"])
api_router.include_router(embedding_router, prefix="/embedding", tags=["embedding"])
api_router.include_router(vector_search_router, prefix="/search", tags=["search"])
Empty file.
12 changes: 0 additions & 12 deletions infra_ai_service/api/system/views.py

This file was deleted.

0 comments on commit c4fe45a

Please sign in to comment.