From 66928f5a25f9317030a287565e033b8e98f9c389 Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sat, 17 Aug 2024 13:13:11 +0800 Subject: [PATCH 1/6] chore: move issue impl to petercat utils --- petercat_utils/data_class.py | 8 +++++++- .../rag_helper}/issue_retrieval.py | 3 +-- server/cats/data_class.py | 10 ---------- server/routers/rag.py | 7 ++----- 4 files changed, 10 insertions(+), 18 deletions(-) rename {server/cats => petercat_utils/rag_helper}/issue_retrieval.py (98%) delete mode 100644 server/cats/data_class.py diff --git a/petercat_utils/data_class.py b/petercat_utils/data_class.py index 4f62364e..77eeb568 100644 --- a/petercat_utils/data_class.py +++ b/petercat_utils/data_class.py @@ -82,7 +82,6 @@ class GitDocConfig(BaseModel): class RAGGitDocConfig(GitDocConfig): bot_id: str - class TaskStatus(Enum): NOT_STARTED = auto() IN_PROGRESS = auto() @@ -90,3 +89,10 @@ class TaskStatus(Enum): ON_HOLD = auto() CANCELLED = auto() ERROR = auto() + +class GitIssueConfig(BaseModel): + repo_name: str + issue_id: str + +class RAGIssueDocConfig(GitIssueConfig): + bot_id: str diff --git a/server/cats/issue_retrieval.py b/petercat_utils/rag_helper/issue_retrieval.py similarity index 98% rename from server/cats/issue_retrieval.py rename to petercat_utils/rag_helper/issue_retrieval.py index 0c68b4e5..08e72875 100644 --- a/server/cats/issue_retrieval.py +++ b/petercat_utils/rag_helper/issue_retrieval.py @@ -6,8 +6,7 @@ from langchain_core.documents import Document from langchain_openai import OpenAIEmbeddings from petercat_utils import get_client - -from .data_class import RAGIssueDocConfig +from petercat_utils.data_class import RAGIssueDocConfig g = Github() diff --git a/server/cats/data_class.py b/server/cats/data_class.py deleted file mode 100644 index 8feae606..00000000 --- a/server/cats/data_class.py +++ /dev/null @@ -1,10 +0,0 @@ -from pydantic import BaseModel - - -class GitIssueConfig(BaseModel): - repo_name: str - issue_id: str - - -class RAGIssueDocConfig(GitIssueConfig): - bot_id: str diff --git a/server/routers/rag.py b/server/routers/rag.py index ae16954e..ea193d2d 100644 --- a/server/routers/rag.py +++ b/server/routers/rag.py @@ -2,11 +2,8 @@ from typing import Optional from fastapi import APIRouter, Depends -from petercat_utils.data_class import RAGGitDocConfig -from petercat_utils.rag_helper import retrieval, task - -from cats import issue_retrieval -from cats.data_class import RAGIssueDocConfig +from petercat_utils.data_class import RAGGitDocConfig, RAGIssueDocConfig +from petercat_utils.rag_helper import retrieval, task, issue_retrieval from verify.rate_limit import verify_rate_limit router = APIRouter( From 3dbe11c3a0a6b5a056434ea1ebc7d4c48cd6fa83 Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sat, 17 Aug 2024 13:38:00 +0800 Subject: [PATCH 2/6] chore: rename issue_docs to rag_issues --- petercat_utils/rag_helper/issue_retrieval.py | 4 ++-- server/sql/rag_docs.sql | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/petercat_utils/rag_helper/issue_retrieval.py b/petercat_utils/rag_helper/issue_retrieval.py index 08e72875..ca6c72c0 100644 --- a/petercat_utils/rag_helper/issue_retrieval.py +++ b/petercat_utils/rag_helper/issue_retrieval.py @@ -10,8 +10,8 @@ g = Github() -TABLE_NAME = "issue_docs" -QUERY_NAME = "match_issue_docs" +TABLE_NAME = "rag_issues" +QUERY_NAME = "match_rag_issues" CHUNK_SIZE = 2000 CHUNK_OVERLAP = 200 diff --git a/server/sql/rag_docs.sql b/server/sql/rag_docs.sql index 75528710..8bb667ef 100644 --- a/server/sql/rag_docs.sql +++ b/server/sql/rag_docs.sql @@ -23,7 +23,7 @@ create table rag_docs file_path varchar ); -create table issue_docs +create table rag_issues ( id uuid primary key, content text, @@ -72,9 +72,9 @@ $$; -- Drop the existing function if it already exists -drop function if exists match_issue_docs; +drop function if exists match_rag_issues; -create function match_issue_docs +create function match_rag_issues ( query_embedding vector (1536), filter jsonb default '{}' @@ -94,11 +94,11 @@ begin content, metadata, embedding, - 1 - (issue_docs.embedding <=> query_embedding + 1 - (rag_issues.embedding <=> query_embedding ) as similarity - from issue_docs + from rag_issues where metadata @> jsonb_extract_path(filter, 'metadata') and bot_id = jsonb_extract_path_text(filter, 'bot_id') - order by issue_docs.embedding <=> query_embedding; + order by rag_issues.embedding <=> query_embedding; end; $$; From 0ad441f3e81f83a7df115dbb1609ea0749095bb6 Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sat, 17 Aug 2024 16:56:09 +0800 Subject: [PATCH 3/6] feat: split git doc task --- petercat_utils/data_class.py | 5 +- petercat_utils/rag_helper/git_doc_task.py | 157 +++++++++++++++ petercat_utils/rag_helper/issue_retrieval.py | 4 +- petercat_utils/rag_helper/task.py | 190 ++++++------------- server/bot/builder.py | 33 ++-- server/routers/rag.py | 8 +- 6 files changed, 245 insertions(+), 152 deletions(-) create mode 100644 petercat_utils/rag_helper/git_doc_task.py diff --git a/petercat_utils/data_class.py b/petercat_utils/data_class.py index 77eeb568..a942d78e 100644 --- a/petercat_utils/data_class.py +++ b/petercat_utils/data_class.py @@ -90,9 +90,12 @@ class TaskStatus(Enum): CANCELLED = auto() ERROR = auto() +class TaskType(Enum): + GitDoc = auto() + class GitIssueConfig(BaseModel): repo_name: str issue_id: str -class RAGIssueDocConfig(GitIssueConfig): +class RAGGitIssueConfig(GitIssueConfig): bot_id: str diff --git a/petercat_utils/rag_helper/git_doc_task.py b/petercat_utils/rag_helper/git_doc_task.py new file mode 100644 index 00000000..a477d978 --- /dev/null +++ b/petercat_utils/rag_helper/git_doc_task.py @@ -0,0 +1,157 @@ +from typing import Optional, Dict + +from github import Github, Repository +from petercat_utils.data_class import TaskType + +from .task import GitTask +from ..data_class import RAGGitDocConfig, TaskStatus, TaskType +from ..db.client.supabase import get_client + +g = Github() + +TABLE_NAME = "rag_tasks" + + +class GitDocTask(GitTask): + def __init__(self, + commit_id, + node_type, + sha, + bot_id, + path, + repo_name, + status=TaskStatus.NOT_STARTED, + from_id=None, + id=None + ): + super().__init__(bot_id=bot_id, type=TaskType.GitDoc, from_id=from_id, id=id, status=status, + repo_name=repo_name) + self.commit_id = commit_id + self.node_type = node_type + self.sha = sha + self.path = path + + def extra_save_data(self): + data = { + "commit_id": self.commit_id, + "node_type": self.node_type, + "path": self.path, + "sha": self.sha, + } + return data + + +def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None): + if not path: + return sha + else: + tree_data = repo.get_git_tree(sha) + for item in tree_data.tree: + if path.split("/")[0] == item.path: + return get_path_sha(repo, item.sha, "/".join(path.split("/")[1:])) + + +def add_rag_git_doc_task(config: RAGGitDocConfig, + extra: Optional[Dict[str, Optional[str]]] = { + "node_type": None, + "from_task_id": None, + } + ): + repo = g.get_repo(config.repo_name) + + commit_id = ( + config.commit_id + if config.commit_id + else repo.get_branch(config.branch).commit.sha + ) + if config.file_path == "" or config.file_path is None: + extra["node_type"] = "tree" + + if not extra.get("node_type"): + content = repo.get_contents(config.file_path, ref=commit_id) + if isinstance(content, list): + extra["node_type"] = "tree" + else: + extra["node_type"] = "blob" + + sha = get_path_sha(repo, commit_id, config.file_path) + + doc_task = GitDocTask(commit_id=commit_id, + sha=sha, + repo_name=config.repo_name, + node_type=extra["node_type"], + bot_id=config.bot_id, + path=config.file_path) + res = doc_task.save() + doc_task.send() + return res + +def handle_tree_task(task): + supabase = get_client() + ( + supabase.table(TABLE_NAME) + .update({"status": TaskStatus.IN_PROGRESS.name}) + .eq("id", task["id"]) + .execute() + ) + + repo = g.get_repo(task["repo_name"]) + tree_data = repo.get_git_tree(task["sha"]) + + task_list = list( + filter( + lambda item: item["path"].endswith(".md") or item["node_type"] == "tree", + map( + lambda item: { + "repo_name": task["repo_name"], + "commit_id": task["commit_id"], + "status": TaskStatus.NOT_STARTED.name, + "node_type": item.type, + "from_task_id": task["id"], + "path": "/".join(filter(lambda s: s, [task["path"], item.path])), + "sha": item.sha, + "bot_id": task["bot_id"], + }, + tree_data.tree, + ), + ) + ) + + if len(task_list) > 0: + result = supabase.table(TABLE_NAME).insert(task_list).execute() + + for record in result.data: + task_id = record["id"] + message_id = send_task_message(task_id=task_id) + print(f"record={record}, task_id={task_id}, message_id={message_id}") + + return (supabase.table(TABLE_NAME).update( + {"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))}, + "status": TaskStatus.COMPLETED.name}) + .eq("id", task["id"]) + .execute()) + + +def handle_blob_task(task): + supabase = get_client() + ( + supabase.table(TABLE_NAME) + .update({"status": TaskStatus.IN_PROGRESS.name}) + .eq("id", task["id"]) + .execute() + ) + + retrieval.add_knowledge_by_doc( + RAGGitDocConfig( + repo_name=task["repo_name"], + file_path=task["path"], + commit_id=task["commit_id"], + bot_id=task["bot_id"], + ) + ) + return ( + supabase.table(TABLE_NAME) + .update({"status": TaskStatus.COMPLETED.name}) + .eq("id", task["id"]) + .execute() + ) diff --git a/petercat_utils/rag_helper/issue_retrieval.py b/petercat_utils/rag_helper/issue_retrieval.py index ca6c72c0..defb843d 100644 --- a/petercat_utils/rag_helper/issue_retrieval.py +++ b/petercat_utils/rag_helper/issue_retrieval.py @@ -6,7 +6,7 @@ from langchain_core.documents import Document from langchain_openai import OpenAIEmbeddings from petercat_utils import get_client -from petercat_utils.data_class import RAGIssueDocConfig +from petercat_utils.data_class import RAGGitIssueConfig g = Github() @@ -76,7 +76,7 @@ def get_issue_document_list(issue: Issue): return document_list -def add_knowledge_by_issue(config: RAGIssueDocConfig): +def add_knowledge_by_issue(config: RAGGitIssueConfig): supabase = get_client() is_added_query = ( supabase.table(TABLE_NAME) diff --git a/petercat_utils/rag_helper/task.py b/petercat_utils/rag_helper/task.py index 2a8732cd..0f770f31 100644 --- a/petercat_utils/rag_helper/task.py +++ b/petercat_utils/rag_helper/task.py @@ -1,13 +1,16 @@ import json from enum import Enum from typing import Optional, Dict + import boto3 +from petercat_utils.data_class import RAGGitIssueConfig, TaskType # Create SQS client sqs = boto3.client("sqs") from github import Github from github import Repository +from abc import ABC, abstractmethod from ..utils.env import get_env_variable from ..data_class import RAGGitDocConfig, TaskStatus @@ -17,10 +20,67 @@ g = Github() TABLE_NAME = "rag_tasks" +TABLE_NAME_MAP = { + TaskType.GitDoc: 'rag_tasks' +} SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL") +# Base GitTask Class +class GitTask(ABC): + def __init__(self, type, repo_name, bot_id, status=TaskStatus.NOT_STARTED, from_id=None, id=None): + self.type = type + self.id = id + self.from_id = from_id + self.status = status + self.repo_name = repo_name + self.bot_id = bot_id + + @property + def table_name(self): + return TABLE_NAME_MAP[self.type] + + def get_table(self): + supabase = get_client() + supabase.table(self.table_name) + + def update_status(self, status: TaskStatus): + return (self.get_table() + .update({"status": status.name}) + .eq("id", self.id) + .execute()) + + def save(self): + data = { + **self.extra_save_data(), + "repo_name": self.repo_name, + "bot_id": self.bot_id, + "from_task_id": self.from_id, + "status": self.status.name, + } + res = self.get_table().insert(data).execute() + self.id = res.data[0]['id'] + return res + + @abstractmethod + def extra_save_data(self): + pass + + def send(self): + assert self.id, "Task ID needed, save it first" + assert self.type, "Task type needed, set it first" + + response = sqs.send_message( + QueueUrl=SQS_QUEUE_URL, + DelaySeconds=10, + MessageBody=(json.dumps({"task_id": self.id, "task_type": self.type})), + ) + message_id = response["MessageId"] + print(f"task_id={task_id}, message_id={message_id}") + return message_id + + def send_task_message(task_id: str): response = sqs.send_message( QueueUrl=SQS_QUEUE_URL, @@ -30,65 +90,6 @@ def send_task_message(task_id: str): return response["MessageId"] -def add_task( - config: RAGGitDocConfig, - extra: Optional[Dict[str, Optional[str]]] = { - "node_type": None, - "from_task_id": None, - }, -): - repo = g.get_repo(config.repo_name) - commit_id = ( - config.commit_id - if config.commit_id - else repo.get_branch(config.branch).commit.sha - ) - - if config.file_path == "" or config.file_path is None: - extra["node_type"] = "tree" - - if not extra.get("node_type"): - content = repo.get_contents(config.file_path, ref=commit_id) - if isinstance(content, list): - extra["node_type"] = "tree" - else: - extra["node_type"] = "blob" - - sha = get_path_sha(repo, commit_id, config.file_path) - - supabase = get_client() - - data = { - "repo_name": config.repo_name, - "commit_id": commit_id, - "status": TaskStatus.NOT_STARTED.name, - "node_type": extra["node_type"], - "from_task_id": extra["from_task_id"], - "path": config.file_path, - "sha": sha, - "bot_id": config.bot_id, - } - - res = supabase.table(TABLE_NAME).insert(data).execute() - - record = res.data[0] - task_id = record["id"] - - message_id = send_task_message(task_id=task_id) - print(f"record={record}, task_id={task_id}, message_id={message_id}") - return res - - -def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None): - if not path: - return sha - else: - tree_data = repo.get_git_tree(sha) - for item in tree_data.tree: - if path.split("/")[0] == item.path: - return get_path_sha(repo, item.sha, "/".join(path.split("/")[1:])) - - def get_oldest_task(): supabase = get_client() @@ -111,77 +112,6 @@ def get_task_by_id(task_id): return response.data[0] if (len(response.data) > 0) else None -def handle_tree_task(task): - supabase = get_client() - ( - supabase.table(TABLE_NAME) - .update({"status": TaskStatus.IN_PROGRESS.name}) - .eq("id", task["id"]) - .execute() - ) - - repo = g.get_repo(task["repo_name"]) - tree_data = repo.get_git_tree(task["sha"]) - - task_list = list( - filter( - lambda item: item["path"].endswith(".md") or item["node_type"] == "tree", - map( - lambda item: { - "repo_name": task["repo_name"], - "commit_id": task["commit_id"], - "status": TaskStatus.NOT_STARTED.name, - "node_type": item.type, - "from_task_id": task["id"], - "path": "/".join(filter(lambda s: s, [task["path"], item.path])), - "sha": item.sha, - "bot_id": task["bot_id"], - }, - tree_data.tree, - ), - ) - ) - - if len(task_list) > 0: - result = supabase.table(TABLE_NAME).insert(task_list).execute() - - for record in result.data: - task_id = record["id"] - message_id = send_task_message(task_id=task_id) - print(f"record={record}, task_id={task_id}, message_id={message_id}") - - return (supabase.table(TABLE_NAME).update( - {"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))}, - "status": TaskStatus.COMPLETED.name}) - .eq("id", task["id"]) - .execute()) - - -def handle_blob_task(task): - supabase = get_client() - ( - supabase.table(TABLE_NAME) - .update({"status": TaskStatus.IN_PROGRESS.name}) - .eq("id", task["id"]) - .execute() - ) - - retrieval.add_knowledge_by_doc( - RAGGitDocConfig( - repo_name=task["repo_name"], - file_path=task["path"], - commit_id=task["commit_id"], - bot_id=task["bot_id"], - ) - ) - return ( - supabase.table(TABLE_NAME) - .update({"status": TaskStatus.COMPLETED.name}) - .eq("id", task["id"]) - .execute() - ) - - def trigger_task(task_id: Optional[str]): task = get_task_by_id(task_id) if task_id else get_oldest_task() if task is None: diff --git a/server/bot/builder.py b/server/bot/builder.py index 5acbe1b3..dd60e5b1 100644 --- a/server/bot/builder.py +++ b/server/bot/builder.py @@ -1,18 +1,19 @@ from typing import List, Optional + from github import Github from petercat_utils import get_client -from prompts.bot_template import generate_prompt_by_repo_name -from petercat_utils.rag_helper.task import add_task from petercat_utils.data_class import RAGGitDocConfig +from petercat_utils.rag_helper.git_doc_task import add_rag_git_doc_task +from prompts.bot_template import generate_prompt_by_repo_name g = Github() async def bot_info_generator( - uid: str, - repo_name: str, - starters: Optional[List[str]] = None, - hello_message: Optional[str] = None + uid: str, + repo_name: str, + starters: Optional[List[str]] = None, + hello_message: Optional[str] = None ): try: # Step1:Get the repository object @@ -20,16 +21,17 @@ async def bot_info_generator( # Step2: Generate the prompt prompt = generate_prompt_by_repo_name(repo_name) - + # Step3: Generate the bot data bot_data = { - "name": repo.name, + "name": repo.name, "description": repo.description, "avatar": repo.organization.avatar_url if repo.organization else None, "prompt": prompt, "uid": uid, "label": "Assistant", - "starters": starters if starters else [f"介绍一下 {repo.name} 这个项目", f"查看 {repo_name} 的贡献指南", "我该怎样快速上手"], + "starters": starters if starters else [f"介绍一下 {repo.name} 这个项目", f"查看 {repo_name} 的贡献指南", + "我该怎样快速上手"], "public": False, "hello_message": hello_message if hello_message else "我是你专属的答疑机器人,你可以问我关于当前项目的任何问题~" } @@ -39,7 +41,8 @@ async def bot_info_generator( print(f"An error occurred: {e}") return None -def trigger_rag_task (repo_name: str, bot_id: str): + +def trigger_rag_task(repo_name: str, bot_id: str): try: repo = g.get_repo(repo_name) default_branch = repo.default_branch @@ -50,16 +53,16 @@ def trigger_rag_task (repo_name: str, bot_id: str): file_path="", commit_id="", ) - add_task(config) + add_rag_git_doc_task(config) except Exception as e: print(f"trigger_rag_task error: {e}") async def bot_builder( - uid: str, - repo_name: str, - starters: Optional[List[str]] = None, - hello_message: Optional[str] = None + uid: str, + repo_name: str, + starters: Optional[List[str]] = None, + hello_message: Optional[str] = None ): """ create a bot based on the given github repository. diff --git a/server/routers/rag.py b/server/routers/rag.py index ea193d2d..a0965587 100644 --- a/server/routers/rag.py +++ b/server/routers/rag.py @@ -2,8 +2,8 @@ from typing import Optional from fastapi import APIRouter, Depends -from petercat_utils.data_class import RAGGitDocConfig, RAGIssueDocConfig -from petercat_utils.rag_helper import retrieval, task, issue_retrieval +from petercat_utils.data_class import RAGGitDocConfig, RAGGitIssueConfig +from petercat_utils.rag_helper import retrieval, task, issue_retrieval, git_doc_task from verify.rate_limit import verify_rate_limit router = APIRouter( @@ -31,7 +31,7 @@ def add_knowledge_by_doc(config: RAGGitDocConfig): @router.post("/rag/add_knowledge_by_issue", dependencies=[Depends(verify_rate_limit)]) -def add_knowledge_by_issue(config: RAGIssueDocConfig): +def add_knowledge_by_issue(config: RAGGitIssueConfig): try: result = issue_retrieval.add_knowledge_by_issue(config) if result: @@ -56,7 +56,7 @@ def search_knowledge(query: str, bot_id: str, filter: dict = {}): @router.post("/rag/add_task", dependencies=[Depends(verify_rate_limit)]) def add_task(config: RAGGitDocConfig): try: - data = task.add_task(config) + data = git_doc_task.add_rag_git_doc_task(config) return data except Exception as e: return json.dumps({"success": False, "message": str(e)}) From 9e61ec0e4b9051e307f6969046b8dab7e56cd4ed Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sat, 17 Aug 2024 18:27:08 +0800 Subject: [PATCH 4/6] chore(rag): refactor task trigger --- petercat_utils/data_class.py | 39 +++--- petercat_utils/rag_helper/git_doc_task.py | 143 ++++++++++------------ petercat_utils/rag_helper/git_task.py | 83 +++++++++++++ petercat_utils/rag_helper/task.py | 100 +++++---------- server/routers/rag.py | 6 +- subscriber/handler.py | 12 +- 6 files changed, 205 insertions(+), 178 deletions(-) create mode 100644 petercat_utils/rag_helper/git_task.py diff --git a/petercat_utils/data_class.py b/petercat_utils/data_class.py index a942d78e..8510a2d6 100644 --- a/petercat_utils/data_class.py +++ b/petercat_utils/data_class.py @@ -1,11 +1,9 @@ from enum import Enum, auto -from dataclasses import dataclass -from typing import List, Optional -from datetime import datetime -from typing import Literal, Optional, List, TypeAlias, Union -from pydantic import BaseModel +from typing import Literal, Optional, List, TypeAlias from typing import Union +from pydantic import BaseModel + class ImageURL(BaseModel): url: str @@ -59,18 +57,6 @@ class S3Config(BaseModel): file_path: Optional[str] = None -class GitIssueConfig(BaseModel): - repo_name: str - page: Optional[int] = None - """The page number for paginated results. - Defaults to 1 in the GitHub API.""" - per_page: Optional[int] = 30 - """Number of items per page. - Defaults to 30 in the GitHub API.""" - state: Optional[Literal["open", "closed", "all"]] = "all" - """Filter on issue state. Can be one of: 'open', 'closed', 'all'.""" - - class GitDocConfig(BaseModel): repo_name: str """File path of the documentation file. eg:'docs/blog/build-ghost.zh-CN.md'""" @@ -82,7 +68,13 @@ class GitDocConfig(BaseModel): class RAGGitDocConfig(GitDocConfig): bot_id: str -class TaskStatus(Enum): + +class AutoNameEnum(Enum): + def _generate_next_value_(name, start, count, last_values): + return name + + +class TaskStatus(AutoNameEnum): NOT_STARTED = auto() IN_PROGRESS = auto() COMPLETED = auto() @@ -90,12 +82,19 @@ class TaskStatus(Enum): CANCELLED = auto() ERROR = auto() -class TaskType(Enum): - GitDoc = auto() + +class TaskType(AutoNameEnum): + GIT_DOC = auto() + +class GitDocTaskNodeType(AutoNameEnum): + TREE = auto() + BLOB = auto() + class GitIssueConfig(BaseModel): repo_name: str issue_id: str + class RAGGitIssueConfig(GitIssueConfig): bot_id: str diff --git a/petercat_utils/rag_helper/git_doc_task.py b/petercat_utils/rag_helper/git_doc_task.py index a477d978..c7274052 100644 --- a/petercat_utils/rag_helper/git_doc_task.py +++ b/petercat_utils/rag_helper/git_doc_task.py @@ -1,21 +1,18 @@ from typing import Optional, Dict from github import Github, Repository -from petercat_utils.data_class import TaskType -from .task import GitTask -from ..data_class import RAGGitDocConfig, TaskStatus, TaskType -from ..db.client.supabase import get_client +import retrieval +from petercat_utils.rag_helper.git_task import GitTask +from ..data_class import RAGGitDocConfig, TaskStatus, TaskType, GitDocTaskNodeType g = Github() -TABLE_NAME = "rag_tasks" - class GitDocTask(GitTask): def __init__(self, commit_id, - node_type, + node_type: GitDocTaskNodeType, sha, bot_id, path, @@ -24,7 +21,7 @@ def __init__(self, from_id=None, id=None ): - super().__init__(bot_id=bot_id, type=TaskType.GitDoc, from_id=from_id, id=id, status=status, + super().__init__(bot_id=bot_id, type=TaskType.GIT_DOC, from_id=from_id, id=id, status=status, repo_name=repo_name) self.commit_id = commit_id self.node_type = node_type @@ -40,6 +37,66 @@ def extra_save_data(self): } return data + def handle_tree_node(self): + repo = g.get_repo(self.repo_name) + tree_data = repo.get_git_tree(self.sha) + + task_list = list( + filter( + lambda item: item["path"].endswith(".md") or item["node_type"] == "tree", + map( + lambda item: { + "repo_name": self.repo_name, + "commit_id": self.commit_id, + "status": TaskStatus.NOT_STARTED.name, + "node_type": item.type, + "from_task_id": self.id, + "path": "/".join(filter(lambda s: s, [self.path, item.path])), + "sha": item.sha, + "bot_id": self.bot_id, + }, + tree_data.tree, + ), + ) + ) + + if len(task_list) > 0: + result = self.get_table().insert(task_list).execute() + + for record in result.data: + doc_task = GitDocTask(id=record["id"], + commit_id=record["commit_id"], + sha=record["sha"], + repo_name=record["repo_name"], + node_type=record["node_type"], + bot_id=record["bot_id"], + path=record["path"]) + doc_task.send() + + return (self.get_table().update( + {"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))}, + "status": TaskStatus.COMPLETED.name}) + .eq("id", self.id) + .execute()) + + def handle_blob_task(self): + self.update_status(TaskStatus.IN_PROGRESS) + + retrieval.add_knowledge_by_doc( + RAGGitDocConfig( + repo_name=self.repo_name, + file_path=self.path, + commit_id=self.commit_id, + bot_id=self.bot_id, + ) + ) + return self.update_status(TaskStatus.COMPLETED) + + def handle(self): + self.update_status(TaskStatus.IN_PROGRESS) + if self.node_type == GitDocTaskNodeType.TREE.value: + return self.handle_tree_node() + def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None): if not path: @@ -85,73 +142,3 @@ def add_rag_git_doc_task(config: RAGGitDocConfig, res = doc_task.save() doc_task.send() return res - -def handle_tree_task(task): - supabase = get_client() - ( - supabase.table(TABLE_NAME) - .update({"status": TaskStatus.IN_PROGRESS.name}) - .eq("id", task["id"]) - .execute() - ) - - repo = g.get_repo(task["repo_name"]) - tree_data = repo.get_git_tree(task["sha"]) - - task_list = list( - filter( - lambda item: item["path"].endswith(".md") or item["node_type"] == "tree", - map( - lambda item: { - "repo_name": task["repo_name"], - "commit_id": task["commit_id"], - "status": TaskStatus.NOT_STARTED.name, - "node_type": item.type, - "from_task_id": task["id"], - "path": "/".join(filter(lambda s: s, [task["path"], item.path])), - "sha": item.sha, - "bot_id": task["bot_id"], - }, - tree_data.tree, - ), - ) - ) - - if len(task_list) > 0: - result = supabase.table(TABLE_NAME).insert(task_list).execute() - - for record in result.data: - task_id = record["id"] - message_id = send_task_message(task_id=task_id) - print(f"record={record}, task_id={task_id}, message_id={message_id}") - - return (supabase.table(TABLE_NAME).update( - {"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))}, - "status": TaskStatus.COMPLETED.name}) - .eq("id", task["id"]) - .execute()) - - -def handle_blob_task(task): - supabase = get_client() - ( - supabase.table(TABLE_NAME) - .update({"status": TaskStatus.IN_PROGRESS.name}) - .eq("id", task["id"]) - .execute() - ) - - retrieval.add_knowledge_by_doc( - RAGGitDocConfig( - repo_name=task["repo_name"], - file_path=task["path"], - commit_id=task["commit_id"], - bot_id=task["bot_id"], - ) - ) - return ( - supabase.table(TABLE_NAME) - .update({"status": TaskStatus.COMPLETED.name}) - .eq("id", task["id"]) - .execute() - ) diff --git a/petercat_utils/rag_helper/git_task.py b/petercat_utils/rag_helper/git_task.py new file mode 100644 index 00000000..0dd8ba4b --- /dev/null +++ b/petercat_utils/rag_helper/git_task.py @@ -0,0 +1,83 @@ +import json +from abc import ABC, abstractmethod + +import boto3 + +from ..data_class import TaskStatus, TaskType +from ..db.client.supabase import get_client +from ..utils.env import get_env_variable + +sqs = boto3.client("sqs") + +TABLE_NAME_MAP = { + TaskType.GIT_DOC: 'rag_tasks' +} +SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL") + + +# Base GitTask Class +class GitTask(ABC): + type: TaskType + + def __init__(self, type, repo_name, bot_id, status=TaskStatus.NOT_STARTED, from_id=None, id=None): + self.type = type + self.id = id + self.from_id = from_id + self.status = status + self.repo_name = repo_name + self.bot_id = bot_id + + @staticmethod + def get_table_name(type: TaskType): + return TABLE_NAME_MAP[type] + + @property + def table_name(self): + return GitTask.get_table_name(self.type) + + @property + def raw_data(self): + data = { + **self.extra_save_data(), + "repo_name": self.repo_name, + "bot_id": self.bot_id, + "from_task_id": self.from_id, + "status": self.status.name, + } + return data + + def get_table(self): + supabase = get_client() + return supabase.table(self.table_name) + + def update_status(self, status: TaskStatus): + return (self.get_table() + .update({"status": status.name}) + .eq("id", self.id) + .execute()) + + def save(self): + res = self.get_table().insert(self.raw_data).execute() + self.id = res.data[0]['id'] + return res + + @abstractmethod + def extra_save_data(self): + pass + + @abstractmethod + def handle(self): + pass + + def send(self): + assert self.id, "Task ID needed, save it first" + assert self.type, "Task type needed, set it first" + + response = sqs.send_message( + QueueUrl=SQS_QUEUE_URL, + DelaySeconds=10, + MessageBody=(json.dumps({"task_id": self.id, "task_type": self.type})), + ) + message_id = response["MessageId"] + print(f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}") + return message_id diff --git a/petercat_utils/rag_helper/task.py b/petercat_utils/rag_helper/task.py index 0f770f31..31d986d0 100644 --- a/petercat_utils/rag_helper/task.py +++ b/petercat_utils/rag_helper/task.py @@ -1,86 +1,27 @@ import json -from enum import Enum -from typing import Optional, Dict +from typing import Optional import boto3 -from petercat_utils.data_class import RAGGitIssueConfig, TaskType + +from petercat_utils.rag_helper.git_task import GitTask +from .git_doc_task import GitDocTask # Create SQS client sqs = boto3.client("sqs") from github import Github -from github import Repository -from abc import ABC, abstractmethod from ..utils.env import get_env_variable -from ..data_class import RAGGitDocConfig, TaskStatus +from ..data_class import TaskStatus, TaskType, GitDocTaskNodeType from ..db.client.supabase import get_client -from ..rag_helper import retrieval g = Github() TABLE_NAME = "rag_tasks" -TABLE_NAME_MAP = { - TaskType.GitDoc: 'rag_tasks' -} SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL") -# Base GitTask Class -class GitTask(ABC): - def __init__(self, type, repo_name, bot_id, status=TaskStatus.NOT_STARTED, from_id=None, id=None): - self.type = type - self.id = id - self.from_id = from_id - self.status = status - self.repo_name = repo_name - self.bot_id = bot_id - - @property - def table_name(self): - return TABLE_NAME_MAP[self.type] - - def get_table(self): - supabase = get_client() - supabase.table(self.table_name) - - def update_status(self, status: TaskStatus): - return (self.get_table() - .update({"status": status.name}) - .eq("id", self.id) - .execute()) - - def save(self): - data = { - **self.extra_save_data(), - "repo_name": self.repo_name, - "bot_id": self.bot_id, - "from_task_id": self.from_id, - "status": self.status.name, - } - res = self.get_table().insert(data).execute() - self.id = res.data[0]['id'] - return res - - @abstractmethod - def extra_save_data(self): - pass - - def send(self): - assert self.id, "Task ID needed, save it first" - assert self.type, "Task type needed, set it first" - - response = sqs.send_message( - QueueUrl=SQS_QUEUE_URL, - DelaySeconds=10, - MessageBody=(json.dumps({"task_id": self.id, "task_type": self.type})), - ) - message_id = response["MessageId"] - print(f"task_id={task_id}, message_id={message_id}") - return message_id - - def send_task_message(task_id: str): response = sqs.send_message( QueueUrl=SQS_QUEUE_URL, @@ -112,15 +53,32 @@ def get_task_by_id(task_id): return response.data[0] if (len(response.data) > 0) else None -def trigger_task(task_id: Optional[str]): - task = get_task_by_id(task_id) if task_id else get_oldest_task() +def get_task( task_type: TaskType, task_id: str): + supabase = get_client() + response = (supabase.table(GitTask.get_table_name(task_type)) + .select("*") + .eq("id", task_id) + .execute()) + if len(response.data) > 0: + data = response.data[0] + if task_type is TaskType.GIT_DOC: + return GitDocTask( + id=data["id"], + commit_id=data["commit_id"], + sha=data["sha"], + repo_name=data["repo_name"], + node_type=data["node_type"], + bot_id=data["bot_id"], + path=data["path"], + status=data["status"] + ) + + +def trigger_task(task_type: TaskType, task_id: Optional[str]): + task = get_task(task_type, task_id) if task_id else get_oldest_task() if task is None: return task - if task["node_type"] == "tree": - return handle_tree_task(task) - else: - return handle_blob_task(task) - + return task.handle() def get_latest_task_by_bot_id(bot_id: str): supabase = get_client() diff --git a/server/routers/rag.py b/server/routers/rag.py index a0965587..7e92461b 100644 --- a/server/routers/rag.py +++ b/server/routers/rag.py @@ -2,7 +2,7 @@ from typing import Optional from fastapi import APIRouter, Depends -from petercat_utils.data_class import RAGGitDocConfig, RAGGitIssueConfig +from petercat_utils.data_class import RAGGitDocConfig, RAGGitIssueConfig, TaskType from petercat_utils.rag_helper import retrieval, task, issue_retrieval, git_doc_task from verify.rate_limit import verify_rate_limit @@ -63,8 +63,8 @@ def add_task(config: RAGGitDocConfig): @router.post("/rag/trigger_task", dependencies=[Depends(verify_rate_limit)]) -def trigger_task(task_id: Optional[str] = None): - data = task.trigger_task(task_id) +def trigger_task(task_type: TaskType, task_id: Optional[str] = None): + data = task.trigger_task(task_type, task_id) return data diff --git a/subscriber/handler.py b/subscriber/handler.py index 0a443f4d..eb819328 100644 --- a/subscriber/handler.py +++ b/subscriber/handler.py @@ -14,12 +14,12 @@ def lambda_handler(event, context): message_dict = json.loads(body) task_id = message_dict["task_id"] - task = task_helper.get_task_by_id(task_id) - if not (task is None): - if task['node_type'] == 'tree': - task_helper.handle_tree_task(task) - else: - task_helper.handle_blob_task(task) + task_type = message_dict["task_type"] + task = task_helper.get_task(task_type, task_id) + if task is None: + return task + return task.handle() + # process message print(f"message content: message={message_dict}, task_id={task_id}, task={task}") except Exception as e: From 159ad6902fe2d5c09d2223e8708f004570bdff1b Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sun, 18 Aug 2024 04:57:09 +0800 Subject: [PATCH 5/6] fix: node type save and handle --- petercat_utils/data_class.py | 1 + petercat_utils/rag_helper/git_doc_task.py | 108 +++++++++++----------- petercat_utils/rag_helper/git_task.py | 2 +- petercat_utils/rag_helper/task.py | 2 +- server/routers/rag.py | 9 +- subscriber/handler.py | 11 ++- 6 files changed, 71 insertions(+), 62 deletions(-) diff --git a/petercat_utils/data_class.py b/petercat_utils/data_class.py index 8510a2d6..4d5cdc5f 100644 --- a/petercat_utils/data_class.py +++ b/petercat_utils/data_class.py @@ -86,6 +86,7 @@ class TaskStatus(AutoNameEnum): class TaskType(AutoNameEnum): GIT_DOC = auto() + class GitDocTaskNodeType(AutoNameEnum): TREE = auto() BLOB = auto() diff --git a/petercat_utils/rag_helper/git_doc_task.py b/petercat_utils/rag_helper/git_doc_task.py index c7274052..cb59824a 100644 --- a/petercat_utils/rag_helper/git_doc_task.py +++ b/petercat_utils/rag_helper/git_doc_task.py @@ -1,14 +1,62 @@ -from typing import Optional, Dict +from typing import Optional from github import Github, Repository -import retrieval -from petercat_utils.rag_helper.git_task import GitTask +from .git_task import GitTask from ..data_class import RAGGitDocConfig, TaskStatus, TaskType, GitDocTaskNodeType +from ..rag_helper import retrieval g = Github() +def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None): + if not path: + return sha + else: + tree_data = repo.get_git_tree(sha) + for item in tree_data.tree: + if path.split("/")[0] == item.path: + return get_path_sha(repo, item.sha, "/".join(path.split("/")[1:])) + + +def add_rag_git_doc_task(config: RAGGitDocConfig, + extra=None + ): + if extra is None: + extra = { + "node_type": None, + "from_task_id": None, + } + repo = g.get_repo(config.repo_name) + + commit_id = ( + config.commit_id + if config.commit_id + else repo.get_branch(config.branch).commit.sha + ) + if config.file_path == "" or config.file_path is None: + extra["node_type"] = GitDocTaskNodeType.TREE.value + + if not extra.get("node_type"): + content = repo.get_contents(config.file_path, ref=commit_id) + if isinstance(content, list): + extra["node_type"] = GitDocTaskNodeType.TREE.value + else: + extra["node_type"] = GitDocTaskNodeType.BLOB.value + + sha = get_path_sha(repo, commit_id, config.file_path) + + doc_task = GitDocTask(commit_id=commit_id, + sha=sha, + repo_name=config.repo_name, + node_type=extra["node_type"], + bot_id=config.bot_id, + path=config.file_path) + res = doc_task.save() + doc_task.send() + return res + + class GitDocTask(GitTask): def __init__(self, commit_id, @@ -49,7 +97,7 @@ def handle_tree_node(self): "repo_name": self.repo_name, "commit_id": self.commit_id, "status": TaskStatus.NOT_STARTED.name, - "node_type": item.type, + "node_type": (item.type + '').upper(), "from_task_id": self.id, "path": "/".join(filter(lambda s: s, [self.path, item.path])), "sha": item.sha, @@ -79,9 +127,7 @@ def handle_tree_node(self): .eq("id", self.id) .execute()) - def handle_blob_task(self): - self.update_status(TaskStatus.IN_PROGRESS) - + def handle_blob_node(self): retrieval.add_knowledge_by_doc( RAGGitDocConfig( repo_name=self.repo_name, @@ -96,49 +142,7 @@ def handle(self): self.update_status(TaskStatus.IN_PROGRESS) if self.node_type == GitDocTaskNodeType.TREE.value: return self.handle_tree_node() - - -def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None): - if not path: - return sha - else: - tree_data = repo.get_git_tree(sha) - for item in tree_data.tree: - if path.split("/")[0] == item.path: - return get_path_sha(repo, item.sha, "/".join(path.split("/")[1:])) - - -def add_rag_git_doc_task(config: RAGGitDocConfig, - extra: Optional[Dict[str, Optional[str]]] = { - "node_type": None, - "from_task_id": None, - } - ): - repo = g.get_repo(config.repo_name) - - commit_id = ( - config.commit_id - if config.commit_id - else repo.get_branch(config.branch).commit.sha - ) - if config.file_path == "" or config.file_path is None: - extra["node_type"] = "tree" - - if not extra.get("node_type"): - content = repo.get_contents(config.file_path, ref=commit_id) - if isinstance(content, list): - extra["node_type"] = "tree" + elif self.node_type == GitDocTaskNodeType.BLOB.value: + return self.handle_blob_node() else: - extra["node_type"] = "blob" - - sha = get_path_sha(repo, commit_id, config.file_path) - - doc_task = GitDocTask(commit_id=commit_id, - sha=sha, - repo_name=config.repo_name, - node_type=extra["node_type"], - bot_id=config.bot_id, - path=config.file_path) - res = doc_task.save() - doc_task.send() - return res + raise ValueError(f"Unsupported node type [{self.node_type}]") diff --git a/petercat_utils/rag_helper/git_task.py b/petercat_utils/rag_helper/git_task.py index 0dd8ba4b..eb6f465f 100644 --- a/petercat_utils/rag_helper/git_task.py +++ b/petercat_utils/rag_helper/git_task.py @@ -76,7 +76,7 @@ def send(self): response = sqs.send_message( QueueUrl=SQS_QUEUE_URL, DelaySeconds=10, - MessageBody=(json.dumps({"task_id": self.id, "task_type": self.type})), + MessageBody=(json.dumps({"task_id": self.id, "task_type": self.type.value})), ) message_id = response["MessageId"] print(f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}") diff --git a/petercat_utils/rag_helper/task.py b/petercat_utils/rag_helper/task.py index 31d986d0..1010127a 100644 --- a/petercat_utils/rag_helper/task.py +++ b/petercat_utils/rag_helper/task.py @@ -53,7 +53,7 @@ def get_task_by_id(task_id): return response.data[0] if (len(response.data) > 0) else None -def get_task( task_type: TaskType, task_id: str): +def get_task( task_type: TaskType, task_id: str) -> GitTask: supabase = get_client() response = (supabase.table(GitTask.get_table_name(task_type)) .select("*") diff --git a/server/routers/rag.py b/server/routers/rag.py index 7e92461b..d66efee8 100644 --- a/server/routers/rag.py +++ b/server/routers/rag.py @@ -2,9 +2,10 @@ from typing import Optional from fastapi import APIRouter, Depends +from verify.rate_limit import verify_rate_limit + from petercat_utils.data_class import RAGGitDocConfig, RAGGitIssueConfig, TaskType from petercat_utils.rag_helper import retrieval, task, issue_retrieval, git_doc_task -from verify.rate_limit import verify_rate_limit router = APIRouter( prefix="/api", @@ -64,8 +65,10 @@ def add_task(config: RAGGitDocConfig): @router.post("/rag/trigger_task", dependencies=[Depends(verify_rate_limit)]) def trigger_task(task_type: TaskType, task_id: Optional[str] = None): - data = task.trigger_task(task_type, task_id) - return data + try: + task.trigger_task(task_type, task_id) + except Exception as e: + return json.dumps({"success": False, "message": str(e)}) @router.get("/rag/chunk/list", dependencies=[Depends(verify_rate_limit)]) diff --git a/subscriber/handler.py b/subscriber/handler.py index eb819328..7eee0244 100644 --- a/subscriber/handler.py +++ b/subscriber/handler.py @@ -2,28 +2,29 @@ from petercat_utils import task as task_helper + def lambda_handler(event, context): if event: batch_item_failures = [] sqs_batch_response = {} - + for record in event["Records"]: try: body = record["body"] print(f"receive message here: {body}") - + message_dict = json.loads(body) task_id = message_dict["task_id"] task_type = message_dict["task_type"] task = task_helper.get_task(task_type, task_id) if task is None: return task - return task.handle() + task.handle() # process message print(f"message content: message={message_dict}, task_id={task_id}, task={task}") except Exception as e: batch_item_failures.append({"itemIdentifier": record['messageId']}) - + sqs_batch_response["batchItemFailures"] = batch_item_failures - return sqs_batch_response \ No newline at end of file + return sqs_batch_response From 95c74931f7306651494da77ad4696b17d5a83dcc Mon Sep 17 00:00:00 2001 From: MadratJerry Date: Sun, 18 Aug 2024 05:12:39 +0800 Subject: [PATCH 6/6] feat: improve type hint --- petercat_utils/rag_helper/task.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/petercat_utils/rag_helper/task.py b/petercat_utils/rag_helper/task.py index 1010127a..b17eca52 100644 --- a/petercat_utils/rag_helper/task.py +++ b/petercat_utils/rag_helper/task.py @@ -3,8 +3,8 @@ import boto3 -from petercat_utils.rag_helper.git_task import GitTask from .git_doc_task import GitDocTask +from .git_task import GitTask # Create SQS client sqs = boto3.client("sqs") @@ -12,7 +12,7 @@ from github import Github from ..utils.env import get_env_variable -from ..data_class import TaskStatus, TaskType, GitDocTaskNodeType +from ..data_class import TaskStatus, TaskType from ..db.client.supabase import get_client g = Github() @@ -53,7 +53,7 @@ def get_task_by_id(task_id): return response.data[0] if (len(response.data) > 0) else None -def get_task( task_type: TaskType, task_id: str) -> GitTask: +def get_task(task_type: TaskType, task_id: str) -> GitTask: supabase = get_client() response = (supabase.table(GitTask.get_table_name(task_type)) .select("*") @@ -80,6 +80,7 @@ def trigger_task(task_type: TaskType, task_id: Optional[str]): return task return task.handle() + def get_latest_task_by_bot_id(bot_id: str): supabase = get_client() response = (