Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetryHandlerSkeleton #152

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from deltacat.utils.ray_utils.retry_handler.retryable_error import RetryableError

class AWSSecurityTokenException(RetryableError):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, *args: object) -> None:
super().__init__(*args)
14 changes: 14 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from logging import Logger
import logging


def configure_logger(logger: Logger) -> Logger:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
logging.basicConfig(level=logging.INFO,
format='[%(asctime)s] %(levelname)s [%(name)s;%(filename)s.%(funcName)s:%(lineno)d] %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S')

# These modules were not configured to honor the log level specified,
# Hence, explicitly setting log level for them.
logging.getLogger("deltacat.utils.pyarrow").setLevel(logging.INFO)
logging.getLogger("amazoncerts.cacerts_helpers").setLevel(logging.ERROR)
return logger
7 changes: 7 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/non_retryable_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class NonRetryableError(RuntimeError):
"""
Class represents a non-retryable error
"""

def __init__(self, *args:object) --> None:
super().__init__(*args)
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from deltacat.utils.ray_utils.retry_handler.task_constants import DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BATCH_SIZE_MULTIPLICATIVE_DECREASE_FACTOR, DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BACK_OFF_IN_MS, DEFAULT_RAY_REMOTE_TASK_BATCH_POSITIVE_FEEDBACK_BATCH_SIZE_ADDITIVE_INCREASE
from dataclasses import dataclass

