Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: handle the failed task message #600

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions petercat_utils/rag_helper/git_doc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None,
retry_count=0,
):
super().__init__(
type=TaskType.GIT_DOC,
from_id=from_id,
id=id,
status=status,
repo_name=repo_name,
retry_count=retry_count,

Check warning on line 79 in petercat_utils/rag_helper/git_doc_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_doc_task.py#L79

Added line #L79 was not covered by tests
)
self.commit_id = commit_id
self.node_type = GitDocTaskNodeType(node_type)
Expand Down
66 changes: 38 additions & 28 deletions petercat_utils/rag_helper/git_issue_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
g.get_repo(config.repo_name)

issue_task = GitIssueTask(
issue_id='',
issue_id="",
node_type=GitIssueTaskNodeType.REPO,
bot_id=config.bot_id,
repo_name=config.repo_name
repo_name=config.repo_name,
)
res = issue_task.save()
issue_task.send()
Expand All @@ -26,17 +26,26 @@
issue_id: str
node_type: GitIssueTaskNodeType

def __init__(self,
issue_id,
node_type: GitIssueTaskNodeType,
bot_id,
repo_name,
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None
):
super().__init__(bot_id=bot_id, type=TaskType.GIT_ISSUE, from_id=from_id, id=id, status=status,
repo_name=repo_name)
def __init__(
self,
issue_id,
node_type: GitIssueTaskNodeType,
bot_id,
repo_name,
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None,
retry_count=0,

Check warning on line 38 in petercat_utils/rag_helper/git_issue_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_issue_task.py#L38

Added line #L38 was not covered by tests
):
super().__init__(
bot_id=bot_id,

Check warning on line 41 in petercat_utils/rag_helper/git_issue_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_issue_task.py#L40-L41

Added lines #L40 - L41 were not covered by tests
type=TaskType.GIT_ISSUE,
from_id=from_id,
id=id,

Check warning on line 44 in petercat_utils/rag_helper/git_issue_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_issue_task.py#L44

Added line #L44 was not covered by tests
status=status,
repo_name=repo_name,
retry_count=retry_count,
)
self.issue_id = issue_id
self.node_type = GitIssueTaskNodeType(node_type)

Expand Down Expand Up @@ -75,27 +84,28 @@
if len(task_list) > 0:
result = self.get_table().insert(task_list).execute()
for record in result.data:
issue_task = GitIssueTask(id=record["id"],
issue_id=record["issue_id"],
repo_name=record["repo_name"],
node_type=record["node_type"],
bot_id=record["bot_id"],
status=record["status"],
from_id=record["from_task_id"]
)
issue_task = GitIssueTask(
id=record["id"],

Check warning on line 88 in petercat_utils/rag_helper/git_issue_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_issue_task.py#L88

Added line #L88 was not covered by tests
issue_id=record["issue_id"],
repo_name=record["repo_name"],
node_type=record["node_type"],
bot_id=record["bot_id"],
status=record["status"],
from_id=record["from_task_id"],

Check warning on line 94 in petercat_utils/rag_helper/git_issue_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_issue_task.py#L94

Added line #L94 was not covered by tests
)
issue_task.send()

return (self.get_table().update(
{"status": TaskStatus.COMPLETED.value})
.eq("id", self.id)
.execute())
return (
self.get_table()
.update({"status": TaskStatus.COMPLETED.value})
.eq("id", self.id)

Check warning on line 101 in petercat_utils/rag_helper/git_issue_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_issue_task.py#L101

Added line #L101 was not covered by tests
.execute()
)

def handle_issue_node(self):
issue_retrieval.add_knowledge_by_issue(
RAGGitIssueConfig(
repo_name=self.repo_name,
bot_id=self.bot_id,
issue_id=self.issue_id
repo_name=self.repo_name, bot_id=self.bot_id, issue_id=self.issue_id
)
)
return self.update_status(TaskStatus.COMPLETED)
12 changes: 10 additions & 2 deletions petercat_utils/rag_helper/git_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
status=TaskStatus.NOT_STARTED,
from_id=None,
id=None,
retry_count=0,
):
self.type = type
self.id = id
self.from_id = from_id
self.status = status
self.repo_name = repo_name
self.retry_count = retry_count

