Skip to content

Commit

Permalink
Merge pull request #12 from MaskerPRC/sjh_merge
Browse files Browse the repository at this point in the history
Add support for pgvector
  • Loading branch information
jlcoo authored Sep 23, 2024
2 parents c0d1a08 + cadde51 commit 44ef8d4
Show file tree
Hide file tree
Showing 26 changed files with 215 additions and 160 deletions.
17 changes: 16 additions & 1 deletion .env-example
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# create new with next to this one, with name .env

RELOAD=True

BASE_URL=http://localhost:8000

# 数据库配置
DB_NAME=
DB_USER=
DB_PASSWORD=
DB_HOST=
DB_PORT=

# 模型名称配置
MODEL_NAME=

# 配置项
VECTOR_EXTENSION=
TABLE_NAME=
VECTOR_DIMENSION=
LANGUAGE=
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,4 @@ dmypy.json
cython_debug/

.python-version
infra_ai_service.egg-info
3 changes: 0 additions & 3 deletions infra_ai_service.egg-info/PKG-INFO

This file was deleted.

10 changes: 0 additions & 10 deletions infra_ai_service.egg-info/SOURCES.txt

This file was deleted.

1 change: 0 additions & 1 deletion infra_ai_service.egg-info/dependency_links.txt

This file was deleted.

2 changes: 0 additions & 2 deletions infra_ai_service.egg-info/top_level.txt

This file was deleted.

10 changes: 1 addition & 9 deletions infra_ai_service/api/ai_enhance/embedding.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from fastapi import APIRouter

from infra_ai_service.api.ai_enhance.text_process import TextInput
from infra_ai_service.api.common.utils import setup_qdrant_environment

from infra_ai_service.model.model import EmbeddingOutput
from infra_ai_service.service.embedding_service import create_embedding, \
get_collection_status
from infra_ai_service.service.embedding_service import create_embedding

router = APIRouter()


@router.post("/embed/", response_model=EmbeddingOutput)
async def embed_text(input_data: TextInput):
return await create_embedding(input_data.content)


@router.get("/status/")
async def status():
return await get_collection_status()
2 changes: 1 addition & 1 deletion infra_ai_service/api/ai_enhance/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from fastapi import APIRouter

from infra_ai_service.model.model import SearchOutput, SearchInput
from infra_ai_service.model.model import SearchInput, SearchOutput
from infra_ai_service.service.search_service import perform_vector_search

router = APIRouter()
Expand Down
29 changes: 0 additions & 29 deletions infra_ai_service/api/common/utils.py

This file was deleted.

4 changes: 2 additions & 2 deletions infra_ai_service/api/router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fastapi.routing import APIRouter

from infra_ai_service.api.ai_enhance.text_process import \
router as text_process_router
from infra_ai_service.api.ai_enhance.embedding import \
router as embedding_router
from infra_ai_service.api.ai_enhance.text_process import \
router as text_process_router
from infra_ai_service.api.ai_enhance.vector_search import \
router as vector_search_router

Expand Down
File renamed without changes.
29 changes: 29 additions & 0 deletions infra_ai_service/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# infra_ai_service/common/utils.py

from infra_ai_service.config.config import settings


async def setup_database(pool):
async with pool.connection() as conn:
# 创建扩展名,使用配置项
await conn.execute(
f"CREATE EXTENSION IF NOT EXISTS {settings.VECTOR_EXTENSION}"
)
# 创建表,使用配置项
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {settings.TABLE_NAME} (
id bigserial PRIMARY KEY,
content text,
embedding vector({settings.VECTOR_DIMENSION})
)
"""
)
# 创建索引,使用配置项
await conn.execute(
f"""
CREATE INDEX IF NOT EXISTS {settings.TABLE_NAME}_content_idx
ON {settings.TABLE_NAME}
USING GIN (to_tsvector('{settings.LANGUAGE}', content))
"""
)
46 changes: 34 additions & 12 deletions infra_ai_service/config/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# infra_ai_service/config/config.py

from pathlib import Path
from sys import modules

from pydantic import BaseSettings

Expand All @@ -12,27 +13,48 @@ class Settings(BaseSettings):
ENV: str = "dev"
HOST: str = "0.0.0.0"
PORT: int = 8000
_BASE_URL: str = f"https://{HOST}:{PORT}"
# quantity of workers for uvicorn
_BASE_URL: str = f"http://{HOST}:{PORT}"
WORKERS_COUNT: int = 1
# Enable uvicorn reloading
RELOAD: bool = False

# 数据库配置项
DB_NAME: str = ""
DB_USER: str = ""
DB_PASSWORD: str = ""
DB_HOST: str = ""
DB_PORT: int = 0

# 模型名称配置项
MODEL_NAME: str = ""

# 新增的配置项
VECTOR_EXTENSION: str = ""
TABLE_NAME: str = ""
VECTOR_DIMENSION: int = 0
LANGUAGE: str = ""

@property
def BASE_URL(self) -> str:
return self._BASE_URL if self._BASE_URL.endswith(
"/") else f"{self._BASE_URL}/"
if self._BASE_URL.endswith("/"):
return self._BASE_URL
else:
return f"{self._BASE_URL}/"

class Config:
env_file = f"{BASE_DIR}/.env"
env_file_encoding = "utf-8"
fields = {
"_BASE_URL": {
"env": "BASE_URL",
},
"_DB_BASE": {
"env": "DB_BASE",
},
"_BASE_URL": {"env": "BASE_URL"},
"DB_NAME": {"env": "DB_NAME"},
"DB_USER": {"env": "DB_USER"},
"DB_PASSWORD": {"env": "DB_PASSWORD"},
"DB_HOST": {"env": "DB_HOST"},
"DB_PORT": {"env": "DB_PORT"},
"MODEL_NAME": {"env": "MODEL_NAME"},
"VECTOR_EXTENSION": {"env": "VECTOR_EXTENSION"},
"TABLE_NAME": {"env": "TABLE_NAME"},
"VECTOR_DIMENSION": {"env": "VECTOR_DIMENSION"},
"LANGUAGE": {"env": "LANGUAGE"},
}


Expand Down
8 changes: 7 additions & 1 deletion infra_ai_service/core/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from infra_ai_service.api.router import api_router
from fastapi import FastAPI
from fastapi.responses import UJSONResponse

from infra_ai_service.api.router import api_router
from infra_ai_service.sdk.pgvector import setup_model_and_pool


def get_app() -> FastAPI:
"""
Expand All @@ -23,4 +25,8 @@ def get_app() -> FastAPI:

