From cadde518257bb75369179ee036c095c3d5a831bd Mon Sep 17 00:00:00 2001 From: loopsaaage Date: Sat, 14 Sep 2024 15:00:59 +0800 Subject: [PATCH] Add support for pgvector --- .env-example | 17 +++- .gitignore | 1 + infra_ai_service.egg-info/PKG-INFO | 3 - infra_ai_service.egg-info/SOURCES.txt | 10 --- .../dependency_links.txt | 1 - infra_ai_service.egg-info/top_level.txt | 2 - infra_ai_service/api/ai_enhance/embedding.py | 10 +-- .../api/ai_enhance/vector_search.py | 2 +- infra_ai_service/api/common/utils.py | 29 ------ infra_ai_service/api/router.py | 4 +- infra_ai_service/{api => }/common/__init__.py | 0 infra_ai_service/common/utils.py | 29 ++++++ infra_ai_service/config/config.py | 46 +++++++--- infra_ai_service/core/app.py | 8 +- infra_ai_service/model/model.py | 3 +- infra_ai_service/sdk/pgvector.py | 28 ++++++ infra_ai_service/sdk/qdrant.py | 3 - infra_ai_service/server.py | 1 + infra_ai_service/service/embedding_service.py | 66 +++++++------- infra_ai_service/service/search_service.py | 90 +++++++++++-------- infra_ai_service/service/text_service.py | 5 +- requirements.txt | 4 +- setup.py | 2 +- test-demos/async_demo.py | 3 +- test-demos/demo.py | 6 +- tox.ini | 2 +- 26 files changed, 215 insertions(+), 160 deletions(-) delete mode 100644 infra_ai_service.egg-info/PKG-INFO delete mode 100644 infra_ai_service.egg-info/SOURCES.txt delete mode 100644 infra_ai_service.egg-info/dependency_links.txt delete mode 100644 infra_ai_service.egg-info/top_level.txt delete mode 100644 infra_ai_service/api/common/utils.py rename infra_ai_service/{api => }/common/__init__.py (100%) create mode 100644 infra_ai_service/common/utils.py create mode 100644 infra_ai_service/sdk/pgvector.py delete mode 100644 infra_ai_service/sdk/qdrant.py diff --git a/.env-example b/.env-example index 3966ffb..9c55262 100644 --- a/.env-example +++ b/.env-example @@ -1,5 +1,20 @@ # create new with next to this one, with name .env RELOAD=True - BASE_URL=http://localhost:8000 + +# 数据库配置 +DB_NAME= +DB_USER= +DB_PASSWORD= +DB_HOST= +DB_PORT= + +# 模型名称配置 +MODEL_NAME= + +# 配置项 +VECTOR_EXTENSION= +TABLE_NAME= +VECTOR_DIMENSION= +LANGUAGE= \ No newline at end of file diff --git a/.gitignore b/.gitignore index 9d38591..0e4576d 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,4 @@ dmypy.json cython_debug/ .python-version +infra_ai_service.egg-info diff --git a/infra_ai_service.egg-info/PKG-INFO b/infra_ai_service.egg-info/PKG-INFO deleted file mode 100644 index 6e745ac..0000000 --- a/infra_ai_service.egg-info/PKG-INFO +++ /dev/null @@ -1,3 +0,0 @@ -Metadata-Version: 2.1 -Name: infra_ai_service -Version: 0.1 diff --git a/infra_ai_service.egg-info/SOURCES.txt b/infra_ai_service.egg-info/SOURCES.txt deleted file mode 100644 index 1e67938..0000000 --- a/infra_ai_service.egg-info/SOURCES.txt +++ /dev/null @@ -1,10 +0,0 @@ -README.md -setup.py -infra_ai_service/__init__.py -infra_ai_service/demo.py -infra_ai_service.egg-info/PKG-INFO -infra_ai_service.egg-info/SOURCES.txt -infra_ai_service.egg-info/dependency_links.txt -infra_ai_service.egg-info/top_level.txt -tests/__init__.py -tests/test_demo.py \ No newline at end of file diff --git a/infra_ai_service.egg-info/dependency_links.txt b/infra_ai_service.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/infra_ai_service.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/infra_ai_service.egg-info/top_level.txt b/infra_ai_service.egg-info/top_level.txt deleted file mode 100644 index 7ca9f2e..0000000 --- a/infra_ai_service.egg-info/top_level.txt +++ /dev/null @@ -1,2 +0,0 @@ -infra_ai_service -tests diff --git a/infra_ai_service/api/ai_enhance/embedding.py b/infra_ai_service/api/ai_enhance/embedding.py index 268a9cc..8aa4936 100644 --- a/infra_ai_service/api/ai_enhance/embedding.py +++ b/infra_ai_service/api/ai_enhance/embedding.py @@ -1,11 +1,8 @@ from fastapi import APIRouter from infra_ai_service.api.ai_enhance.text_process import TextInput -from infra_ai_service.api.common.utils import setup_qdrant_environment - from infra_ai_service.model.model import EmbeddingOutput -from infra_ai_service.service.embedding_service import create_embedding, \ - get_collection_status +from infra_ai_service.service.embedding_service import create_embedding router = APIRouter() @@ -13,8 +10,3 @@ @router.post("/embed/", response_model=EmbeddingOutput) async def embed_text(input_data: TextInput): return await create_embedding(input_data.content) - - -@router.get("/status/") -async def status(): - return await get_collection_status() diff --git a/infra_ai_service/api/ai_enhance/vector_search.py b/infra_ai_service/api/ai_enhance/vector_search.py index 1b0040a..1bbfab0 100644 --- a/infra_ai_service/api/ai_enhance/vector_search.py +++ b/infra_ai_service/api/ai_enhance/vector_search.py @@ -1,6 +1,6 @@ from fastapi import APIRouter -from infra_ai_service.model.model import SearchOutput, SearchInput +from infra_ai_service.model.model import SearchInput, SearchOutput from infra_ai_service.service.search_service import perform_vector_search router = APIRouter() diff --git a/infra_ai_service/api/common/utils.py b/infra_ai_service/api/common/utils.py deleted file mode 100644 index b249b2c..0000000 --- a/infra_ai_service/api/common/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -from fastapi import HTTPException -from fastembed.embedding import DefaultEmbedding -from qdrant_client import QdrantClient -from qdrant_client.http.models import Distance, VectorParams - - -def setup_qdrant_environment(): - # 初始化FastEmbed模型和Qdrant客户端 - fastembed_model = DefaultEmbedding() - qdrant_client = QdrantClient(url="http://localhost:6333") - collection_name = 'test_simi' - - # 检查集合是否存在,如果不存在则创建 - try: - qdrant_client.get_collection(collection_name) - print(f"Collection {collection_name} already exists") - except HTTPException as e: - # 获取向量维度 - sample_embedding = next(fastembed_model.embed(["Sample text"])) - vector_size = len(sample_embedding) - - # 创建集合 - qdrant_client.create_collection( - collection_name=collection_name, - vectors_config=VectorParams(size=vector_size, - distance=Distance.COSINE), - ) - print(f"Created collection: {collection_name}") - return fastembed_model, qdrant_client, collection_name diff --git a/infra_ai_service/api/router.py b/infra_ai_service/api/router.py index a355b81..e262c62 100644 --- a/infra_ai_service/api/router.py +++ b/infra_ai_service/api/router.py @@ -1,9 +1,9 @@ from fastapi.routing import APIRouter -from infra_ai_service.api.ai_enhance.text_process import \ - router as text_process_router from infra_ai_service.api.ai_enhance.embedding import \ router as embedding_router +from infra_ai_service.api.ai_enhance.text_process import \ + router as text_process_router from infra_ai_service.api.ai_enhance.vector_search import \ router as vector_search_router diff --git a/infra_ai_service/api/common/__init__.py b/infra_ai_service/common/__init__.py similarity index 100% rename from infra_ai_service/api/common/__init__.py rename to infra_ai_service/common/__init__.py diff --git a/infra_ai_service/common/utils.py b/infra_ai_service/common/utils.py new file mode 100644 index 0000000..b1c60ab --- /dev/null +++ b/infra_ai_service/common/utils.py @@ -0,0 +1,29 @@ +# infra_ai_service/common/utils.py + +from infra_ai_service.config.config import settings + + +async def setup_database(pool): + async with pool.connection() as conn: + # 创建扩展名,使用配置项 + await conn.execute( + f"CREATE EXTENSION IF NOT EXISTS {settings.VECTOR_EXTENSION}" + ) + # 创建表,使用配置项 + await conn.execute( + f""" + CREATE TABLE IF NOT EXISTS {settings.TABLE_NAME} ( + id bigserial PRIMARY KEY, + content text, + embedding vector({settings.VECTOR_DIMENSION}) + ) + """ + ) + # 创建索引,使用配置项 + await conn.execute( + f""" + CREATE INDEX IF NOT EXISTS {settings.TABLE_NAME}_content_idx + ON {settings.TABLE_NAME} + USING GIN (to_tsvector('{settings.LANGUAGE}', content)) + """ + ) diff --git a/infra_ai_service/config/config.py b/infra_ai_service/config/config.py index 41f0398..c2f51c0 100644 --- a/infra_ai_service/config/config.py +++ b/infra_ai_service/config/config.py @@ -1,5 +1,6 @@ +# infra_ai_service/config/config.py + from pathlib import Path -from sys import modules from pydantic import BaseSettings @@ -12,27 +13,48 @@ class Settings(BaseSettings): ENV: str = "dev" HOST: str = "0.0.0.0" PORT: int = 8000 - _BASE_URL: str = f"https://{HOST}:{PORT}" - # quantity of workers for uvicorn + _BASE_URL: str = f"http://{HOST}:{PORT}" WORKERS_COUNT: int = 1 - # Enable uvicorn reloading RELOAD: bool = False + # 数据库配置项 + DB_NAME: str = "" + DB_USER: str = "" + DB_PASSWORD: str = "" + DB_HOST: str = "" + DB_PORT: int = 0 + + # 模型名称配置项 + MODEL_NAME: str = "" + + # 新增的配置项 + VECTOR_EXTENSION: str = "" + TABLE_NAME: str = "" + VECTOR_DIMENSION: int = 0 + LANGUAGE: str = "" + @property def BASE_URL(self) -> str: - return self._BASE_URL if self._BASE_URL.endswith( - "/") else f"{self._BASE_URL}/" + if self._BASE_URL.endswith("/"): + return self._BASE_URL + else: + return f"{self._BASE_URL}/" class Config: env_file = f"{BASE_DIR}/.env" env_file_encoding = "utf-8" fields = { - "_BASE_URL": { - "env": "BASE_URL", - }, - "_DB_BASE": { - "env": "DB_BASE", - }, + "_BASE_URL": {"env": "BASE_URL"}, + "DB_NAME": {"env": "DB_NAME"}, + "DB_USER": {"env": "DB_USER"}, + "DB_PASSWORD": {"env": "DB_PASSWORD"}, + "DB_HOST": {"env": "DB_HOST"}, + "DB_PORT": {"env": "DB_PORT"}, + "MODEL_NAME": {"env": "MODEL_NAME"}, + "VECTOR_EXTENSION": {"env": "VECTOR_EXTENSION"}, + "TABLE_NAME": {"env": "TABLE_NAME"}, + "VECTOR_DIMENSION": {"env": "VECTOR_DIMENSION"}, + "LANGUAGE": {"env": "LANGUAGE"}, } diff --git a/infra_ai_service/core/app.py b/infra_ai_service/core/app.py index 3060a98..d43b75a 100644 --- a/infra_ai_service/core/app.py +++ b/infra_ai_service/core/app.py @@ -1,7 +1,9 @@ -from infra_ai_service.api.router import api_router from fastapi import FastAPI from fastapi.responses import UJSONResponse +from infra_ai_service.api.router import api_router +from infra_ai_service.sdk.pgvector import setup_model_and_pool + def get_app() -> FastAPI: """ @@ -23,4 +25,8 @@ def get_app() -> FastAPI: app.include_router(router=api_router, prefix="/api") + @app.on_event("startup") + async def startup_event(): + await setup_model_and_pool() + return app diff --git a/infra_ai_service/model/model.py b/infra_ai_service/model/model.py index 1bd7bbb..943cb16 100644 --- a/infra_ai_service/model/model.py +++ b/infra_ai_service/model/model.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel from typing import List +from pydantic import BaseModel + class SearchInput(BaseModel): query_text: str diff --git a/infra_ai_service/sdk/pgvector.py b/infra_ai_service/sdk/pgvector.py new file mode 100644 index 0000000..8fdf23b --- /dev/null +++ b/infra_ai_service/sdk/pgvector.py @@ -0,0 +1,28 @@ +from psycopg_pool import AsyncConnectionPool +from sentence_transformers import SentenceTransformer + +from infra_ai_service.common.utils import setup_database +from infra_ai_service.config.config import settings + +# 初始化模型 +model = None +# 创建连接池(暂时不初始化) +pool = None + + +async def setup_model_and_pool(): + global model, pool + # 初始化模型 + model = SentenceTransformer(settings.MODEL_NAME) + # 创建异步连接池 + conn_str = ( + f"dbname={settings.DB_NAME} " + f"user={settings.DB_USER} " + f"password={settings.DB_PASSWORD} " + f"host={settings.DB_HOST} " + f"port={settings.DB_PORT}" + ) + pool = AsyncConnectionPool(conn_str, open=True) + + # 设置数据库 + await setup_database(pool) diff --git a/infra_ai_service/sdk/qdrant.py b/infra_ai_service/sdk/qdrant.py deleted file mode 100644 index 2478e18..0000000 --- a/infra_ai_service/sdk/qdrant.py +++ /dev/null @@ -1,3 +0,0 @@ -from infra_ai_service.api.common.utils import setup_qdrant_environment - -fastembed_model, qdrant_client, collection_name = setup_qdrant_environment() diff --git a/infra_ai_service/server.py b/infra_ai_service/server.py index 72063d6..917cde0 100644 --- a/infra_ai_service/server.py +++ b/infra_ai_service/server.py @@ -1,4 +1,5 @@ import uvicorn + from infra_ai_service.config.config import settings diff --git a/infra_ai_service/service/embedding_service.py b/infra_ai_service/service/embedding_service.py index 401f06a..84a3111 100644 --- a/infra_ai_service/service/embedding_service.py +++ b/infra_ai_service/service/embedding_service.py @@ -1,48 +1,42 @@ -from fastapi import HTTPException +import asyncio import uuid +from concurrent.futures import ThreadPoolExecutor + +from fastapi import HTTPException -from infra_ai_service.model.model import PointStruct, EmbeddingOutput -from infra_ai_service.sdk.qdrant import fastembed_model, qdrant_client, \ - collection_name +from infra_ai_service.model.model import EmbeddingOutput +from infra_ai_service.sdk import pgvector async def create_embedding(content): try: - embeddings = list(fastembed_model.embed([content])) - if not embeddings: + # 确保模型已初始化 + if pgvector.model is None: raise HTTPException(status_code=500, - detail="Failed to generate embedding") - - embedding_vector = embeddings[0] - point_id = str(uuid.uuid4()) - - qdrant_client.upsert( - collection_name=collection_name, - points=[ - PointStruct( - id=point_id, - vector=embedding_vector.tolist(), - payload={"text": content} + detail="Model is not initialized") + + # 使用线程池执行同步的嵌入计算 + loop = asyncio.get_running_loop() + with ThreadPoolExecutor() as pool_executor: + embedding_vector = await loop.run_in_executor( + pool_executor, pgvector.model.encode, [content] + ) + embedding_vector = embedding_vector[0] + + # 将 ndarray 转换为列表 + embedding_vector_list = embedding_vector.tolist() + + # 从连接池获取连接 + async with pgvector.pool.connection() as conn: + async with conn.cursor() as cur: + await cur.execute( + "INSERT INTO documents (content, embedding) " + "VALUES (%s, %s) RETURNING id", + (content, embedding_vector_list), # 使用转换后的列表 ) - ] - ) + point_id = (await cur.fetchone())[0] - return EmbeddingOutput(id=point_id, - embedding=embedding_vector.tolist()) + return EmbeddingOutput(id=point_id, embedding=embedding_vector_list) except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing embedding: {e}") - - -async def get_collection_status(): - try: - collection_info = qdrant_client.get_collection(collection_name) - return { - "collection_name": collection_name, - "vectors_count": collection_info.vectors_count, - "status": "ready" if collection_info.status == "green" - else "not ready" - } - except Exception as e: - raise HTTPException(status_code=400, - detail=f"Error getting collection status: {e}") diff --git a/infra_ai_service/service/search_service.py b/infra_ai_service/service/search_service.py index c865612..54b10f9 100644 --- a/infra_ai_service/service/search_service.py +++ b/infra_ai_service/service/search_service.py @@ -1,53 +1,65 @@ -from fastapi import HTTPException +# infraAIService/infra_ai_service/service/search_service.py + +import asyncio import logging +from concurrent.futures import ThreadPoolExecutor -from infra_ai_service.model.model import SearchOutput, SearchResult, \ - SearchInput -from infra_ai_service.sdk.qdrant import qdrant_client, collection_name, \ - fastembed_model +from fastapi import HTTPException + +from infra_ai_service.model.model import SearchInput, SearchOutput, \ + SearchResult +from infra_ai_service.sdk import pgvector logger = logging.getLogger(__name__) async def perform_vector_search(input_data: SearchInput): try: - # 检查集合是否存在 - collection_info = qdrant_client.get_collection(collection_name) - if not collection_info: - logger.error(f"Collection {collection_name} does not exist") - raise HTTPException(status_code=404, - detail=f"Collection {collection_name} does " - f"not exist") - - # 生成查询文本的嵌入 - query_vector = list(fastembed_model.embed([input_data.query_text])) - if not query_vector: - logger.error("Failed to generate query embedding") - raise HTTPException(status_code=500, - detail="Failed to generate query embedding") - - # 执行向量搜索 - search_results = qdrant_client.search( - collection_name=collection_name, - query_vector=query_vector[0], - limit=input_data.top_n, - score_threshold=input_data.score_threshold - ) + # 确保模型已初始化 + if pgvector.model is None: + logger.error("模型未初始化") + raise HTTPException(status_code=500, detail="模型未初始化") - # 转换搜索结果为输出格式 - results = [ - SearchResult( - id=str(result.id), - score=result.score, - text=result.payload.get('text', 'No text available') + # 生成查询文本的嵌入向量 + loop = asyncio.get_running_loop() + with ThreadPoolExecutor() as pool_executor: + embedding_vector = await loop.run_in_executor( + pool_executor, pgvector.model.encode, [input_data.query_text] ) - for result in search_results - ] + embedding_vector = embedding_vector[0] + + # 将 ndarray 转换为列表 + embedding_vector_list = embedding_vector.tolist() + + # 从连接池获取连接 + async with pgvector.pool.connection() as conn: + async with conn.cursor() as cur: + # 执行向量搜索查询,显式转换参数为 vector 类型 + await cur.execute( + """ + SELECT id, content, embedding, + 1 - (embedding <#> %s::vector) + AS similarity + FROM documents + ORDER BY similarity DESC + LIMIT %s + """, + (embedding_vector_list, input_data.top_n), + ) + rows = await cur.fetchall() + + # 转换搜索结果为输出格式 + results = [] + for row in rows: + similarity = row[3] # 相似度得分 + if similarity >= input_data.score_threshold: + results.append( + SearchResult(id=str(row[0]), score=similarity, + text=row[1]) # 内容 + ) return SearchOutput(results=results) except Exception as e: - logger.error(f"Error performing vector search: {str(e)}", - exc_info=True) + logger.error(f"执行向量搜索时出错: {str(e)}", exc_info=True) raise HTTPException(status_code=500, - detail=f"Error performing vector search: " - f"{str(e)}") + detail=f"执行向量搜索时出错: {str(e)}") diff --git a/infra_ai_service/service/text_service.py b/infra_ai_service/service/text_service.py index 9331a5c..832d8bd 100644 --- a/infra_ai_service/service/text_service.py +++ b/infra_ai_service/service/text_service.py @@ -1,6 +1,7 @@ +import logging import re + from fastapi import HTTPException -import logging from infra_ai_service.model.model import TextOutput @@ -10,7 +11,7 @@ def clean_text(text: str) -> str: try: # 正确转义正则表达式中的特殊字符 - return re.sub(r'[{}\[\]()@.#\\_\':\/-]', '', text) + return re.sub(r"[{}\[\]()@.#\\_\':\/-]", "", text) except re.error as e: logger.error(f"Regex error: {str(e)}", exc_info=True) raise HTTPException(status_code=400, detail="Regex processing error") diff --git a/requirements.txt b/requirements.txt index 150a888..c585d4a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,6 @@ requests==2.31.0 httpx==0.23.0 pydantic==1.10.12 fastembed==0.3.6 -qdrant-client==1.11.1 +setuptools~=74.1.2 +psycopg~=3.2.1 +pgvector~=0.3.3 \ No newline at end of file diff --git a/setup.py b/setup.py index 7805aa3..e2b8e21 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ # setup.py -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name="infra_ai_service", diff --git a/test-demos/async_demo.py b/test-demos/async_demo.py index 38e858f..f196f2c 100644 --- a/test-demos/async_demo.py +++ b/test-demos/async_demo.py @@ -5,11 +5,12 @@ async def fetch_data(): print("开始获取数据...") await asyncio.sleep(2) # 模拟IO操作 print("数据获取完成") - return {'data': 123} + return {"data": 123} async def main(): result = await fetch_data() print(result) + asyncio.run(main()) diff --git a/test-demos/demo.py b/test-demos/demo.py index 9a68e2f..3714b49 100644 --- a/test-demos/demo.py +++ b/test-demos/demo.py @@ -1,8 +1,6 @@ -from fastapi import FastAPI -from pydantic import BaseModel -from fastapi import Depends -from fastapi import HTTPException +from fastapi import Depends, FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel class Item(BaseModel): diff --git a/tox.ini b/tox.ini index 1234e31..ad22c39 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,6 @@ deps = pytest_asyncio asyncpg fastapi - qdrant_client fastembed commands = pytest tests/ --cov=infra_ai_service --cov-report=term-missing @@ -28,6 +27,7 @@ commands = [testenv:coverage] deps = coverage + {[testenv]deps} commands = coverage report --fail-under=0 coverage html