Skip to content

Commit

Permalink
feat: impl union rag git task (#210)
Browse files Browse the repository at this point in the history
Refactored the entire task processing and scheduling logic
  • Loading branch information
xingwanying authored Aug 18, 2024
2 parents b8b9e70 + 95c7493 commit 40472f4
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 214 deletions.
45 changes: 27 additions & 18 deletions petercat_utils/data_class.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from enum import Enum, auto
from dataclasses import dataclass
from typing import List, Optional
from datetime import datetime
from typing import Literal, Optional, List, TypeAlias, Union
from pydantic import BaseModel
from typing import Literal, Optional, List, TypeAlias
from typing import Union

from pydantic import BaseModel


class ImageURL(BaseModel):
url: str
Expand Down Expand Up @@ -59,18 +57,6 @@ class S3Config(BaseModel):
file_path: Optional[str] = None


class GitIssueConfig(BaseModel):
repo_name: str
page: Optional[int] = None
"""The page number for paginated results.
Defaults to 1 in the GitHub API."""
per_page: Optional[int] = 30
"""Number of items per page.
Defaults to 30 in the GitHub API."""
state: Optional[Literal["open", "closed", "all"]] = "all"
"""Filter on issue state. Can be one of: 'open', 'closed', 'all'."""


class GitDocConfig(BaseModel):
repo_name: str
"""File path of the documentation file. eg:'docs/blog/build-ghost.zh-CN.md'"""
Expand All @@ -83,10 +69,33 @@ class RAGGitDocConfig(GitDocConfig):
bot_id: str


class TaskStatus(Enum):
class AutoNameEnum(Enum):
def _generate_next_value_(name, start, count, last_values):
return name


class TaskStatus(AutoNameEnum):
NOT_STARTED = auto()
IN_PROGRESS = auto()
COMPLETED = auto()
ON_HOLD = auto()
CANCELLED = auto()
ERROR = auto()


class TaskType(AutoNameEnum):
GIT_DOC = auto()


class GitDocTaskNodeType(AutoNameEnum):
TREE = auto()
BLOB = auto()


class GitIssueConfig(BaseModel):
repo_name: str
issue_id: str


class RAGGitIssueConfig(GitIssueConfig):
bot_id: str
148 changes: 148 additions & 0 deletions petercat_utils/rag_helper/git_doc_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Optional

from github import Github, Repository

from .git_task import GitTask
from ..data_class import RAGGitDocConfig, TaskStatus, TaskType, GitDocTaskNodeType
from ..rag_helper import retrieval

g = Github()


def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None):
if not path:
return sha
else:
tree_data = repo.get_git_tree(sha)
for item in tree_data.tree:
if path.split("/")[0] == item.path:
return get_path_sha(repo, item.sha, "/".join(path.split("/")[1:]))


def add_rag_git_doc_task(config: RAGGitDocConfig,
extra=None
):
if extra is None:
extra = {
"node_type": None,
"from_task_id": None,
}
repo = g.get_repo(config.repo_name)

commit_id = (
config.commit_id
if config.commit_id
else repo.get_branch(config.branch).commit.sha
)
if config.file_path == "" or config.file_path is None:
extra["node_type"] = GitDocTaskNodeType.TREE.value

if not extra.get("node_type"):
content = repo.get_contents(config.file_path, ref=commit_id)
if isinstance(content, list):
extra["node_type"] = GitDocTaskNodeType.TREE.value
else:
extra["node_type"] = GitDocTaskNodeType.BLOB.value

sha = get_path_sha(repo, commit_id, config.file_path)

doc_task = GitDocTask(commit_id=commit_id,
sha=sha,
repo_name=config.repo_name,
node_type=extra["node_type"],
bot_id=config.bot_id,
path=config.file_path)
res = doc_task.save()
doc_task.send()
return res


class GitDocTask(GitTask):
def __init__(self,
commit_id,
node_type: GitDocTaskNodeType,
sha,
bot_id,
path,
repo_name,
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None
):
super().__init__(bot_id=bot_id, type=TaskType.GIT_DOC, from_id=from_id, id=id, status=status,
repo_name=repo_name)
self.commit_id = commit_id
self.node_type = node_type
self.sha = sha
self.path = path

def extra_save_data(self):
data = {
"commit_id": self.commit_id,
"node_type": self.node_type,
"path": self.path,
"sha": self.sha,
}
return data

def handle_tree_node(self):
repo = g.get_repo(self.repo_name)
tree_data = repo.get_git_tree(self.sha)

