diff --git a/src/zimscraperlib/executor.py b/src/zimscraperlib/executor.py new file mode 100644 index 0000000..a043f0a --- /dev/null +++ b/src/zimscraperlib/executor.py @@ -0,0 +1,173 @@ +import datetime +import queue +import threading +from collections.abc import Callable + +from zimscraperlib import logger + +_shutdown = False +# Lock that ensures that new workers are not created while the interpreter is +# shutting down. Must be held while mutating _threads_queues and _shutdown. +_global_shutdown_lock = threading.Lock() + + +def excepthook(args): # pragma: no cover + logger.error(f"UNHANDLED Exception in {args.thread.name}: {args.exc_type}") + logger.exception(args.exc_value) + + +threading.excepthook = excepthook + + +class ScraperExecutor(queue.Queue): + """Custom FIFO queue based Executor that's less generic than ThreadPoolExec one + + Providing more flexibility for the use cases we're interested about: + - halt immediately (sort of) upon exception (if requested) + - able to join() then restart later to accomodate successive steps + + See: https://github.com/python/cpython/blob/3.8/Lib/concurrent/futures/thread.py + """ + + def __init__( + self, + queue_size: int = 10, + nb_workers: int = 1, + executor_name: str = "executor", + thread_deadline_sec: int = 60, + ): + super().__init__(queue_size) + self.executor_name = executor_name + self._shutdown_lock = threading.Lock() + self.nb_workers = nb_workers + self.exceptions = [] + self.thread_deadline_sec = thread_deadline_sec + + @property + def exception(self): + """Exception raises in any thread, if any""" + try: + return self.exceptions[0:1].pop() + except IndexError: + return None + + @property + def alive(self): + """whether it should continue running""" + return not self._shutdown + + def submit(self, task: Callable, **kwargs): + """Submit a callable and its kwargs for execution in one of the workers""" + with self._shutdown_lock, _global_shutdown_lock: + if not self.alive: + raise RuntimeError("cannot submit task to dead executor") + if self.no_more: + raise RuntimeError( + "cannot submit task to a joined executor, restart it first" + ) + if _shutdown: + raise RuntimeError( # pragma: no cover + "cannot submit task after interpreter shutdown" + ) + + while True: + try: + self.put((task, kwargs), block=True, timeout=3.0) + except queue.Full: + if self.no_more: + # rarely happens except if submit and join are done in different + # threads, but we need this to escape the while loop + break # pragma: no cover + else: + break + + def start(self): + """Enable executor, starting requested amount of workers + + Workers are started always, not provisioned dynamically""" + self.drain() + self._workers: set[threading.Thread] = set() + self.no_more = False + self._shutdown = False + self.exceptions[:] = [] + + for n in range(self.nb_workers): + t = threading.Thread(target=self.worker, name=f"{self.executor_name}-{n}") + t.daemon = True + t.start() + self._workers.add(t) + + def worker(self): + while self.alive or self.no_more: + try: + func, kwargs = self.get(block=True, timeout=2.0) + except queue.Empty: + if self.no_more: + break + continue + except TypeError: # pragma: no cover + # received None from the queue. most likely shuting down + return + + raises = kwargs.pop("raises") if "raises" in kwargs.keys() else False + callback = kwargs.pop("callback") if "callback" in kwargs.keys() else None + dont_release = kwargs.pop("dont_release", False) + + try: + func(**kwargs) + except Exception as exc: + logger.error(f"Error processing {func} with {kwargs=}") + logger.exception(exc) + if raises: # to cover when raises = False + self.exceptions.append(exc) + self.shutdown() + finally: + # user will manually release the queue for this task. + # most likely in a libzim-written callback + if not dont_release: + self.task_done() + if callback: + callback.__call__() + + def drain(self): + """Empty the queue without processing the tasks (tasks will be lost)""" + while True: + try: + self.get_nowait() + except queue.Empty: + break + + def join(self): + """Await completion of workers, requesting them to stop taking new task""" + logger.debug(f"joining all threads for {self.executor_name}") + self.no_more = True + for num, t in enumerate(self._workers): + deadline = datetime.datetime.now(tz=datetime.UTC) + datetime.timedelta( + seconds=self.thread_deadline_sec + ) + logger.debug( + f"Giving {self.executor_name}-{num} {self.thread_deadline_sec}s to join" + ) + e = threading.Event() + while t.is_alive() and datetime.datetime.now(tz=datetime.UTC) < deadline: + t.join(1) + e.wait(timeout=2) + if t.is_alive(): + logger.debug( + f"Thread {self.executor_name}-{num} is not joining. Skipping…" + ) + else: + logger.debug(f"Thread {self.executor_name}-{num} joined") + logger.debug(f"all threads joined for {self.executor_name}") + + def shutdown(self, *, wait=True): + """stop the executor, either somewhat immediately or awaiting completion""" + logger.debug(f"shutting down {self.executor_name} with {wait=}") + with self._shutdown_lock: + self._shutdown = True + + # Drain all work items from the queue + if not wait: + self.drain() + if wait: + self.join() diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..67a39fb --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,286 @@ +import logging +from functools import partial +from time import sleep + +import pytest + +from zimscraperlib import logger +from zimscraperlib.executor import ScraperExecutor + +logger.setLevel(logging.DEBUG) + + +class SomethingBadError(Exception): + """Test exception""" + + pass + + +class Holder: + value = 1 + + +@pytest.mark.slow +def test_executor_ok(): + """Test basic standard case""" + + def increment(holder: Holder): + holder.value += 1 + + executor = ScraperExecutor(nb_workers=2) + executor.start() + test_value = Holder() + for _ in range(99): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value) + executor.shutdown() + assert test_value.value == 100 + + +@pytest.mark.slow +def test_executor_with_one_failure(): + """Test case where the tasks are raising one failure and we want to stop asap""" + + def increment(holder: Holder): + holder.value += 1 + if holder.value == 20: + raise SomethingBadError() + + executor = ScraperExecutor(nb_workers=2) + executor.start() + test_value = Holder() + with pytest.raises(SomethingBadError): + for _ in range(99): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value, raises=True) + assert len(executor.exceptions) == 1 + executor.shutdown() + # we have two workers, while one failed, the time it takes to raise the exception is + # significant and the other worker is still processing items and we are still + # enqueuing more items when the queue gets free, so we have many items processed + # before the code stops + assert test_value.value >= 21 + + +@pytest.mark.slow +def test_executor_with_many_failure_raised(): + """Test case where the tasks are raising many failures and we want to stop asap""" + + def increment(holder: Holder): + holder.value += 1 + if holder.value >= 20: + raise SomethingBadError() + + executor = ScraperExecutor(nb_workers=3) + executor.start() + test_value = Holder() + with pytest.raises(SomethingBadError): + for _ in range(99): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value, raises=True) + executor.shutdown() + # we have three workers, all failing once value is greater or equal to 20 + assert len(executor.exceptions) == 3 + assert test_value.value == 22 + + +@pytest.mark.slow +def test_executor_slow(): + """Test case where the tasks are slow to run""" + + def increment(holder: Holder): + holder.value += 1 + sleep(5) + + executor = ScraperExecutor(nb_workers=2) + executor.start() + test_value = Holder() + for _ in range(19): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value) + executor.shutdown() + assert test_value.value == 20 + + +@pytest.mark.slow +def test_executor_stop_immediately(): + """Test case where we ask the executor to stop without waiting""" + + def increment(holder: Holder): + holder.value += 1 + sleep(1) + + executor = ScraperExecutor(nb_workers=2) + executor.start() + test_value = Holder() + for _ in range(5): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value) + executor.shutdown(wait=False) + # we stopped asap, but 1 task might have been done in every worker + assert test_value.value <= 3 + + +@pytest.mark.slow +def test_executor_stop_once_done(): + """Test case where we ask the executor to stop once all tasks are done""" + + def increment(holder: Holder): + holder.value += 1 + sleep(1) + + executor = ScraperExecutor(nb_workers=2) + executor.start() + test_value = Holder() + for _ in range(4): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value) + executor.shutdown() + assert test_value.value == 5 # we waited for queue to be processed + + +@pytest.mark.slow +def test_executor_stop_thread_not_joining(): + """Test case where threads take longer to join than the thread_deadline_sec""" + + def increment(holder: Holder): + holder.value += 1 + sleep(5) + + executor = ScraperExecutor(nb_workers=2, thread_deadline_sec=1) + executor.start() + test_value = Holder() + for _ in range(4): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value) + executor.shutdown() + assert test_value.value >= 3 # threads finished their job before we stopped them + + +@pytest.mark.slow +def test_executor_already_shutdown(): + """Test case where we submit a task to an executor who is already shutdown""" + + def increment(holder: Holder): + holder.value += 1 + + executor = ScraperExecutor(nb_workers=2) + executor.start() + test_value = Holder() + for _ in range(2): + executor.submit(increment, holder=test_value) + executor.shutdown() + assert test_value.value == 3 + with pytest.raises(RuntimeError): + executor.submit(increment, holder=test_value) + + +@pytest.mark.slow +def test_executor_already_joined(): + """Test case where we submit a task to an executor who is already joined""" + + def increment(holder: Holder): + holder.value += 1 + + executor = ScraperExecutor(nb_workers=2, queue_size=2) + executor.start() + test_value = Holder() + for _ in range(2): + executor.submit(increment, holder=test_value) + executor.join() + assert test_value.value == 3 + with pytest.raises(RuntimeError): + executor.submit(increment, holder=test_value) + + +@pytest.mark.slow +def test_executor_join_and_restart(): + """Test case where we join an executor, and then restart it and submit tasks""" + + def increment(holder: Holder): + holder.value += 1 + + executor = ScraperExecutor(nb_workers=2, queue_size=2) + executor.start() + test_value = Holder() + for _ in range(2): + executor.submit(increment, holder=test_value) + executor.join() + assert test_value.value == 3 + executor.start() + for _ in range(5): + executor.submit(increment, holder=test_value) + executor.join() + assert test_value.value == 8 + + +@pytest.mark.slow +def test_executor_callback_and_custom_release(): + """Test custom callback and custom release of the queue""" + + def increment(holder: Holder): + holder.value += 1 + + def callback(executor: ScraperExecutor, holder: Holder): + holder.value += 1 + executor.task_done() + + executor = ScraperExecutor(nb_workers=2, queue_size=2) + executor.start() + test_value = Holder() + for _ in range(2): + executor.submit( + increment, + holder=test_value, + callback=partial(callback, executor=executor, holder=test_value), + dont_release=True, + ) + executor.join() + assert test_value.value == 5 + + +@pytest.mark.slow +def test_executor_with_many_failure_not_raised(): + """Test case where we do not mind about exceptions during async processing""" + + def increment(holder: Holder): + holder.value += 1 + if holder.value >= 20: + raise SomethingBadError() + + executor = ScraperExecutor(nb_workers=3) + executor.start() + test_value = Holder() + for _ in range(99): + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value) + executor.shutdown() + assert len(executor.exceptions) == 0 + assert test_value.value == 100 + + +@pytest.mark.slow +def test_executor_slow_to_submit(): + """Check that executor does not care if tasks are submitted very slowly""" + + def increment(holder: Holder): + holder.value += 1 + + executor = ScraperExecutor(nb_workers=2) + executor.start() + test_value = Holder() + for _ in range(2): + sleep(5) + if exception := executor.exception: + raise exception + executor.submit(increment, holder=test_value) + executor.shutdown() + assert test_value.value == 3