Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc Improvements I] Add more annotations types and docstrings #52

Merged
merged 23 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ test = [
]

demo = [
"websockets==12.0",
"flask==2.3.2",
"flask-cors==4.0.0"
"requests",
garrettmflynn marked this conversation as resolved.
Show resolved Hide resolved
"websockets",
"flask",
"flask-cors"
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions src/tqdm_publisher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._handler import TQDMProgressHandler
from ._publisher import TQDMPublisher
from ._publisher import TQDMProgressPublisher
from ._subscriber import TQDMProgressSubscriber

__all__ = ["TQDMPublisher", "TQDMProgressSubscriber", "TQDMProgressHandler"]
__all__ = ["TQDMProgressPublisher", "TQDMProgressSubscriber", "TQDMProgressHandler"]
1 change: 0 additions & 1 deletion src/tqdm_publisher/_demos/_demo_command_line_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,5 @@ def _command_line_interface():
webbrowser.open_new_tab(f"http://localhost:{CLIENT_PORT}/{client_relative_path}")

demo_info["server"]()

else:
print(f"{command} is an invalid command.")
2 changes: 1 addition & 1 deletion src/tqdm_publisher/_demos/_multiple_bars/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def run(self):
showcase of an alternative approach to defining and scoping the execution.
"""
all_task_durations_in_seconds = [1.0 for _ in range(10)] # Ten seconds at one task per second
self.progress_bar = tqdm_publisher.TQDMPublisher(iterable=all_task_durations_in_seconds)
self.progress_bar = tqdm_publisher.TQDMProgressPublisher(iterable=all_task_durations_in_seconds)
self.progress_bar.subscribe(callback=self.update)

for task_duration in self.progress_bar:
Expand Down
4 changes: 2 additions & 2 deletions src/tqdm_publisher/_demos/_parallel_bars/_client.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ const getBar = (request_id, id) => {

// Update the specified progress bar when a message is received from the server
const onProgressUpdate = (event) => {
const { request_id, id, format_dict } = JSON.parse(event.data);
const bar = getBar(request_id, id);
const { request_id, progress_bar_id, format_dict } = JSON.parse(event.data);
const bar = getBar(request_id, progress_bar_id);
bar.style.width = 100 * (format_dict.n / format_dict.total) + '%';
}

Expand Down
144 changes: 100 additions & 44 deletions src/tqdm_publisher/_demos/_parallel_bars/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import time
import uuid
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List
from typing import List, Union

import requests
from flask import Flask, Response, jsonify, request
from flask_cors import CORS, cross_origin

from tqdm_publisher import TQDMProgressHandler, TQDMPublisher
from tqdm_publisher import TQDMProgressHandler, TQDMProgressPublisher
from tqdm_publisher._demos._parallel_bars._client import (
create_http_server,
find_free_port,
Expand All @@ -24,20 +24,21 @@
# Each outer entry is a list of 'tasks' to perform on a particular worker
# For demonstration purposes, each in the list of tasks is the length of time in seconds
# that each iteration of the task takes to run and update the progress bar (emulated by sleeping)
SECONDS_PER_TASK = 1
NUMBER_OF_TASKS_PER_JOB = 6
BASE_SECONDS_PER_TASK = 0.5 # The base time for each task; actual time increases proportional to the index of the task
NUMBER_OF_TASKS_PER_JOB = 10
garrettmflynn marked this conversation as resolved.
Show resolved Hide resolved
TASK_TIMES: List[List[float]] = [
[SECONDS_PER_TASK * task_index] * task_index for task_index in range(1, NUMBER_OF_TASKS_PER_JOB + 1)
[BASE_SECONDS_PER_TASK * task_index] * NUMBER_OF_TASKS_PER_JOB
for task_index in range(1, NUMBER_OF_TASKS_PER_JOB + 1)
]

WEBSOCKETS = {}

## NOTE: TQDMProgressHandler cannot be called from a process...so we just use a queue directly
## TQDMProgressHandler cannot be called from a process...so we just use a global reference exposed to each subprocess
progress_handler = TQDMProgressHandler()


def forward_updates_over_sse(request_id, id, n, total, **kwargs):
progress_handler._announce(dict(request_id=request_id, id=id, format_dict=dict(n=n, total=total)))
def forward_updates_over_server_sent_events(request_id: str, progress_bar_id: str, n: int, total: int, **kwargs):
progress_handler.announce(
dict(request_id=request_id, progress_bar_id=progress_bar_id, format_dict=dict(n=n, total=total), **kwargs)
)


class ThreadedHTTPServer:
Expand Down Expand Up @@ -83,14 +84,14 @@ def _run_sleep_tasks_in_subprocess(
The index of this task in the list of all tasks from the buffer map.
Each index would map to a different tqdm position.
request_id : int
Identifier of ??.
Identifier of the request, provided by the client.
url : str
The localhost URL to sent progress updates to.
"""

subprogress_bar_id = uuid.uuid4()

sub_progress_bar = TQDMPublisher(
sub_progress_bar = TQDMProgressPublisher(
iterable=task_times,
position=iteration_index + 1,
desc=f"Progress on iteration {iteration_index} ({id})",
Expand All @@ -107,50 +108,110 @@ def _run_sleep_tasks_in_subprocess(
time.sleep(sleep_time)


def run_parallel_processes(request_id, url: str):
def run_parallel_processes(*, all_task_times: List[List[float]], request_id: str, url: str):
garrettmflynn marked this conversation as resolved.
Show resolved Hide resolved

futures = list()
with ProcessPoolExecutor(max_workers=N_JOBS) as executor:

# # Assign the parallel jobs
for iteration_index, task_times in enumerate(TASK_TIMES):
for iteration_index, task_times_per_job in enumerate(all_task_times):
futures.append(
executor.submit(
_run_sleep_tasks_in_subprocess,
task_times=task_times,
task_times=task_times_per_job,
iteration_index=iteration_index,
request_id=request_id,
url=url,
)
)

total_tasks_iterable = as_completed(futures)
total_tasks_progress_bar = TQDMPublisher(
iterable=total_tasks_iterable, total=len(TASK_TIMES), desc="Total tasks completed"
total_tasks_progress_bar = TQDMProgressPublisher(
iterable=total_tasks_iterable, total=len(all_task_times), desc="Total tasks completed"
)

# The 'total' progress bar bas an ID equivalent to the request ID
total_tasks_progress_bar.subscribe(
lambda format_dict: forward_to_http_server(
url=url, request_id=request_id, progress_bar_id=request_id, **format_dict
)
)

# Trigger the deployment of the parallel jobs
for _ in total_tasks_progress_bar:
pass


def format_sse(data: str, event=None) -> str:
msg = f"data: {json.dumps(data)}\n\n"
if event is not None:
msg = f"event: {event}\n{msg}"
return msg
def format_server_sent_events(*, message_data: str, event_type: str = "message") -> str:
"""
Format an `event_type` type server-sent event with `data` in a way expected by the EventSource browser implementation.

With reference to the following demonstration of frontend elements.

```javascript
const server_sent_event = new EventSource("/api/v1/sse");

/*
* This will listen only for events
* similar to the following:
*
* event: notice
* data: useful data
* id: someid
*/
server_sent_event.addEventListener("notice", (event) => {
console.log(event.data);
});

/*
* Similarly, this will listen for events
* with the field `event: update`
*/
server_sent_event.addEventListener("update", (event) => {
console.log(event.data);
});

/*
* The event "message" is a special case, as it
* will capture events without an event field
* as well as events that have the specific type
* `event: message` It will not trigger on any
* other event type.
*/
server_sent_event.addEventListener("message", (event) => {
console.log(event.data);
});
```

Parameters
----------
message_data : str
The message data to be sent to the client.
event_type : str, default="message"
The type of event corresponding to the message data.

Returns
-------
formatted_message : str
The formatted message to be sent to the client.
"""

# message = f"event: {event_type}\n" if event_type != "" else ""
# message += f"data: {message_data}\n\n"
# return message

message = f"data: {message_data}\n\n"
if event_type != "":
message = f"event: {event_type}\n{message}"
return message


def listen_to_events():
messages = progress_handler.listen() # returns a queue.Queue
while True:
msg = messages.get() # blocks until a new message arrives
yield format_sse(msg)
message_data = messages.get() # blocks until a new message arrives
print("Message data", message_data)
yield format_server_sent_events(message_data=json.dumps(message_data))


app = Flask(__name__)
Expand All @@ -164,7 +225,7 @@ def listen_to_events():
def start():
data = json.loads(request.data) if request.data else {}
request_id = data["request_id"]
run_parallel_processes(request_id, f"http://localhost:{PORT}")
run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=f"http://localhost:{PORT}")
return jsonify({"status": "success"})


Expand All @@ -191,14 +252,10 @@ async def start_server(port):
flask_server = ThreadedFlaskServer(port=3768)
flask_server.start()

# # DEMO ONE: Direct updates from HTTP server
# http_server = ThreadedHTTPServer(port=port, callback=forward_updates_over_sse)
# http_server.start()
# await asyncio.Future()

# DEMO TWO: Queue
def update_queue(request_id, id, n, total, **kwargs):
forward_updates_over_sse(request_id, id, n, total)
def update_queue(request_id: str, progress_bar_id: str, n: int, total: int, **kwargs):
forward_updates_over_server_sent_events(
request_id=request_id, progress_bar_id=progress_bar_id, n=n, total=total
)

http_server = ThreadedHTTPServer(port=PORT, callback=update_queue)
http_server.start()
Expand All @@ -207,12 +264,18 @@ def update_queue(request_id, id, n, total, **kwargs):


def run_parallel_bar_demo() -> None:
"""Asynchronously start the servers"""
asyncio.run(start_server(PORT))
"""Asynchronously start the servers."""
asyncio.run(start_server(port=PORT))


def _run_parallel_bars_demo(port: str, host: str):
URL = f"http://{host}:{port}"

request_id = uuid.uuid4()
run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=URL)

if __name__ == "__main__":

if __name__ == "main":
CodyCBakerPhD marked this conversation as resolved.
Show resolved Hide resolved
flags_list = sys.argv[1:]

port_flag = "--port" in flags_list
Expand All @@ -228,11 +291,4 @@ def run_parallel_bar_demo() -> None:
else:
HOST = "localhost"

URL = f"http://{HOST}:{PORT}" if port_flag else None

if URL is None:
raise ValueError("URL is not defined.")

# Just run the parallel processes
request_id = uuid.uuid4()
run_parallel_processes(request_id, URL)
_run_parallel_bars_demo(port=PORT, host=HOST)
Loading
Loading