From 8ab85048733c2d91e95bc8a7e499428af8ae5e1d Mon Sep 17 00:00:00 2001 From: Garrett Michael Flynn Date: Mon, 20 May 2024 09:25:43 -0700 Subject: [PATCH] Allow manual updates (#65) * Allow manual updates * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_subscriber.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update arguments and changelog * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update CHANGELOG.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cody Baker <51133164+CodyCBakerPhD@users.noreply.github.com> --- CHANGELOG.md | 3 +++ src/tqdm_publisher/_handler.py | 4 +-- src/tqdm_publisher/_publisher.py | 4 +-- src/tqdm_publisher/_subscriber.py | 7 ++--- src/tqdm_publisher/testing.py | 1 - tests/test_subscriber.py | 43 +++++++++++++++++++++++++++++-- 6 files changed, 52 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d9fa00..651a446 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,6 @@ +## Unreleased +* Update the arguments for all `tqdm_publisher` classes to mirror the `tqdm` constructor, adding additional parameters as required keyword arguments. [PR #65](https://github.com/catalystneuro/tqdm_publisher/pull/65) + ## v0.1.0 (April 26th, 2024) * The first alpha release of `tqdm_publisher`. diff --git a/src/tqdm_publisher/_handler.py b/src/tqdm_publisher/_handler.py index d39a29c..d8bee4e 100644 --- a/src/tqdm_publisher/_handler.py +++ b/src/tqdm_publisher/_handler.py @@ -17,7 +17,7 @@ def listen(self) -> queue.Queue: return new_queue def create_progress_subscriber( - self, iterable: Iterable[Any], additional_metadata: dict = dict(), **tqdm_kwargs + self, *tqdm_args, additional_metadata: dict = dict(), **tqdm_kwargs ) -> TQDMProgressSubscriber: def on_progress_update(progress_update: dict): @@ -31,7 +31,7 @@ def on_progress_update(progress_update: dict): """ self.announce(message=dict(**progress_update, **additional_metadata)) - return TQDMProgressSubscriber(iterable=iterable, on_progress_update=on_progress_update, **tqdm_kwargs) + return TQDMProgressSubscriber(*tqdm_args, on_progress_update=on_progress_update, **tqdm_kwargs) def announce(self, message: Dict[Any, Any]): """ diff --git a/src/tqdm_publisher/_publisher.py b/src/tqdm_publisher/_publisher.py index 7ff34c2..e9fdb4f 100644 --- a/src/tqdm_publisher/_publisher.py +++ b/src/tqdm_publisher/_publisher.py @@ -5,8 +5,8 @@ class TQDMProgressPublisher(base_tqdm): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, *tqdm_args, **tqdm_kwargs): + super().__init__(*tqdm_args, **tqdm_kwargs) self.progress_bar_id = str(uuid4()) self.callbacks = dict() diff --git a/src/tqdm_publisher/_subscriber.py b/src/tqdm_publisher/_subscriber.py index e5715b3..1b3eb53 100644 --- a/src/tqdm_publisher/_subscriber.py +++ b/src/tqdm_publisher/_subscriber.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, Iterable +from typing import Any, Dict, Iterable, Optional from ._publisher import TQDMProgressPublisher class TQDMProgressSubscriber(TQDMProgressPublisher): - def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs): - super().__init__(iterable=iterable, **tqdm_kwargs) + def __init__(self, *tqdm_args, on_progress_update: callable, **tqdm_kwargs): + + super().__init__(*tqdm_args, **tqdm_kwargs) def run_on_progress_update(format_dict: Dict[str, Any]): """ diff --git a/src/tqdm_publisher/testing.py b/src/tqdm_publisher/testing.py index 1b499e2..4baae0b 100644 --- a/src/tqdm_publisher/testing.py +++ b/src/tqdm_publisher/testing.py @@ -7,7 +7,6 @@ async def sleep_func(sleep_duration: float = 1) -> float: def create_tasks(number_of_tasks: int = 10**5): - number_of_tasks = 10**5 sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)] tasks = list() diff --git a/tests/test_subscriber.py b/tests/test_subscriber.py index 7ac3a0d..f40bb19 100644 --- a/tests/test_subscriber.py +++ b/tests/test_subscriber.py @@ -30,10 +30,49 @@ def test_callback(data): assert "n" in format and "total" in format tasks = create_tasks() - subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks)) + total = len(tasks) + subscriber = TQDMProgressSubscriber( + asyncio.as_completed(tasks), + total=total, + mininterval=0, + on_progress_update=test_callback, + ) # Simulate an update to trigger the callback for f in subscriber: await f - assert n_callback_executions > 1 + assert n_callback_executions == total + 1 + + +# Test manual update management +@pytest.mark.asyncio +async def test_manual_updates(): + n_callback_executions = 0 + + def test_callback(data): + nonlocal n_callback_executions + n_callback_executions += 1 + + print(data) + + assert "progress_bar_id" in data + identifier = data["progress_bar_id"] + assert str(UUID(identifier, version=4)) == identifier + + assert "format_dict" in data + format = data["format_dict"] + assert "n" in format and "total" in format + + tasks = create_tasks(10) + total = len(tasks) + subscriber = TQDMProgressSubscriber(on_progress_update=test_callback, total=total, mininterval=0) + + # Simulate updatse to trigger the callback + for task in asyncio.as_completed(tasks): + await task + subscriber.update(1) # Update by 1 iteration + + assert n_callback_executions == total + 1 + + subscriber.close()