From fcfde43b503414ed718d274bbc7d3ce78cbbcd9f Mon Sep 17 00:00:00 2001 From: Garrett Michael Flynn Date: Wed, 15 May 2024 11:42:48 -0700 Subject: [PATCH] Allow manual updates --- src/tqdm_publisher/_subscriber.py | 4 ++-- src/tqdm_publisher/testing.py | 1 - tests/test_subscriber.py | 35 ++++++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/src/tqdm_publisher/_subscriber.py b/src/tqdm_publisher/_subscriber.py index e5715b3..69f10a7 100644 --- a/src/tqdm_publisher/_subscriber.py +++ b/src/tqdm_publisher/_subscriber.py @@ -1,10 +1,10 @@ -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, Optional from ._publisher import TQDMProgressPublisher class TQDMProgressSubscriber(TQDMProgressPublisher): - def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs): + def __init__(self, on_progress_update: callable, iterable: Optional[Iterable[Any]]=None, **tqdm_kwargs): super().__init__(iterable=iterable, **tqdm_kwargs) def run_on_progress_update(format_dict: Dict[str, Any]): diff --git a/src/tqdm_publisher/testing.py b/src/tqdm_publisher/testing.py index 1b499e2..4baae0b 100644 --- a/src/tqdm_publisher/testing.py +++ b/src/tqdm_publisher/testing.py @@ -7,7 +7,6 @@ async def sleep_func(sleep_duration: float = 1) -> float: def create_tasks(number_of_tasks: int = 10**5): - number_of_tasks = 10**5 sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)] tasks = list() diff --git a/tests/test_subscriber.py b/tests/test_subscriber.py index 7ac3a0d..e187ce9 100644 --- a/tests/test_subscriber.py +++ b/tests/test_subscriber.py @@ -30,10 +30,43 @@ def test_callback(data): assert "n" in format and "total" in format tasks = create_tasks() - subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks)) + subscriber = TQDMProgressSubscriber(test_callback, asyncio.as_completed(tasks), total=len(tasks)) # Simulate an update to trigger the callback for f in subscriber: await f assert n_callback_executions > 1 + + +# Test manual update management +@pytest.mark.asyncio +async def test_manual_updates(): + n_callback_executions = 0 + + def test_callback(data): + nonlocal n_callback_executions + n_callback_executions += 1 + + print(data) + + assert "progress_bar_id" in data + identifier = data["progress_bar_id"] + assert str(UUID(identifier, version=4)) == identifier + + assert "format_dict" in data + format = data["format_dict"] + assert "n" in format and "total" in format + + tasks = create_tasks(10) + total = len(tasks) + subscriber = TQDMProgressSubscriber(test_callback, total=total) + + # Simulate updatse to trigger the callback + for task in asyncio.as_completed(tasks): + await task + subscriber.update(1) # Update by 1 iteration + + assert n_callback_executions > 1 + + subscriber.close() \ No newline at end of file