From d167e494414ede953068282a402dc409fc2cd207 Mon Sep 17 00:00:00 2001 From: Garrett Michael Flynn Date: Mon, 15 Apr 2024 13:35:01 -0500 Subject: [PATCH] Simplified parallel structure --- .../_demos/_parallel_bars/_client.py | 62 --------------- .../_demos/_parallel_bars/_server.py | 75 +++++-------------- 2 files changed, 20 insertions(+), 117 deletions(-) delete mode 100644 src/tqdm_publisher/_demos/_parallel_bars/_client.py diff --git a/src/tqdm_publisher/_demos/_parallel_bars/_client.py b/src/tqdm_publisher/_demos/_parallel_bars/_client.py deleted file mode 100644 index 0f45979..0000000 --- a/src/tqdm_publisher/_demos/_parallel_bars/_client.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Demo of parallel tqdm client.""" - -# HTTP server addition -import http.server -import json -import signal -import socket -import socketserver -import sys - - -def find_free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) # Bind to a free port provided by the host - return s.getsockname()[1] # Return the port number assigned - - -def GLOBAL_CALLBACK(request_id, id, format_dict): - print("Global Update", request_id, id, f"{format_dict['n']}/{format_dict['total']}") - - -def create_http_server(port: int, callback): - class MyHttpRequestHandler(http.server.SimpleHTTPRequestHandler): - - def do_POST(self): - content_length = int(self.headers["Content-Length"]) - post_data = json.loads(self.rfile.read(content_length).decode("utf-8")) - callback(post_data["request_id"], post_data["id"], post_data["data"]) - self.send_response(200) - self.end_headers() - - with socketserver.TCPServer(("", port), MyHttpRequestHandler) as httpd: - - def signal_handler(signal, frame): - print("\n\nInterrupt signal received. Closing server...") - httpd.server_close() - httpd.socket.close() - print("Server closed.") - sys.exit(0) - - try: - signal.signal(signal.SIGINT, signal_handler) - except: - pass # Allow to work in thread - - print(f"Serving HTTP on port {port}") - httpd.serve_forever() - - -if __name__ == "__main__": - - flags_list = sys.argv[1:] - - port_flag = "--port" in flags_list - - if port_flag: - port_index = flags_list.index("--port") - PORT = int(flags_list[port_index + 1]) - else: - PORT = find_free_port() - - create_http_server(port=PORT, callback=GLOBAL_CALLBACK) diff --git a/src/tqdm_publisher/_demos/_parallel_bars/_server.py b/src/tqdm_publisher/_demos/_parallel_bars/_server.py index 57d1f89..3ff1b3f 100644 --- a/src/tqdm_publisher/_demos/_parallel_bars/_server.py +++ b/src/tqdm_publisher/_demos/_parallel_bars/_server.py @@ -3,21 +3,16 @@ import asyncio import json import sys -import threading import time import uuid from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import List, Union +from typing import List import requests from flask import Flask, Response, jsonify, request from flask_cors import CORS, cross_origin from tqdm_publisher import TQDMProgressHandler, TQDMProgressPublisher -from tqdm_publisher._demos._parallel_bars._client import ( - create_http_server, - find_free_port, -) N_JOBS = 3 @@ -32,25 +27,7 @@ ] ## TQDMProgressHandler cannot be called from a process...so we just use a global reference exposed to each subprocess -progress_handler = TQDMProgressHandler() - - -def forward_updates_over_server_sent_events(request_id, progress_bar_id, format_dict): - progress_handler.announce(dict(request_id=request_id, progress_bar_id=progress_bar_id, format_dict=format_dict)) - - -class ThreadedHTTPServer: - def __init__(self, port: int, callback): - self.port = port - self.callback = callback - - def run(self): - create_http_server(port=self.port, callback=self.callback) - - def start(self): - thread = threading.Thread(target=self.run) - thread.start() - +progress_handler = TQDMProgressHandler() def forward_to_http_server(url: str, request_id: str, progress_bar_id: int, format_dict: dict): """ @@ -214,7 +191,7 @@ def listen_to_events(): app = Flask(__name__) cors = CORS(app) app.config["CORS_HEADERS"] = "Content-Type" -PORT = find_free_port() +PORT = 3768 # find_free_port() @app.route("/start", methods=["POST"]) @@ -222,42 +199,31 @@ def listen_to_events(): def start(): data = json.loads(request.data) if request.data else {} request_id = data["request_id"] - run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=f"http://localhost:{PORT}") + url = f"http://localhost:{PORT}/update" + app.logger.info(url) + + run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=url) return jsonify({"status": "success"}) +@app.route("/update", methods=["POST"]) +@cross_origin() +def update(): + data = json.loads(request.data) if request.data else {} + request_id = data["request_id"] + progress_bar_id = data["id"] + format_dict = data["data"] + + # Forward updates over Sever-Side Events + progress_handler.announce(dict(request_id=request_id, progress_bar_id=progress_bar_id, format_dict=format_dict)) + @app.route("/events", methods=["GET"]) @cross_origin() def events(): return Response(listen_to_events(), mimetype="text/event-stream") - -class ThreadedFlaskServer: - def __init__(self, port: int): - self.port = port - - def run(self): - app.run(host="localhost", port=self.port) - - def start(self): - thread = threading.Thread(target=self.run) - thread.start() - - async def start_server(port): - - flask_server = ThreadedFlaskServer(port=3768) - flask_server.start() - - def update_queue(request_id: str, progress_bar_id: str, format_dict: dict): - forward_updates_over_server_sent_events( - request_id=request_id, progress_bar_id=progress_bar_id, format_dict=format_dict - ) - - http_server = ThreadedHTTPServer(port=PORT, callback=update_queue) - http_server.start() - - await asyncio.Future() + app.run(host="localhost", port=port) def run_parallel_bar_demo() -> None: @@ -266,8 +232,7 @@ def run_parallel_bar_demo() -> None: def _run_parallel_bars_demo(port: str, host: str): - URL = f"http://{host}:{port}" - + URL = f"http://{host}:{port}/update" request_id = uuid.uuid4() run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=URL)