diff --git a/src/aap_eda/core/management/commands/scheduler.py b/src/aap_eda/core/management/commands/scheduler.py index d837d2471..ea72258f6 100644 --- a/src/aap_eda/core/management/commands/scheduler.py +++ b/src/aap_eda/core/management/commands/scheduler.py @@ -70,16 +70,15 @@ https://github.com/rq/rq-scheduler/blob/master/README.rst """ import logging -import re +import typing from datetime import datetime -from time import sleep import django_rq -import redis -from ansible_base.lib.redis.client import DABRedisCluster +import rq_scheduler from django.conf import settings from django_rq.management.commands import rqscheduler -from rq_scheduler import Scheduler + +from aap_eda.core import tasking logger = logging.getLogger(__name__) @@ -89,19 +88,24 @@ RQ_CRON_JOBS = getattr(settings, "RQ_CRON_JOBS", None) -def delete_scheduled_jobs(scheduler: Scheduler): +@tasking.redis_connect_retry() +def delete_scheduled_jobs(scheduler: rq_scheduler.Scheduler) -> None: """Cancel any existing jobs in the scheduler when the app starts up.""" for job in scheduler.get_jobs(): logging.info("Deleting scheduled job: %s", job) job.delete() -def add_startup_jobs(scheduler: Scheduler) -> None: +def add_startup_jobs(scheduler: rq_scheduler.Scheduler) -> None: if not RQ_STARTUP_JOBS: logger.info("No scheduled jobs. Skipping.") return - for entry in RQ_STARTUP_JOBS: + @tasking.redis_connect_retry() + def _add_startup_job( + scheduler: rq_scheduler.Scheduler, + entry: dict[str, typing.Any], + ) -> None: logger.info('Adding startup job "%s"', entry["func"]) scheduled_time = entry.pop("scheduled_time", None) if scheduled_time is None: @@ -111,13 +115,20 @@ def add_startup_jobs(scheduler: Scheduler) -> None: **entry, ) + for entry in RQ_STARTUP_JOBS: + _add_startup_job(scheduler, entry) + -def add_periodic_jobs(scheduler: Scheduler) -> None: +def add_periodic_jobs(scheduler: rq_scheduler.Scheduler) -> None: if not RQ_PERIODIC_JOBS: logger.info("No periodic jobs. Skipping.") return - for entry in RQ_PERIODIC_JOBS: + @tasking.redis_connect_retry() + def _add_periodic_job( + scheduler: rq_scheduler.Scheduler, + entry: dict[str, typing.Any], + ) -> None: logger.info('Adding periodic job "%s"', entry["func"]) scheduled_time = entry.pop("scheduled_time", None) if scheduled_time is None: @@ -127,17 +138,27 @@ def add_periodic_jobs(scheduler: Scheduler) -> None: **entry, ) + for entry in RQ_PERIODIC_JOBS: + _add_periodic_job(scheduler, entry) -def add_cron_jobs(scheduler: Scheduler) -> None: + +def add_cron_jobs(scheduler: rq_scheduler.Scheduler) -> None: """Schedule cron jobs.""" if not RQ_CRON_JOBS: logger.info("No cron jobs. Skipping.") return - for entry in RQ_CRON_JOBS: + @tasking.redis_connect_retry() + def _add_cron_job( + scheduler: rq_scheduler.Scheduler, + entry: dict[str, typing.Any], + ) -> None: logger.info('Adding cron job "%s"', entry["func"]) scheduler.cron(**entry) + for entry in RQ_CRON_JOBS: + _add_cron_job(scheduler, entry) + class Command(rqscheduler.Command): help = "Runs RQ scheduler with configured jobs." @@ -153,73 +174,6 @@ def handle(self, *args, **options) -> None: add_startup_jobs(scheduler) add_periodic_jobs(scheduler) add_cron_jobs(scheduler) - # We are going to start our own loop here to catch exceptions which - # might be coming from a redis cluster and retrying things. - while True: - try: - super().handle(*args, **options) - except ( - redis.exceptions.TimeoutError, - redis.exceptions.ClusterDownError, - redis.exceptions.ConnectionError, - ) as e: - # If we got one of these exceptions but are not on a Cluster go - # ahead and raise it normally. - if not isinstance(scheduler.connection, DABRedisCluster): - raise - - # There are a lot of different exceptions that inherit from - # ConnectionError. So we need to make sure if we got that its - # an actual ConnectionError. If not, go ahead and raise it. - # Note: ClusterDownError and TimeoutError are not subclasses - # of ConnectionError. - if ( - isinstance(e, redis.exceptions.ConnectionError) - and type(e) is not redis.exceptions.ConnectionError - ): - raise - - downed_node_ip = re.findall( - r"[0-9]+(?:\.[0-9]+){3}:[0-9]+", str(e) - ) - - # If we got a cluster issue we will loop here until we can ping - # the server again. - max_backoff = 60 - current_backoff = 1 - while True: - if current_backoff > max_backoff: - # Maybe we just got a network glitch and are waiting - # for a cluster member to fail when its not going to. - # At this point we've waited for 60 secs so lets go - # ahead and let the scheduler try and restart. - logger.error( - "Connection to redis is still down " - "going to attempt to restart scheduler" - ) - break - - backoff = min(current_backoff, max_backoff) - logger.error( - f"Connection to redis cluster failed. Attempting to " - f"reconnect in {backoff}" - ) - sleep(backoff) - current_backoff = 2 * current_backoff - try: - if downed_node_ip: - cluster_nodes = ( - scheduler.connection.cluster_nodes() - ) - for ip in downed_node_ip: - if "fail" not in cluster_nodes[ip]["flags"]: - raise Exception( - "Failed node is not yet in a failed " - "state" - ) - else: - scheduler.connection.ping() - break - # We could tighten this exception up - except Exception: - pass + super().handle(*args, **options) + + handle = tasking.redis_connect_retry()(handle) diff --git a/src/aap_eda/core/tasking/__init__.py b/src/aap_eda/core/tasking/__init__.py index 7aebb5710..f0aff97e6 100644 --- a/src/aap_eda/core/tasking/__init__.py +++ b/src/aap_eda/core/tasking/__init__.py @@ -1,23 +1,17 @@ """Tools for running background tasks.""" from __future__ import annotations +import functools import logging +import time +import typing from datetime import datetime, timedelta -from time import sleep from types import MethodType -from typing import ( - Any, - Callable, - Iterable, - List, - Optional, - Protocol, - Type, - Union, -) +import django_rq import redis import rq +import rq_scheduler from ansible_base.lib import constants from ansible_base.lib.redis.client import ( DABRedis, @@ -26,18 +20,7 @@ get_redis_status as _get_redis_status, ) from django.conf import settings -from django_rq import enqueue, get_queue, get_scheduler, job -from django_rq.queues import Queue as _Queue -from rq import Connection, Worker as _Worker, results -from rq.defaults import ( - DEFAULT_JOB_MONITORING_INTERVAL, - DEFAULT_RESULT_TTL, - DEFAULT_WORKER_TTL, -) -from rq.job import Job as _Job, JobStatus -from rq.registry import StartedJobRegistry -from rq.serializers import JSONSerializer -from rq_scheduler import Scheduler as _Scheduler +from rq import results as rq_results from aap_eda.settings import default @@ -46,18 +29,15 @@ "Queue", "ActivationWorker", "DefaultWorker", - "enqueue", - "job", - "get_queue", "unique_enqueue", "job_from_queue", ] logger = logging.getLogger(__name__) -ErrorHandlerType = Callable[[_Job], None] +ErrorHandlerType = typing.Callable[[rq.job.Job], None] -_ErrorHandlersArgType = Union[ +_ErrorHandlersArgType = typing.Union[ list[ErrorHandlerType], tuple[ErrorHandlerType], ErrorHandlerType, @@ -65,6 +45,68 @@ ] +def redis_connect_retry( + max_delay: int = 60, + loop_exit: typing.Optional[typing.Callable[[Exception], bool]] = None, +) -> typing.Callable: + max_delay = max(max_delay, 1) + + def decorator(func: typing.Callable) -> typing.Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> typing.Optional[typing.Any]: + value = None + delay = 1 + while True: + try: + value = func(*args, **kwargs) + if delay > 1: + logger.info("Connection to redis re-established.") + break + except ( + redis.exceptions.ClusterDownError, + redis.exceptions.ConnectionError, + redis.exceptions.RedisClusterException, + redis.exceptions.TimeoutError, + ) as e: + # There are a lot of different exceptions that inherit from + # ConnectionError. So we need to make sure if we got that + # its an actual ConnectionError. If not, go ahead and raise + # it. + # Note: ClusterDownError and TimeoutError are not + # subclasses of ConnectionError. + if ( + isinstance(e, redis.exceptions.ConnectionError) + and type(e) is not redis.exceptions.ConnectionError + ): + raise + + # RedisClusterException is used as a catch-all for various + # faults. The only one we should tolerate is that which + # includes "Redis Cluster cannot be connected." which is + # experienced when there are zero cluster hosts that can be + # reached. + if isinstance( + e, redis.exceptions.RedisClusterException + ) and ("Redis Cluster cannot be connected." not in str(e)): + raise + + if (loop_exit is not None) and loop_exit(e): + break + + delay = min(delay, max_delay) + logger.error( + f"Connection to redis failed; retrying in {delay}s." + ) + time.sleep(delay) + + delay *= 2 + return value + + return wrapper + + return decorator + + def _create_url_from_parameters(**kwargs) -> str: # Make the URL that DAB will expect for instantiation. schema = "unix" @@ -80,7 +122,7 @@ def _create_url_from_parameters(**kwargs) -> str: return url -def _prune_redis_kwargs(**kwargs) -> dict[str, Any]: +def _prune_redis_kwargs(**kwargs) -> dict[str, typing.Any]: """Prunes the kwargs of unsupported parameters for RedisCluster.""" # HA cluster does not support an alternate redis db and will generate an # exception if we pass a value (even the default). If we're in that @@ -97,7 +139,7 @@ def _prune_redis_kwargs(**kwargs) -> dict[str, Any]: return kwargs -def get_redis_client(**kwargs) -> Union[DABRedis, DABRedisCluster]: +def get_redis_client(**kwargs) -> typing.Union[DABRedis, DABRedisCluster]: """Instantiate a Redis client via DAB. DAB will return an appropriate client for HA based on the passed @@ -122,7 +164,7 @@ def is_redis_failed() -> bool: return status == constants.STATUS_FAILED -class Scheduler(_Scheduler): +class Scheduler(rq_scheduler.Scheduler): """Custom scheduler class.""" def __init__( @@ -198,12 +240,12 @@ def enable_redis_prefix(): def eda_get_key(job_id): return f"{redis_prefix}:results:{job_id}" - results.get_key = eda_get_key + rq_results.get_key = eda_get_key def cls_get_key(cls, job_id): return f"{redis_prefix}:results:{job_id}" - results.Result.get_key = MethodType(cls_get_key, results.Result) + rq_results.Result.get_key = MethodType(cls_get_key, rq_results.Result) def property_registry_cleaning_key(self): return f"{redis_prefix}:clean_registries:{self.name}" @@ -218,17 +260,17 @@ def property_registry_cleaning_key(self): enable_redis_prefix() -class SerializerProtocol(Protocol): +class SerializerProtocol(typing.Protocol): @staticmethod - def dumps(obj: Any) -> bytes: + def dumps(obj: typing.Any) -> bytes: ... @staticmethod - def loads(data: bytes) -> Any: + def loads(data: bytes) -> typing.Any: ... -class Queue(_Queue): +class Queue(django_rq.queues.Queue): """Custom queue class. Uses JSONSerializer as a default one. @@ -238,14 +280,14 @@ def __init__( self, name: str = "default", default_timeout: int = -1, - connection: Optional[Connection] = None, + connection: typing.Optional[rq.Connection] = None, is_async: bool = True, - job_class: Optional[_Job] = None, - serializer: Optional[SerializerProtocol] = None, - **kwargs: Any, + job_class: typing.Optional[rq.job.Job] = None, + serializer: typing.Optional[SerializerProtocol] = None, + **kwargs: typing.Any, ): if serializer is None: - serializer = JSONSerializer + serializer = rq.serializers.JSONSerializer super().__init__( name=name, @@ -258,7 +300,7 @@ def __init__( ) -class Job(_Job): +class Job(rq.job.Job): """Custom job class. Uses JSONSerializer as a default one. @@ -266,12 +308,12 @@ class Job(_Job): def __init__( self, - id: Optional[str] = None, - connection: Optional[Connection] = None, - serializer: Optional[SerializerProtocol] = None, + id: typing.Optional[str] = None, + connection: typing.Optional[rq.Connection] = None, + serializer: typing.Optional[SerializerProtocol] = None, ): if serializer is None: - serializer = JSONSerializer + serializer = rq.serializers.JSONSerializer connection = _get_necessary_client_connection(connection) super().__init__(id, connection, serializer) @@ -282,7 +324,9 @@ def __init__( # couldn't use it as DAB requires a url parameter that Redis does not. # If the connection a worker is given is not from DAB we replace it # with one that is. -def _get_necessary_client_connection(connection: Connection) -> Connection: +def _get_necessary_client_connection( + connection: rq.Connection, +) -> rq.Connection: if not isinstance(connection, (DABRedis, DABRedisCluster)): connection = get_redis_client( **default.rq_redis_client_instantiation_parameters() @@ -290,7 +334,7 @@ def _get_necessary_client_connection(connection: Connection) -> Connection: return connection -class Worker(_Worker): +class Worker(rq.Worker): """Custom worker class. Provides establishment of DAB Redis client and work arounds for various @@ -299,20 +343,22 @@ class Worker(_Worker): def __init__( self, - queues: Iterable[Union[Queue, str]], - name: Optional[str] = None, - default_result_ttl: int = DEFAULT_RESULT_TTL, - connection: Optional[Connection] = None, - exc_handler: Any = None, + queues: typing.Iterable[typing.Union[Queue, str]], + name: typing.Optional[str] = None, + default_result_ttl: int = rq.defaults.DEFAULT_RESULT_TTL, + connection: typing.Optional[rq.Connection] = None, + exc_handler: typing.Any = None, exception_handlers: _ErrorHandlersArgType = None, - default_worker_ttl: int = DEFAULT_WORKER_TTL, - job_class: Type[_Job] = None, - queue_class: Type[_Queue] = None, + default_worker_ttl: int = rq.defaults.DEFAULT_WORKER_TTL, + job_class: typing.Type[rq.job.Job] = None, + queue_class: typing.Type[django_rq.queues.Queue] = None, log_job_description: bool = True, - job_monitoring_interval: int = DEFAULT_JOB_MONITORING_INTERVAL, + job_monitoring_interval: int = ( + rq.defaults.DEFAULT_JOB_MONITORING_INTERVAL + ), disable_default_exception_handler: bool = False, prepare_for_work: bool = True, - serializer: Optional[SerializerProtocol] = None, + serializer: typing.Optional[SerializerProtocol] = None, ): connection = _get_necessary_client_connection(connection) super().__init__( @@ -329,14 +375,14 @@ def __init__( job_monitoring_interval=job_monitoring_interval, disable_default_exception_handler=disable_default_exception_handler, # noqa: E501 prepare_for_work=prepare_for_work, - serializer=JSONSerializer, + serializer=rq.serializers.JSONSerializer, ) self.is_shutting_down = False def _set_connection( self, - connection: Union[DABRedis, DABRedisCluster], - ) -> Union[DABRedis, DABRedisCluster]: + connection: typing.Union[DABRedis, DABRedisCluster], + ) -> typing.Union[DABRedis, DABRedisCluster]: # A DABRedis connection doesn't need intervention. if isinstance(connection, DABRedis): return super()._set_connection(connection) @@ -366,12 +412,14 @@ def _set_connection( @classmethod def all( cls, - connection: Optional[Union[DABRedis, DABRedisCluster]] = None, - job_class: Optional[Type[Job]] = None, - queue_class: Optional[Type[Queue]] = None, - queue: Optional[Queue] = None, + connection: typing.Optional[ + typing.Union[DABRedis, DABRedisCluster] + ] = None, + job_class: typing.Optional[typing.Type[Job]] = None, + queue_class: typing.Optional[typing.Type[Queue]] = None, + queue: typing.Optional[Queue] = None, serializer=None, - ) -> List[Worker]: + ) -> typing.List[Worker]: # If we don't have a queue (whose connection would be used) make # certain that we have an appropriate connection and pass it # to the superclass. @@ -386,7 +434,10 @@ def all( ) def handle_job_success( - self, job: Job, queue: Queue, started_job_registry: StartedJobRegistry + self, + job: Job, + queue: Queue, + started_job_registry: rq.registry.StartedJobRegistry, ): # A DABRedis connection doesn't need intervention. if isinstance(self.connection, DABRedis): @@ -417,86 +468,42 @@ def handle_warm_shutdown_request(self): self.is_shutting_down = True super().handle_warm_shutdown_request() - # We are going to override the work function to create our own loop. - # This will allow us to catch exceptions that the default work method will - # not handle and restart our worker process if we hit them. + # We are overriding the work function to utilize our own common + # Redis connection looping. def work( self, burst: bool = False, logging_level: str = "INFO", date_format: str = rq.defaults.DEFAULT_LOGGING_DATE_FORMAT, log_format: str = rq.defaults.DEFAULT_LOGGING_FORMAT, - max_jobs: Optional[int] = None, + max_jobs: typing.Optional[int] = None, with_scheduler: bool = False, ) -> bool: + value = None while True: - # super.work() returns a value that we want to return on a normal - # exit. - return_value = None - try: - return_value = super().work( - burst, - logging_level, - date_format, - log_format, - max_jobs, - with_scheduler, - ) - except ( - redis.exceptions.TimeoutError, - redis.exceptions.ClusterDownError, - redis.exceptions.ConnectionError, - ) as e: - # If we got one of these exceptions but are not on a Cluster go - # ahead and raise it normally. - if not isinstance(self.connection, DABRedisCluster): - raise - - # There are a lot of different exceptions that inherit from - # ConnectionError. So we need to make sure if we got that its - # an actual ConnectionError. If not, go ahead and raise it. - # Note: ClusterDownError and TimeoutError are not subclasses - # of ConnectionError. - if ( - isinstance(e, redis.exceptions.ConnectionError) - and type(e) is not redis.exceptions.ConnectionError - ): - raise - - # If we got a cluster issue we will loop here until we can ping - # the server again. - max_backoff = 60 - current_backoff = 1 - while True: - backoff = min(current_backoff, max_backoff) - logger.error( - f"Connection to redis cluster failed. Attempting to " - f"reconnect in {backoff}" - ) - sleep(backoff) - current_backoff = 2 * current_backoff - try: - self.connection.ping() - break - # We could tighten this exception up. - except Exception: - pass - # At this point return value is none so we are going to go - # ahead and fall through to the loop to restart. - - # We are outside of the work function with either: - # a "normal exist" - # an exit that did not raise an exception - if return_value: - logger.debug(f"Working exited normally with {return_value}") - return return_value - elif self.is_shutting_down: - # Get got a warm shutdown request, lets respect it - return return_value - else: - logger.error( - "Work exited no return value, going to restart the worker" - ) + value = redis_connect_retry( + loop_exit=lambda e: self.is_shutting_down + )(super().work)( + burst, + logging_level, + date_format, + log_format, + max_jobs, + with_scheduler, + ) + + # If there's a return value or the worker is shutting down + # break out of the loop. + if (value is not None) or self.is_shutting_down: + if value is not None: + logger.debug(f"Working exited normally with {value}") + break + + logger.error( + "Work exited no return value, going to restart the worker" + ) + + return value class DefaultWorker(Worker): @@ -507,20 +514,12 @@ class DefaultWorker(Worker): def __init__( self, - queues: Iterable[Union[Queue, str]], - name: Optional[str] = "default", - default_result_ttl: int = DEFAULT_RESULT_TTL, - connection: Optional[Connection] = None, - exc_handler: Any = None, - exception_handlers: _ErrorHandlersArgType = None, - default_worker_ttl: int = DEFAULT_WORKER_TTL, - job_class: Type[_Job] = None, - queue_class: Type[_Queue] = None, - log_job_description: bool = True, - job_monitoring_interval: int = DEFAULT_JOB_MONITORING_INTERVAL, - disable_default_exception_handler: bool = False, - prepare_for_work: bool = True, - serializer: Optional[SerializerProtocol] = None, + queues: typing.Iterable[typing.Union[Queue, str]], + name: typing.Optional[str] = "default", + job_class: typing.Type[rq.job.Job] = None, + queue_class: typing.Type[django_rq.queues.Queue] = None, + serializer: typing.Optional[SerializerProtocol] = None, + **kwargs, ): if job_class is None: job_class = Job @@ -530,18 +529,10 @@ def __init__( super().__init__( queues=queues, name=name, - default_result_ttl=default_result_ttl, - connection=connection, - exc_handler=exc_handler, - exception_handlers=exception_handlers, - default_worker_ttl=default_worker_ttl, job_class=job_class, queue_class=queue_class, - log_job_description=log_job_description, - job_monitoring_interval=job_monitoring_interval, - disable_default_exception_handler=disable_default_exception_handler, # noqa: E501 - prepare_for_work=prepare_for_work, - serializer=JSONSerializer, + serializer=rq.serializers.JSONSerializer, + **kwargs, ) @@ -553,20 +544,14 @@ class ActivationWorker(Worker): def __init__( self, - queues: Iterable[Union[Queue, str]], - name: Optional[str] = "activation", - default_result_ttl: int = DEFAULT_RESULT_TTL, - connection: Optional[Connection] = None, - exc_handler: Any = None, - exception_handlers: _ErrorHandlersArgType = None, - default_worker_ttl: int = DEFAULT_WORKER_TTL, - job_class: Type[_Job] = None, - queue_class: Type[_Queue] = None, - log_job_description: bool = True, - job_monitoring_interval: int = DEFAULT_JOB_MONITORING_INTERVAL, - disable_default_exception_handler: bool = False, - prepare_for_work: bool = True, - serializer: Optional[SerializerProtocol] = None, + queues: typing.Iterable[typing.Union[Queue, str]], + name: typing.Optional[str] = "activation", + connection: typing.Optional[rq.Connection] = None, + default_worker_ttl: int = rq.defaults.DEFAULT_WORKER_TTL, + job_class: typing.Type[rq.job.Job] = None, + queue_class: typing.Type[django_rq.queues.Queue] = None, + serializer: typing.Optional[SerializerProtocol] = None, + **kwargs, ): if job_class is None: job_class = Job @@ -577,26 +562,21 @@ def __init__( super().__init__( queues=[Queue(name=queue_name, connection=connection)], name=name, - default_result_ttl=default_result_ttl, connection=connection, - exc_handler=exc_handler, - exception_handlers=exception_handlers, default_worker_ttl=settings.DEFAULT_WORKER_TTL, job_class=job_class, queue_class=queue_class, - log_job_description=log_job_description, - job_monitoring_interval=job_monitoring_interval, - disable_default_exception_handler=disable_default_exception_handler, # noqa: E501 - prepare_for_work=prepare_for_work, - serializer=JSONSerializer, + serializer=rq.serializers.JSONSerializer, + **kwargs, ) +@redis_connect_retry() def enqueue_delay( queue_name: str, job_id: str, delay: int, *args, **kwargs ) -> Job: """Enqueue a job to run after specific seconds.""" - scheduler = get_scheduler(name=queue_name) + scheduler = django_rq.get_scheduler(name=queue_name) return scheduler.enqueue_at( datetime.utcnow() + timedelta(seconds=delay), job_id=job_id, @@ -605,11 +585,13 @@ def enqueue_delay( ) +@redis_connect_retry() def queue_cancel_job(queue_name: str, job_id: str) -> None: - scheduler = get_scheduler(name=queue_name) + scheduler = django_rq.get_scheduler(name=queue_name) scheduler.cancel(job_id) +@redis_connect_retry() def unique_enqueue(queue_name: str, job_id: str, *args, **kwargs) -> Job: """Enqueue a new job if it is not already enqueued. @@ -624,22 +606,25 @@ def unique_enqueue(queue_name: str, job_id: str, *args, **kwargs) -> Job: ) return job - queue = get_queue(name=queue_name) + queue = django_rq.get_queue(name=queue_name) kwargs["job_id"] = job_id logger.info(f"Enqueing unique job: {job_id}") return queue.enqueue(*args, **kwargs) -def job_from_queue(queue: Union[Queue, str], job_id: str) -> Optional[Job]: +@redis_connect_retry() +def job_from_queue( + queue: typing.Union[Queue, str], job_id: str +) -> typing.Optional[Job]: """Return queue job if it not canceled or finished else None.""" if type(queue) is str: - queue = get_queue(name=queue) + queue = django_rq.get_queue(name=queue) job = queue.fetch_job(job_id) if job and job.get_status(refresh=True) in [ - JobStatus.QUEUED, - JobStatus.STARTED, - JobStatus.DEFERRED, - JobStatus.SCHEDULED, + rq.job.JobStatus.QUEUED, + rq.job.JobStatus.STARTED, + rq.job.JobStatus.DEFERRED, + rq.job.JobStatus.SCHEDULED, ]: return job return None diff --git a/src/aap_eda/services/activation/activation_manager.py b/src/aap_eda/services/activation/activation_manager.py index a7077c019..1457107f3 100644 --- a/src/aap_eda/services/activation/activation_manager.py +++ b/src/aap_eda/services/activation/activation_manager.py @@ -17,16 +17,16 @@ import typing as tp from datetime import timedelta +import rq from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.db import transaction from django.db.utils import IntegrityError from django.utils import timezone from pydantic import ValidationError -from rq import get_current_job from aap_eda.api.serializers.activation import is_activation_valid -from aap_eda.core import models +from aap_eda.core import models, tasking from aap_eda.core.enums import ActivationStatus, RestartPolicy from aap_eda.services.activation import exceptions from aap_eda.services.activation.engine import exceptions as engine_exceptions @@ -1023,8 +1023,9 @@ def _create_activation_instance(self): queue_name=queue_name, ) + @tasking.redis_connect_retry() def _get_queue_name(self) -> str: - this_job = get_current_job() + this_job = rq.get_current_job() return this_job.origin def _get_container_request(self) -> ContainerRequest: diff --git a/src/aap_eda/services/activation/engine/podman.py b/src/aap_eda/services/activation/engine/podman.py index c2a86d7e5..7a4a8f0e3 100644 --- a/src/aap_eda/services/activation/engine/podman.py +++ b/src/aap_eda/services/activation/engine/podman.py @@ -15,13 +15,13 @@ import logging import os +import rq from dateutil import parser from django.conf import settings from podman import PodmanClient from podman.domain.images import Image from podman.errors import ContainerError, ImageNotFound from podman.errors.exceptions import APIError, NotFound -from rq.timeouts import JobTimeoutException from aap_eda.core.enums import ActivationStatus from aap_eda.utils import str_to_bool @@ -373,7 +373,7 @@ def _pull_image( except APIError as e: LOGGER.error(f"Failed to pull image {request.image_url}: {e}") raise exceptions.ContainerStartError(str(e)) - except JobTimeoutException as e: + except rq.timeouts.JobTimeoutException as e: msg = f"Timeout: {e}" LOGGER.error(msg) log_handler.write(msg, True) diff --git a/src/aap_eda/tasks/analytics.py b/src/aap_eda/tasks/analytics.py index c53ea6eae..2fd78181d 100644 --- a/src/aap_eda/tasks/analytics.py +++ b/src/aap_eda/tasks/analytics.py @@ -16,11 +16,11 @@ from datetime import datetime, timezone import django_rq -from rq.exceptions import NoSuchJobError +import rq from aap_eda.analytics import collector from aap_eda.conf import application_settings -from aap_eda.core.tasking import Job, unique_enqueue +from aap_eda.core import tasking logger = logging.getLogger(__name__) @@ -30,6 +30,7 @@ ANALYTICS_TASKS_QUEUE = "default" +@tasking.redis_connect_retry() def schedule_gather_analytics() -> None: scheduler = django_rq.get_scheduler() func = "aap_eda.tasks.analytics.gather_analytics" @@ -47,10 +48,13 @@ def schedule_gather_analytics() -> None: ) +@tasking.redis_connect_retry() def reschedule_gather_analytics(new_interval: int, serializer=None) -> None: try: - job = Job.fetch(ANALYTICS_SCHEDULE_JOB_ID, serializer=serializer) - except NoSuchJobError: + job = tasking.Job.fetch( + ANALYTICS_SCHEDULE_JOB_ID, serializer=serializer + ) + except rq.exceptions.NoSuchJobError: logger.warning(f"Job {ANALYTICS_SCHEDULE_JOB_ID} does not exist") return job.meta["interval"] = new_interval @@ -63,7 +67,7 @@ def reschedule_gather_analytics(new_interval: int, serializer=None) -> None: def gather_analytics(queue_name: str = ANALYTICS_TASKS_QUEUE) -> None: logger.info("Queue EDA analytics") - unique_enqueue(queue_name, ANALYTICS_JOB_ID, _gather_analytics) + tasking.unique_enqueue(queue_name, ANALYTICS_JOB_ID, _gather_analytics) def _gather_analytics() -> None: diff --git a/src/aap_eda/tasks/orchestrator.py b/src/aap_eda/tasks/orchestrator.py index 1c51b7f40..341b5dde2 100644 --- a/src/aap_eda/tasks/orchestrator.py +++ b/src/aap_eda/tasks/orchestrator.py @@ -18,19 +18,18 @@ from datetime import datetime, timedelta from typing import Optional +import django_rq from django.conf import settings from django.core.exceptions import ObjectDoesNotExist -from django_rq import get_queue import aap_eda.tasks.activation_request_queue as requests_queue -from aap_eda.core import models +from aap_eda.core import models, tasking from aap_eda.core.enums import ( ActivationRequest, ActivationStatus, ProcessParentType, ) from aap_eda.core.models import Activation, ActivationRequestQueue -from aap_eda.core.tasking import Worker, unique_enqueue from aap_eda.services.activation import exceptions from aap_eda.services.activation.activation_manager import ( ActivationManager, @@ -300,7 +299,7 @@ def dispatch( ) return - unique_enqueue( + tasking.unique_enqueue( queue_name, job_id, _manage, @@ -365,16 +364,17 @@ def get_queue_name_by_parent_id( return process.rulebookprocessqueue.queue_name +@tasking.redis_connect_retry() def check_rulebook_queue_health(queue_name: str) -> bool: """Check for the state of the queue. Returns True if the queue is healthy, False otherwise. Clears the queue if all workers are dead to avoid stuck processes. """ - queue = get_queue(queue_name) + queue = django_rq.get_queue(queue_name) all_workers_dead = True - for worker in Worker.all(queue=queue): + for worker in tasking.Worker.all(queue=queue): last_heartbeat = worker.last_heartbeat if last_heartbeat is None: continue @@ -482,7 +482,7 @@ def monitor_rulebook_processes() -> None: def enqueue_monitor_rulebook_processes() -> None: """Wrap monitor_rulebook_processes to ensure only one task is enqueued.""" - unique_enqueue( + tasking.unique_enqueue( "default", "monitor_rulebook_processes", monitor_rulebook_processes, diff --git a/src/aap_eda/tasks/project.py b/src/aap_eda/tasks/project.py index e76a7d33b..edffefeb1 100644 --- a/src/aap_eda/tasks/project.py +++ b/src/aap_eda/tasks/project.py @@ -14,15 +14,19 @@ import logging +import django_rq from django.conf import settings -from aap_eda.core import models -from aap_eda.core.tasking import get_queue, job, unique_enqueue +from aap_eda.core import models, tasking from aap_eda.services.project import ProjectImportError, ProjectImportService logger = logging.getLogger(__name__) PROJECT_TASKS_QUEUE = "default" +# Wrap the django_rq job decorator so its processing is within our retry +# code. +job = tasking.redis_connect_retry()(django_rq.job) + @job(PROJECT_TASKS_QUEUE) def import_project(project_id: int): @@ -54,9 +58,16 @@ def sync_project(project_id: int): # default is the default queue def monitor_project_tasks(queue_name: str = PROJECT_TASKS_QUEUE): job_id = "monitor_project_tasks" - unique_enqueue(queue_name, job_id, _monitor_project_tasks, queue_name) + tasking.unique_enqueue( + queue_name, job_id, _monitor_project_tasks, queue_name + ) +# Although this is a periodically run task and that could be viewed as +# providing resilience to Redis connection issues we decorate it with the +# redis_connect_retry to maintain the model that anything directly dependent on +# a Redis connection is wrapped by retries. +@tasking.redis_connect_retry() def _monitor_project_tasks(queue_name: str) -> None: """Handle project tasks that are stuck. @@ -66,7 +77,7 @@ def _monitor_project_tasks(queue_name: str) -> None: """ logger.info("Task started: Monitor project tasks") - queue = get_queue(queue_name) + queue = django_rq.get_queue(queue_name) # Filter projects that doesn't have any related job pending_projects = models.Project.objects.filter( diff --git a/tests/conftest.py b/tests/conftest.py index 1b51e316e..733b762f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,11 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import pytest from aap_eda.settings import default +################################################################# +# Log capture factory +################################################################# +@pytest.fixture +def caplog_factory(caplog): + def _factory(logger): + logger.setLevel(logging.INFO) + logger.handlers += [caplog.handler] + return caplog + + return _factory + + ################################################################# # Redis ################################################################# diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index aa8c25aa2..798d08279 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -13,7 +13,6 @@ # limitations under the License. import copy -import logging import uuid from typing import Any, Dict, List from unittest import mock @@ -1128,16 +1127,6 @@ def default_queue(test_queue_name, redis_external) -> Queue: return Queue(test_queue_name, connection=redis_external) -@pytest.fixture -def caplog_factory(caplog): - def _factory(logger): - logger.setLevel(logging.INFO) - logger.handlers += [caplog.handler] - return caplog - - return _factory - - @pytest.fixture def container_engine_mock() -> MagicMock: return create_autospec(ContainerEngine, instance=True) diff --git a/tests/integration/services/activation/test_manager.py b/tests/integration/services/activation/test_manager.py index 946fb35b1..015090ed0 100644 --- a/tests/integration/services/activation/test_manager.py +++ b/tests/integration/services/activation/test_manager.py @@ -299,7 +299,7 @@ def test_start_first_run( ) container_engine_mock.start.return_value = "test-pod-id" with mock.patch( - "aap_eda.services.activation.activation_manager.get_current_job", + "rq.get_current_job", return_value=job_mock, ): activation_manager.start() @@ -335,7 +335,7 @@ def test_start_restart( ) container_engine_mock.start.return_value = "test-pod-id" with mock.patch( - "aap_eda.services.activation.activation_manager.get_current_job", + "rq.get_current_job", return_value=job_mock, ): activation_manager.start(is_restart=True) @@ -690,7 +690,7 @@ def test_start_max_running_activations( ) with pytest.raises(exceptions.MaxRunningProcessesError), mock.patch( - "aap_eda.services.activation.activation_manager.get_current_job", + "rq.get_current_job", return_value=job_mock, ): activation_manager.start() diff --git a/tests/integration/tasks/test_orchestrator.py b/tests/integration/tasks/test_orchestrator.py index 9fbf2d03a..e60aecb44 100644 --- a/tests/integration/tasks/test_orchestrator.py +++ b/tests/integration/tasks/test_orchestrator.py @@ -164,7 +164,7 @@ def test_manage_not_start( return_value=container_engine_mock, ): with mock.patch( - "aap_eda.services.activation.activation_manager.get_current_job", + "rq.get_current_job", return_value=job_mock, ): orchestrator._manage(ProcessParentType.ACTIVATION, activation.id) @@ -186,7 +186,7 @@ def test_manage_not_start( (orchestrator.restart_rulebook_process, ActivationRequest.RESTART), ], ) -@mock.patch("aap_eda.tasks.orchestrator.unique_enqueue") +@mock.patch("aap_eda.tasks.orchestrator.tasking.unique_enqueue") @mock.patch("aap_eda.tasks.orchestrator.get_least_busy_queue_name") def test_activation_requests( get_queue_name_mock, @@ -216,7 +216,7 @@ def test_activation_requests( @pytest.mark.django_db -@mock.patch("aap_eda.tasks.orchestrator.unique_enqueue") +@mock.patch("aap_eda.tasks.orchestrator.tasking.unique_enqueue") @mock.patch("aap_eda.tasks.orchestrator.get_least_busy_queue_name") def test_monitor_rulebook_processes( get_queue_name_mock, enqueue_mock, activation, max_running_processes @@ -297,7 +297,7 @@ def side_effect(*args, **kwargs): return_value=container_engine_mock, ): with mock.patch( - "aap_eda.services.activation.activation_manager.get_current_job", + "rq.get_current_job", return_value=job_mock, ): orchestrator._manage(ProcessParentType.ACTIVATION, activation.id) @@ -309,7 +309,7 @@ def side_effect(*args, **kwargs): @pytest.mark.django_db -@mock.patch("aap_eda.tasks.orchestrator.unique_enqueue") +@mock.patch("aap_eda.tasks.orchestrator.tasking.unique_enqueue") def test_monitor_rulebook_processes_unique(enqueue_mock): orchestrator.enqueue_monitor_rulebook_processes() enqueue_mock.assert_called_once_with( diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index f663918ce..8c49af324 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -16,6 +16,7 @@ from datetime import datetime, timedelta from unittest import mock +import django_rq import pytest from django.conf import settings @@ -74,12 +75,12 @@ def _process_object_filter(**kwargs) -> mock.Mock: def _get_queue(name: str) -> mock.Mock: return mock_queues[name] - monkeypatch.setattr(orchestrator, "get_queue", _get_queue) + monkeypatch.setattr(django_rq, "get_queue", _get_queue) def _worker_all(queue=None) -> list: return queue.workers - monkeypatch.setattr(orchestrator.Worker, "all", _worker_all) + monkeypatch.setattr(orchestrator.tasking.Worker, "all", _worker_all) def _process_count(queue: mock.Mock) -> None: return queue.process_count @@ -394,9 +395,7 @@ def setup_queue_health(): timedelta_mock.return_value = timedelta(seconds=60) patches = { - "get_queue": mock.patch( - "aap_eda.tasks.orchestrator.get_queue", get_queue_mock - ), + "get_queue": mock.patch("django_rq.get_queue", get_queue_mock), "datetime": mock.patch( "aap_eda.tasks.orchestrator.datetime", datetime_mock ), @@ -434,7 +433,9 @@ def test_check_rulebook_queue_health_all_workers_dead(setup_queue_health): all_workers_mock = mock.Mock(return_value=[worker_mock]) datetime_mock.now.return_value = datetime(2022, 1, 1, minute=5) - with mock.patch("aap_eda.tasks.orchestrator.Worker.all", all_workers_mock): + with mock.patch( + "aap_eda.tasks.orchestrator.tasking.Worker.all", all_workers_mock + ): result = check_rulebook_queue_health(queue_name) get_queue_mock.assert_called_once_with(queue_name) @@ -460,7 +461,9 @@ def test_check_rulebook_queue_health_some_workers_alive(setup_queue_health): all_workers_mock = mock.Mock(return_value=[worker_mock1, worker_mock2]) datetime_mock.now.return_value = datetime(2022, 1, 1, hour=6, second=30) - with mock.patch("aap_eda.tasks.orchestrator.Worker.all", all_workers_mock): + with mock.patch( + "aap_eda.tasks.orchestrator.tasking.Worker.all", all_workers_mock + ): result = check_rulebook_queue_health(queue_name) get_queue_mock.assert_called_once_with(queue_name) diff --git a/tests/unit/test_redis_connect_retry.py b/tests/unit/test_redis_connect_retry.py new file mode 100644 index 000000000..3d82b8763 --- /dev/null +++ b/tests/unit/test_redis_connect_retry.py @@ -0,0 +1,117 @@ +# Copyright 2024 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import redis + +from aap_eda.core import tasking + + +@pytest.fixture +def tasking_caplog(caplog_factory): + return caplog_factory(tasking.logger) + + +def test_max_delay(tasking_caplog): + """Tests that the specification of a maximum delay value is respected and + used.""" + + # The sleep between retries grows exponentially as 1, 2, 4, 8, 16, 32 and + # is then capped at a default of 60. + # + # This test will loop sufficiently to reach 60 multiple times but will + # specify a cap of 5. + loop_count = 0 + loop_limit = 8 + + @tasking.redis_connect_retry(max_delay=5) + def _test_function(): + nonlocal loop_count + + loop_count += 1 + if loop_count >= (loop_limit + 1): + return + raise redis.exceptions.ConnectionError + + _test_function() + + assert "Connection to redis failed; retrying in 5s." in tasking_caplog.text + assert ( + "Connection to redis failed; retrying in 60s." + not in tasking_caplog.text + ) + + +class SubclassConnectionError(redis.exceptions.ConnectionError): + pass + + +@pytest.mark.parametrize( + ("tolerate", "exception"), + [ + [True, redis.exceptions.ClusterDownError("cluster down")], + [False, SubclassConnectionError("not tolerated")], + [True, redis.exceptions.ConnectionError("connection error")], + [ + True, + redis.exceptions.RedisClusterException( + "Redis Cluster cannot be connected." + ), + ], + [False, redis.exceptions.RedisClusterException("not tolerated")], + [True, redis.exceptions.TimeoutError("timeout error")], + ], +) +def test_retry_exceptions(tolerate, exception): + """Tests that that exceptions to be tolerated for retry are tolerated and + those particular instances that should not be are not.""" + loop_count = 0 + loop_limit = 2 + + @tasking.redis_connect_retry() + def _test_function(): + nonlocal loop_count + + loop_count += 1 + if loop_count >= (loop_limit + 1): + return + raise exception + + if tolerate: + _test_function() + else: + with pytest.raises(type(exception)): + _test_function() + + +def test_loop_exit(): + """Tests that the specification of a loop exit function is respected and + used.""" + + loop_count = 0 + loop_limit = 2 + + class LoopLimitExceeded(Exception): + pass + + @tasking.redis_connect_retry(loop_exit=lambda e: loop_count >= loop_limit) + def _test_function(): + nonlocal loop_count + + loop_count += 1 + if loop_count >= (loop_limit + 1): + raise LoopLimitExceeded + raise redis.exceptions.ConnectionError + + _test_function()