Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc Improvements I] Add more annotations types and docstrings #52

Merged
merged 23 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/tqdm_publisher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._handler import TQDMProgressHandler
from ._publisher import TQDMPublisher
from ._publisher import TQDMProgressPublisher
from ._subscriber import TQDMProgressSubscriber

__all__ = ["TQDMPublisher", "TQDMProgressSubscriber", "TQDMProgressHandler"]
__all__ = ["TQDMProgressPublisher", "TQDMProgressSubscriber", "TQDMProgressHandler"]
1 change: 0 additions & 1 deletion src/tqdm_publisher/_demos/_demo_command_line_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,5 @@ def _command_line_interface():
webbrowser.open_new_tab(f"http://localhost:{CLIENT_PORT}/{client_relative_path}")

demo_info["server"]()

else:
print(f"{command} is an invalid command.")
2 changes: 1 addition & 1 deletion src/tqdm_publisher/_demos/_multiple_bars/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def run(self):
showcase of an alternative approach to defining and scoping the execution.
"""
all_task_durations_in_seconds = [1.0 for _ in range(10)] # Ten seconds at one task per second
self.progress_bar = tqdm_publisher.TQDMPublisher(iterable=all_task_durations_in_seconds)
self.progress_bar = tqdm_publisher.TQDMProgressPublisher(iterable=all_task_durations_in_seconds)
self.progress_bar.subscribe(callback=self.update)

for task_duration in self.progress_bar:
Expand Down
71 changes: 39 additions & 32 deletions src/tqdm_publisher/_demos/_parallel_bars/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import time
import uuid
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List
from typing import List, Union

import requests
from flask import Flask, Response, jsonify, request
from flask_cors import CORS, cross_origin

from tqdm_publisher import TQDMProgressHandler, TQDMPublisher
from tqdm_publisher import TQDMProgressHandler, TQDMProgressPublisher
from tqdm_publisher._demos._parallel_bars._client import (
create_http_server,
find_free_port,
Expand All @@ -24,20 +24,24 @@
# Each outer entry is a list of 'tasks' to perform on a particular worker
# For demonstration purposes, each in the list of tasks is the length of time in seconds
# that each iteration of the task takes to run and update the progress bar (emulated by sleeping)
SECONDS_PER_TASK = 1
NUMBER_OF_TASKS_PER_JOB = 6
BASE_SECONDS_PER_TASK = 0.5 # The base time for each task; actual time increases proportional to the index of the task
NUMBER_OF_TASKS_PER_JOB = 10
garrettmflynn marked this conversation as resolved.
Show resolved Hide resolved
TASK_TIMES: List[List[float]] = [
[SECONDS_PER_TASK * task_index] * task_index for task_index in range(1, NUMBER_OF_TASKS_PER_JOB + 1)
[BASE_SECONDS_PER_TASK * task_index] * NUMBER_OF_TASKS_PER_JOB
for task_index in range(1, NUMBER_OF_TASKS_PER_JOB + 1)
]

WEBSOCKETS = {}

## NOTE: TQDMProgressHandler cannot be called from a process...so we just use a queue directly
## NOTE: TQDMProgressHandler cannot be called from a process...so we just use a global reference exposed to each subprocess
progress_handler = TQDMProgressHandler()


def forward_updates_over_sse(request_id, id, n, total, **kwargs):
progress_handler._announce(dict(request_id=request_id, id=id, format_dict=dict(n=n, total=total)))
def forward_updates_over_server_side_events(request_id: str, progress_bar_id: str, n: int, total: int, **kwargs):
# TODO: shouldn't this line use `create_progress_subscriber`? Otherwise consider making `.accounce` non-private
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
progress_handler._announce(
dict(request_id=request_id, progress_bar_id=progress_bar_id, format_dict=dict(n=n, total=total), **kwargs)
)


class ThreadedHTTPServer:
Expand Down Expand Up @@ -83,14 +87,14 @@ def _run_sleep_tasks_in_subprocess(
The index of this task in the list of all tasks from the buffer map.
Each index would map to a different tqdm position.
request_id : int
Identifier of ??.
Identifier of the request.
garrettmflynn marked this conversation as resolved.
Show resolved Hide resolved
url : str
The localhost URL to sent progress updates to.
"""

subprogress_bar_id = uuid.uuid4()

