Skip to content

Commit

Permalink
Retry policies implementation (#11)
Browse files Browse the repository at this point in the history
Signed-off-by: Deepanshu Agarwal <[email protected]>
Co-authored-by: Chris Gillum <[email protected]>
  • Loading branch information
DeepanshuA and cgillum authored Oct 26, 2023
1 parent c999097 commit 80de6c2
Show file tree
Hide file tree
Showing 5 changed files with 721 additions and 80 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### New

- Retry policies for activities and sub-orchestrations ([#11](https://github.com/microsoft/durabletask-python/pull/11)) - contributed by [@DeepanshuA](https://github.com/DeepanshuA)

## v0.1.0a5

### New
Expand Down
193 changes: 157 additions & 36 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
# See https://peps.python.org/pep-0563/
from __future__ import annotations

import math
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import Any, Callable, Generator, Generic, List, TypeVar, Union
from typing import (Any, Callable, Generator, Generic, List, Optional, TypeVar,
Union)

import durabletask.internal.helpers as pbh
import durabletask.internal.orchestrator_service_pb2 as pb
Expand Down Expand Up @@ -87,17 +89,18 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task:

@abstractmethod
def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,
input: Union[TInput, None] = None) -> Task[TOutput]:
input: Optional[TInput] = None,
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
"""Schedule an activity for execution.
Parameters
----------
activity: Union[Activity[TInput, TOutput], str]
A reference to the activity function to call.
input: Union[TInput, None]
input: Optional[TInput]
The JSON-serializable input (or None) to pass to the activity.
return_type: task.Task[TOutput]
The JSON-serializable output type to expect from the activity result.
retry_policy: Optional[RetryPolicy]
The retry policy to use for this activity call.
Returns
-------
Expand All @@ -108,19 +111,22 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *,

@abstractmethod
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
input: Union[TInput, None] = None,
instance_id: Union[str, None] = None) -> Task[TOutput]:
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None) -> Task[TOutput]:
"""Schedule sub-orchestrator function for execution.
Parameters
----------
orchestrator: Orchestrator[TInput, TOutput]
A reference to the orchestrator function to call.
input: Union[TInput, None]
input: Optional[TInput]
The optional JSON-serializable input to pass to the orchestrator function.
instance_id: Union[str, None]
instance_id: Optional[str]
A unique ID to use for the sub-orchestration instance. If not specified, a
random UUID will be used.
retry_policy: Optional[RetryPolicy]
The retry policy to use for this sub-orchestrator call.
Returns
-------
Expand Down Expand Up @@ -162,7 +168,7 @@ def continue_as_new(self, new_input: Any, *, save_events: bool = False) -> None:


class FailureDetails:
def __init__(self, message: str, error_type: str, stack_trace: Union[str, None]):
def __init__(self, message: str, error_type: str, stack_trace: Optional[str]):
self._message = message
self._error_type = error_type
self._stack_trace = stack_trace
Expand All @@ -176,7 +182,7 @@ def error_type(self) -> str:
return self._error_type

@property
def stack_trace(self) -> Union[str, None]:
def stack_trace(self) -> Optional[str]:
return self._stack_trace


Expand Down Expand Up @@ -206,8 +212,8 @@ class OrchestrationStateError(Exception):
class Task(ABC, Generic[T]):
"""Abstract base class for asynchronous tasks in a durable orchestration."""
_result: T
_exception: Union[TaskFailedError, None]
_parent: Union[CompositeTask[T], None]
_exception: Optional[TaskFailedError]
_parent: Optional[CompositeTask[T]]

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -261,29 +267,6 @@ def get_tasks(self) -> List[Task]:
def on_child_completed(self, task: Task[T]):
pass


class CompletableTask(Task[T]):

def __init__(self):
super().__init__()

def complete(self, result: T):
if self._is_complete:
raise ValueError('The task has already completed.')
self._result = result
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)

def fail(self, message: str, details: pb.TaskFailureDetails):
if self._is_complete:
raise ValueError('The task has already completed.')
self._exception = TaskFailedError(message, details)
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)


class WhenAllTask(CompositeTask[List[T]]):
"""A task that completes when all of its child tasks complete."""

Expand Down Expand Up @@ -313,6 +296,76 @@ def get_completed_tasks(self) -> int:
return self._completed_tasks


class CompletableTask(Task[T]):

def __init__(self):
super().__init__()
self._retryable_parent = None

