From 6fa1a2561f5763d33959b0734791da7b03d49814 Mon Sep 17 00:00:00 2001 From: abram axel booth Date: Mon, 2 Dec 2024 16:58:21 -0500 Subject: [PATCH] wip: indexer recover from connection errors --- share/search/daemon.py | 68 ++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 16 deletions(-) diff --git a/share/search/daemon.py b/share/search/daemon.py index 90aedb855..a3b3ba09a 100644 --- a/share/search/daemon.py +++ b/share/search/daemon.py @@ -1,5 +1,6 @@ import contextlib import collections +from collections.abc import Callable import dataclasses import logging import queue @@ -7,7 +8,9 @@ import threading import time +import amqp.exceptions from django.conf import settings +import kombu from kombu.mixins import ConsumerMixin import sentry_sdk @@ -27,6 +30,7 @@ MINIMUM_BACKOFF_FACTOR = 1.6 # unitless ratio MAXIMUM_BACKOFF_FACTOR = 2.0 # unitless ratio MAXIMUM_BACKOFF_TIMEOUT = 60 # seconds +CONNECTION_HEARTBEAT = 20 # seconds class TooFastSlowDown(Exception): @@ -35,7 +39,10 @@ class TooFastSlowDown(Exception): class IndexerDaemonControl: def __init__(self, celery_app, *, daemonthread_context=None, stop_event=None): - self.celery_app = celery_app + self.kombu_connection = kombu.Connection( + celery_app.conf.broker_url, # use celery_app.conf for consistent config + heartbeat=CONNECTION_HEARTBEAT, + ) self.daemonthread_context = daemonthread_context self._daemonthreads = [] # shared stop_event for all threads below @@ -50,10 +57,16 @@ def start_daemonthreads_for_strategy(self, index_strategy): ) # spin up daemonthreads, ready for messages self._daemonthreads.extend(_daemon.start()) - # assign a thread to pass messages to this daemon - threading.Thread( - target=CeleryMessageConsumer(self.celery_app, _daemon).run, - ).start() + _consumer = KombuMessageConsumer( + kombu_connection=self.kombu_connection.clone(), + stop_event=self.stop_event, + index_strategy=index_strategy, + message_callback=_daemon.on_message, + ) + # give the daemon direct access to the connection, for acking purposes + _daemon.ack_callback = _consumer.ensure_ack + # assign a thread for the consumer to receive and enqueue messages to this daemon + threading.Thread(target=_consumer.run).start() return _daemon def start_all_daemonthreads(self): @@ -67,18 +80,16 @@ def stop_daemonthreads(self, *, wait=False): _thread.join() -class CeleryMessageConsumer(ConsumerMixin): +class KombuMessageConsumer(ConsumerMixin): PREFETCH_COUNT = 7500 - # (from ConsumerMixin) - # should_stop: bool + should_stop: bool # (from ConsumerMixin) - def __init__(self, celery_app, indexer_daemon): - self.connection = celery_app.pool.acquire(block=True) - self.celery_app = celery_app - self.__stop_event = indexer_daemon.stop_event - self.__message_callback = indexer_daemon.on_message - self.__index_strategy = indexer_daemon.index_strategy + def __init__(self, *, kombu_connection, stop_event, message_callback, index_strategy): + self.connection = kombu_connection + self.__stop_event = stop_event + self.__message_callback = message_callback + self.__index_strategy = index_strategy # overrides ConsumerMixin.run def run(self): @@ -112,9 +123,31 @@ def get_consumers(self, Consumer, channel): def __repr__(self): return '<{}({})>'.format(self.__class__.__name__, self.__index_strategy.name) + def consume(self, *args, **kwargs): + # wrap `consume` in `kombu.Connection.ensure`, following guidance from + # https://docs.celeryq.dev/projects/kombu/en/stable/userguide/failover.html#consumer + consume = self.connection.ensure(self.connection, super().consume) + return consume(*args, **kwargs) + + def ensure_ack(self, daemon_message: messages.DaemonMessage): + # if the connection the message came thru is no longer usable, + # use `kombu.Connection.autoretry` and `kombu.Channel.basic_ack` + # to ensure the ack goes thru + try: + daemon_message.ack() + except (ConnectionError, amqp.exceptions.ConnectionError): + @self.connection.autoretry + def _do_ack(*, channel): + try: + channel.basic_ack(daemon_message.kombu_message.delivery_tag) + finally: + channel.close() + _do_ack() + class IndexerDaemon: MAX_LOCAL_QUEUE_SIZE = 5000 + ack_callback: Callable[[messages.DaemonMessage], None] | None = None def __init__(self, index_strategy, *, stop_event=None, daemonthread_context=None): self.stop_event = ( @@ -154,6 +187,7 @@ def start_typed_loop_and_queue(self, message_type) -> threading.Thread: local_message_queue=_queue_from_rabbit_to_daemon, log_prefix=f'{repr(self)} MessageHandlingLoop: ', daemonthread_context=self.__daemonthread_context, + ack_callback=self.ack_callback, ) return _handling_loop.start_thread() @@ -186,7 +220,8 @@ class MessageHandlingLoop: stop_event: threading.Event local_message_queue: queue.Queue log_prefix: str - daemonthread_context: contextlib.AbstractContextManager + daemonthread_context: Callable[[], contextlib.AbstractContextManager] + ack_callback: Callable[[messages.DaemonMessage], None] | None = None _leftover_daemon_messages_by_target_id = None def __post_init__(self): @@ -248,6 +283,7 @@ def _get_daemon_messages(self): return daemon_messages_by_target_id def _handle_some_messages(self): + assert self.ack_callback is not None start_time = time.time() doc_count, error_count = 0, 0 daemon_messages_by_target_id = self._get_daemon_messages() @@ -270,7 +306,7 @@ def _handle_some_messages(self): sentry_sdk.capture_message('error handling message', extras={'message_response': message_response}) target_id = message_response.index_message.target_id for daemon_message in daemon_messages_by_target_id.pop(target_id, ()): - daemon_message.ack() # finally set it free + self.ack_callback(daemon_message) if daemon_messages_by_target_id: # should be empty by now logger.error('%sUnhandled messages?? %s', self.log_prefix, len(daemon_messages_by_target_id)) sentry_sdk.capture_message(