@staticmethod
def get_table_name(type: TaskType):
Expand Down Expand Up @@ -82,11 +84,17 @@
QueueUrl=SQS_QUEUE_URL,
DelaySeconds=10,
MessageBody=(
json.dumps({"task_id": self.id, "task_type": self.type.value})
json.dumps(
{
"task_id": self.id,

Check warning on line 89 in petercat_utils/rag_helper/git_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_task.py#L88-L89

Added lines #L88 - L89 were not covered by tests
"task_type": self.type.value,
"retry_count": self.retry_count,
}

Check warning on line 92 in petercat_utils/rag_helper/git_task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/git_task.py#L92

Added line #L92 was not covered by tests
)
),
)
message_id = response["MessageId"]
print(
f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}"
f"task_id={self.id}, task_type={self.type.value}, message_id={message_id}, retry_count={self.retry_count}"
)
return message_id
20 changes: 5 additions & 15 deletions petercat_utils/rag_helper/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,6 @@
SQS_QUEUE_URL = get_env_variable("SQS_QUEUE_URL")


def send_task_message(task_id: str):
response = sqs.send_message(
QueueUrl=SQS_QUEUE_URL,
DelaySeconds=10,
MessageBody=(json.dumps({"task_id": task_id})),
)
return response["MessageId"]


def get_oldest_task():
supabase = get_client()

Expand All @@ -54,10 +45,7 @@
return response.data[0] if (len(response.data) > 0) else None


def get_task(
task_type: TaskType,
task_id: str,
) -> GitTask:
def get_task(task_type: TaskType, task_id: str, retry_count=0) -> GitTask:
supabase = get_client()
response = (
supabase.table(GitTask.get_table_name(task_type))
Expand All @@ -77,6 +65,7 @@
path=data["path"],
status=data["status"],
from_id=data["from_task_id"],
retry_count=retry_count,

Check warning on line 68 in petercat_utils/rag_helper/task.py

View check run for this annotation

Codecov / codecov/patch

petercat_utils/rag_helper/task.py#L68

Added line #L68 was not covered by tests
)
if task_type == TaskType.GIT_ISSUE:
return GitIssueTask(
Expand All @@ -87,11 +76,12 @@
bot_id=data["bot_id"],
status=data["status"],
from_id=data["from_task_id"],
retry_count=retry_count,
)


def trigger_task(task_type: TaskType, task_id: Optional[str]):
task = get_task(task_type, task_id) if task_id else get_oldest_task()
def trigger_task(task_type: TaskType, task_id: Optional[str], retry_count: int = 0):
task = get_task(task_type, task_id, retry_count) if task_id else get_oldest_task()
if task is None:
return task
return task.handle()
29 changes: 16 additions & 13 deletions server/aws/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,32 @@
STATIC_SECRET_NAME = get_env_variable("STATIC_SECRET_NAME")
STATIC_KEYPAIR_ID = get_env_variable("STATIC_KEYPAIR_ID")


def rsa_signer(message):
private_key_str = get_private_key(STATIC_SECRET_NAME)
private_key = rsa.PrivateKey.load_pkcs1(private_key_str.encode('utf-8'))
return rsa.sign(message, private_key, 'SHA-1')
private_key = rsa.PrivateKey.load_pkcs1(private_key_str.encode("utf-8"))

Check warning on line 20 in server/aws/service.py

View check run for this annotation

Codecov / codecov/patch

server/aws/service.py#L20

Added line #L20 was not covered by tests
return rsa.sign(message, private_key, "SHA-1")


