Skip to content

Commit

Permalink
100% coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettmflynn committed Apr 25, 2024
1 parent 9364169 commit 1272fc1
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[run]
omit = */_demos/*
35 changes: 26 additions & 9 deletions src/tqdm_publisher/_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@


class TQDMProgressHandler:
def __init__(self):
self.listeners: List[queue.Queue] = []
def __init__(
self,
queue_cls: queue.Queue = queue.Queue # Can provide different queue implementations (e.g. asyncio.Queue)
):
self._queue = queue_cls
self.listeners: List[self._queue] = []

def listen(self) -> queue.Queue:
new_queue = queue.Queue(maxsize=25)
new_queue = self._queue(maxsize=0)
self.listeners.append(new_queue)
return new_queue

Expand Down Expand Up @@ -41,9 +45,22 @@ def announce(self, message: Dict[Any, Any]):
"""
number_of_listeners = len(self.listeners)
listener_indices = range(number_of_listeners)
listener_indices_from_newest_to_oldest = reversed(listener_indices)
for listener_index in listener_indices_from_newest_to_oldest:
if not self.listeners[listener_index].full():
self.listeners[listener_index].put_nowait(item=message)
else: # When full, remove the newest listener in the stack
del self.listeners[listener_index]
for listener_index in listener_indices:
self.listeners[listener_index].put_nowait(item=message)


def unsubscribe(self, listener: queue.Queue) -> bool:
"""
Unsubscribe a listener from the handler.
Args:
listener: The listener to unsubscribe.
Returns:
bool: True if the listener was successfully unsubscribed, False otherwise.
"""
try:
self.listeners.remove(listener)
return True
except ValueError:
return False
2 changes: 1 addition & 1 deletion src/tqdm_publisher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ async def sleep_func(sleep_duration: float = 1) -> float:
await asyncio.sleep(delay=sleep_duration)


def create_tasks():
def create_tasks(number_of_tasks: int = 10**5):
number_of_tasks = 10**5
sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)]

Expand Down
68 changes: 68 additions & 0 deletions tests/test_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import asyncio

import pytest

from uuid import UUID
from tqdm_publisher import TQDMProgressHandler
from tqdm_publisher.testing import create_tasks

N_SUBSCRIBERS = 3
N_TASKS_PER_SUBSCRIBER = 3

# Test concurrent callback execution
@pytest.mark.asyncio
async def test_subscription_and_callback_execution():

handler = TQDMProgressHandler(asyncio.Queue)

n_callback_executions = dict()

def test_callback(data):

nonlocal n_callback_executions

assert "progress_bar_id" in data
identifier = data["progress_bar_id"]
assert str(UUID(identifier, version=4)) == identifier

if identifier not in n_callback_executions:
n_callback_executions[identifier] = 0

n_callback_executions[identifier] += 1

assert "format_dict" in data
format = data["format_dict"]
assert "n" in format and "total" in format

queue = handler.listen()

for _ in range(N_SUBSCRIBERS):

subscriber = handler.create_progress_subscriber(
asyncio.as_completed(create_tasks(number_of_tasks=N_TASKS_PER_SUBSCRIBER)),
total=N_TASKS_PER_SUBSCRIBER
)

for f in subscriber:
await f

while not queue.empty():
message = await queue.get()
test_callback(message)
queue.task_done()

assert len(n_callback_executions) == N_SUBSCRIBERS

for _, n_executions in n_callback_executions.items():
assert n_executions > 1


def test_unsubscription():
handler = TQDMProgressHandler(asyncio.Queue)
queue = handler.listen()
assert len(handler.listeners) == 1
result = handler.unsubscribe(queue)
assert result == True
assert len(handler.listeners) == 0
result = handler.unsubscribe(queue)
assert result == False
7 changes: 4 additions & 3 deletions tests/test_basic.py → tests/test_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tqdm_publisher import TQDMProgressPublisher
from tqdm_publisher.testing import create_tasks

N_SUBSCRIPTIONS = 10

def test_initialization():
publisher = TQDMProgressPublisher()
Expand All @@ -29,8 +30,8 @@ def test_callback(identifier, data):
tasks = create_tasks()
publisher = TQDMProgressPublisher(asyncio.as_completed(tasks), total=len(tasks))

n_subscriptions = 10
for i in range(n_subscriptions):
N_SUBSCRIPTIONS = 10
for i in range(N_SUBSCRIPTIONS):
callback_id = publisher.subscribe(
lambda data, i=i: test_callback(i, data)
) # Creates a new scoped i value for subscription
Expand All @@ -40,7 +41,7 @@ def test_callback(identifier, data):
for f in publisher:
await f

assert len(n_callback_executions) == n_subscriptions
assert len(n_callback_executions) == N_SUBSCRIPTIONS

for identifier, n_executions in n_callback_executions.items():
assert n_executions > 1
Expand Down
38 changes: 38 additions & 0 deletions tests/test_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import asyncio

import pytest

from uuid import UUID
from tqdm_publisher import TQDMProgressSubscriber
from tqdm_publisher.testing import create_tasks

def test_initialization():
subscriber = TQDMProgressSubscriber(iterable=[], on_progress_update=lambda x: x)
assert len(subscriber.callbacks) == 1


# Test concurrent callback execution
@pytest.mark.asyncio
async def test_subscription_and_callback_execution():
n_callback_executions = 0

def test_callback(data):
nonlocal n_callback_executions
n_callback_executions += 1

assert "progress_bar_id" in data
identifier = data["progress_bar_id"]
assert str(UUID(identifier, version=4)) == identifier

assert "format_dict" in data
format = data["format_dict"]
assert "n" in format and "total" in format

tasks = create_tasks()
subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks))

# Simulate an update to trigger the callback
for f in subscriber:
await f

assert n_callback_executions > 1

0 comments on commit 1272fc1

Please sign in to comment.