Skip to content

Commit

Permalink
Allow manual updates
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettmflynn committed May 15, 2024
1 parent 0e61338 commit fcfde43
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/tqdm_publisher/_subscriber.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, Dict, Iterable
from typing import Any, Dict, Iterable, Optional

from ._publisher import TQDMProgressPublisher


class TQDMProgressSubscriber(TQDMProgressPublisher):
def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs):
def __init__(self, on_progress_update: callable, iterable: Optional[Iterable[Any]]=None, **tqdm_kwargs):
super().__init__(iterable=iterable, **tqdm_kwargs)

def run_on_progress_update(format_dict: Dict[str, Any]):
Expand Down
1 change: 0 additions & 1 deletion src/tqdm_publisher/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ async def sleep_func(sleep_duration: float = 1) -> float:


def create_tasks(number_of_tasks: int = 10**5):
number_of_tasks = 10**5
sleep_durations = [random.uniform(0, 5.0) for _ in range(number_of_tasks)]

tasks = list()
Expand Down
35 changes: 34 additions & 1 deletion tests/test_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,43 @@ def test_callback(data):
assert "n" in format and "total" in format

tasks = create_tasks()
subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks))
subscriber = TQDMProgressSubscriber(test_callback, asyncio.as_completed(tasks), total=len(tasks))

# Simulate an update to trigger the callback
for f in subscriber:
await f

assert n_callback_executions > 1


# Test manual update management
@pytest.mark.asyncio
async def test_manual_updates():
n_callback_executions = 0

def test_callback(data):
nonlocal n_callback_executions
n_callback_executions += 1

print(data)

assert "progress_bar_id" in data
identifier = data["progress_bar_id"]
assert str(UUID(identifier, version=4)) == identifier

assert "format_dict" in data
format = data["format_dict"]
assert "n" in format and "total" in format

tasks = create_tasks(10)
total = len(tasks)
subscriber = TQDMProgressSubscriber(test_callback, total=total)

# Simulate updatse to trigger the callback
for task in asyncio.as_completed(tasks):
await task
subscriber.update(1) # Update by 1 iteration

assert n_callback_executions > 1

subscriber.close()

0 comments on commit fcfde43

Please sign in to comment.