Skip to content

Commit

Permalink
Functional websocket declaration (#42)
Browse files Browse the repository at this point in the history
* Update _server.py

* more keyword arguments; break up docstring; remove unused capture

* more keyword arguments

---------

Co-authored-by: Garrett Michael Flynn <[email protected]>
  • Loading branch information
CodyCBakerPhD and garrettmflynn authored Mar 11, 2024
1 parent 9426dc5 commit d24f4a6
Showing 1 changed file with 27 additions and 52 deletions.
79 changes: 27 additions & 52 deletions src/tqdm_publisher/_demo/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,20 @@
import json
import threading
import time
from typing import Any, Dict, List
from uuid import uuid4

import websockets

from tqdm_publisher import TQDMPublisher
import tqdm_publisher


def start_progress_bar(*, client_id: str, progress_bar_id: str, client_callback: callable) -> None:
def start_progress_bar(*, progress_bar_id: str, client_callback: callable) -> None:
"""
Emulate running the specified number of tasks by sleeping the specified amount of time on each iteration.
Defaults are chosen for a deterministic and regular update period of one second for a total time of one minute.
"""
all_task_durations_in_seconds = [1.0 for _ in range(60)] # One minute at one second per update
progress_bar = TQDMPublisher(iterable=all_task_durations_in_seconds)
progress_bar = tqdm_publisher.TQDMPublisher(iterable=all_task_durations_in_seconds)

def run_function_on_progress_update(format_dict: dict) -> None:
"""
Expand All @@ -29,68 +27,45 @@ def run_function_on_progress_update(format_dict: dict) -> None:
In this demo, we will execute the `client_callback` whose protocol is known only to the WebSocketHandler.
"""
client_callback(client_id=client_id, progress_bar_id=progress_bar_id, format_dict=format_dict)
client_callback(progress_bar_id=progress_bar_id, format_dict=format_dict)

progress_bar.subscribe(callback=run_function_on_progress_update)

for task_duration in progress_bar:
time.sleep(task_duration)


class WebSocketHandler:
"""
This is a class that handles the WebSocket connections and the communication protocol
between the server and the client.
While we could have implemented this as a function, a class is chosen here to maintain reference
to the clients within a defined scope.
"""

def __init__(self) -> None:
"""Initialize the mapping of client IDs to ."""
self.clients: Dict[str, Any] = dict()

def forward_progress_to_client(self, *, client_id: str, progress_bar_id: str, format_dict: dict) -> None:
"""This is the function that will run on every update of the TQDM object. It will forward the progress to the client."""
asyncio.run(self.send(client_id=client_id, data=dict(progress_bar_id=progress_bar_id, format_dict=format_dict)))
async def handler(websocket: websockets.WebSocketServerProtocol) -> None:
"""Handle messages from the client and manage the client connections."""

async def send(self, client_id: str, data: dict) -> None:
"""Send an arbitrary JSON object `data` to the client identifier by `client_id`."""
await self.clients[client_id].send(json.dumps(obj=data))

async def handler(self, websocket: websockets.WebSocketServerProtocol) -> None:
"""Register new WebSocket clients and handle their messages."""
client_id = str(uuid4())

# Register client connection
self.clients[client_id] = websocket
def forward_progress_to_client(*, progress_bar_id: str, format_dict: dict) -> None:
"""
This is the function that will run on every update of the TQDM object.
# Wait for messages from the client
try:
async for message in websocket:
message_from_client = json.loads(message)
It will forward the progress to the client.
"""
asyncio.run(
websocket.send(message=json.dumps(obj=dict(progress_bar_id=progress_bar_id, format_dict=format_dict)))
)

if message_from_client["command"] == "start":
thread = threading.Thread(
target=start_progress_bar,
kwargs=dict(
client_id=client_id,
progress_bar_id=message_from_client["progress_bar_id"],
client_callback=self.forward_progress_to_client,
),
)
thread.start()
# Wait for messages from the client
async for message in websocket:
message_from_client = json.loads(message)

# Deregister the client when the connection is closed
finally:
if client_id in self.clients:
del self.clients[client_id]
if message_from_client["command"] == "start":
thread = threading.Thread(
target=start_progress_bar,
kwargs=dict(
progress_bar_id=message_from_client["progress_bar_id"],
client_callback=forward_progress_to_client,
),
)
thread.start()


async def spawn_server() -> None:
"""Spawn the server asynchronously."""
handler = WebSocketHandler().handler
async with websockets.serve(handler, "", 8000):
async with websockets.serve(ws_handler=handler, host="", port=8000):
await asyncio.Future()


Expand Down

0 comments on commit d24f4a6

Please sign in to comment.