-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from MaskerPRC/sjh_merge
Add support for pgvector
- Loading branch information
Showing
26 changed files
with
215 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,20 @@ | ||
# create new with next to this one, with name .env | ||
|
||
RELOAD=True | ||
|
||
BASE_URL=http://localhost:8000 | ||
|
||
# 数据库配置 | ||
DB_NAME= | ||
DB_USER= | ||
DB_PASSWORD= | ||
DB_HOST= | ||
DB_PORT= | ||
|
||
# 模型名称配置 | ||
MODEL_NAME= | ||
|
||
# 配置项 | ||
VECTOR_EXTENSION= | ||
TABLE_NAME= | ||
VECTOR_DIMENSION= | ||
LANGUAGE= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -142,3 +142,4 @@ dmypy.json | |
cython_debug/ | ||
|
||
.python-version | ||
infra_ai_service.egg-info |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,12 @@ | ||
from fastapi import APIRouter | ||
|
||
from infra_ai_service.api.ai_enhance.text_process import TextInput | ||
from infra_ai_service.api.common.utils import setup_qdrant_environment | ||
|
||
from infra_ai_service.model.model import EmbeddingOutput | ||
from infra_ai_service.service.embedding_service import create_embedding, \ | ||
get_collection_status | ||
from infra_ai_service.service.embedding_service import create_embedding | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.post("/embed/", response_model=EmbeddingOutput) | ||
async def embed_text(input_data: TextInput): | ||
return await create_embedding(input_data.content) | ||
|
||
|
||
@router.get("/status/") | ||
async def status(): | ||
return await get_collection_status() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# infra_ai_service/common/utils.py | ||
|
||
from infra_ai_service.config.config import settings | ||
|
||
|
||
async def setup_database(pool): | ||
async with pool.connection() as conn: | ||
# 创建扩展名,使用配置项 | ||
await conn.execute( | ||
f"CREATE EXTENSION IF NOT EXISTS {settings.VECTOR_EXTENSION}" | ||
) | ||
# 创建表,使用配置项 | ||
await conn.execute( | ||
f""" | ||
CREATE TABLE IF NOT EXISTS {settings.TABLE_NAME} ( | ||
id bigserial PRIMARY KEY, | ||
content text, | ||
embedding vector({settings.VECTOR_DIMENSION}) | ||
) | ||
""" | ||
) | ||
# 创建索引,使用配置项 | ||
await conn.execute( | ||
f""" | ||
CREATE INDEX IF NOT EXISTS {settings.TABLE_NAME}_content_idx | ||
ON {settings.TABLE_NAME} | ||
USING GIN (to_tsvector('{settings.LANGUAGE}', content)) | ||
""" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from psycopg_pool import AsyncConnectionPool | ||
from sentence_transformers import SentenceTransformer | ||
|
||
from infra_ai_service.common.utils import setup_database | ||
from infra_ai_service.config.config import settings | ||
|
||
# 初始化模型 | ||
model = None | ||
# 创建连接池(暂时不初始化) | ||
pool = None | ||
|
||
|
||
async def setup_model_and_pool(): | ||
global model, pool | ||
# 初始化模型 | ||
model = SentenceTransformer(settings.MODEL_NAME) | ||
# 创建异步连接池 | ||
conn_str = ( | ||
f"dbname={settings.DB_NAME} " | ||
f"user={settings.DB_USER} " | ||
f"password={settings.DB_PASSWORD} " | ||
f"host={settings.DB_HOST} " | ||
f"port={settings.DB_PORT}" | ||
) | ||
pool = AsyncConnectionPool(conn_str, open=True) | ||
|
||
# 设置数据库 | ||
await setup_database(pool) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import uvicorn | ||
|
||
from infra_ai_service.config.config import settings | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,42 @@ | ||
from fastapi import HTTPException | ||
import asyncio | ||
import uuid | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
from fastapi import HTTPException | ||
|
||
from infra_ai_service.model.model import PointStruct, EmbeddingOutput | ||
from infra_ai_service.sdk.qdrant import fastembed_model, qdrant_client, \ | ||
collection_name | ||
from infra_ai_service.model.model import EmbeddingOutput | ||
from infra_ai_service.sdk import pgvector | ||
|
||
|
||
async def create_embedding(content): | ||
try: | ||
embeddings = list(fastembed_model.embed([content])) | ||
if not embeddings: | ||
# 确保模型已初始化 | ||
if pgvector.model is None: | ||
raise HTTPException(status_code=500, | ||
detail="Failed to generate embedding") | ||
|
||
embedding_vector = embeddings[0] | ||
point_id = str(uuid.uuid4()) | ||
|
||
qdrant_client.upsert( | ||
collection_name=collection_name, | ||
points=[ | ||
PointStruct( | ||
id=point_id, | ||
vector=embedding_vector.tolist(), | ||
payload={"text": content} | ||
detail="Model is not initialized") | ||
|
||
# 使用线程池执行同步的嵌入计算 | ||
loop = asyncio.get_running_loop() | ||
with ThreadPoolExecutor() as pool_executor: | ||
embedding_vector = await loop.run_in_executor( | ||
pool_executor, pgvector.model.encode, [content] | ||
) | ||
embedding_vector = embedding_vector[0] | ||
|
||
# 将 ndarray 转换为列表 | ||
embedding_vector_list = embedding_vector.tolist() | ||
|
||
# 从连接池获取连接 | ||
async with pgvector.pool.connection() as conn: | ||
async with conn.cursor() as cur: | ||
await cur.execute( | ||
"INSERT INTO documents (content, embedding) " | ||
"VALUES (%s, %s) RETURNING id", | ||
(content, embedding_vector_list), # 使用转换后的列表 | ||
) | ||
] | ||
) | ||
point_id = (await cur.fetchone())[0] | ||
|
||
return EmbeddingOutput(id=point_id, | ||
embedding=embedding_vector.tolist()) | ||
return EmbeddingOutput(id=point_id, embedding=embedding_vector_list) | ||
except Exception as e: | ||
raise HTTPException(status_code=400, | ||
detail=f"Error processing embedding: {e}") | ||
|
||
|
||
async def get_collection_status(): | ||
try: | ||
collection_info = qdrant_client.get_collection(collection_name) | ||
return { | ||
"collection_name": collection_name, | ||
"vectors_count": collection_info.vectors_count, | ||
"status": "ready" if collection_info.status == "green" | ||
else "not ready" | ||
} | ||
except Exception as e: | ||
raise HTTPException(status_code=400, | ||
detail=f"Error getting collection status: {e}") |
Oops, something went wrong.