From 1272fc16615454fff1d51a97b9dd80da85f9bbe5 Mon Sep 17 00:00:00 2001 From: Garrett Michael Flynn Date: Thu, 25 Apr 2024 14:44:16 -0700 Subject: [PATCH] 100% coverage --- .coveragerc | 2 + src/tqdm_publisher/_handler.py | 35 ++++++++--- src/tqdm_publisher/testing.py | 2 +- tests/test_handler.py | 68 ++++++++++++++++++++++ tests/{test_basic.py => test_publisher.py} | 7 ++- tests/test_subscriber.py | 38 ++++++++++++ 6 files changed, 139 insertions(+), 13 deletions(-) create mode 100644 .coveragerc create mode 100644 tests/test_handler.py rename tests/{test_basic.py => test_publisher.py} (91%) create mode 100644 tests/test_subscriber.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..89535d0 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,2 @@ +[run] +omit = */_demos/* \ No newline at end of file diff --git a/src/tqdm_publisher/_handler.py b/src/tqdm_publisher/_handler.py index 9cabad4..cc90939 100644 --- a/src/tqdm_publisher/_handler.py +++ b/src/tqdm_publisher/_handler.py @@ -5,11 +5,15 @@ class TQDMProgressHandler: - def __init__(self): - self.listeners: List[queue.Queue] = [] + def __init__( + self, + queue_cls: queue.Queue = queue.Queue # Can provide different queue implementations (e.g. asyncio.Queue) + ): + self._queue = queue_cls + self.listeners: List[self._queue] = [] def listen(self) -> queue.Queue: - new_queue = queue.Queue(maxsize=25) + new_queue = self._queue(maxsize=0) self.listeners.append(new_queue) return new_queue @@ -41,9 +45,22 @@ def announce(self, message: Dict[Any, Any]): """ number_of_listeners = len(self.listeners) listener_indices = range(number_of_listeners) - listener_indices_from_newest_to_oldest = reversed(listener_indices) - for listener_index in listener_indices_from_newest_to_oldest: - if not self.listeners[listener_index].full(): - self.listeners[listener_index].put_nowait(item=message) - else: # When full, remove the newest listener in the stack - del self.listeners[listener_index] + for listener_index in listener_indices: + self.listeners[listener_index].put_nowait(item=message) + + + def unsubscribe(self, listener: queue.Queue) -> bool: + """ + Unsubscribe a listener from the handler. + + Args: + listener: The listener to unsubscribe. + + Returns: + bool: True if the listener was successfully unsubscribed, False otherwise. + """ + try: + self.listeners.remove(listener) + return True + except ValueError: + return False diff --git a/src/tqdm_publisher/testing.py b/src/tqdm_publisher/testing.py index ea932fb..1b499e2 100644 --- a/src/tqdm_publisher/testing.py +++ b/src/tqdm_publisher/testing.py @@ -6,7 +6,7 @@ async def sleep_func(sleep_duration: float = 1) -> float: await asyncio.sleep(delay=sleep_duration) -def create_tasks(): +def create_tasks(number_of_tasks: int = 10**5): number_of_tasks = 10**5 sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)] diff --git a/tests/test_handler.py b/tests/test_handler.py new file mode 100644 index 0000000..7a41828 --- /dev/null +++ b/tests/test_handler.py @@ -0,0 +1,68 @@ +import asyncio + +import pytest + +from uuid import UUID +from tqdm_publisher import TQDMProgressHandler +from tqdm_publisher.testing import create_tasks + +N_SUBSCRIBERS = 3 +N_TASKS_PER_SUBSCRIBER = 3 + +# Test concurrent callback execution +@pytest.mark.asyncio +async def test_subscription_and_callback_execution(): + + handler = TQDMProgressHandler(asyncio.Queue) + + n_callback_executions = dict() + + def test_callback(data): + + nonlocal n_callback_executions + + assert "progress_bar_id" in data + identifier = data["progress_bar_id"] + assert str(UUID(identifier, version=4)) == identifier + + if identifier not in n_callback_executions: + n_callback_executions[identifier] = 0 + + n_callback_executions[identifier] += 1 + + assert "format_dict" in data + format = data["format_dict"] + assert "n" in format and "total" in format + + queue = handler.listen() + + for _ in range(N_SUBSCRIBERS): + + subscriber = handler.create_progress_subscriber( + asyncio.as_completed(create_tasks(number_of_tasks=N_TASKS_PER_SUBSCRIBER)), + total=N_TASKS_PER_SUBSCRIBER + ) + + for f in subscriber: + await f + + while not queue.empty(): + message = await queue.get() + test_callback(message) + queue.task_done() + + assert len(n_callback_executions) == N_SUBSCRIBERS + + for _, n_executions in n_callback_executions.items(): + assert n_executions > 1 + + +def test_unsubscription(): + handler = TQDMProgressHandler(asyncio.Queue) + queue = handler.listen() + assert len(handler.listeners) == 1 + result = handler.unsubscribe(queue) + assert result == True + assert len(handler.listeners) == 0 + result = handler.unsubscribe(queue) + assert result == False diff --git a/tests/test_basic.py b/tests/test_publisher.py similarity index 91% rename from tests/test_basic.py rename to tests/test_publisher.py index a511080..09dfe32 100644 --- a/tests/test_basic.py +++ b/tests/test_publisher.py @@ -5,6 +5,7 @@ from tqdm_publisher import TQDMProgressPublisher from tqdm_publisher.testing import create_tasks +N_SUBSCRIPTIONS = 10 def test_initialization(): publisher = TQDMProgressPublisher() @@ -29,8 +30,8 @@ def test_callback(identifier, data): tasks = create_tasks() publisher = TQDMProgressPublisher(asyncio.as_completed(tasks), total=len(tasks)) - n_subscriptions = 10 - for i in range(n_subscriptions): + N_SUBSCRIPTIONS = 10 + for i in range(N_SUBSCRIPTIONS): callback_id = publisher.subscribe( lambda data, i=i: test_callback(i, data) ) # Creates a new scoped i value for subscription @@ -40,7 +41,7 @@ def test_callback(identifier, data): for f in publisher: await f - assert len(n_callback_executions) == n_subscriptions + assert len(n_callback_executions) == N_SUBSCRIPTIONS for identifier, n_executions in n_callback_executions.items(): assert n_executions > 1 diff --git a/tests/test_subscriber.py b/tests/test_subscriber.py new file mode 100644 index 0000000..09a6dc9 --- /dev/null +++ b/tests/test_subscriber.py @@ -0,0 +1,38 @@ +import asyncio + +import pytest + +from uuid import UUID +from tqdm_publisher import TQDMProgressSubscriber +from tqdm_publisher.testing import create_tasks + +def test_initialization(): + subscriber = TQDMProgressSubscriber(iterable=[], on_progress_update=lambda x: x) + assert len(subscriber.callbacks) == 1 + + +# Test concurrent callback execution +@pytest.mark.asyncio +async def test_subscription_and_callback_execution(): + n_callback_executions = 0 + + def test_callback(data): + nonlocal n_callback_executions + n_callback_executions += 1 + + assert "progress_bar_id" in data + identifier = data["progress_bar_id"] + assert str(UUID(identifier, version=4)) == identifier + + assert "format_dict" in data + format = data["format_dict"] + assert "n" in format and "total" in format + + tasks = create_tasks() + subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks)) + + # Simulate an update to trigger the callback + for f in subscriber: + await f + + assert n_callback_executions > 1 \ No newline at end of file