Skip to content

Commit

Permalink
Refine callbackAfterSeconds and update retryable error worker
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinFoka committed Oct 29, 2024
1 parent 70da240 commit ac008dd
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
31 changes: 13 additions & 18 deletions examples/simple_worker/workers/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import random
import time
from typing import Optional
Expand Down Expand Up @@ -225,30 +226,24 @@ class WorkerOutput(TaskOutput):

def execute(self, worker_input: WorkerInput) -> TaskResult[WorkerOutput]:

def fail_once() -> None:
def fail_three_times() -> None:
"""
Simulates a failure on the first execution by creating a temporary file and raising an exception.
If the temporary file exists, it is removed, indicating a previous failure.
Raises:
RetryOnExceptionError: Raised to trigger a retry with a specified delay.
Simulates three consecutive failures by using an environment variable to track attempts.
Raises an exception on the first three executions and succeeds on the fourth.
"""
import tempfile
from pathlib import Path

temp_file_path: Path = Path(tempfile.gettempdir()) / f'{self.__class__.__name__}.txt'

if temp_file_path.exists():
temp_file_path.unlink()
return
env_var_name: str = f"{self.__class__.__name__}_attempt_count"
attempt_count: int = int(os.getenv(env_var_name, 0))

try:
temp_file_path.write_text('Hello World!')
raise ZeroDivisionError('Simulated failure')
except ZeroDivisionError as e:
if attempt_count < 3:
os.environ[env_var_name]: str = str(attempt_count + 1)
raise RuntimeError("Simulated failure")
else:
os.environ.pop(env_var_name, None)
except RuntimeError as e:
raise RetryOnExceptionError(e, max_retries=5, retry_delay_seconds=5)

fail_once()
fail_three_times()

return TaskResult(
status=TaskResultStatus.COMPLETED,
Expand Down
12 changes: 7 additions & 5 deletions frinx/client/v2/frinx_conductor_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def task_not_found_anymore(self, task_not_found: NextWorkerTask) -> None:

class FrinxConductorWrapper:
def __init__(
self, server_url: str, max_thread_count: int, polling_interval: float = 0.1,
self, server_url: str, max_thread_count: int, polling_interval: float = 0.1,
worker_id: str | None = None, headers: dict[str, Any] | None = None
) -> None:
# Synchronizes access to self.queues by producer thread (in read_queue) and consumer threads (in tasks_in_queue)
Expand Down Expand Up @@ -247,10 +247,12 @@ def execute(self, task: RawTaskIO, task_blueprint: WorkerImpl) -> None:
error_msg = 'Task execution function MUST return a response as a dict with status and output fields'
raise Exception(error_msg)

task['status'] = resp['status']
task['outputData'] = resp.get('output', {})
task['logs'] = resp.get('logs', [])
task['logs'].extend(root_log_handler.get_logs())
task.update({
'status': resp['status'],
'callbackAfterSeconds': resp.get('callback_after_seconds', 0),
'outputData': resp.get('output', {}),
'logs': resp.get('logs', []) + root_log_handler.get_logs()
})

logger.debug('Executing a task %s, response: %s', task['taskId'], resp)
logger.debug('Executing a task %s, task body: %s', task['taskId'], task)
Expand Down
6 changes: 4 additions & 2 deletions frinx/common/worker/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

RawTaskIO: TypeAlias = dict[str, Any]


class RetryOnExceptionError(Exception):
"""
Exception class representing a task that needs to be retried.
Expand Down Expand Up @@ -41,7 +42,7 @@ def update_task_result(self, task: RawTaskIO, task_result: TaskResult[Any]) -> T

if self._should_retry(current_poll_count):
task_result.status = TaskResultStatus.IN_PROGRESS
task['callbackAfterSeconds'] = self.retry_delay_seconds
task_result.callback_after_seconds = self.retry_delay_seconds
else:
task_result.status = TaskResultStatus.FAILED

Expand All @@ -56,7 +57,8 @@ def _log_task_status(self, task_result: TaskResult[Any], current_poll_count: int
"""Logs the task status with the current poll count and exception details."""
error_name: str = type(self.caught_exception).__name__
error_info: str = str(self.caught_exception)
log_message = f'{RetryOnExceptionError.__name__}({current_poll_count}): {error_name} - {error_info}'
log_message = (f'{RetryOnExceptionError.__name__}({current_poll_count}/{self.max_retries}): '
f'{error_name} - {error_info}')

if isinstance(task_result.logs, str):
task_result.logs = [task_result.logs]
Expand Down
1 change: 1 addition & 0 deletions frinx/common/worker/task_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TaskResult(BaseModel, Generic[TO]):
status: TaskResultStatus
output: TO | None = None
logs: typing.Union[list[str], str] = Field(default=[])
callback_after_seconds: int | None = Field(default=0)

model_config = ConfigDict(
validate_assignment=True,
Expand Down

0 comments on commit ac008dd

Please sign in to comment.