app.include_router(router=api_router, prefix="/api")

@app.on_event("startup")
async def startup_event():
await setup_model_and_pool()

return app
3 changes: 2 additions & 1 deletion infra_ai_service/model/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel
from typing import List

from pydantic import BaseModel


class SearchInput(BaseModel):
query_text: str
Expand Down
28 changes: 28 additions & 0 deletions infra_ai_service/sdk/pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from psycopg_pool import AsyncConnectionPool
from sentence_transformers import SentenceTransformer

from infra_ai_service.common.utils import setup_database
from infra_ai_service.config.config import settings

# 初始化模型
model = None
# 创建连接池(暂时不初始化)
pool = None


async def setup_model_and_pool():
global model, pool
# 初始化模型
model = SentenceTransformer(settings.MODEL_NAME)
# 创建异步连接池
conn_str = (
f"dbname={settings.DB_NAME} "
f"user={settings.DB_USER} "
f"password={settings.DB_PASSWORD} "
f"host={settings.DB_HOST} "
f"port={settings.DB_PORT}"
)
pool = AsyncConnectionPool(conn_str, open=True)

# 设置数据库
await setup_database(pool)
3 changes: 0 additions & 3 deletions infra_ai_service/sdk/qdrant.py

This file was deleted.

1 change: 1 addition & 0 deletions infra_ai_service/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uvicorn

from infra_ai_service.config.config import settings


Expand Down
66 changes: 30 additions & 36 deletions infra_ai_service/service/embedding_service.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,42 @@
from fastapi import HTTPException
import asyncio
import uuid
from concurrent.futures import ThreadPoolExecutor

from fastapi import HTTPException

from infra_ai_service.model.model import PointStruct, EmbeddingOutput
from infra_ai_service.sdk.qdrant import fastembed_model, qdrant_client, \
collection_name
from infra_ai_service.model.model import EmbeddingOutput
from infra_ai_service.sdk import pgvector


async def create_embedding(content):
try:
embeddings = list(fastembed_model.embed([content]))
if not embeddings:
# 确保模型已初始化
if pgvector.model is None:
raise HTTPException(status_code=500,
detail="Failed to generate embedding")

embedding_vector = embeddings[0]
point_id = str(uuid.uuid4())

qdrant_client.upsert(
collection_name=collection_name,
points=[
PointStruct(
id=point_id,
vector=embedding_vector.tolist(),
payload={"text": content}
detail="Model is not initialized")

# 使用线程池执行同步的嵌入计算
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as pool_executor:
embedding_vector = await loop.run_in_executor(
pool_executor, pgvector.model.encode, [content]
)
embedding_vector = embedding_vector[0]

# 将 ndarray 转换为列表
embedding_vector_list = embedding_vector.tolist()

# 从连接池获取连接
async with pgvector.pool.connection() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO documents (content, embedding) "
"VALUES (%s, %s) RETURNING id",
(content, embedding_vector_list), # 使用转换后的列表
)
]
)
point_id = (await cur.fetchone())[0]

return EmbeddingOutput(id=point_id,
embedding=embedding_vector.tolist())
return EmbeddingOutput(id=point_id, embedding=embedding_vector_list)
except Exception as e:
raise HTTPException(status_code=400,
detail=f"Error processing embedding: {e}")


async def get_collection_status():
try:
collection_info = qdrant_client.get_collection(collection_name)
return {
"collection_name": collection_name,
"vectors_count": collection_info.vectors_count,
"status": "ready" if collection_info.status == "green"
else "not ready"
}
except Exception as e:
raise HTTPException(status_code=400,
detail=f"Error getting collection status: {e}")
Loading

0 comments on commit 44ef8d4

Please sign in to comment.