diff --git a/.gitignore b/.gitignore index 40e0ea72..8ad9c20e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,13 @@ # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. __pycache__/ -.pyc +*.pyc # dependencies /node_modules /.pnp .pnp.js +server/temp/ + # testing /coverage @@ -39,3 +41,4 @@ next-env.d.ts .yarn /server/.aws-sam/* .aws-sam/* + diff --git a/app/api/chat/retrieval/route.ts b/app/api/chat/retrieval/route.ts index 33094a3d..b0d5d4f1 100644 --- a/app/api/chat/retrieval/route.ts +++ b/app/api/chat/retrieval/route.ts @@ -7,10 +7,7 @@ import { ChatOpenAI } from 'langchain/chat_models/openai'; import { PromptTemplate } from 'langchain/prompts'; import { SupabaseVectorStore } from 'langchain/vectorstores/supabase'; import { Document } from 'langchain/document'; -import { - RunnableSequence, - RunnablePassthrough, -} from 'langchain/schema/runnable'; +import { RunnableSequence } from 'langchain/schema/runnable'; import { BytesOutputParser, StringOutputParser, diff --git a/server/Dockerfile.aws.lambda b/server/Dockerfile.aws.lambda index 8e5f0560..77fc2e0b 100644 --- a/server/Dockerfile.aws.lambda +++ b/server/Dockerfile.aws.lambda @@ -8,4 +8,14 @@ COPY . ${LAMBDA_TASK_ROOT} COPY requirements.txt . RUN pip3 install -r requirements.txt --target "${LAMBDA_TASK_ROOT}" -U --no-cache-dir +# Setup NLTK again in system path to execute nltk.downloader +RUN pip install nltk +# Setup directory for NLTK_DATA +RUN mkdir -p /opt/nltk_data + +# Download NLTK_DATA to build directory +RUN python -W ignore -m nltk.downloader punkt -d /opt/nltk_data +RUN python -W ignore -m nltk.downloader stopwords -d /opt/nltk_data +RUN python -W ignore -m nltk.downloader averaged_perceptron_tagger -d /opt/nltk_data + CMD ["python", "main.py"] \ No newline at end of file diff --git a/server/data_class.py b/server/data_class.py index 8ea7fd8f..28c04f6a 100644 --- a/server/data_class.py +++ b/server/data_class.py @@ -1,3 +1,4 @@ +from typing import Optional from pydantic import BaseModel @@ -17,4 +18,8 @@ class ChatData(BaseModel): class ExecuteMessage(BaseModel): type: str repo: str - path: str \ No newline at end of file + path: str + +class S3Config(BaseModel): + s3_bucket: str + file_path: Optional[str] = None diff --git a/server/main.py b/server/main.py index 1b4e6785..f68220d9 100644 --- a/server/main.py +++ b/server/main.py @@ -1,5 +1,5 @@ import os -from rag import retrieval + import uvicorn from fastapi import FastAPI from fastapi.responses import StreamingResponse @@ -11,7 +11,7 @@ from data_class import ChatData # Import fastapi routers -from routers import health_checker, github +from routers import health_checker, github, rag open_api_key = get_env_variable("OPENAI_API_KEY") is_dev = bool(get_env_variable("IS_DEV")) @@ -33,25 +33,17 @@ app.include_router(health_checker.router) app.include_router(github.router) - +app.include_router(rag.router) @app.post("/api/chat/stream", response_class=StreamingResponse) def run_agent_chat(input_data: ChatData): result = stream.agent_chat(input_data, open_api_key) return StreamingResponse(result, media_type="text/event-stream") -@app.post("/api/rag/add_knowledge") -def add_knowledge(): - data=retrieval.add_knowledge() - return data -@app.post("/api/rag/search_knowledge") -def search_knowledge(query: str): - data=retrieval.search_knowledge(query) - return data if __name__ == "__main__": if is_dev: uvicorn.run("main:app", host="0.0.0.0", port=int(os.environ.get("PORT", "8080")), reload=True) else: - uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080"))) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", "8080"))) diff --git a/server/rag/retrieval.py b/server/rag/retrieval.py index 2c48f163..ed9a8ee3 100644 --- a/server/rag/retrieval.py +++ b/server/rag/retrieval.py @@ -1,17 +1,22 @@ -import os import json -from langchain_community.document_loaders import TextLoader +import boto3 from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import CharacterTextSplitter from langchain_community.vectorstores import SupabaseVectorStore from db.supabase.client import get_client +from data_class import S3Config from uilts.env import get_env_variable +from langchain_community.document_loaders import S3DirectoryLoader + supabase_url = get_env_variable("SUPABASE_URL") supabase_key = get_env_variable("SUPABASE_SERVICE_KEY") + + table_name="antd_knowledge" query_name="match_antd_knowledge" -chunk_size=500 +chunk_size=2000 + def convert_document_to_dict(document): return { @@ -32,36 +37,33 @@ def init_retriever(): return db.as_retriever() -def add_knowledge(): - current_dir = os.path.dirname(os.path.abspath(__file__)) - target_file_path = os.path.join(current_dir, "../docs/test.md") - loader = TextLoader(target_file_path) - documents = loader.load() - text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) - docs = text_splitter.split_documents(documents) - embeddings = OpenAIEmbeddings() +def add_knowledge(config: S3Config): try: - SupabaseVectorStore.from_documents( - docs, - embeddings, - client=supabase, - table_name=table_name, - query_name=query_name, - chunk_size=chunk_size, - ) - return json.dumps({ - "success": True, - "message": "Knowledge added successfully!" - }) + loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path) + documents = loader.load() + text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=0) + docs = text_splitter.split_documents(documents) + embeddings = OpenAIEmbeddings() + SupabaseVectorStore.from_documents( + docs, + embeddings, + client=get_client(), + table_name=table_name, + query_name=query_name, + chunk_size=chunk_size, + ) + return json.dumps({ + "success": True, + "message": "Knowledge added successfully!", + "docs_len": len(documents) + }) except Exception as e: - return json.dumps({ - "success": False, - "message": str(e) - }) + return json.dumps({ + "success": False, + "message": str(e) + }) - - def search_knowledge(query: str): retriever = init_retriever() docs = retriever.get_relevant_documents(query) diff --git a/server/requirements.txt b/server/requirements.txt index 854dfc4c..8b851929 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -15,4 +15,5 @@ load_dotenv supabase boto3>=1.34.84 pyjwt>=2.4.0 -pydantic>=2.7.0 \ No newline at end of file +pydantic>=2.7.0 +unstructured[md] diff --git a/server/routers/github.py b/server/routers/github.py index 474ddd26..647852e2 100644 --- a/server/routers/github.py +++ b/server/routers/github.py @@ -18,7 +18,7 @@ router = APIRouter( prefix="/api/github", - tags=["health_checkers"], + tags=["github"], responses={404: {"description": "Not found"}}, ) diff --git a/server/routers/health_checker.py b/server/routers/health_checker.py index e4c078f3..243b8b26 100644 --- a/server/routers/health_checker.py +++ b/server/routers/health_checker.py @@ -8,4 +8,4 @@ @router.get("/health_checker") def health_checker(): - return {"Hello": "World"} \ No newline at end of file + return { "Hello": "World" } \ No newline at end of file diff --git a/server/routers/rag.py b/server/routers/rag.py new file mode 100644 index 00000000..fded286c --- /dev/null +++ b/server/routers/rag.py @@ -0,0 +1,20 @@ +from fastapi import APIRouter +from rag import retrieval +from data_class import S3Config + +router = APIRouter( + prefix="/api", + tags=["rag"], + responses={404: {"description": "Not found"}}, +) + + +@router.post("/rag/add_knowledge") +def add_knowledge(config: S3Config): + data=retrieval.add_knowledge(config) + return data + +@router.post("/rag/search_knowledge") +def search_knowledge(query: str): + data=retrieval.search_knowledge(query) + return data diff --git a/server/tools/issue.py b/server/tools/issue.py index 0a4342fc..525a55b4 100644 --- a/server/tools/issue.py +++ b/server/tools/issue.py @@ -2,7 +2,6 @@ from typing import Optional from github import Github from langchain.tools import tool -from uilts.env import get_env_variable DEFAULT_REPO_NAME = "ant-design/ant-design" @@ -84,15 +83,15 @@ def search_issues( :param state: The state of the issue, e.g: open, closed, all """ try: - search_query = f'{keyword} in:title,body,comments repo:{repo_name}' + search_query = f"{keyword} in:title,body,comments repo:{repo_name}" # Retrieve a list of open issues from the repository issues = g.search_issues(query=search_query, sort=sort, order=order)[:max_num] print(f"issues: {issues}") issues_list = [ { - 'issue_name': f'Issue #{issue.number} - {issue.title}', - 'issue_url': issue.html_url + "issue_name": f"Issue #{issue.number} - {issue.title}", + "issue_url": issue.html_url } for issue in issues ] diff --git a/server/tools/sourcecode.py b/server/tools/sourcecode.py index c0bf304f..e0a1cdd4 100644 --- a/server/tools/sourcecode.py +++ b/server/tools/sourcecode.py @@ -2,8 +2,6 @@ from github import Github from github.ContentFile import ContentFile from langchain.tools import tool -from uilts.env import get_env_variable - DEFAULT_REPO_NAME = "ant-design/ant-design" @@ -29,7 +27,8 @@ def search_code( # Perform the search for code files containing the keyword code_files = g.search_code(query=query)[:max_num] - return code_files + + return code_files except Exception as e: print(f"An error occurred: {e}") return None