Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow manual updates #65

Merged
merged 8 commits into from
May 20, 2024
4 changes: 2 additions & 2 deletions src/tqdm_publisher/_subscriber.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, Optional

from ._publisher import TQDMProgressPublisher


class TQDMProgressSubscriber(TQDMProgressPublisher):
def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs):
def __init__(self, on_progress_update: callable, iterable: Optional[Iterable[Any]] = None, **tqdm_kwargs):
super().__init__(iterable=iterable, **tqdm_kwargs)
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved

def run_on_progress_update(format_dict: Dict[str, Any]):
Expand Down
1 change: 0 additions & 1 deletion src/tqdm_publisher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ async def sleep_func(sleep_duration: float = 1) -> float:


def create_tasks(number_of_tasks: int = 10**5):
number_of_tasks = 10**5
sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)]

tasks = list()
Expand Down
43 changes: 41 additions & 2 deletions tests/test_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,49 @@ def test_callback(data):
assert "n" in format and "total" in format

tasks = create_tasks()
subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks))
total = len(tasks)
subscriber = TQDMProgressSubscriber(
test_callback,
asyncio.as_completed(tasks),
total=total,
mininterval=0,
)

# Simulate an update to trigger the callback
for f in subscriber:
await f

assert n_callback_executions > 1
assert n_callback_executions == total + 1


# Test manual update management
@pytest.mark.asyncio
async def test_manual_updates():
n_callback_executions = 0

def test_callback(data):
nonlocal n_callback_executions
n_callback_executions += 1

print(data)

assert "progress_bar_id" in data
identifier = data["progress_bar_id"]
assert str(UUID(identifier, version=4)) == identifier

assert "format_dict" in data
format = data["format_dict"]
assert "n" in format and "total" in format

tasks = create_tasks(10)
total = len(tasks)
subscriber = TQDMProgressSubscriber(test_callback, total=total, mininterval=0)

# Simulate updatse to trigger the callback
for task in asyncio.as_completed(tasks):
await task
subscriber.update(1) # Update by 1 iteration

assert n_callback_executions == total + 1

subscriber.close()
Loading