Skip to content

Commit

Permalink
feat: impl issue task and release petercat utils (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaoHai authored Aug 18, 2024
2 parents d11a50d + 9dfcea1 commit 51880cb
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 28 deletions.
18 changes: 11 additions & 7 deletions petercat_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from .db.client.supabase import get_client
from .rag_helper import github_file_loader, retrieval, issue_retrieval, task, git_task, git_issue_task, git_doc_task
from .utils.env import get_env_variable
from .rag_helper import github_file_loader, retrieval, task

__all__ = [
"get_client",
"get_env_variable",
"github_file_loader",
"retrieval",
"task"
]
"get_client",
"get_env_variable",
"github_file_loader",
"retrieval",
"issue_retrieval",
"task",
"git_task",
"git_issue_task",
"git_doc_task"
]
20 changes: 13 additions & 7 deletions petercat_utils/data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ class RAGGitDocConfig(GitDocConfig):
bot_id: str


class GitIssueConfig(BaseModel):
repo_name: str
issue_id: str


class RAGGitIssueConfig(GitIssueConfig):
bot_id: str


class AutoNameEnum(Enum):
def _generate_next_value_(name, start, count, last_values):
return name
Expand All @@ -85,17 +94,14 @@ class TaskStatus(AutoNameEnum):

class TaskType(AutoNameEnum):
GIT_DOC = auto()
GIT_ISSUE = auto()


class GitDocTaskNodeType(AutoNameEnum):
TREE = auto()
BLOB = auto()


class GitIssueConfig(BaseModel):
repo_name: str
issue_id: str


class RAGGitIssueConfig(GitIssueConfig):
bot_id: str
class GitIssueTaskNodeType(AutoNameEnum):
REPO = auto()
ISSUE = auto()
6 changes: 3 additions & 3 deletions petercat_utils/rag_helper/git_doc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self,
def extra_save_data(self):
data = {
"commit_id": self.commit_id,
"node_type": self.node_type,
"node_type": self.node_type.value,
"path": self.path,
"sha": self.sha,
}
Expand All @@ -96,7 +96,7 @@ def handle_tree_node(self):
lambda item: {
"repo_name": self.repo_name,
"commit_id": self.commit_id,
"status": TaskStatus.NOT_STARTED.name,
"status": TaskStatus.NOT_STARTED.value,
"node_type": (item.type + '').upper(),
"from_task_id": self.id,
"path": "/".join(filter(lambda s: s, [self.path, item.path])),
Expand All @@ -123,7 +123,7 @@ def handle_tree_node(self):

return (self.get_table().update(
{"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))},
"status": TaskStatus.COMPLETED.name})
"status": TaskStatus.COMPLETED.value})
.eq("id", self.id)
.execute())

Expand Down
101 changes: 101 additions & 0 deletions petercat_utils/rag_helper/git_issue_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from github import Github

from .git_task import GitTask
from ..data_class import GitIssueTaskNodeType, TaskStatus, TaskType, RAGGitIssueConfig
from ..rag_helper import issue_retrieval

g = Github()


def add_rag_git_issue_task(config: RAGGitIssueConfig):
g.get_repo(config.repo_name)

issue_task = GitIssueTask(
issue_id='',
node_type=GitIssueTaskNodeType.REPO,
bot_id=config.bot_id,
repo_name=config.repo_name
)
res = issue_task.save()
issue_task.send()

return res


class GitIssueTask(GitTask):
issue_id: str
node_type: GitIssueTaskNodeType

def __init__(self,
issue_id,
node_type: GitIssueTaskNodeType,
bot_id,
repo_name,
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None
):
super().__init__(bot_id=bot_id, type=TaskType.GIT_ISSUE, from_id=from_id, id=id, status=status,
repo_name=repo_name)
self.issue_id = issue_id
self.node_type = node_type

def extra_save_data(self):
return {
"issue_id": self.issue_id,
"node_type": self.node_type.value,
}

def handle(self):
self.update_status(TaskStatus.IN_PROGRESS)
if self.node_type == GitIssueTaskNodeType.REPO.value:
return self.handle_repo_node()
elif self.node_type == GitIssueTaskNodeType.ISSUE.value:
return self.handle_issue_node()
else:
raise ValueError(f"Unsupported node type [{self.node_type}]")

def handle_repo_node(self):
repo = g.get_repo(self.repo_name)
repo.get_issues()
issues = [issue for issue in repo.get_issues()]
task_list = list(
map(
lambda item: {
"repo_name": self.repo_name,
"issue_id": str(item.number),
"status": TaskStatus.NOT_STARTED.value,
"node_type": GitIssueTaskNodeType.ISSUE.value,
"from_task_id": self.id,
"bot_id": self.bot_id,
},
issues,
),
)
if len(task_list) > 0:
result = self.get_table().insert(task_list).execute()
for record in result.data:
issue_task = GitIssueTask(id=record["id"],
issue_id=record["issue_id"],
repo_name=record["repo_name"],
node_type=record["node_type"],
bot_id=record["bot_id"],
status=record["status"],
from_id=record["from_task_id"]
)
issue_task.send()

