diff --git a/petercat_utils/__init__.py b/petercat_utils/__init__.py index 9180b088..24c24049 100644 --- a/petercat_utils/__init__.py +++ b/petercat_utils/__init__.py @@ -1,11 +1,15 @@ from .db.client.supabase import get_client +from .rag_helper import github_file_loader, retrieval, issue_retrieval, task, git_task, git_issue_task, git_doc_task from .utils.env import get_env_variable -from .rag_helper import github_file_loader, retrieval, task __all__ = [ - "get_client", - "get_env_variable", - "github_file_loader", - "retrieval", - "task" -] \ No newline at end of file + "get_client", + "get_env_variable", + "github_file_loader", + "retrieval", + "issue_retrieval", + "task", + "git_task", + "git_issue_task", + "git_doc_task" +] diff --git a/petercat_utils/data_class.py b/petercat_utils/data_class.py index 4d5cdc5f..d7208e02 100644 --- a/petercat_utils/data_class.py +++ b/petercat_utils/data_class.py @@ -69,6 +69,15 @@ class RAGGitDocConfig(GitDocConfig): bot_id: str +class GitIssueConfig(BaseModel): + repo_name: str + issue_id: str + + +class RAGGitIssueConfig(GitIssueConfig): + bot_id: str + + class AutoNameEnum(Enum): def _generate_next_value_(name, start, count, last_values): return name @@ -85,6 +94,7 @@ class TaskStatus(AutoNameEnum): class TaskType(AutoNameEnum): GIT_DOC = auto() + GIT_ISSUE = auto() class GitDocTaskNodeType(AutoNameEnum): @@ -92,10 +102,6 @@ class GitDocTaskNodeType(AutoNameEnum): BLOB = auto() -class GitIssueConfig(BaseModel): - repo_name: str - issue_id: str - - -class RAGGitIssueConfig(GitIssueConfig): - bot_id: str +class GitIssueTaskNodeType(AutoNameEnum): + REPO = auto() + ISSUE = auto() diff --git a/petercat_utils/rag_helper/git_doc_task.py b/petercat_utils/rag_helper/git_doc_task.py index cb59824a..d3186638 100644 --- a/petercat_utils/rag_helper/git_doc_task.py +++ b/petercat_utils/rag_helper/git_doc_task.py @@ -79,7 +79,7 @@ def __init__(self, def extra_save_data(self): data = { "commit_id": self.commit_id, - "node_type": self.node_type, + "node_type": self.node_type.value, "path": self.path, "sha": self.sha, } @@ -96,7 +96,7 @@ def handle_tree_node(self): lambda item: { "repo_name": self.repo_name, "commit_id": self.commit_id, - "status": TaskStatus.NOT_STARTED.name, + "status": TaskStatus.NOT_STARTED.value, "node_type": (item.type + '').upper(), "from_task_id": self.id, "path": "/".join(filter(lambda s: s, [self.path, item.path])), @@ -123,7 +123,7 @@ def handle_tree_node(self): return (self.get_table().update( {"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))}, - "status": TaskStatus.COMPLETED.name}) + "status": TaskStatus.COMPLETED.value}) .eq("id", self.id) .execute()) diff --git a/petercat_utils/rag_helper/git_issue_task.py b/petercat_utils/rag_helper/git_issue_task.py new file mode 100644 index 00000000..7a9ce8ff --- /dev/null +++ b/petercat_utils/rag_helper/git_issue_task.py @@ -0,0 +1,101 @@ +from github import Github + +from .git_task import GitTask +from ..data_class import GitIssueTaskNodeType, TaskStatus, TaskType, RAGGitIssueConfig +from ..rag_helper import issue_retrieval + +g = Github() + + +def add_rag_git_issue_task(config: RAGGitIssueConfig): + g.get_repo(config.repo_name) + + issue_task = GitIssueTask( + issue_id='', + node_type=GitIssueTaskNodeType.REPO, + bot_id=config.bot_id, + repo_name=config.repo_name + ) + res = issue_task.save() + issue_task.send() + + return res + + +class GitIssueTask(GitTask): + issue_id: str + node_type: GitIssueTaskNodeType + + def __init__(self, + issue_id, + node_type: GitIssueTaskNodeType, + bot_id, + repo_name, + status=TaskStatus.NOT_STARTED, + from_id=None, + id=None + ): + super().__init__(bot_id=bot_id, type=TaskType.GIT_ISSUE, from_id=from_id, id=id, status=status, + repo_name=repo_name) + self.issue_id = issue_id + self.node_type = node_type + + def extra_save_data(self): + return { + "issue_id": self.issue_id, + "node_type": self.node_type.value, + } + + def handle(self): + self.update_status(TaskStatus.IN_PROGRESS) + if self.node_type == GitIssueTaskNodeType.REPO.value: + return self.handle_repo_node() + elif self.node_type == GitIssueTaskNodeType.ISSUE.value: + return self.handle_issue_node() + else: + raise ValueError(f"Unsupported node type [{self.node_type}]") + + def handle_repo_node(self): + repo = g.get_repo(self.repo_name) + repo.get_issues() + issues = [issue for issue in repo.get_issues()] + task_list = list( + map( + lambda item: { + "repo_name": self.repo_name, + "issue_id": str(item.number), + "status": TaskStatus.NOT_STARTED.value, + "node_type": GitIssueTaskNodeType.ISSUE.value, + "from_task_id": self.id, + "bot_id": self.bot_id, + }, + issues, + ), + ) + if len(task_list) > 0: + result = self.get_table().insert(task_list).execute() + for record in result.data: + issue_task = GitIssueTask(id=record["id"], + issue_id=record["issue_id"], + repo_name=record["repo_name"], + node_type=record["node_type"], + bot_id=record["bot_id"], + status=record["status"], + from_id=record["from_task_id"] + ) + issue_task.send() + + return (self.get_table().update( + {"status": TaskStatus.COMPLETED.value}) + .eq("id", self.id) + .execute()) + + def handle_issue_node(self): + issue_retrieval.add_knowledge_by_issue( + RAGGitIssueConfig( + repo_name=self.repo_name, + bot_id=self.bot_id, + issue_id=self.issue_id + ) + ) + return self.update_status(TaskStatus.COMPLETED) diff --git a/petercat_utils/rag_helper/git_task.py b/petercat_utils/rag_helper/git_task.py index eb6f465f..ea96f968 100644 --- a/petercat_utils/rag_helper/git_task.py +++ b/petercat_utils/rag_helper/git_task.py @@ -10,7 +10,8 @@ sqs = boto3.client("sqs") TABLE_NAME_MAP = { - TaskType.GIT_DOC: 'rag_tasks' + TaskType.GIT_DOC: 'rag_tasks', + TaskType.GIT_ISSUE: 'git_issue_tasks' } SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL") @@ -42,7 +43,7 @@ def raw_data(self): "repo_name": self.repo_name, "bot_id": self.bot_id, "from_task_id": self.from_id, - "status": self.status.name, + "status": self.status.value, } return data @@ -52,7 +53,7 @@ def get_table(self): def update_status(self, status: TaskStatus): return (self.get_table() - .update({"status": status.name}) + .update({"status": status.value}) .eq("id", self.id) .execute()) diff --git a/petercat_utils/rag_helper/task.py b/petercat_utils/rag_helper/task.py index b17eca52..162c1e83 100644 --- a/petercat_utils/rag_helper/task.py +++ b/petercat_utils/rag_helper/task.py @@ -4,6 +4,7 @@ import boto3 from .git_doc_task import GitDocTask +from .git_issue_task import GitIssueTask from .git_task import GitTask # Create SQS client @@ -37,7 +38,7 @@ def get_oldest_task(): response = ( supabase.table(TABLE_NAME) .select("*") - .eq("status", TaskStatus.NOT_STARTED.name) + .eq("status", TaskStatus.NOT_STARTED.value) .order("created_at", desc=False) .limit(1) .execute() @@ -70,7 +71,18 @@ def get_task(task_type: TaskType, task_id: str) -> GitTask: node_type=data["node_type"], bot_id=data["bot_id"], path=data["path"], - status=data["status"] + status=data["status"], + from_id=data["from_task_id"] + ) + if task_type is TaskType.GIT_ISSUE: + return GitIssueTask( + id=data["id"], + issue_id=data["issue_id"], + repo_name=data["repo_name"], + node_type=data["node_type"], + bot_id=data["bot_id"], + status=data["status"], + from_id=data["from_task_id"] ) diff --git a/pyproject.toml b/pyproject.toml index c2e69eee..f71b295f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "petercat_utils" -version = "0.1.22" +version = "0.1.23" description = "" authors = ["raoha.rh "] readme = "README.md" diff --git a/server/bot/builder.py b/server/bot/builder.py index dd60e5b1..c4599362 100644 --- a/server/bot/builder.py +++ b/server/bot/builder.py @@ -3,7 +3,7 @@ from github import Github from petercat_utils import get_client from petercat_utils.data_class import RAGGitDocConfig -from petercat_utils.rag_helper.git_doc_task import add_rag_git_doc_task +from petercat_utils import git_doc_task from prompts.bot_template import generate_prompt_by_repo_name g = Github() @@ -53,7 +53,7 @@ def trigger_rag_task(repo_name: str, bot_id: str): file_path="", commit_id="", ) - add_rag_git_doc_task(config) + git_doc_task.add_rag_git_doc_task(config) except Exception as e: print(f"trigger_rag_task error: {e}") diff --git a/server/routers/rag.py b/server/routers/rag.py index d66efee8..ee4627da 100644 --- a/server/routers/rag.py +++ b/server/routers/rag.py @@ -5,7 +5,7 @@ from verify.rate_limit import verify_rate_limit from petercat_utils.data_class import RAGGitDocConfig, RAGGitIssueConfig, TaskType -from petercat_utils.rag_helper import retrieval, task, issue_retrieval, git_doc_task +from petercat_utils.rag_helper import retrieval, task, issue_retrieval, git_doc_task, git_issue_task router = APIRouter( prefix="/api", @@ -54,14 +54,22 @@ def search_knowledge(query: str, bot_id: str, filter: dict = {}): return data -@router.post("/rag/add_task", dependencies=[Depends(verify_rate_limit)]) -def add_task(config: RAGGitDocConfig): +@router.post("/rag/add_git_doc_task", dependencies=[Depends(verify_rate_limit)]) +def add_git_doc_task(config: RAGGitDocConfig): try: data = git_doc_task.add_rag_git_doc_task(config) return data except Exception as e: return json.dumps({"success": False, "message": str(e)}) +@router.post("/rag/add_git_issue_task", dependencies=[Depends(verify_rate_limit)]) +def add_git_issue_task(config: RAGGitIssueConfig): + try: + data = git_issue_task.add_rag_git_issue_task(config) + return data + except Exception as e: + return json.dumps({"success": False, "message": str(e)}) + @router.post("/rag/trigger_task", dependencies=[Depends(verify_rate_limit)]) def trigger_task(task_type: TaskType, task_id: Optional[str] = None):