-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: impl union rag git task (#210)
Refactored the entire task processing and scheduling logic
- Loading branch information
Showing
10 changed files
with
336 additions
and
214 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
from typing import Optional | ||
|
||
from github import Github, Repository | ||
|
||
from .git_task import GitTask | ||
from ..data_class import RAGGitDocConfig, TaskStatus, TaskType, GitDocTaskNodeType | ||
from ..rag_helper import retrieval | ||
|
||
g = Github() | ||
|
||
|
||
def get_path_sha(repo: Repository.Repository, sha: str, path: Optional[str] = None): | ||
if not path: | ||
return sha | ||
else: | ||
tree_data = repo.get_git_tree(sha) | ||
for item in tree_data.tree: | ||
if path.split("/")[0] == item.path: | ||
return get_path_sha(repo, item.sha, "/".join(path.split("/")[1:])) | ||
|
||
|
||
def add_rag_git_doc_task(config: RAGGitDocConfig, | ||
extra=None | ||
): | ||
if extra is None: | ||
extra = { | ||
"node_type": None, | ||
"from_task_id": None, | ||
} | ||
repo = g.get_repo(config.repo_name) | ||
|
||
commit_id = ( | ||
config.commit_id | ||
if config.commit_id | ||
else repo.get_branch(config.branch).commit.sha | ||
) | ||
if config.file_path == "" or config.file_path is None: | ||
extra["node_type"] = GitDocTaskNodeType.TREE.value | ||
|
||
if not extra.get("node_type"): | ||
content = repo.get_contents(config.file_path, ref=commit_id) | ||
if isinstance(content, list): | ||
extra["node_type"] = GitDocTaskNodeType.TREE.value | ||
else: | ||
extra["node_type"] = GitDocTaskNodeType.BLOB.value | ||
|
||
sha = get_path_sha(repo, commit_id, config.file_path) | ||
|
||
doc_task = GitDocTask(commit_id=commit_id, | ||
sha=sha, | ||
repo_name=config.repo_name, | ||
node_type=extra["node_type"], | ||
bot_id=config.bot_id, | ||
path=config.file_path) | ||
res = doc_task.save() | ||
doc_task.send() | ||
return res | ||
|
||
|
||
class GitDocTask(GitTask): | ||
def __init__(self, | ||
commit_id, | ||
node_type: GitDocTaskNodeType, | ||
sha, | ||
bot_id, | ||
path, | ||
repo_name, | ||
status=TaskStatus.NOT_STARTED, | ||
from_id=None, | ||
id=None | ||
): | ||
super().__init__(bot_id=bot_id, type=TaskType.GIT_DOC, from_id=from_id, id=id, status=status, | ||
repo_name=repo_name) | ||
self.commit_id = commit_id | ||
self.node_type = node_type | ||
self.sha = sha | ||
self.path = path | ||
|
||
def extra_save_data(self): | ||
data = { | ||
"commit_id": self.commit_id, | ||
"node_type": self.node_type, | ||
"path": self.path, | ||
"sha": self.sha, | ||
} | ||
return data | ||
|
||
def handle_tree_node(self): | ||
repo = g.get_repo(self.repo_name) | ||
tree_data = repo.get_git_tree(self.sha) | ||
|
||
task_list = list( | ||
filter( | ||
lambda item: item["path"].endswith(".md") or item["node_type"] == "tree", | ||
map( | ||
lambda item: { | ||
"repo_name": self.repo_name, | ||
"commit_id": self.commit_id, | ||
"status": TaskStatus.NOT_STARTED.name, | ||
"node_type": (item.type + '').upper(), | ||
"from_task_id": self.id, | ||
"path": "/".join(filter(lambda s: s, [self.path, item.path])), | ||
"sha": item.sha, | ||
"bot_id": self.bot_id, | ||
}, | ||
tree_data.tree, | ||
), | ||
) | ||
) | ||
|
||
if len(task_list) > 0: | ||
result = self.get_table().insert(task_list).execute() | ||
|
||
for record in result.data: | ||
doc_task = GitDocTask(id=record["id"], | ||
commit_id=record["commit_id"], | ||
sha=record["sha"], | ||
repo_name=record["repo_name"], | ||
node_type=record["node_type"], | ||
bot_id=record["bot_id"], | ||
path=record["path"]) | ||
doc_task.send() | ||
|
||
return (self.get_table().update( | ||
{"metadata": {"tree": list(map(lambda item: item.raw_data, tree_data.tree))}, | ||
"status": TaskStatus.COMPLETED.name}) | ||
.eq("id", self.id) | ||
.execute()) | ||
|
||
def handle_blob_node(self): | ||
retrieval.add_knowledge_by_doc( | ||
RAGGitDocConfig( | ||
repo_name=self.repo_name, | ||
file_path=self.path, | ||
commit_id=self.commit_id, | ||
bot_id=self.bot_id, | ||
) | ||
) | ||
return self.update_status(TaskStatus.COMPLETED) | ||
|
||
def handle(self): | ||
self.update_status(TaskStatus.IN_PROGRESS) | ||
if self.node_type == GitDocTaskNodeType.TREE.value: | ||
return self.handle_tree_node() | ||
elif self.node_type == GitDocTaskNodeType.BLOB.value: | ||
return self.handle_blob_node() | ||
else: | ||
raise ValueError(f"Unsupported node type [{self.node_type}]") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import json | ||
from abc import ABC, abstractmethod | ||
|
||
import boto3 | ||
|
||
from ..data_class import TaskStatus, TaskType | ||
from ..db.client.supabase import get_client | ||
from ..utils.env import get_env_variable | ||
|
||
sqs = boto3.client("sqs") | ||
|
||
TABLE_NAME_MAP = { | ||
TaskType.GIT_DOC: 'rag_tasks' | ||
} | ||
SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL") | ||
|
||
|
||
# Base GitTask Class | ||
class GitTask(ABC): | ||
type: TaskType | ||
|
||
def __init__(self, type, repo_name, bot_id, status=TaskStatus.NOT_STARTED, from_id=None, id=None): | ||
self.type = type | ||
self.id = id | ||
self.from_id = from_id | ||
self.status = status | ||
self.repo_name = repo_name | ||
self.bot_id = bot_id | ||
|
||
@staticmethod | ||
def get_table_name(type: TaskType): | ||
return TABLE_NAME_MAP[type] | ||
|
||
@property | ||
def table_name(self): | ||
return GitTask.get_table_name(self.type) | ||
|
||
@property | ||
def raw_data(self): | ||
data = { | ||
**self.extra_save_data(), | ||
"repo_name": self.repo_name, | ||
"bot_id": self.bot_id, | ||
"from_task_id": self.from_id, | ||
"status": self.status.name, | ||
} | ||
return data | ||
|
||
def get_table(self): | ||
supabase = get_client() | ||
return supabase.table(self.table_name) | ||
|
||
def update_status(self, status: TaskStatus): | ||
return (self.get_table() | ||
.update({"status": status.name}) | ||
.eq("id", self.id) | ||
.execute()) | ||
|
||
def save(self): | ||
res = self.get_table().insert(self.raw_data).execute() | ||
self.id = res.data[0]['id'] | ||
return res | ||
|
||
@abstractmethod | ||
def extra_save_data(self): | ||
pass | ||
|
||
@abstractmethod | ||
def handle(self): | ||
pass | ||
|
||
def send(self): | ||
assert self.id, "Task ID needed, save it first" | ||
assert self.type, "Task type needed, set it first" | ||
|
||
response = sqs.send_message( | ||
QueueUrl=SQS_QUEUE_URL, | ||
DelaySeconds=10, | ||
MessageBody=(json.dumps({"task_id": self.id, "task_type": self.type.value})), | ||
) | ||
message_id = response["MessageId"] | ||
print(f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}") | ||
return message_id |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.