def create_signed_url(url, expire_minutes=60) -> str:
cloudfront_signer = CloudFrontSigner(STATIC_KEYPAIR_ID, rsa_signer)

Check warning on line 26 in server/aws/service.py

View check run for this annotation

Codecov / codecov/patch

server/aws/service.py#L26

Added line #L26 was not covered by tests
# 设置过期时间
expire_date = datetime.now() + timedelta(minutes=expire_minutes)

Check warning on line 29 in server/aws/service.py

View check run for this annotation

Codecov / codecov/patch

server/aws/service.py#L29

Added line #L29 was not covered by tests
# 创建签名 URL
signed_url = cloudfront_signer.generate_presigned_url(
url=url,
date_less_than=expire_date
url=url, date_less_than=expire_date
)

Check warning on line 34 in server/aws/service.py

View check run for this annotation

Codecov / codecov/patch

server/aws/service.py#L34

Added line #L34 was not covered by tests
return signed_url


def upload_image_to_s3(file, metadata: ImageMetaData, s3_client):
try:
file_content = file.file.read()
md5_hash = hashlib.md5()
md5_hash.update(file.filename.encode('utf-8'))
md5_hash.update(file.filename.encode("utf-8"))
s3_key = md5_hash.hexdigest()
encoded_filename = (
base64.b64encode(metadata.title.encode("utf-8")).decode("utf-8")
Expand All @@ -62,11 +64,12 @@
ContentType=file.content_type,
Metadata=custom_metadata,
)
# you need to redirect your static domain to your s3 bucket domain
s3_url = f"{STATIC_URL}/{s3_key}"
signed_url = create_signed_url(url=s3_url, expire_minutes=60) \
if (STATIC_SECRET_NAME and STATIC_KEYPAIR_ID) \
else s3_url
return {"message": "File uploaded successfully", "url": signed_url }
signed_url = (
create_signed_url(url=s3_url, expire_minutes=60)
if (STATIC_SECRET_NAME and STATIC_KEYPAIR_ID)

Check warning on line 70 in server/aws/service.py

View check run for this annotation

Codecov / codecov/patch

server/aws/service.py#L70

Added line #L70 was not covered by tests
else s3_url
)
return {"message": "File uploaded successfully", "url": signed_url}
except Exception as e:
raise UploadError(detail=str(e))
30 changes: 19 additions & 11 deletions subscriber/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,38 @@
from petercat_utils import task as task_helper
from petercat_utils.data_class import TaskType

MAX_RETRY_COUNT = 5


def lambda_handler(event, context):
if event:
batch_item_failures = []
sqs_batch_response = {}

for record in event["Records"]:
try:
body = record["body"]
print(f"receive message here: {body}")
body = record["body"]
print(f"receive message here: {body}")

message_dict = json.loads(body)
task_id = message_dict["task_id"]
task_type = message_dict["task_type"]
task = task_helper.get_task(TaskType(task_type), task_id)
message_dict = json.loads(body)
task_id = message_dict["task_id"]
task_type = message_dict["task_type"]
retry_count = message_dict["retry_count"]
task = task_helper.get_task(TaskType(task_type), task_id)
try:
if task is None:
return task
task.handle()

# process message
print(f"message content: message={message_dict}, task_id={task_id}, task={task}")
print(
f"message content: message={message_dict}, task_id={task_id}, task={task}, retry_count={retry_count}"
)
except Exception as e:
print(f"message handle error: ${e}")
batch_item_failures.append({"itemIdentifier": record['messageId']})
if retry_count < MAX_RETRY_COUNT:
xingwanying marked this conversation as resolved.
Show resolved Hide resolved
retry_count += 1
task_helper.trigger_task(task_type, task_id, retry_count)
else:
print(f"message handle error: ${e}")
batch_item_failures.append({"itemIdentifier": record["messageId"]})

sqs_batch_response["batchItemFailures"] = batch_item_failures
return sqs_batch_response
Loading