-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9364169
commit 1272fc1
Showing
6 changed files
with
139 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[run] | ||
omit = */_demos/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import asyncio | ||
|
||
import pytest | ||
|
||
from uuid import UUID | ||
from tqdm_publisher import TQDMProgressHandler | ||
from tqdm_publisher.testing import create_tasks | ||
|
||
N_SUBSCRIBERS = 3 | ||
N_TASKS_PER_SUBSCRIBER = 3 | ||
|
||
# Test concurrent callback execution | ||
@pytest.mark.asyncio | ||
async def test_subscription_and_callback_execution(): | ||
|
||
handler = TQDMProgressHandler(asyncio.Queue) | ||
|
||
n_callback_executions = dict() | ||
|
||
def test_callback(data): | ||
|
||
nonlocal n_callback_executions | ||
|
||
assert "progress_bar_id" in data | ||
identifier = data["progress_bar_id"] | ||
assert str(UUID(identifier, version=4)) == identifier | ||
|
||
if identifier not in n_callback_executions: | ||
n_callback_executions[identifier] = 0 | ||
|
||
n_callback_executions[identifier] += 1 | ||
|
||
assert "format_dict" in data | ||
format = data["format_dict"] | ||
assert "n" in format and "total" in format | ||
|
||
queue = handler.listen() | ||
|
||
for _ in range(N_SUBSCRIBERS): | ||
|
||
subscriber = handler.create_progress_subscriber( | ||
asyncio.as_completed(create_tasks(number_of_tasks=N_TASKS_PER_SUBSCRIBER)), | ||
total=N_TASKS_PER_SUBSCRIBER | ||
) | ||
|
||
for f in subscriber: | ||
await f | ||
|
||
while not queue.empty(): | ||
message = await queue.get() | ||
test_callback(message) | ||
queue.task_done() | ||
|
||
assert len(n_callback_executions) == N_SUBSCRIBERS | ||
|
||
for _, n_executions in n_callback_executions.items(): | ||
assert n_executions > 1 | ||
|
||
|
||
def test_unsubscription(): | ||
handler = TQDMProgressHandler(asyncio.Queue) | ||
queue = handler.listen() | ||
assert len(handler.listeners) == 1 | ||
result = handler.unsubscribe(queue) | ||
assert result == True | ||
assert len(handler.listeners) == 0 | ||
result = handler.unsubscribe(queue) | ||
assert result == False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import asyncio | ||
|
||
import pytest | ||
|
||
from uuid import UUID | ||
from tqdm_publisher import TQDMProgressSubscriber | ||
from tqdm_publisher.testing import create_tasks | ||
|
||
def test_initialization(): | ||
subscriber = TQDMProgressSubscriber(iterable=[], on_progress_update=lambda x: x) | ||
assert len(subscriber.callbacks) == 1 | ||
|
||
|
||
# Test concurrent callback execution | ||
@pytest.mark.asyncio | ||
async def test_subscription_and_callback_execution(): | ||
n_callback_executions = 0 | ||
|
||
def test_callback(data): | ||
nonlocal n_callback_executions | ||
n_callback_executions += 1 | ||
|
||
assert "progress_bar_id" in data | ||
identifier = data["progress_bar_id"] | ||
assert str(UUID(identifier, version=4)) == identifier | ||
|
||
assert "format_dict" in data | ||
format = data["format_dict"] | ||
assert "n" in format and "total" in format | ||
|
||
tasks = create_tasks() | ||
subscriber = TQDMProgressSubscriber(asyncio.as_completed(tasks), test_callback, total=len(tasks)) | ||
|
||
# Simulate an update to trigger the callback | ||
for f in subscriber: | ||
await f | ||
|
||
assert n_callback_executions > 1 |