class RayRemoteTasksBatchScalingParams():
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Represents the batch scaling params of the Ray remote tasks
need to add constants that this file refers to
"""
def __init__(self,
initial_batch_size: int,
negative_feedback_back_off_in_ms: int = DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BACK_OFF_IN_MS,
positive_feedback_batch_size_additive_increase: int = DEFAULT_RAY_REMOTE_TASK_BATCH_POSITIVE_FEEDBACK_BATCH_SIZE_ADDITIVE_INCREASE,
negative_feedback_batch_size_multiplicative_decrease_factor: int = DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BATCH_SIZE_MULTIPLICATIVE_DECREASE_FACTOR):
self.initial_batch_size = initial_batch_size
self.negative_feedback_back_off_in_ms = negative_feedback_back_off_in_ms
self.positive_feedback_batch_size_additive_increase = positive_feedback_batch_size_additive_increase
self.negative_feedback_batch_size_multiplicative_decrease_factor = negative_feedback_batch_size_multiplicative_decrease_factor
175 changes: 175 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations
from typing import Any, Dict, List, cast
from deltacat.utils.ray_utils.retry_handler.ray_remote_tasks_batch_scaling_params import RayRemoteTasksBatchScalingParams
#import necessary errors here
import ray
import time
import logging
from deltacat.utils.ray_utils.retry_handler.logger import configure_logger
from deltacat.utils.ray_utils.retry_handler.task_execution_error import RayRemoteTaskExecutionError
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskInfoObject
from deltacat.utils.ray_utils.retry_handler.retry_strategy_config import get_retry_strategy_config_for_known_exception

logger = configure_logger(logging.getLogger(__name__))

import ray
import time
import logging
from typing import Any, Dict, List, cast
from ray.types import ObjectRef
from RetryExceptions.retryable_exception import RetryableException
from RetryExceptions.non_retryable_exception import NonRetryableException
from RetryExceptions.TaskInfoObject import TaskInfoObject

#inputs: task_callable, task_input, ray_remote_task_options, exception_retry_strategy_configs
#include a seperate class for errors: break down into retryable and non-retryable
#seperate class to put info in a way that the retry class can handle: ray retry task info

#This is what specifically retries a single task
@ray.remote
def submit_single_task(taskObj: TaskInfoObject, progressNotifier: Optional[NotificationInterface] = None) -> Any:
try:
taskObj.attempt_count += 1
curr_attempt = taskObj.attempt_count
if progressNotifier is not None:
#method call to straggler detection using notifier
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
logger.debug(f"Executing the submitted Ray remote task as part of attempt number: {current_attempt_number}")
return tackObj.task_callable(taskObj.task_input)
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
except (Exception) as exception:
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, task_info_object.exception_retry_strategy_configs)
#pass to a new method that handles exception strategy
#retry_strat = ...exception_retry_strategy_configs
if exception_retry_strategy_config is not None:
return RayRemoteTaskExecutionError(exception_retry_strategy_config.exception, task_info_object)




class RetryHandler:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
#given a list of tasks that are failing, we want to classify the error messages and redirect the task
#depending on the exception type using a wrapper
#wrapper function that before execution, checks what exception is being thrown and go to second method to
#commence retrying
def execute_task(self, ray_remote_task_info: RayRemoteTaskInfo) -> Any:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
self.start_tasks_execution([ray_remote_task_info])
return self.wait_and_get_all_task_results()[0]

"""
Starts execution of all given Ray remote tasks
"""
def start_tasks_execution(self, ray_remote_task_infos: List[TaskInfoObject]) -> None:
self.start_tasks_execution_in_batches(ray_remote_task_infos, RayRemoteTasksBatchScalingParams(initial_batch_size=len(ray_remote_task_infos)))

"""
Starts execution of given Ray remote tasks in batches depending on the given Batch scaling params
"""
def start_tasks_execution_in_batches(self, ray_remote_task_infos: List[RayRemoteTaskInfo], batch_scaling_params: RayRemoteTasksBatchScalingParams) -> None:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
self.num_of_submitted_tasks = len(ray_remote_task_infos)
self.current_batch_size = min(batch_scaling_params.initial_batch_size, self.num_of_submitted_tasks)
self.num_of_submitted_tasks_completed = 0
self.remaining_ray_remote_task_infos = ray_remote_task_infos
self.batch_scaling_params = batch_scaling_params
self.task_promise_obj_ref_to_task_info_map: Dict[Any, RayRemoteTaskInfo] = {}

self.unfinished_promises: List[Any] = []
logger.info(f"Starting the execution of {len(ray_remote_task_infos)} Ray remote tasks. Concurrency of tasks execution: {self.current_batch_size}")
self.__submit_tasks(self.remaining_ray_remote_task_infos[:self.current_batch_size])
self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[self.current_batch_size:]


def wait_and_get_all_task_results(self) -> List[Any]:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
return self.wait_and_get_task_results(self.num_of_submitted_tasks)

def get_task_results(self, num_of_results: int) -> List[Any]:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
#implement wrapper here that before execution will try catch an exception
#get what tasks we need to run our execution on
finished, unfinished = ray.wait(unfinished, num_of_results)
#assuming we have the tasks we want to get results of
for finished in finished:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
finished_result = None
try:
finished_result = ray.get(finished)
except (Exception) as exception:
#if exception send to method handle_ray_exception to determine what to do and assign the corresp error
finished_result = self.handle_ray_exception(exception=exception, TaskInfoObject = )#evaluate the exception and return the error

if finished_result and type(finished_result) == RayRemoteTaskExecutionError:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
finished_result = cast(RayRemoteTaskExecutionError, finished_result)
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(finished_result.exception,
finished_result.ray_remote_task_info.exception_retry_strategy_configs)
if (exception_retry_strategy_config is None or finished_result.ray_remote_task_info.num_of_attempts > exception_retry_strategy_config.max_retry_attempts):
logger.error(f"The submitted task has exhausted all the maximum retries configured and finally throws exception - {finished_result.exception}")
raise finished_result.exception
self.__update_ray_remote_task_options_on_exception(finished_result.exception, finished_result.ray_remote_task_info)
self.unfinished_promises.append(self.__invoke_ray_remote_task(ray_remote_task_info=finished_result.ray_remote_task_info))
else:
successful_results.append(finished_result)
del self.task_promise_obj_ref_to_task_info_map[str(finished_promise)]

num_of_successful_results = len(successful_results)
self.num_of_submitted_tasks_completed += num_of_successful_results
self.current_batch_size -= num_of_successful_results

self.__enqueue_new_tasks(num_of_successful_results)

if num_of_successful_results < num_of_results:
successful_results.extend(self.wait_and_get_task_results(num_of_results - num_of_successful_results))
return successful_results
else:
return successful_results


def __enqueue_new_tasks(self, num_of_tasks: int) -> None:
new_tasks_submitted = self.remaining_ray_remote_task_infos[:num_of_tasks]
num_of_new_tasks_submitted = len(new_tasks_submitted)
self.__submit_tasks(new_tasks_submitted)
self.remaining_ray_remote_task_infos = self.remaining_ray_remote_task_infos[num_of_tasks:]
self.current_batch_size += num_of_new_tasks_submitted
logger.info(f"Enqueued {num_of_new_tasks_submitted} new tasks. Current concurrency of tasks execution: {self.current_batch_size}, Current Task progress: {self.num_of_submitted_tasks_completed}/{self.num_of_submitted_tasks}")

def __submit_tasks(self, ray_remote_task_infos: List[RayRemoteTaskInfo]) -> None:
for ray_remote_task_info in ray_remote_task_infos:
time.sleep(0.005)
self.unfinished_promises.append(self.__invoke_ray_remote_task(ray_remote_task_info=ray_remote_task_info))

def __invoke_ray_remote_task(self, ray_remote_task_info: RayRemoteTaskInfo) -> Any:
ray_remote_task_options_arguments = dict()

if ray_remote_task_info.ray_remote_task_options.memory:
ray_remote_task_options_arguments['memory'] = ray_remote_task_info.ray_remote_task_options.memory

if ray_remote_task_info.ray_remote_task_options.num_cpus:
ray_remote_task_options_arguments['num_cpus'] = ray_remote_task_info.ray_remote_task_options.num_cpus

if ray_remote_task_info.ray_remote_task_options.placement_group:
ray_remote_task_options_arguments['placement_group'] = ray_remote_task_info.ray_remote_task_options.placement_group

ray_remote_task_promise_obj_ref = submit_single_task.options(**ray_remote_task_options_arguments).remote(ray_remote_task_info=ray_remote_task_info)
self.task_promise_obj_ref_to_task_info_map[str(ray_remote_task_promise_obj_ref)] = ray_remote_task_info

return ray_remote_task_promise_obj_ref

def __update_ray_remote_task_options_on_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo):
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
exception_retry_strategy_config = get_retry_strategy_config_for_known_exception(exception, ray_remote_task_info.exception_retry_strategy_configs)
if exception_retry_strategy_config and ray_remote_task_info.ray_remote_task_options.memory:
logger.info(f"Updating the Ray remote task options after encountering exception: {exception}")
ray_remote_task_memory_multiply_factor = exception_retry_strategy_config.ray_remote_task_memory_multiply_factor
ray_remote_task_info.ray_remote_task_options.memory *= ray_remote_task_memory_multiply_factor
logger.info(f"Updated ray remote task options Memory: {ray_remote_task_info.ray_remote_task_options.memory}")

def __handle_ray_exception(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> RayRemoteTaskExecutionError:
logger.error(f"Ray remote task failed with {type(exception)} Ray exception: {exception}")
if type(exception).__name__ == "RayTaskError(UnexpectedRayTaskError)":
raise UnexpectedRayTaskError(str(exception))
elif type(exception).__name__ == "RayTaskError(RayOutOfMemoryError)":
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
return RayRemoteTaskExecutionError(exception=RayOutOfMemoryError(str(exception)), ray_remote_task_info=ray_remote_task_info)
elif type(exception) == ray.exceptions.OwnerDiedError:
return RayRemoteTaskExecutionError(exception=RayOwnerDiedError(str(exception)), ray_remote_task_info=ray_remote_task_info)
elif type(exception) == ray.exceptions.WorkerCrashedError:
return RayRemoteTaskExecutionError(exception=RayWorkerCrashedError(str(exception)), ray_remote_task_info=ray_remote_task_info)
elif type(exception) == ray.exceptions.LocalRayletDiedError:
return RayRemoteTaskExecutionError(exception=RayLocalRayletDiedError(str(exception)), ray_remote_task_info=ray_remote_task_info)
elif type(exception) == ray.exceptions.RaySystemError:
return RayRemoteTaskExecutionError(exception=RaySystemError(str(exception)), ray_remote_task_info=ray_remote_task_info)

raise UnexpectedRayPlatformError(str(exception))
14 changes: 14 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/retry_strategy_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List, Optional

from ray_manager.models.ray_remote_task_exception_retry_strategy_config import RayRemoteTaskExceptionRetryConfig

def get_retry_strategy_config_for_known_exception(exception: Exception, exception_retry_strategy_configs: List[RayRemoteTaskExceptionRetryConfig]) -> Optional[RayRemoteTaskExceptionRetryConfig]:
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
for exception_retry_strategy_config in exception_retry_strategy_configs:
if type(exception) == type(exception_retry_strategy_config.exception):
return exception_retry_strategy_config

for exception_retry_strategy_config in exception_retry_strategy_configs:
if isinstance(exception, type(exception_retry_strategy_config.exception)):
return exception_retry_strategy_config

return None
7 changes: 7 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/retryable_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class RetryableError(RuntimeError):
"""
class for errors that can be retried
"""

def __init__(self, *args: object) --> None:
super().__init__(*args)
32 changes: 32 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/task_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
Default max retry attempts of Ray remote task
"""
DEFAULT_MAX_RAY_REMOTE_TASK_RETRY_ATTEMPTS = 3
"""
Default initial backoff before Ray remote task retry, in milli seconds
"""
DEFAULT_RAY_REMOTE_TASK_RETRY_INITIAL_BACK_OFF_IN_MS = 5000
"""
Default Ray remote task retry back off factor
"""
DEFAULT_RAY_REMOTE_TASK_RETRY_BACK_OFF_FACTOR = 2
"""
Default Ray remote task memory multiplication factor
"""
DEFAULT_RAY_REMOTE_TASK_MEMORY_MULTIPLICATION_FACTOR = 1
"""
Ray remote task memory multiplication factor for Ray out of memory error
"""
RAY_REMOTE_TASK_MEMORY_MULTIPLICATION_FACTOR_FOR_OUT_OF_MEMORY_ERROR = 2
"""
Default Ray remote task batch negative feedback back off in milli seconds
"""
DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BACK_OFF_IN_MS = 0
"""
Default Ray remote task batch positive feedback batch size additive increase
"""
DEFAULT_RAY_REMOTE_TASK_BATCH_POSITIVE_FEEDBACK_BATCH_SIZE_ADDITIVE_INCREASE = 0
"""
Default Ray remote task batch positive feedback batch size multiplicative decrease factor
"""
DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BATCH_SIZE_MULTIPLICATIVE_DECREASE_FACTOR = 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List
from deltacat.utils.ray_utils.retry_handler.task_constants import DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BATCH_SIZE_MULTIPLICATIVE_DECREASE_FACTOR, DEFAULT_RAY_REMOTE_TASK_BATCH_NEGATIVE_FEEDBACK_BACK_OFF_IN_MS, DEFAULT_RAY_REMOTE_TASK_BATCH_POSITIVE_FEEDBACK_BATCH_SIZE_ADDITIVE_INCREASE

