diff --git a/pyproject.toml b/pyproject.toml index 9f735ab..0aa5d2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,9 +44,10 @@ test = [ ] demo = [ - "websockets==12.0", - "flask==2.3.2", - "flask-cors==4.0.0" + "requests", + "websockets", + "flask", + "flask-cors" ] [project.urls] diff --git a/src/tqdm_publisher/__init__.py b/src/tqdm_publisher/__init__.py index 37c74a8..770ada9 100644 --- a/src/tqdm_publisher/__init__.py +++ b/src/tqdm_publisher/__init__.py @@ -1,5 +1,5 @@ from ._handler import TQDMProgressHandler -from ._publisher import TQDMPublisher +from ._publisher import TQDMProgressPublisher from ._subscriber import TQDMProgressSubscriber -__all__ = ["TQDMPublisher", "TQDMProgressSubscriber", "TQDMProgressHandler"] +__all__ = ["TQDMProgressPublisher", "TQDMProgressSubscriber", "TQDMProgressHandler"] diff --git a/src/tqdm_publisher/_demos/_demo_command_line_interface.py b/src/tqdm_publisher/_demos/_demo_command_line_interface.py index 107fe18..da27593 100644 --- a/src/tqdm_publisher/_demos/_demo_command_line_interface.py +++ b/src/tqdm_publisher/_demos/_demo_command_line_interface.py @@ -54,6 +54,5 @@ def _command_line_interface(): webbrowser.open_new_tab(f"http://localhost:{CLIENT_PORT}/{client_relative_path}") demo_info["server"]() - else: print(f"{command} is an invalid command.") diff --git a/src/tqdm_publisher/_demos/_multiple_bars/_server.py b/src/tqdm_publisher/_demos/_multiple_bars/_server.py index ca4a6bf..d81b912 100644 --- a/src/tqdm_publisher/_demos/_multiple_bars/_server.py +++ b/src/tqdm_publisher/_demos/_multiple_bars/_server.py @@ -44,7 +44,7 @@ def run(self): showcase of an alternative approach to defining and scoping the execution. """ all_task_durations_in_seconds = [1.0 for _ in range(10)] # Ten seconds at one task per second - self.progress_bar = tqdm_publisher.TQDMPublisher(iterable=all_task_durations_in_seconds) + self.progress_bar = tqdm_publisher.TQDMProgressPublisher(iterable=all_task_durations_in_seconds) self.progress_bar.subscribe(callback=self.update) for task_duration in self.progress_bar: diff --git a/src/tqdm_publisher/_demos/_parallel_bars/_client.js b/src/tqdm_publisher/_demos/_parallel_bars/_client.js index 43bbcf9..aa46df6 100644 --- a/src/tqdm_publisher/_demos/_parallel_bars/_client.js +++ b/src/tqdm_publisher/_demos/_parallel_bars/_client.js @@ -41,8 +41,8 @@ const getBar = (request_id, id) => { // Update the specified progress bar when a message is received from the server const onProgressUpdate = (event) => { - const { request_id, id, format_dict } = JSON.parse(event.data); - const bar = getBar(request_id, id); + const { request_id, progress_bar_id, format_dict } = JSON.parse(event.data); + const bar = getBar(request_id, progress_bar_id); bar.style.width = 100 * (format_dict.n / format_dict.total) + '%'; } diff --git a/src/tqdm_publisher/_demos/_parallel_bars/_server.py b/src/tqdm_publisher/_demos/_parallel_bars/_server.py index 0ab5190..34da542 100644 --- a/src/tqdm_publisher/_demos/_parallel_bars/_server.py +++ b/src/tqdm_publisher/_demos/_parallel_bars/_server.py @@ -7,13 +7,13 @@ import time import uuid from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import List +from typing import List, Union import requests from flask import Flask, Response, jsonify, request from flask_cors import CORS, cross_origin -from tqdm_publisher import TQDMProgressHandler, TQDMPublisher +from tqdm_publisher import TQDMProgressHandler, TQDMProgressPublisher from tqdm_publisher._demos._parallel_bars._client import ( create_http_server, find_free_port, @@ -24,20 +24,21 @@ # Each outer entry is a list of 'tasks' to perform on a particular worker # For demonstration purposes, each in the list of tasks is the length of time in seconds # that each iteration of the task takes to run and update the progress bar (emulated by sleeping) -SECONDS_PER_TASK = 1 -NUMBER_OF_TASKS_PER_JOB = 6 +BASE_SECONDS_PER_TASK = 0.5 # The base time for each task; actual time increases proportional to the index of the task +NUMBER_OF_TASKS_PER_JOB = 10 TASK_TIMES: List[List[float]] = [ - [SECONDS_PER_TASK * task_index] * task_index for task_index in range(1, NUMBER_OF_TASKS_PER_JOB + 1) + [BASE_SECONDS_PER_TASK * task_index] * NUMBER_OF_TASKS_PER_JOB + for task_index in range(1, NUMBER_OF_TASKS_PER_JOB + 1) ] -WEBSOCKETS = {} - -## NOTE: TQDMProgressHandler cannot be called from a process...so we just use a queue directly +## TQDMProgressHandler cannot be called from a process...so we just use a global reference exposed to each subprocess progress_handler = TQDMProgressHandler() -def forward_updates_over_sse(request_id, id, n, total, **kwargs): - progress_handler._announce(dict(request_id=request_id, id=id, format_dict=dict(n=n, total=total))) +def forward_updates_over_server_sent_events(request_id: str, progress_bar_id: str, n: int, total: int, **kwargs): + progress_handler.announce( + dict(request_id=request_id, progress_bar_id=progress_bar_id, format_dict=dict(n=n, total=total), **kwargs) + ) class ThreadedHTTPServer: @@ -83,14 +84,14 @@ def _run_sleep_tasks_in_subprocess( The index of this task in the list of all tasks from the buffer map. Each index would map to a different tqdm position. request_id : int - Identifier of ??. + Identifier of the request, provided by the client. url : str The localhost URL to sent progress updates to. """ subprogress_bar_id = uuid.uuid4() - sub_progress_bar = TQDMPublisher( + sub_progress_bar = TQDMProgressPublisher( iterable=task_times, position=iteration_index + 1, desc=f"Progress on iteration {iteration_index} ({id})", @@ -107,17 +108,17 @@ def _run_sleep_tasks_in_subprocess( time.sleep(sleep_time) -def run_parallel_processes(request_id, url: str): +def run_parallel_processes(*, all_task_times: List[List[float]], request_id: str, url: str): futures = list() with ProcessPoolExecutor(max_workers=N_JOBS) as executor: # # Assign the parallel jobs - for iteration_index, task_times in enumerate(TASK_TIMES): + for iteration_index, task_times_per_job in enumerate(all_task_times): futures.append( executor.submit( _run_sleep_tasks_in_subprocess, - task_times=task_times, + task_times=task_times_per_job, iteration_index=iteration_index, request_id=request_id, url=url, @@ -125,32 +126,92 @@ def run_parallel_processes(request_id, url: str): ) total_tasks_iterable = as_completed(futures) - total_tasks_progress_bar = TQDMPublisher( - iterable=total_tasks_iterable, total=len(TASK_TIMES), desc="Total tasks completed" + total_tasks_progress_bar = TQDMProgressPublisher( + iterable=total_tasks_iterable, total=len(all_task_times), desc="Total tasks completed" ) + # The 'total' progress bar bas an ID equivalent to the request ID total_tasks_progress_bar.subscribe( lambda format_dict: forward_to_http_server( url=url, request_id=request_id, progress_bar_id=request_id, **format_dict ) ) + # Trigger the deployment of the parallel jobs for _ in total_tasks_progress_bar: pass -def format_sse(data: str, event=None) -> str: - msg = f"data: {json.dumps(data)}\n\n" - if event is not None: - msg = f"event: {event}\n{msg}" - return msg +def format_server_sent_events(*, message_data: str, event_type: str = "message") -> str: + """ + Format an `event_type` type server-sent event with `data` in a way expected by the EventSource browser implementation. + + With reference to the following demonstration of frontend elements. + + ```javascript + const server_sent_event = new EventSource("/api/v1/sse"); + + /* + * This will listen only for events + * similar to the following: + * + * event: notice + * data: useful data + * id: someid + */ + server_sent_event.addEventListener("notice", (event) => { + console.log(event.data); + }); + + /* + * Similarly, this will listen for events + * with the field `event: update` + */ + server_sent_event.addEventListener("update", (event) => { + console.log(event.data); + }); + + /* + * The event "message" is a special case, as it + * will capture events without an event field + * as well as events that have the specific type + * `event: message` It will not trigger on any + * other event type. + */ + server_sent_event.addEventListener("message", (event) => { + console.log(event.data); + }); + ``` + + Parameters + ---------- + message_data : str + The message data to be sent to the client. + event_type : str, default="message" + The type of event corresponding to the message data. + + Returns + ------- + formatted_message : str + The formatted message to be sent to the client. + """ + + # message = f"event: {event_type}\n" if event_type != "" else "" + # message += f"data: {message_data}\n\n" + # return message + + message = f"data: {message_data}\n\n" + if event_type != "": + message = f"event: {event_type}\n{message}" + return message def listen_to_events(): messages = progress_handler.listen() # returns a queue.Queue while True: - msg = messages.get() # blocks until a new message arrives - yield format_sse(msg) + message_data = messages.get() # blocks until a new message arrives + print("Message data", message_data) + yield format_server_sent_events(message_data=json.dumps(message_data)) app = Flask(__name__) @@ -164,7 +225,7 @@ def listen_to_events(): def start(): data = json.loads(request.data) if request.data else {} request_id = data["request_id"] - run_parallel_processes(request_id, f"http://localhost:{PORT}") + run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=f"http://localhost:{PORT}") return jsonify({"status": "success"}) @@ -191,14 +252,10 @@ async def start_server(port): flask_server = ThreadedFlaskServer(port=3768) flask_server.start() - # # DEMO ONE: Direct updates from HTTP server - # http_server = ThreadedHTTPServer(port=port, callback=forward_updates_over_sse) - # http_server.start() - # await asyncio.Future() - - # DEMO TWO: Queue - def update_queue(request_id, id, n, total, **kwargs): - forward_updates_over_sse(request_id, id, n, total) + def update_queue(request_id: str, progress_bar_id: str, n: int, total: int, **kwargs): + forward_updates_over_server_sent_events( + request_id=request_id, progress_bar_id=progress_bar_id, n=n, total=total + ) http_server = ThreadedHTTPServer(port=PORT, callback=update_queue) http_server.start() @@ -207,12 +264,18 @@ def update_queue(request_id, id, n, total, **kwargs): def run_parallel_bar_demo() -> None: - """Asynchronously start the servers""" - asyncio.run(start_server(PORT)) + """Asynchronously start the servers.""" + asyncio.run(start_server(port=PORT)) + + +def _run_parallel_bars_demo(port: str, host: str): + URL = f"http://{host}:{port}" + request_id = uuid.uuid4() + run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=URL) -if __name__ == "__main__": +if __name__ == "main": flags_list = sys.argv[1:] port_flag = "--port" in flags_list @@ -228,11 +291,4 @@ def run_parallel_bar_demo() -> None: else: HOST = "localhost" - URL = f"http://{HOST}:{PORT}" if port_flag else None - - if URL is None: - raise ValueError("URL is not defined.") - - # Just run the parallel processes - request_id = uuid.uuid4() - run_parallel_processes(request_id, URL) + _run_parallel_bars_demo(port=PORT, host=HOST) diff --git a/src/tqdm_publisher/_demos/_parallel_bars/_server_ws.py b/src/tqdm_publisher/_demos/_parallel_bars/_server_ws.py index 088a8c7..6f301cf 100644 --- a/src/tqdm_publisher/_demos/_parallel_bars/_server_ws.py +++ b/src/tqdm_publisher/_demos/_parallel_bars/_server_ws.py @@ -6,13 +6,13 @@ import threading import time import uuid -from concurrent.futures import ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor, as_completed from typing import List import requests import websockets -from tqdm_publisher import TQDMProgressHandler, TQDMPublisher +from tqdm_publisher import TQDMProgressHandler, TQDMProgressPublisher from tqdm_publisher._demos._parallel_bars._client import ( create_http_server, find_free_port, @@ -20,24 +20,25 @@ N_JOBS = 3 +N_JOBS = 3 + # Each outer entry is a list of 'tasks' to perform on a particular worker # For demonstration purposes, each in the list of tasks is the length of time in seconds # that each iteration of the task takes to run and update the progress bar (emulated by sleeping) +BASE_SECONDS_PER_TASK = 0.5 # The base time for each task; actual time increases proportional to the index of the task +NUMBER_OF_TASKS_PER_JOB = 10 TASK_TIMES: List[List[float]] = [ - [0.1 for _ in range(100)], - [0.2 for _ in range(100)], - [0.3 for _ in range(10)], - [0.4 for _ in range(10)], - [0.5 for _ in range(10)], + [BASE_SECONDS_PER_TASK * task_index] * NUMBER_OF_TASKS_PER_JOB + for task_index in range(1, NUMBER_OF_TASKS_PER_JOB + 1) ] WEBSOCKETS = {} -## NOTE: TQDMProgressHandler cannot be called from a process...so we just use a queue directly +## TQDMProgressHandler cannot be called from a process...so we just use a queue directly progress_handler = TQDMProgressHandler() -def forward_updates_over_websocket(request_id, id, n, total, **kwargs): +def forward_updates_over_websocket(request_id, progress_bar_id, n, total, **kwargs): ws = WEBSOCKETS.get(request_id) if ws: @@ -47,7 +48,7 @@ def forward_updates_over_websocket(request_id, id, n, total, **kwargs): message=json.dumps( obj=dict( format_dict=dict(n=n, total=total), - id=id, + progress_bar_id=progress_bar_id, request_id=request_id, ) ) @@ -82,21 +83,18 @@ def start(self): thread.start() -def forward_to_http_server(url: str, request_id: str, id: int, n: int, total: int, **kwargs): +def forward_to_http_server(url: str, request_id: str, progress_bar_id: int, n: int, total: int, **kwargs): """ This is the parallel callback definition. Its parameters are attributes of a tqdm instance and their values are what a typical default tqdm printout to console would contain (update step `n` out of `total` iterations). """ - json_data = json.dumps(obj=dict(request_id=request_id, id=str(id), data=dict(n=n, total=total))) + json_data = json.dumps(obj=dict(request_id=request_id, id=str(progress_bar_id), data=dict(n=n, total=total))) requests.post(url=url, data=json_data, headers={"Content-Type": "application/json"}) -def _run_sleep_tasks_in_subprocess( - args, - # task_times: List[float], iteration_index: int, id: int, url: str -): +def _run_sleep_tasks_in_subprocess(task_times: List[float], iteration_index: int, request_id: int, url: str): """ Run a 'task' that takes a certain amount of time to run on each worker. @@ -110,42 +108,59 @@ def _run_sleep_tasks_in_subprocess( The index of this task in the list of all tasks from the buffer map. Each index would map to a different tqdm position. id : int - Identifier of ??. + Identifier of the request, provided by the client. url : str The localhost URL to sent progress updates to. """ - task_times, iteration_index, request_id, url = args - - id = uuid.uuid4() + subprogress_bar_id = uuid.uuid4() - sub_progress_bar = TQDMPublisher( + sub_progress_bar = TQDMProgressPublisher( iterable=task_times, position=iteration_index + 1, - desc=f"Progress on iteration {iteration_index} ({id})", + desc=f"Progress on iteration {iteration_index} ({subprogress_bar_id})", leave=False, ) - sub_progress_bar.subscribe(lambda format_dict: forward_to_http_server(url, request_id, id, **format_dict)) + sub_progress_bar.subscribe( + lambda format_dict: forward_to_http_server(url, request_id, subprogress_bar_id, **format_dict) + ) for sleep_time in sub_progress_bar: time.sleep(sleep_time) -def run_parallel_processes(request_id, url: str): +def run_parallel_processes(*, all_task_times: List[List[float]], request_id: str, url: str): + futures = list() with ProcessPoolExecutor(max_workers=N_JOBS) as executor: # Assign the parallel jobs - job_map = executor.map( - _run_sleep_tasks_in_subprocess, - [(task_times, iteration_index, request_id, url) for iteration_index, task_times in enumerate(TASK_TIMES)], + for iteration_index, task_times_per_job in enumerate(all_task_times): + futures.append( + executor.submit( + _run_sleep_tasks_in_subprocess, + task_times=task_times_per_job, + iteration_index=iteration_index, + request_id=request_id, + url=url, + ) + ) + + total_tasks_iterable = as_completed(futures) + total_tasks_progress_bar = TQDMProgressPublisher( + iterable=total_tasks_iterable, total=len(all_task_times), desc="Total tasks completed" ) - # Send initialization for pool progress bar - forward_to_http_server(url, request_id, id=request_id, n=0, total=len(TASK_TIMES)) + # The 'total' progress bar bas an ID equivalent to the request ID + total_tasks_progress_bar.subscribe( + lambda format_dict: forward_to_http_server( + url=url, request_id=request_id, progress_bar_id=request_id, **format_dict + ) + ) - for _ in job_map: + # Trigger the deployment of the parallel jobs + for _ in total_tasks_progress_bar: pass @@ -164,7 +179,7 @@ async def handler(url: str, websocket: websockets.WebSocketServerProtocol) -> No if message_from_client["command"] == "start": request_id = message_from_client["request_id"] WEBSOCKETS[request_id] = dict(ref=websocket, id=connection_id) - run_parallel_processes(request_id, url) + run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=url) async def spawn_server() -> None: @@ -176,14 +191,8 @@ async def spawn_server() -> None: async with websockets.serve(ws_handler=lambda websocket: handler(URL, websocket), host="", port=3768): - # # DEMO ONE: Direct updates from HTTP server - # http_server = ThreadedHTTPServer(port=PORT, callback=forward_updates_over_websocket) - # http_server.start() - # await asyncio.Future() - - # DEMO TWO: Queue - def update_queue(request_id, id, n, total, **kwargs): - progress_handler._announce(dict(request_id=request_id, id=id, n=n, total=total)) + def update_queue(request_id, progress_bar_id, n, total, **kwargs): + progress_handler.announce(dict(request_id=request_id, progress_bar_id=progress_bar_id, n=n, total=total)) http_server = ThreadedHTTPServer(port=PORT, callback=update_queue) http_server.start() @@ -222,4 +231,4 @@ def run_parallel_bar_demo() -> None: # Just run the parallel processes request_id = uuid.uuid4() - run_parallel_processes(request_id, URL) + run_parallel_processes(all_task_times=TASK_TIMES, request_id=request_id, url=URL) diff --git a/src/tqdm_publisher/_demos/_single_bar/_server.py b/src/tqdm_publisher/_demos/_single_bar/_server.py index 32025fb..2e4fd03 100644 --- a/src/tqdm_publisher/_demos/_single_bar/_server.py +++ b/src/tqdm_publisher/_demos/_single_bar/_server.py @@ -15,7 +15,7 @@ def start_progress_bar(*, progress_callback: callable) -> None: Defaults are chosen for a deterministic and regular update period of one second for a total time of 10 seconds. """ all_task_durations_in_seconds = [1.0 for _ in range(10)] # Ten seconds at one second per task - progress_bar = tqdm_publisher.TQDMPublisher(iterable=all_task_durations_in_seconds) + progress_bar = tqdm_publisher.TQDMProgressPublisher(iterable=all_task_durations_in_seconds) def run_function_on_progress_update(format_dict: dict) -> None: """ @@ -29,7 +29,7 @@ def run_function_on_progress_update(format_dict: dict) -> None: This specifically requires the `id` of the progress bar and the `format_dict` of the TQDM instance. """ - progress_callback(format_dict=format_dict, progress_bar_id=progress_bar.id) + progress_callback(format_dict=format_dict, progress_bar_id=progress_bar.progress_bar_id) progress_bar.subscribe(callback=run_function_on_progress_update) diff --git a/src/tqdm_publisher/_handler.py b/src/tqdm_publisher/_handler.py index d72555e..9cabad4 100644 --- a/src/tqdm_publisher/_handler.py +++ b/src/tqdm_publisher/_handler.py @@ -1,27 +1,49 @@ import queue +from typing import Any, Dict, Iterable, List from ._subscriber import TQDMProgressSubscriber class TQDMProgressHandler: def __init__(self): - self.listeners = [] - - def listen(self): - q = queue.Queue(maxsize=25) - self.listeners.append(q) - return q - - def create(self, iterable, additional_metadata: dict = dict(), **tqdm_kwargs): - return TQDMProgressSubscriber( - iterable, - lambda progress_update: self._announce(dict(**progress_update, **additional_metadata)), - **tqdm_kwargs, - ) - - def _announce(self, msg): - for i in reversed(range(len(self.listeners))): - try: - self.listeners[i].put_nowait(msg) - except queue.Full: - del self.listeners[i] + self.listeners: List[queue.Queue] = [] + + def listen(self) -> queue.Queue: + new_queue = queue.Queue(maxsize=25) + self.listeners.append(new_queue) + return new_queue + + def create_progress_subscriber( + self, iterable: Iterable[Any], additional_metadata: dict = dict(), **tqdm_kwargs + ) -> TQDMProgressSubscriber: + + def on_progress_update(progress_update: dict): + """ + This is the injection called on every update of the progress bar. + + It triggers the announcement event over all listeners on each update of the progress bar. + + It must be defined inside this local scope to communicate the `additional_metadata` from the level above + without including it in the method signature. + """ + self.announce(message=dict(**progress_update, **additional_metadata)) + + return TQDMProgressSubscriber(iterable=iterable, on_progress_update=on_progress_update, **tqdm_kwargs) + + def announce(self, message: Dict[Any, Any]): + """ + Announce a message to all listeners. + + This message can be any dictionary. But, when used internally, is + expected to contain the progress_bar_id and format_dict of the TQDMProgressSubscriber update function, + as well as any additional metadata supplied by the create_progress_subscriber method. + + """ + number_of_listeners = len(self.listeners) + listener_indices = range(number_of_listeners) + listener_indices_from_newest_to_oldest = reversed(listener_indices) + for listener_index in listener_indices_from_newest_to_oldest: + if not self.listeners[listener_index].full(): + self.listeners[listener_index].put_nowait(item=message) + else: # When full, remove the newest listener in the stack + del self.listeners[listener_index] diff --git a/src/tqdm_publisher/_publisher.py b/src/tqdm_publisher/_publisher.py index 81a7ee8..7c5d5ce 100644 --- a/src/tqdm_publisher/_publisher.py +++ b/src/tqdm_publisher/_publisher.py @@ -4,10 +4,10 @@ from tqdm import tqdm as base_tqdm -class TQDMPublisher(base_tqdm): +class TQDMProgressPublisher(base_tqdm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.id = str(uuid4()) + self.progress_bar_id = str(uuid4()) self.callbacks = dict() # Override the update method to run callbacks @@ -19,7 +19,7 @@ def update(self, n: int = 1) -> Union[bool, None]: return displayed - def subscribe(self, callback: callable): + def subscribe(self, callback: callable) -> str: """ Subscribe to updates from the progress bar. @@ -37,7 +37,7 @@ def subscribe(self, callback: callable): Returns ------- - str + callback_id : str A unique identifier for the callback. This ID is a UUID string and can be used to reference the registered callback in future operations. @@ -55,7 +55,7 @@ def subscribe(self, callback: callable): self.callbacks[callback_id] = callback return callback_id - def unsubscribe(self, callback_id: str): + def unsubscribe(self, callback_id: str) -> bool: """ Unsubscribe a previously registered callback from the progress bar updates. @@ -72,7 +72,7 @@ def unsubscribe(self, callback_id: str): Returns ------- - bool + success : bool Returns True if the callback was successfully removed, or False if no callback was found with the given ID. diff --git a/src/tqdm_publisher/_subscriber.py b/src/tqdm_publisher/_subscriber.py index 289d3d6..e5715b3 100644 --- a/src/tqdm_publisher/_subscriber.py +++ b/src/tqdm_publisher/_subscriber.py @@ -1,7 +1,22 @@ -from ._publisher import TQDMPublisher +from typing import Any, Dict, Iterable +from ._publisher import TQDMProgressPublisher -class TQDMProgressSubscriber(TQDMPublisher): - def __init__(self, iterable, on_progress_update: callable, **tqdm_kwargs): - super().__init__(iterable, **tqdm_kwargs) - self.subscribe(lambda format_dict: on_progress_update(dict(progress_bar_id=self.id, format_dict=format_dict))) + +class TQDMProgressSubscriber(TQDMProgressPublisher): + def __init__(self, iterable: Iterable[Any], on_progress_update: callable, **tqdm_kwargs): + super().__init__(iterable=iterable, **tqdm_kwargs) + + def run_on_progress_update(format_dict: Dict[str, Any]): + """ + This is the injection called on every update of the progress bar. + + It calls the `on_progress_update` function, which must take a dictionary + containing the progress bar ID and `format_dict`. + + It must be defined inside this local scope to include the `.progress_bar_id` attribute from the level above + without including it in the method signature. + """ + on_progress_update(dict(progress_bar_id=self.progress_bar_id, format_dict=format_dict)) + + self.subscribe(run_on_progress_update) diff --git a/tests/test_basic.py b/tests/test_basic.py index 5c2b10f..a511080 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,12 +2,12 @@ import pytest -from tqdm_publisher import TQDMPublisher +from tqdm_publisher import TQDMProgressPublisher from tqdm_publisher.testing import create_tasks def test_initialization(): - publisher = TQDMPublisher() + publisher = TQDMProgressPublisher() assert len(publisher.callbacks) == 0 @@ -27,7 +27,7 @@ def test_callback(identifier, data): assert "n" in data and "total" in data tasks = create_tasks() - publisher = TQDMPublisher(asyncio.as_completed(tasks), total=len(tasks)) + publisher = TQDMProgressPublisher(asyncio.as_completed(tasks), total=len(tasks)) n_subscriptions = 10 for i in range(n_subscriptions): @@ -51,7 +51,7 @@ def dummy_callback(data): pass tasks = [] - publisher = TQDMPublisher(asyncio.as_completed(tasks), total=len(tasks)) + publisher = TQDMProgressPublisher(asyncio.as_completed(tasks), total=len(tasks)) callback_id = publisher.subscribe(dummy_callback) result = publisher.unsubscribe(callback_id) assert result == True