task_list = list(
filter(
lambda item: item["path"].endswith(".md") or item["node_type"] == "tree",
map(
lambda item: {
"repo_name": self.repo_name,
"commit_id": self.commit_id,
"status": TaskStatus.NOT_STARTED.name,
"node_type": (item.type + '').upper(),
"from_task_id": self.id,
"path": "/".join(filter(lambda s: s, [self.path, item.path])),
"sha": item.sha,
"bot_id": self.bot_id,
},
tree_data.tree,
),
)
)

if len(task_list) > 0:
result = self.get_table().insert(task_list).execute()

for record in result.data:
doc_task = GitDocTask(id=record["id"],
commit_id=record["commit_id"],
sha=record["sha"],
repo_name=record["repo_name"],
node_type=record["node_type"],
bot_id=record["bot_id"],
path=record["path"])
doc_task.send()

return (self.get_table().update(
{"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))},
"status": TaskStatus.COMPLETED.name})
.eq("id", self.id)
.execute())

def handle_blob_node(self):
retrieval.add_knowledge_by_doc(
RAGGitDocConfig(
repo_name=self.repo_name,
file_path=self.path,
commit_id=self.commit_id,
bot_id=self.bot_id,
)
)
return self.update_status(TaskStatus.COMPLETED)

def handle(self):
self.update_status(TaskStatus.IN_PROGRESS)
if self.node_type == GitDocTaskNodeType.TREE.value:
return self.handle_tree_node()
elif self.node_type == GitDocTaskNodeType.BLOB.value:
return self.handle_blob_node()
else:
raise ValueError(f"Unsupported node type [{self.node_type}]")
83 changes: 83 additions & 0 deletions petercat_utils/rag_helper/git_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
from abc import ABC, abstractmethod

import boto3

from ..data_class import TaskStatus, TaskType
from ..db.client.supabase import get_client
from ..utils.env import get_env_variable

sqs = boto3.client("sqs")

TABLE_NAME_MAP = {
TaskType.GIT_DOC: 'rag_tasks'
}
SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL")


# Base GitTask Class
class GitTask(ABC):
type: TaskType

def __init__(self, type, repo_name, bot_id, status=TaskStatus.NOT_STARTED, from_id=None, id=None):
self.type = type
self.id = id
self.from_id = from_id
self.status = status
self.repo_name = repo_name
self.bot_id = bot_id

@staticmethod
def get_table_name(type: TaskType):
return TABLE_NAME_MAP[type]

@property
def table_name(self):
return GitTask.get_table_name(self.type)

@property
def raw_data(self):
data = {
**self.extra_save_data(),
"repo_name": self.repo_name,
"bot_id": self.bot_id,
"from_task_id": self.from_id,
"status": self.status.name,
}
return data

def get_table(self):
supabase = get_client()
return supabase.table(self.table_name)

def update_status(self, status: TaskStatus):
return (self.get_table()
.update({"status": status.name})
.eq("id", self.id)
.execute())

def save(self):
res = self.get_table().insert(self.raw_data).execute()
self.id = res.data[0]['id']
return res

@abstractmethod
def extra_save_data(self):
pass

@abstractmethod
def handle(self):
pass

def send(self):
assert self.id, "Task ID needed, save it first"
assert self.type, "Task type needed, set it first"

response = sqs.send_message(
QueueUrl=SQS_QUEUE_URL,
DelaySeconds=10,
MessageBody=(json.dumps({"task_id": self.id, "task_type": self.type.value})),
)
message_id = response["MessageId"]
print(f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}")
return message_id
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings
from petercat_utils import get_client

from .data_class import RAGIssueDocConfig
from petercat_utils.data_class import RAGGitIssueConfig

g = Github()

TABLE_NAME = "issue_docs"
QUERY_NAME = "match_issue_docs"
TABLE_NAME = "rag_issues"
QUERY_NAME = "match_rag_issues"
CHUNK_SIZE = 2000
CHUNK_OVERLAP = 200

Expand Down Expand Up @@ -77,7 +76,7 @@ def get_issue_document_list(issue: Issue):
return document_list


def add_knowledge_by_issue(config: RAGIssueDocConfig):
def add_knowledge_by_issue(config: RAGGitIssueConfig):
supabase = get_client()
is_added_query = (
supabase.table(TABLE_NAME)
Expand Down
Loading

0 comments on commit 40472f4

Please sign in to comment.