sub_progress_bar = TQDMPublisher(
sub_progress_bar = TQDMProgressPublisher(
iterable=task_times,
position=iteration_index + 1,
desc=f"Progress on iteration {iteration_index} ({id})",
Expand All @@ -107,43 +111,45 @@ def _run_sleep_tasks_in_subprocess(
time.sleep(sleep_time)


def run_parallel_processes(request_id, url: str):
def run_parallel_processes(*, all_task_times: List[List[float]], request_id: str, url: str):
garrettmflynn marked this conversation as resolved.
Show resolved Hide resolved

futures = list()
with ProcessPoolExecutor(max_workers=N_JOBS) as executor:

# # Assign the parallel jobs
for iteration_index, task_times in enumerate(TASK_TIMES):
for iteration_index, task_times_per_job in enumerate(all_task_times):
futures.append(
executor.submit(
_run_sleep_tasks_in_subprocess,
task_times=task_times,
task_times=task_times_per_job,
iteration_index=iteration_index,
request_id=request_id,
url=url,
)
)

total_tasks_iterable = as_completed(futures)
total_tasks_progress_bar = TQDMPublisher(
total_tasks_progress_bar = TQDMProgressPublisher(
iterable=total_tasks_iterable, total=len(TASK_TIMES), desc="Total tasks completed"
)

# The 'total' progress bar bas an ID equivalent to the request ID
total_tasks_progress_bar.subscribe(
lambda format_dict: forward_to_http_server(
url=url, request_id=request_id, progress_bar_id=request_id, **format_dict
)
)

# Trigger the deployment of the parallel jobs
for _ in total_tasks_progress_bar:
pass


def format_sse(data: str, event=None) -> str:
msg = f"data: {json.dumps(data)}\n\n"
def format_sse(data: str, event: Union[str, None] = None) -> str:
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
message = f"data: {json.dumps(data)}\n\n"
if event is not None:
msg = f"event: {event}\n{msg}"
return msg
message = f"event: {event}\n{message}"
return message


def listen_to_events():
Expand All @@ -164,7 +170,7 @@ def listen_to_events():
def start():
data = json.loads(request.data) if request.data else {}
request_id = data["request_id"]
run_parallel_processes(request_id, f"http://localhost:{PORT}")
run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=f"http://localhost:{PORT}")
return jsonify({"status": "success"})


Expand Down Expand Up @@ -197,8 +203,10 @@ async def start_server(port):
# await asyncio.Future()

# DEMO TWO: Queue
def update_queue(request_id, id, n, total, **kwargs):
forward_updates_over_sse(request_id, id, n, total)
def update_queue(request_id: str, progress_bar_id: str, n: int, total: int, **kwargs):
forward_updates_over_server_side_events(
request_id=request_id, progress_bar_id=progress_bar_id, n=n, total=total
)

http_server = ThreadedHTTPServer(port=PORT, callback=update_queue)
http_server.start()
Expand All @@ -207,12 +215,18 @@ def update_queue(request_id, id, n, total, **kwargs):


def run_parallel_bar_demo() -> None:
"""Asynchronously start the servers"""
asyncio.run(start_server(PORT))
"""Asynchronously start the servers."""
asyncio.run(start_server(port=PORT))


def _run_parallel_bars_demo(port: str, host: str):
URL = f"http://{HOST}:{PORT}"
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved

request_id = uuid.uuid4()
run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=URL)

if __name__ == "__main__":

if __name__ == "main":
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
flags_list = sys.argv[1:]

port_flag = "--port" in flags_list
Expand All @@ -228,11 +242,4 @@ def run_parallel_bar_demo() -> None:
else:
HOST = "localhost"

URL = f"http://{HOST}:{PORT}" if port_flag else None

if URL is None:
raise ValueError("URL is not defined.")

# Just run the parallel processes
request_id = uuid.uuid4()
run_parallel_processes(request_id, URL)
_run_parallel_bars_demo(port=PORT, host=HOST)
8 changes: 4 additions & 4 deletions src/tqdm_publisher/_demos/_parallel_bars/_server_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import requests
import websockets

from tqdm_publisher import TQDMProgressHandler, TQDMPublisher
from tqdm_publisher import TQDMProgressHandler, TQDMProgressPublisher
from tqdm_publisher._demos._parallel_bars._client import (
create_http_server,
find_free_port,
Expand Down Expand Up @@ -119,7 +119,7 @@ def _run_sleep_tasks_in_subprocess(

id = uuid.uuid4()

sub_progress_bar = TQDMPublisher(
sub_progress_bar = TQDMProgressPublisher(
iterable=task_times,
position=iteration_index + 1,
desc=f"Progress on iteration {iteration_index} ({id})",
Expand All @@ -132,7 +132,7 @@ def _run_sleep_tasks_in_subprocess(
time.sleep(sleep_time)


def run_parallel_processes(request_id, url: str):
def run_parallel_processes(*, request_id: str, url: str):

with ProcessPoolExecutor(max_workers=N_JOBS) as executor:

Expand Down Expand Up @@ -222,4 +222,4 @@ def run_parallel_bar_demo() -> None:

# Just run the parallel processes
request_id = uuid.uuid4()
run_parallel_processes(request_id, URL)
run_parallel_processes(request_id=request_id, url=URL)
2 changes: 1 addition & 1 deletion src/tqdm_publisher/_demos/_single_bar/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def start_progress_bar(*, progress_callback: callable) -> None:
Defaults are chosen for a deterministic and regular update period of one second for a total time of 10 seconds.
"""
all_task_durations_in_seconds = [1.0 for _ in range(10)] # Ten seconds at one second per task
progress_bar = tqdm_publisher.TQDMPublisher(iterable=all_task_durations_in_seconds)
progress_bar = tqdm_publisher.TQDMProgressPublisher(iterable=all_task_durations_in_seconds)

def run_function_on_progress_update(format_dict: dict) -> None:
"""
Expand Down
59 changes: 39 additions & 20 deletions src/tqdm_publisher/_handler.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,46 @@
import queue
from typing import Any, Dict, Iterable, List

from ._subscriber import TQDMProgressSubscriber


class TQDMProgressHandler:
def __init__(self):
self.listeners = []

def listen(self):
q = queue.Queue(maxsize=25)
self.listeners.append(q)
return q

def create(self, iterable, additional_metadata: dict = dict(), **tqdm_kwargs):
return TQDMProgressSubscriber(
iterable,
lambda progress_update: self._announce(dict(**progress_update, **additional_metadata)),
**tqdm_kwargs,
)

def _announce(self, msg):
for i in reversed(range(len(self.listeners))):
try:
self.listeners[i].put_nowait(msg)
except queue.Full:
del self.listeners[i]
self.listeners: List[queue.Queue] = []

def listen(self) -> queue.Queue:
new_queue = queue.Queue(maxsize=25)
self.listeners.append(new_queue)
return new_queue

def create_progress_subscriber(
self, iterable: Iterable[Any], additional_metadata: dict = dict(), **tqdm_kwargs
) -> TQDMProgressSubscriber:

def on_progress_update(progress_update: dict):
"""
This is the injection called on every update of the progress bar.

It triggers the announcement event over all listeners on each update of the progress bar.

It must be defined inside this local scope to communicate the `additional_metadata` from the level above
without including it in the method signature.
"""
self._announce(message=dict(**progress_update, **additional_metadata))

return TQDMProgressSubscriber(iterable=iterable, on_progress_update=on_progress_update, **tqdm_kwargs)

def _announce(self, message: Dict[Any, Any]):
"""
Announce a message to all listeners.

@garrett - can you describe the expected structure of this message?
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
"""
number_of_listeners = len(self.listeners)
listener_indices = range(number_of_listeners)
listener_indices_from_newest_to_oldest = reversed(listener_indices)
for listener_index in listener_indices_from_newest_to_oldest:
if not self.listeners[listener_index].full():
self.listeners[listener_index].put_nowait(item=message)
else: # When full, remove the newest listener in the stack
del self.listeners[listener_index]
12 changes: 6 additions & 6 deletions src/tqdm_publisher/_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from tqdm import tqdm as base_tqdm


class TQDMPublisher(base_tqdm):
class TQDMProgressPublisher(base_tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.id = str(uuid4())
self.progress_bar_id = str(uuid4())
self.callbacks = dict()

# Override the update method to run callbacks
Expand All @@ -19,7 +19,7 @@ def update(self, n: int = 1) -> Union[bool, None]:

return displayed

def subscribe(self, callback: callable):
def subscribe(self, callback: callable) -> str:
"""
Subscribe to updates from the progress bar.

Expand All @@ -37,7 +37,7 @@ def subscribe(self, callback: callable):

Returns
-------
str
callback_id : str
A unique identifier for the callback. This ID is a UUID string and can be used
to reference the registered callback in future operations.

Expand All @@ -55,7 +55,7 @@ def subscribe(self, callback: callable):
self.callbacks[callback_id] = callback
return callback_id

def unsubscribe(self, callback_id: str):
def unsubscribe(self, callback_id: str) -> bool:
"""
Unsubscribe a previously registered callback from the progress bar updates.

Expand All @@ -72,7 +72,7 @@ def unsubscribe(self, callback_id: str):

Returns
-------
bool
success : bool
Returns True if the callback was successfully removed, or False if no callback was
found with the given ID.

Expand Down
25 changes: 20 additions & 5 deletions src/tqdm_publisher/_subscriber.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
from ._publisher import TQDMPublisher
from typing import Any, Dict, Iterable

from ._publisher import TQDMProgressPublisher

class TQDMProgressSubscriber(TQDMPublisher):
def __init__(self, iterable, on_progress_update: callable, **tqdm_kwargs):
super().__init__(iterable, **tqdm_kwargs)
self.subscribe(lambda format_dict: on_progress_update(dict(progress_bar_id=self.id, format_dict=format_dict)))

class TQDMProgressSubscriber(TQDMProgressPublisher):
def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs):
super().__init__(iterable=iterable, **tqdm_kwargs)

def run_on_progress_update(format_dict: Dict[str, Any]):
"""
This is the injection called on every update of the progress bar.

It calls the `on_progress_update` function, which must take a dictionary
containing the progress bar ID and `format_dict`.

It must be defined inside this local scope to include the `.progress_bar_id` attribute from the level above
without including it in the method signature.
"""
on_progress_update(dict(progress_bar_id=self.progress_bar_id, format_dict=format_dict))

self.subscribe(run_on_progress_update)
Loading
Loading