return (self.get_table().update(
{"status": TaskStatus.COMPLETED.value})
.eq("id", self.id)
.execute())

def handle_issue_node(self):
issue_retrieval.add_knowledge_by_issue(
RAGGitIssueConfig(
repo_name=self.repo_name,
bot_id=self.bot_id,
issue_id=self.issue_id
)
)
return self.update_status(TaskStatus.COMPLETED)
7 changes: 4 additions & 3 deletions petercat_utils/rag_helper/git_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
sqs = boto3.client("sqs")

TABLE_NAME_MAP = {
TaskType.GIT_DOC: 'rag_tasks'
TaskType.GIT_DOC: 'rag_tasks',
TaskType.GIT_ISSUE: 'git_issue_tasks'
}
SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL")

Expand Down Expand Up @@ -42,7 +43,7 @@ def raw_data(self):
"repo_name": self.repo_name,
"bot_id": self.bot_id,
"from_task_id": self.from_id,
"status": self.status.name,
"status": self.status.value,
}
return data

Expand All @@ -52,7 +53,7 @@ def get_table(self):

def update_status(self, status: TaskStatus):
return (self.get_table()
.update({"status": status.name})
.update({"status": status.value})
.eq("id", self.id)
.execute())

Expand Down
16 changes: 14 additions & 2 deletions petercat_utils/rag_helper/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import boto3

from .git_doc_task import GitDocTask
from .git_issue_task import GitIssueTask
from .git_task import GitTask

# Create SQS client
Expand Down Expand Up @@ -37,7 +38,7 @@ def get_oldest_task():
response = (
supabase.table(TABLE_NAME)
.select("*")
.eq("status", TaskStatus.NOT_STARTED.name)
.eq("status", TaskStatus.NOT_STARTED.value)
.order("created_at", desc=False)
.limit(1)
.execute()
Expand Down Expand Up @@ -70,7 +71,18 @@ def get_task(task_type: TaskType, task_id: str) -> GitTask:
node_type=data["node_type"],
bot_id=data["bot_id"],
path=data["path"],
status=data["status"]
status=data["status"],
from_id=data["from_task_id"]
)
if task_type is TaskType.GIT_ISSUE:
return GitIssueTask(
id=data["id"],
issue_id=data["issue_id"],
repo_name=data["repo_name"],
node_type=data["node_type"],
bot_id=data["bot_id"],
status=data["status"],
from_id=data["from_task_id"]
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "petercat_utils"
version = "0.1.22"
version = "0.1.23"
description = ""
authors = ["raoha.rh <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions server/bot/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from github import Github
from petercat_utils import get_client
from petercat_utils.data_class import RAGGitDocConfig
from petercat_utils.rag_helper.git_doc_task import add_rag_git_doc_task
from petercat_utils import git_doc_task
from prompts.bot_template import generate_prompt_by_repo_name

g = Github()
Expand Down Expand Up @@ -53,7 +53,7 @@ def trigger_rag_task(repo_name: str, bot_id: str):
file_path="",
commit_id="",
)
add_rag_git_doc_task(config)
git_doc_task.add_rag_git_doc_task(config)
except Exception as e:
print(f"trigger_rag_task error: {e}")

Expand Down
14 changes: 11 additions & 3 deletions server/routers/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from verify.rate_limit import verify_rate_limit

from petercat_utils.data_class import RAGGitDocConfig, RAGGitIssueConfig, TaskType
from petercat_utils.rag_helper import retrieval, task, issue_retrieval, git_doc_task
from petercat_utils.rag_helper import retrieval, task, issue_retrieval, git_doc_task, git_issue_task

router = APIRouter(
prefix="/api",
Expand Down Expand Up @@ -54,14 +54,22 @@ def search_knowledge(query: str, bot_id: str, filter: dict = {}):
return data


@router.post("/rag/add_task", dependencies=[Depends(verify_rate_limit)])
def add_task(config: RAGGitDocConfig):
@router.post("/rag/add_git_doc_task", dependencies=[Depends(verify_rate_limit)])
def add_git_doc_task(config: RAGGitDocConfig):
try:
data = git_doc_task.add_rag_git_doc_task(config)
return data
except Exception as e:
return json.dumps({"success": False, "message": str(e)})

@router.post("/rag/add_git_issue_task", dependencies=[Depends(verify_rate_limit)])
def add_git_issue_task(config: RAGGitIssueConfig):
try:
data = git_issue_task.add_rag_git_issue_task(config)
return data
except Exception as e:
return json.dumps({"success": False, "message": str(e)})


@router.post("/rag/trigger_task", dependencies=[Depends(verify_rate_limit)])
def trigger_task(task_type: TaskType, task_id: Optional[str] = None):
Expand Down

0 comments on commit 51880cb

Please sign in to comment.