Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow manual updates #65

Merged
merged 8 commits into from
May 20, 2024
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
## Unreleased
* Update the arguments for all `tqdm_publisher` classes to mirror the `tqdm` constructor, adding additional parameters as required keyword arguments. [PR #65](https://github.com/catalystneuro/tqdm_publisher/pull/65)

## v0.1.0 (April 26th, 2024)
* The first alpha release of `tqdm_publisher`.

Expand Down
4 changes: 2 additions & 2 deletions src/tqdm_publisher/_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def listen(self) -> queue.Queue:
return new_queue

def create_progress_subscriber(
self, iterable: Iterable[Any], additional_metadata: dict = dict(), **tqdm_kwargs
self, *tqdm_args, additional_metadata: dict = dict(), **tqdm_kwargs
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
) -> TQDMProgressSubscriber:

def on_progress_update(progress_update: dict):
Expand All @@ -31,7 +31,7 @@ def on_progress_update(progress_update: dict):
"""
self.announce(message=dict(**progress_update, **additional_metadata))

return TQDMProgressSubscriber(iterable=iterable, on_progress_update=on_progress_update, **tqdm_kwargs)
return TQDMProgressSubscriber(*tqdm_args, on_progress_update=on_progress_update, **tqdm_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return TQDMProgressSubscriber(*tqdm_args, on_progress_update=on_progress_update, **tqdm_kwargs)
return TQDMProgressSubscriber(*tqdm_args, *, on_progress_update=on_progress_update, **tqdm_kwargs)

Actually, to match the statement in the CHANGELOG (required), need this star here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not seem acceptable when instantiating a class. Is this where you meant to put this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, meant in the init

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already handled by the *tqdm_args argument, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, * says 'everything to the right of this position is required to be a keyword argument only'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

* by itself that is, *args just unpacks a list of ordered unnamed positional values to the signature

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I guess you're right, if it sees *args it imposes keyword only to the right of that


def announce(self, message: Dict[Any, Any]):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/tqdm_publisher/_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


class TQDMProgressPublisher(base_tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, *tqdm_args, **tqdm_kwargs):
super().__init__(*tqdm_args, **tqdm_kwargs)
self.progress_bar_id = str(uuid4())
self.callbacks = dict()

Expand Down
7 changes: 4 additions & 3 deletions src/tqdm_publisher/_subscriber.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, Optional

from ._publisher import TQDMProgressPublisher


class TQDMProgressSubscriber(TQDMProgressPublisher):
def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs):
super().__init__(iterable=iterable, **tqdm_kwargs)
def __init__(self, *tqdm_args, on_progress_update: callable, **tqdm_kwargs):

CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(*tqdm_args, **tqdm_kwargs)

def run_on_progress_update(format_dict: Dict[str, Any]):
"""
Expand Down
1 change: 0 additions & 1 deletion src/tqdm_publisher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ async def sleep_func(sleep_duration: float = 1) -> float:


def create_tasks(number_of_tasks: int = 10**5):
number_of_tasks = 10**5
sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)]

tasks = list()
Expand Down
43 changes: 41 additions & 2 deletions tests/test_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,49 @@ def test_callback(data):
assert "n" in format and "total" in format

tasks = create_tasks()
subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks))
total = len(tasks)
subscriber = TQDMProgressSubscriber(
asyncio.as_completed(tasks),
total=total,
mininterval=0,
on_progress_update=test_callback,
)

# Simulate an update to trigger the callback
for f in subscriber:
await f

assert n_callback_executions > 1
assert n_callback_executions == total + 1


# Test manual update management
@pytest.mark.asyncio
async def test_manual_updates():
n_callback_executions = 0

def test_callback(data):
nonlocal n_callback_executions
n_callback_executions += 1

print(data)

assert "progress_bar_id" in data
identifier = data["progress_bar_id"]
assert str(UUID(identifier, version=4)) == identifier

assert "format_dict" in data
format = data["format_dict"]
assert "n" in format and "total" in format

tasks = create_tasks(10)
total = len(tasks)
subscriber = TQDMProgressSubscriber(on_progress_update=test_callback, total=total, mininterval=0)

# Simulate updatse to trigger the callback
for task in asyncio.as_completed(tasks):
await task
subscriber.update(1) # Update by 1 iteration

assert n_callback_executions == total + 1

subscriber.close()
Loading