Skip to content

Commit

Permalink
Allow manual updates (#65)
Browse files Browse the repository at this point in the history
* Allow manual updates

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_subscriber.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update arguments and changelog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update CHANGELOG.md

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Cody Baker <[email protected]>
  • Loading branch information
3 people authored May 20, 2024
1 parent 0e61338 commit 8ab8504
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## Unreleased
* Update the arguments for all `tqdm_publisher` classes to mirror the `tqdm` constructor, adding additional parameters as required keyword arguments. [PR #65](https://github.com/catalystneuro/tqdm_publisher/pull/65)

## v0.1.0 (April 26th, 2024)
* The first alpha release of `tqdm_publisher`.

Expand Down
4 changes: 2 additions & 2 deletions src/tqdm_publisher/_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def listen(self) -> queue.Queue:
return new_queue

def create_progress_subscriber(
self, iterable: Iterable[Any], additional_metadata: dict = dict(), **tqdm_kwargs
self, *tqdm_args, additional_metadata: dict = dict(), **tqdm_kwargs
) -> TQDMProgressSubscriber:

def on_progress_update(progress_update: dict):
Expand All @@ -31,7 +31,7 @@ def on_progress_update(progress_update: dict):
"""
self.announce(message=dict(**progress_update, **additional_metadata))

return TQDMProgressSubscriber(iterable=iterable, on_progress_update=on_progress_update, **tqdm_kwargs)
return TQDMProgressSubscriber(*tqdm_args, on_progress_update=on_progress_update, **tqdm_kwargs)

def announce(self, message: Dict[Any, Any]):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/tqdm_publisher/_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class TQDMProgressPublisher(base_tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, *tqdm_args, **tqdm_kwargs):
super().__init__(*tqdm_args, **tqdm_kwargs)
self.progress_bar_id = str(uuid4())
self.callbacks = dict()

Expand Down
7 changes: 4 additions & 3 deletions src/tqdm_publisher/_subscriber.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, Optional

from ._publisher import TQDMProgressPublisher


class TQDMProgressSubscriber(TQDMProgressPublisher):
def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs):
super().__init__(iterable=iterable, **tqdm_kwargs)
def __init__(self, *tqdm_args, on_progress_update: callable, **tqdm_kwargs):

super().__init__(*tqdm_args, **tqdm_kwargs)

def run_on_progress_update(format_dict: Dict[str, Any]):
"""
Expand Down
1 change: 0 additions & 1 deletion src/tqdm_publisher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ async def sleep_func(sleep_duration: float = 1) -> float:


def create_tasks(number_of_tasks: int = 10**5):
number_of_tasks = 10**5
sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)]

tasks = list()
Expand Down
43 changes: 41 additions & 2 deletions tests/test_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,49 @@ def test_callback(data):
assert "n" in format and "total" in format

tasks = create_tasks()
subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks))
total = len(tasks)
subscriber = TQDMProgressSubscriber(
asyncio.as_completed(tasks),
total=total,
mininterval=0,
on_progress_update=test_callback,
)

# Simulate an update to trigger the callback
for f in subscriber:
await f

assert n_callback_executions > 1
assert n_callback_executions == total + 1


# Test manual update management
@pytest.mark.asyncio
async def test_manual_updates():
n_callback_executions = 0

def test_callback(data):
nonlocal n_callback_executions
n_callback_executions += 1

print(data)

assert "progress_bar_id" in data
identifier = data["progress_bar_id"]
assert str(UUID(identifier, version=4)) == identifier

assert "format_dict" in data
format = data["format_dict"]
assert "n" in format and "total" in format

tasks = create_tasks(10)
total = len(tasks)
subscriber = TQDMProgressSubscriber(on_progress_update=test_callback, total=total, mininterval=0)

# Simulate updatse to trigger the callback
for task in asyncio.as_completed(tasks):
await task
subscriber.update(1) # Update by 1 iteration

assert n_callback_executions == total + 1

subscriber.close()

0 comments on commit 8ab8504

Please sign in to comment.