From 61bd4ef576d47a4a8e0f730c4a4d3b3b663616c3 Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Mon, 16 Dec 2024 11:00:08 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9Asolve=20the=20dup=20retrieval=20iss?= =?UTF-8?q?ue=20(#575)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: merge partial env variable and skip validation in the dev * fix: fix dup retrieval * feat: update petercat-utils --- petercat_utils/rag_helper/retrieval.py | 27 +++-- pyproject.toml | 2 +- server/auth/middleware.py | 137 ++++++++++++------------- server/auth/router.py | 22 ++-- server/env.py | 6 ++ server/github_app/router.py | 21 ++-- server/main.py | 25 ++--- server/tests/test_main.py | 8 +- subscriber/requirements.txt | 2 +- 9 files changed, 129 insertions(+), 121 deletions(-) create mode 100644 server/env.py diff --git a/petercat_utils/rag_helper/retrieval.py b/petercat_utils/rag_helper/retrieval.py index 5a3207dd..1cb889e1 100644 --- a/petercat_utils/rag_helper/retrieval.py +++ b/petercat_utils/rag_helper/retrieval.py @@ -1,7 +1,6 @@ import json from typing import Any, Dict - from langchain_community.vectorstores import SupabaseVectorStore from langchain_openai import OpenAIEmbeddings @@ -9,7 +8,6 @@ from ..data_class import GitDocConfig, RAGGitDocConfig, S3Config from ..db.client.supabase import get_client - TABLE_NAME = "rag_docs" QUERY_NAME = "match_embedding_docs" CHUNK_SIZE = 2000 @@ -118,15 +116,16 @@ def add_knowledge_by_doc(config: RAGGitDocConfig): supabase = get_client() is_doc_added_query = ( supabase.table(TABLE_NAME) - .select("id, repo_name, commit_id, file_path") + .select("id") .eq("repo_name", config.repo_name) .eq("commit_id", loader.commit_id) .eq("file_path", config.file_path) + .limit(1) .execute() ) if not is_doc_added_query.data: is_doc_equal_query = ( - supabase.table(TABLE_NAME).select("*").eq("file_sha", loader.file_sha) + supabase.table(TABLE_NAME).select("id").eq("file_sha", loader.file_sha).limit(1) ).execute() if not is_doc_equal_query.data: # If there is no file with the same file_sha, perform embedding. @@ -139,6 +138,18 @@ def add_knowledge_by_doc(config: RAGGitDocConfig): ) return store else: + # Prioritize obtaining the minimal set of records to avoid overlapping with the original records. + minimum_repeat_result = supabase.rpc('count_rag_docs_by_sha', {'file_sha_input': loader.file_sha}).execute() + target_filter = minimum_repeat_result.data[0] + # Copy the minimal set + insert_docs = ( + supabase.table(TABLE_NAME) + .select("*") + .eq("repo_name", target_filter['repo_name']) + .eq("file_path", target_filter['file_path']) + .eq("file_sha", target_filter['file_sha']) + .execute() + ) new_commit_list = [ { **{k: v for k, v in item.items() if k != "id"}, @@ -146,7 +157,7 @@ def add_knowledge_by_doc(config: RAGGitDocConfig): "commit_id": loader.commit_id, "file_path": config.file_path, } - for item in is_doc_equal_query.data + for item in insert_docs.data ] insert_result = supabase.table(TABLE_NAME).insert(new_commit_list).execute() return insert_result @@ -169,9 +180,9 @@ def reload_knowledge(config: RAGGitDocConfig): def search_knowledge( - query: str, - repo_name: str, - meta_filter: Dict[str, Any] = {}, + query: str, + repo_name: str, + meta_filter: Dict[str, Any] = {}, ): retriever = init_retriever( {"filter": {"metadata": meta_filter, "repo_name": repo_name}} diff --git a/pyproject.toml b/pyproject.toml index ef116755..6790e442 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "petercat_utils" -version = "0.1.39" +version = "0.1.40" description = "" authors = ["raoha.rh "] readme = "README.md" diff --git a/server/auth/middleware.py b/server/auth/middleware.py index 44a8912e..73244095 100644 --- a/server/auth/middleware.py +++ b/server/auth/middleware.py @@ -1,88 +1,87 @@ import traceback from typing import Awaitable, Callable + from fastapi import HTTPException, Request, status from fastapi.responses import JSONResponse -from petercat_utils import get_env_variable +from fastapi.security import OAuth2PasswordBearer from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response -from fastapi.security import OAuth2PasswordBearer from core.dao.botDAO import BotDAO - -WEB_URL = get_env_variable("WEB_URL") -ENVRIMENT = get_env_variable("PETERCAT_ENV", "development") +from env import ENVIRONMENT, WEB_URL ALLOW_LIST = [ - "/", - "/favicon.ico", - "/api/health_checker", - "/api/bot/list", - "/api/bot/detail", - "/api/github/app/webhook", - "/app/installation/callback", + "/", + "/favicon.ico", + "/api/health_checker", + "/api/bot/list", + "/api/bot/detail", + "/api/github/app/webhook", + "/app/installation/callback", ] ANONYMOUS_USER_ALLOW_LIST = [ - "/api/auth/userinfo", - "/api/chat/qa", - "/api/chat/stream_qa", + "/api/auth/userinfo", + "/api/chat/qa", + "/api/chat/stream_qa", ] oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/token") + class AuthMiddleWare(BaseHTTPMiddleware): - async def oauth(self, request: Request): - try: - referer = request.headers.get('referer') - origin = request.headers.get('origin') - if referer and referer.startswith(WEB_URL): - return True - - token = await oauth2_scheme(request=request) - if token: - bot_dao = BotDAO() - bot = bot_dao.get_bot(bot_id=token) - return bot and ( - "*" in bot.domain_whitelist - or - origin in bot.domain_whitelist - ) - except HTTPException: - return False - - async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: - try: - # if ENVRIMENT == "development": - # return await call_next(request) - - # Auth 相关的直接放过 - if request.url.path.startswith("/api/auth"): - return await call_next(request) - - if request.url.path in ALLOW_LIST: - return await call_next(request) - - if await self.oauth(request=request): - return await call_next(request) - - # 获取 session 中的用户信息 - user = request.session.get("user") - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") - - if user['sub'].startswith("client|"): - if request.url.path in ANONYMOUS_USER_ALLOW_LIST: - return await call_next(request) - else: - # 如果没有用户信息,返回 401 Unauthorized 错误 - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow") - - return await call_next(request) - except HTTPException as e: - print(traceback.format_exception(e)) - # 处理 HTTP 异常 - return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) - except Exception as e: - # 处理其他异常 - return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"}) + async def oauth(self, request: Request): + try: + referer = request.headers.get('referer') + origin = request.headers.get('origin') + if referer and referer.startswith(WEB_URL): + return True + + token = await oauth2_scheme(request=request) + if token: + bot_dao = BotDAO() + bot = bot_dao.get_bot(bot_id=token) + return bot and ( + "*" in bot.domain_whitelist + or + origin in bot.domain_whitelist + ) + except HTTPException: + return False + + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: + try: + if ENVIRONMENT == "development": + return await call_next(request) + + # Auth 相关的直接放过 + if request.url.path.startswith("/api/auth"): + return await call_next(request) + + if request.url.path in ALLOW_LIST: + return await call_next(request) + + if await self.oauth(request=request): + return await call_next(request) + + # 获取 session 中的用户信息 + user = request.session.get("user") + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized") + + if user['sub'].startswith("client|"): + if request.url.path in ANONYMOUS_USER_ALLOW_LIST: + return await call_next(request) + else: + # 如果没有用户信息,返回 401 Unauthorized 错误 + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Anonymous User Not Allow") + + return await call_next(request) + except HTTPException as e: + print(traceback.format_exception(e)) + # 处理 HTTP 异常 + return JSONResponse(status_code=e.status_code, content={"detail": e.detail}) + except Exception as e: + # 处理其他异常 + return JSONResponse(status_code=500, content={"detail": f"Internal Server Error: {e}"}) diff --git a/server/auth/router.py b/server/auth/router.py index 3d4e7305..240d9d40 100644 --- a/server/auth/router.py +++ b/server/auth/router.py @@ -1,19 +1,18 @@ -from github import Github -from core.dao.profilesDAO import ProfilesDAO +import secrets +from typing import Annotated, Optional + +from authlib.integrations.starlette_client import OAuth from fastapi import APIRouter, Request, HTTPException, status, Depends from fastapi.responses import RedirectResponse, JSONResponse -import secrets -from petercat_utils import get_client, get_env_variable +from github import Github from starlette.config import Config -from authlib.integrations.starlette_client import OAuth -from typing import Annotated, Optional +from auth.get_user_info import generateAnonymousUser, getUserInfoByToken, get_user_id from auth.get_user_info import ( - generateAnonymousUser, getUserAccessToken, - getUserInfoByToken, - get_user_id, ) +from core.dao.profilesDAO import ProfilesDAO +from petercat_utils import get_client, get_env_variable AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN") @@ -26,6 +25,7 @@ LOGIN_URL = f"{API_URL}/api/auth/login" WEB_URL = get_env_variable("WEB_URL") + WEB_LOGIN_SUCCESS_URL = f"{WEB_URL}/user/login" MARKET_URL = f"{WEB_URL}/market" @@ -133,8 +133,8 @@ async def get_agreement_status(user_id: Optional[str] = Depends(get_user_id)): @router.post("/accept/agreement", status_code=200) async def bot_generator( - request: Request, - user_id: Annotated[str | None, Depends(get_user_id)] = None, + request: Request, + user_id: Annotated[str | None, Depends(get_user_id)] = None, ): if not user_id: raise HTTPException(status_code=401, detail="User not found") diff --git a/server/env.py b/server/env.py new file mode 100644 index 00000000..72f79905 --- /dev/null +++ b/server/env.py @@ -0,0 +1,6 @@ +# list all env variables +from petercat_utils import get_env_variable + +WEB_URL = get_env_variable("WEB_URL") +ENVIRONMENT = get_env_variable("PETERCAT_ENV", "development") +API_URL = get_env_variable("API_URL") diff --git a/server/github_app/router.py b/server/github_app/router.py index 01e4ea94..981ddf94 100644 --- a/server/github_app/router.py +++ b/server/github_app/router.py @@ -1,4 +1,6 @@ +import logging from typing import Annotated + from fastapi import ( APIRouter, BackgroundTasks, @@ -8,27 +10,24 @@ Request, status, ) -import logging from fastapi.responses import RedirectResponse - from github import Auth, Github + from auth.get_user_info import get_user from core.dao.repositoryConfigDAO import RepositoryConfigDAO from core.models.bot import RepoBindBotRequest from core.models.user import User - +from env import WEB_URL from github_app.handlers import get_handler from github_app.purchased import PurchaseServer from github_app.utils import ( get_private_key, ) - from petercat_utils import get_env_variable REGIN_NAME = get_env_variable("AWS_REGION") AWS_GITHUB_SECRET_NAME = get_env_variable("AWS_GITHUB_SECRET_NAME") APP_ID = get_env_variable("X_GITHUB_APP_ID") -WEB_URL = get_env_variable("WEB_URL") logger = logging.getLogger() logger.setLevel("INFO") @@ -51,9 +50,9 @@ def github_app_callback(code: str, installation_id: str, setup_action: str): @router.post("/app/webhook") async def github_app_webhook( - request: Request, - background_tasks: BackgroundTasks, - x_github_event: str = Header(...), + request: Request, + background_tasks: BackgroundTasks, + x_github_event: str = Header(...), ): payload = await request.json() if x_github_event == "marketplace_purchase": @@ -86,7 +85,7 @@ async def github_app_webhook( @router.get("/user/repos_installed_app") def get_user_repos_installed_app( - user: Annotated[User | None, Depends(get_user)] = None + user: Annotated[User | None, Depends(get_user)] = None ): """ Get github user installed app repositories which saved in platform database. @@ -116,8 +115,8 @@ def get_user_repos_installed_app( @router.post("/repo/bind_bot", status_code=200) def bind_bot_to_repo( - request: RepoBindBotRequest, - user: Annotated[User | None, Depends(get_user)] = None, + request: RepoBindBotRequest, + user: Annotated[User | None, Depends(get_user)] = None, ): if user is None: raise HTTPException( diff --git a/server/main.py b/server/main.py index a5d747d6..0d4c82fc 100644 --- a/server/main.py +++ b/server/main.py @@ -1,34 +1,29 @@ import os -from fastapi.responses import RedirectResponse import uvicorn - from fastapi import FastAPI -from starlette.middleware.sessions import SessionMiddleware from fastapi.middleware.cors import CORSMiddleware -from auth.cors_middleware import AuthCORSMiddleWare -from i18n.translations import I18nConfig, I18nMiddleware - -from auth.middleware import AuthMiddleWare -from petercat_utils import get_env_variable - +from fastapi.responses import RedirectResponse +from starlette.middleware.sessions import SessionMiddleware # Import fastapi routers from auth import router as auth_router +from auth.cors_middleware import AuthCORSMiddleWare +from auth.middleware import AuthMiddleWare +from aws import router as aws_router from bot import router as bot_router from chat import router as chat_router +from env import ENVIRONMENT, API_URL, WEB_URL +from github_app import router as github_app_router +from i18n.translations import I18nConfig, I18nMiddleware +from petercat_utils import get_env_variable from rag import router as rag_router from task import router as task_router -from github_app import router as github_app_router -from aws import router as aws_router from user import router as user_router AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN") API_AUDIENCE = get_env_variable("API_IDENTIFIER") CLIENT_ID = get_env_variable("AUTH0_CLIENT_ID") -API_URL = get_env_variable("API_URL") -WEB_URL = get_env_variable("WEB_URL") -ENVRIMENT = get_env_variable("PETERCAT_ENV", "development") CALLBACK_URL = f"{API_URL}/api/auth/callback" is_dev = bool(get_env_variable("IS_DEV")) @@ -77,7 +72,7 @@ def home_page(): @app.get("/api/health_checker") def health_checker(): return { - "ENVRIMENT": ENVRIMENT, + "ENVIRONMENT": ENVIRONMENT, "API_URL": API_URL, "WEB_URL": WEB_URL, "CALLBACK_URL": CALLBACK_URL, diff --git a/server/tests/test_main.py b/server/tests/test_main.py index 3eb22630..e8d343e0 100644 --- a/server/tests/test_main.py +++ b/server/tests/test_main.py @@ -1,18 +1,16 @@ from fastapi.testclient import TestClient + +from env import ENVIRONMENT, WEB_URL, API_URL from petercat_utils import get_env_variable from main import app -API_URL = get_env_variable("API_URL") -WEB_URL = get_env_variable("WEB_URL") -ENVRIMENT = get_env_variable("PETERCAT_ENV", "development") - client = TestClient(app) def test_health_checker(): response = client.get("/api/health_checker") assert response.status_code == 200 assert response.json() == { - 'ENVRIMENT': ENVRIMENT, + 'ENVIRONMENT': ENVIRONMENT, 'API_URL': API_URL, 'CALLBACK_URL': f'{API_URL}/api/auth/callback', 'WEB_URL': WEB_URL, diff --git a/subscriber/requirements.txt b/subscriber/requirements.txt index 23c995c0..3ff51c19 100644 --- a/subscriber/requirements.txt +++ b/subscriber/requirements.txt @@ -1 +1 @@ -petercat_utils>=0.1.39 +petercat_utils>=0.1.40