def complete(self, result: T):
if self._is_complete:
raise ValueError('The task has already completed.')
self._result = result
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)

def fail(self, message: str, details: pb.TaskFailureDetails):
if self._is_complete:
raise ValueError('The task has already completed.')
self._exception = TaskFailedError(message, details)
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)


class RetryableTask(CompletableTask[T]):
"""A task that can be retried according to a retry policy."""

def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction,
start_time:datetime, is_sub_orch: bool) -> None:
super().__init__()
self._action = action
self._retry_policy = retry_policy
self._attempt_count = 1
self._start_time = start_time
self._is_sub_orch = is_sub_orch

def increment_attempt_count(self) -> None:
self._attempt_count += 1

def compute_next_delay(self) -> Union[timedelta, None]:
if self._attempt_count >= self._retry_policy.max_number_of_attempts:
return None

retry_expiration: datetime = datetime.max
if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max:
retry_expiration = self._start_time + self._retry_policy.retry_timeout

if self._retry_policy.backoff_coefficient is None:
backoff_coefficient = 1.0
else:
backoff_coefficient = self._retry_policy.backoff_coefficient

if datetime.utcnow() < retry_expiration:
next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds()

if self._retry_policy.max_retry_interval is not None:
next_delay_f = min(next_delay_f, self._retry_policy.max_retry_interval.total_seconds())
return timedelta(seconds=next_delay_f)

return None


class TimerTask(CompletableTask[T]):

def __init__(self) -> None:
super().__init__()

def set_retryable_parent(self, retryable_task: RetryableTask):
self._retryable_parent = retryable_task


class WhenAnyTask(CompositeTask[Task]):
"""A task that completes when any of its child tasks complete."""

Expand Down Expand Up @@ -376,6 +429,74 @@ def task_id(self) -> int:
Activity = Callable[[ActivityContext, TInput], TOutput]


class RetryPolicy:
"""Represents the retry policy for an orchestration or activity function."""

def __init__(self, *,
first_retry_interval: timedelta,
max_number_of_attempts: int,
backoff_coefficient: Optional[float] = 1.0,
max_retry_interval: Optional[timedelta] = None,
retry_timeout: Optional[timedelta] = None):
"""Creates a new RetryPolicy instance.
Parameters
----------
first_retry_interval : timedelta
The retry interval to use for the first retry attempt.
max_number_of_attempts : int
The maximum number of retry attempts.
backoff_coefficient : Optional[float]
The backoff coefficient to use for calculating the next retry interval.
max_retry_interval : Optional[timedelta]
The maximum retry interval to use for any retry attempt.
retry_timeout : Optional[timedelta]
The maximum amount of time to spend retrying the operation.
"""
# validate inputs
if first_retry_interval < timedelta(seconds=0):
raise ValueError('first_retry_interval must be >= 0')
if max_number_of_attempts < 1:
raise ValueError('max_number_of_attempts must be >= 1')
if backoff_coefficient is not None and backoff_coefficient < 1:
raise ValueError('backoff_coefficient must be >= 1')
if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0):
raise ValueError('max_retry_interval must be >= 0')
if retry_timeout is not None and retry_timeout < timedelta(seconds=0):
raise ValueError('retry_timeout must be >= 0')

self._first_retry_interval = first_retry_interval
self._max_number_of_attempts = max_number_of_attempts
self._backoff_coefficient = backoff_coefficient
self._max_retry_interval = max_retry_interval
self._retry_timeout = retry_timeout

@property
def first_retry_interval(self) -> timedelta:
"""The retry interval to use for the first retry attempt."""
return self._first_retry_interval

@property
def max_number_of_attempts(self) -> int:
"""The maximum number of retry attempts."""
return self._max_number_of_attempts

@property
def backoff_coefficient(self) -> Optional[float]:
"""The backoff coefficient to use for calculating the next retry interval."""
return self._backoff_coefficient

@property
def max_retry_interval(self) -> Optional[timedelta]:
"""The maximum retry interval to use for any retry attempt."""
return self._max_retry_interval

@property
def retry_timeout(self) -> Optional[timedelta]:
"""The maximum amount of time to spend retrying the operation."""
return self._retry_timeout


def get_name(fn: Callable) -> str:
"""Returns the name of the provided function"""
name = fn.__name__
Expand Down
Loading

0 comments on commit 80de6c2

Please sign in to comment.