Skip to content

Commit

Permalink
wip: indexer recover from connection errors
Browse files Browse the repository at this point in the history
  • Loading branch information
aaxelb committed Dec 3, 2024
1 parent 8455a7e commit 6fa1a25
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions share/search/daemon.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import contextlib
import collections
from collections.abc import Callable
import dataclasses
import logging
import queue
import random
import threading
import time

import amqp.exceptions
from django.conf import settings
import kombu
from kombu.mixins import ConsumerMixin
import sentry_sdk

Expand All @@ -27,6 +30,7 @@
MINIMUM_BACKOFF_FACTOR = 1.6 # unitless ratio
MAXIMUM_BACKOFF_FACTOR = 2.0 # unitless ratio
MAXIMUM_BACKOFF_TIMEOUT = 60 # seconds
CONNECTION_HEARTBEAT = 20 # seconds


class TooFastSlowDown(Exception):
Expand All @@ -35,7 +39,10 @@ class TooFastSlowDown(Exception):

class IndexerDaemonControl:
def __init__(self, celery_app, *, daemonthread_context=None, stop_event=None):
self.celery_app = celery_app
self.kombu_connection = kombu.Connection(
celery_app.conf.broker_url, # use celery_app.conf for consistent config
heartbeat=CONNECTION_HEARTBEAT,
)
self.daemonthread_context = daemonthread_context
self._daemonthreads = []
# shared stop_event for all threads below
Expand All @@ -50,10 +57,16 @@ def start_daemonthreads_for_strategy(self, index_strategy):
)
# spin up daemonthreads, ready for messages
self._daemonthreads.extend(_daemon.start())
# assign a thread to pass messages to this daemon
threading.Thread(
target=CeleryMessageConsumer(self.celery_app, _daemon).run,
).start()
_consumer = KombuMessageConsumer(
kombu_connection=self.kombu_connection.clone(),
stop_event=self.stop_event,
index_strategy=index_strategy,
message_callback=_daemon.on_message,
)
# give the daemon direct access to the connection, for acking purposes
_daemon.ack_callback = _consumer.ensure_ack
# assign a thread for the consumer to receive and enqueue messages to this daemon
threading.Thread(target=_consumer.run).start()
return _daemon

def start_all_daemonthreads(self):
Expand All @@ -67,18 +80,16 @@ def stop_daemonthreads(self, *, wait=False):
_thread.join()


class CeleryMessageConsumer(ConsumerMixin):
class KombuMessageConsumer(ConsumerMixin):
PREFETCH_COUNT = 7500

# (from ConsumerMixin)
# should_stop: bool
should_stop: bool # (from ConsumerMixin)

def __init__(self, celery_app, indexer_daemon):
self.connection = celery_app.pool.acquire(block=True)
self.celery_app = celery_app
self.__stop_event = indexer_daemon.stop_event
self.__message_callback = indexer_daemon.on_message
self.__index_strategy = indexer_daemon.index_strategy
def __init__(self, *, kombu_connection, stop_event, message_callback, index_strategy):
self.connection = kombu_connection
self.__stop_event = stop_event
self.__message_callback = message_callback
self.__index_strategy = index_strategy

# overrides ConsumerMixin.run
def run(self):
Expand Down Expand Up @@ -112,9 +123,31 @@ def get_consumers(self, Consumer, channel):
def __repr__(self):
return '<{}({})>'.format(self.__class__.__name__, self.__index_strategy.name)

def consume(self, *args, **kwargs):
# wrap `consume` in `kombu.Connection.ensure`, following guidance from
# https://docs.celeryq.dev/projects/kombu/en/stable/userguide/failover.html#consumer
consume = self.connection.ensure(self.connection, super().consume)
return consume(*args, **kwargs)

def ensure_ack(self, daemon_message: messages.DaemonMessage):
# if the connection the message came thru is no longer usable,
# use `kombu.Connection.autoretry` and `kombu.Channel.basic_ack`
# to ensure the ack goes thru
try:
daemon_message.ack()
except (ConnectionError, amqp.exceptions.ConnectionError):
@self.connection.autoretry
def _do_ack(*, channel):
try:
channel.basic_ack(daemon_message.kombu_message.delivery_tag)
finally:
channel.close()
_do_ack()


class IndexerDaemon:
MAX_LOCAL_QUEUE_SIZE = 5000
ack_callback: Callable[[messages.DaemonMessage], None] | None = None

def __init__(self, index_strategy, *, stop_event=None, daemonthread_context=None):
self.stop_event = (
Expand Down Expand Up @@ -154,6 +187,7 @@ def start_typed_loop_and_queue(self, message_type) -> threading.Thread:
local_message_queue=_queue_from_rabbit_to_daemon,
log_prefix=f'{repr(self)} MessageHandlingLoop: ',
daemonthread_context=self.__daemonthread_context,
ack_callback=self.ack_callback,
)
return _handling_loop.start_thread()

Expand Down Expand Up @@ -186,7 +220,8 @@ class MessageHandlingLoop:
stop_event: threading.Event
local_message_queue: queue.Queue
log_prefix: str
daemonthread_context: contextlib.AbstractContextManager
daemonthread_context: Callable[[], contextlib.AbstractContextManager]
ack_callback: Callable[[messages.DaemonMessage], None] | None = None
_leftover_daemon_messages_by_target_id = None

def __post_init__(self):
Expand Down Expand Up @@ -248,6 +283,7 @@ def _get_daemon_messages(self):
return daemon_messages_by_target_id

def _handle_some_messages(self):
assert self.ack_callback is not None
start_time = time.time()
doc_count, error_count = 0, 0
daemon_messages_by_target_id = self._get_daemon_messages()
Expand All @@ -270,7 +306,7 @@ def _handle_some_messages(self):
sentry_sdk.capture_message('error handling message', extras={'message_response': message_response})
target_id = message_response.index_message.target_id
for daemon_message in daemon_messages_by_target_id.pop(target_id, ()):
daemon_message.ack() # finally set it free
self.ack_callback(daemon_message)
if daemon_messages_by_target_id: # should be empty by now
logger.error('%sUnhandled messages?? %s', self.log_prefix, len(daemon_messages_by_target_id))
sentry_sdk.capture_message(
Expand Down

0 comments on commit 6fa1a25

Please sign in to comment.