class TaskExceptionRetryConfig():
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, exception: Exception,
max_retry_attempts: int = DEFAULT_MAX_RAY_REMOTE_TASK_RETRY_ATTEMPTS,
initial_back_off_in_ms: int = DEFAULT_RAY_REMOTE_TASK_RETRY_INITIAL_BACK_OFF_IN_MS,
back_off_factor: int = DEFAULT_RAY_REMOTE_TASK_RETRY_BACK_OFF_FACTOR,
ray_remote_task_memory_multiplication_factor: float = DEFAULT_RAY_REMOTE_TASK_MEMORY_MULTIPLICATION_FACTOR,
is_throttling_exception: bool = False) -> None:
self.exception = exception
self.max_retry_attempts = max_retry_attempts
self.initial_back_off_in_ms = initial_back_off_in_ms
self.back_off_factor = back_off_factor
self.ray_remote_task_memory_multiply_factor = ray_remote_task_memory_multiplication_factor
self.is_throttling_exception = is_throttling_exception

@staticmethod
def getDefaultConfig() -> List[TaskExceptionRetryConfig]:
return [TaskExceptionRetryConfig(exception=RetryableError(), is_throttling_exception=True),
TaskExceptionRetryConfig(exception=RayOutOfMemoryError(), ray_remote_task_memory_multiplication_factor=RAY_REMOTE_TASK_MEMORY_MULTIPLICATION_FACTOR_FOR_OUT_OF_MEMORY_ERROR)]
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class RayRemoteTaskExecutionError():
"""
An error class that denotes the Ray Remote Task Execution Failure
"""
def __init__(self, exception: Exception, ray_remote_task_info: RayRemoteTaskInfo) -> None:
self.exception = exception
self.ray_remote_task_info = ray_remote_task_info
17 changes: 17 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/task_info_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass
from typing import Any, Callable, List
from deltacat.utils.ray_utils.retry_handler.task_info_object import TaskExceptionRetryConfig
from deltacat.utils.ray_utils.retry_handler.task_options import RayRemoteTaskOptions

@dataclass
Class TaskInfoObject:
def __init__(self,
task_callable: Callable[[Any], [Any],
task_input: Any,
ray_remote_task_options: RayRemoteTaskOptions = RayRemoteTaskOptions(),
exception_retry_strategy_configs: List[TaskExceptionRetryConfig]): #what inputs do I need here
self.task_callable = task_callable
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
self.task_input = task_input
self.ray_remote_task_options = ray_remote_task_options
self.exception_retry_strategy_configs = exception_retry_strategy_configs
self.num_of_attempts = 0
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
15 changes: 15 additions & 0 deletions deltacat/utils/ray_utils/retry_handler/task_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass
from typing import Any, Optional

@dataclass
class RayRemoteTaskOptions():
ekaschaw marked this conversation as resolved.
Show resolved Hide resolved
"""
Represents the options corresponding to Ray remote task
"""
def __init__(self,
memory: Optional[float] = None,
num_cpus: Optional[int] = None,
placement_group: Optional[Any] = None) -> None:
self.memory = memory
self.num_cpus = num_cpus
self.placement_group = placement_group