diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml new file mode 100644 index 0000000..18847c6 --- /dev/null +++ b/.github/workflows/linter.yml @@ -0,0 +1,39 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python + +name: Python Linter And Unittest + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.8 + uses: actions/setup-python@v3 + with: + python-version: "3.8" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pyproject-flake8 mypy + pip install -r requirements.txt + - name: Lint with flake8 + run: | + pflake8 . + - name: Lint with MyPy + run: | + mypy . + - name: Run python unittest + run: | + python -m unittest discover -v tests diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4e5153a --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +__pychace__/ +*.py[cod] +*$py.class + +build/ +dist/ +sdist/ +wheels/ +eggs/ +.eggs/ +.idea/ +*.egg-info/ +*.egg +.mypy_cache/ + +venv*/ +.vscode/ + +dask-worker-space/* +.pre-commit-config.yaml diff --git a/Makefile b/Makefile index 21b1dec..937e42b 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ _doc: rm -fr docsvenv build; mkdir build python3.8 -m venv docsvenv . docsvenv/bin/activate; \ - pip install -r docs/requirements_docs.txt; \ - pip install -r requirements.txt; \ - cd docs; make clean && make html + pip install -r docs/requirements_docs.txt; \ + pip install -r requirements.txt; \ + cd docs; make clean && make html zip -r build/scaler_docs.zip docs/build/html/* diff --git a/README.md b/README.md index 6069d2f..5b24d26 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ Citi -

Citi/scaler

+

Citi/scaler

Efficient, lightweight and reliable distributed computation engine. @@ -22,6 +22,9 @@ with a stable and language agnostic protocol for client and worker communications. ```python +import math +from scaler import Client + with Client(address="tcp://127.0.0.1:2345") as client: # Submits 100 tasks futures = [ @@ -43,20 +46,22 @@ messaging errors, among others. - Distributed computing on **multiple cores and multiple servers** - **Python** reference implementation, with **language agnostic messaging protocol** built on top of -[ZeroMQ](https://zeromq.org) + [Cap'n Proto](https://capnproto.org/) and [ZeroMQ](https://zeromq.org) - **Graph** scheduling, which supports [Dask](https://www.dask.org)-like graph computing, optionally you -can use [GraphBLAS](https://graphblas.org) -- **Automated load balancing**. When workers got full of tasks, these will be scheduled to idle workers + can use [GraphBLAS](https://graphblas.org) for massive graph tasks +- **Automated load balancing**. automatically balance busy workers' loads to idle workers, keep every worker as busy as + possible - **Automated recovery** from faulting workers or clients - Supports for **nested tasks**. Tasks can themselves submit new tasks - `top`-like **monitoring tools** +- GUI monitoring tool Scaler's scheduler can be run on PyPy, which will provide a performance boost ## Installation ```bash -$ pip instal scaler +$ pip install scaler # or with graphblas and uvloop support $ pip install scaler[graphblas,uvloop] @@ -77,7 +82,6 @@ A local scheduler and a local set of workers can be conveniently spawn using `Sc ```python from scaler import SchedulerClusterCombo - cluster = SchedulerClusterCombo(address="tcp://127.0.0.1:2345", n_workers=4) ... @@ -154,9 +158,11 @@ from scaler import Client def inc(i): return i + 1 + def add(a, b): return a + b + def minus(a, b): return a - b @@ -169,7 +175,7 @@ graph = { "e": (minus, "d", "c") # e = d - c = 4 - 3 = 1 } -with Client(address="tcp://127.0.0.1:2345") as client +with Client(address="tcp://127.0.0.1:2345") as client: result = client.get(graph, keys=["e"]) print(result) # {"e": 1} ``` @@ -179,18 +185,21 @@ with Client(address="tcp://127.0.0.1:2345") as client Scaler allows tasks to submit new tasks while being executed. Scaler also supports recursive task calls. ```python -def fibonacci(client: Client, n: int): +from scaler import Client + + +def fibonacci(clnt: Client, n: int): if n == 0: return 0 elif n == 1: return 1 else: - a = client.submit(fibonacci, client, n - 1) - b = client.submit(fibonacci, client, n - 2) + a = clnt.submit(fibonacci, clnt, n - 1) + b = clnt.submit(fibonacci, clnt, n - 2) return a.result() + b.result() -with Client(address="tcp://127.0.0.1:2345") as client +with Client(address="tcp://127.0.0.1:2345") as client: result = client.submit(fibonacci, client, 8).result() print(result) # 21 ``` @@ -256,11 +265,11 @@ W|Linux|15943|a7fe8b5e+ 0.0% 30.7m 0.0% 28.3m 1000 0 0 | - function_id_to_tasks section shows task count for each function used - worker section shows worker details, you can use shortcuts to sort by columns, the char * on column header show which column is sorted right now - - agt_cpu/agt_rss means cpu/memory usage of worker agent - - cpu/rss means cpu/memory usage of worker - - free means number of free task slots for this worker - - sent means how many tasks scheduler sent to the worker - - queued means how many tasks worker received and queued + - agt_cpu/agt_rss means cpu/memory usage of worker agent + - cpu/rss means cpu/memory usage of worker + - free means number of free task slots for this worker + - sent means how many tasks scheduler sent to the worker + - queued means how many tasks worker received and queued ### From the web UI @@ -274,7 +283,8 @@ This will open a web server on port `8081`. ## Contributing -Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**. +Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly +appreciated**. We welcome you to: @@ -297,4 +307,5 @@ This project is distributed under the [Apache-2.0 License](https://www.apache.or ## Contact -If you have a query or require support with this project, [raise an issue](https://github.com/Citi/scaler/issues). Otherwise, reach out to [opensource@citi.com](mailto:opensource@citi.com). \ No newline at end of file +If you have a query or require support with this project, [raise an issue](https://github.com/Citi/scaler/issues). +Otherwise, reach out to [opensource@citi.com](mailto:opensource@citi.com). diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bd17dff --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,67 @@ +[build-system] +requires = ["setuptools", "setuptools-scm", "mypy", "black", "flake8", "pyproject-flake8"] +build-backend = "setuptools.build_meta" + +[project] +name = "scaler" +description = "Scaler Distribution Framework" +requires-python = ">=3.8" +readme = { file = "README.md", content-type = "text/markdown" } +license = { text = "Apache 2.0" } +authors = [{ name = "Citi", email = "opensource@citi.com" }] +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Intended Audience :: Developers", + "Operating System :: OS Independent", + "Topic :: System :: Distributed Computing", +] +dynamic = ["dependencies", "version"] + +[project.urls] +Home = "https://github.com/Citi/scaler" + +[project.scripts] +scaler_scheduler = "scaler.entry_points.scheduler:main" +scaler_cluster = "scaler.entry_points.cluster:main" +scaler_top = "scaler.entry_points.top:main" +scaler_ui = "scaler.entry_points.webui:main" + +[project.optional-dependencies] +uvloop = ["uvloop"] +graphblas = ["python-graphblas", "numpy"] +gui = ["nicegui[plotly]"] +all = ["python-graphblas", "numpy", "uvloop", "nicegui[plotly]"] + +[tool.setuptools] +packages = ["scaler"] +include-package-data = true + +[tool.setuptools.dynamic] +dependencies = { file = "requirements.txt" } +version = { attr = "scaler.about.__version__" } + +[tool.mypy] +no_strict_optional = true +check_untyped_defs = true +ignore_missing_imports = true +exclude = [ + "^docs.*$", + "^benchmark.*$", + "^venv.*$" +] + +[tool.flake8] +count = true +statistics = true +max-line-length = 120 +extend-ignore = ["E203"] +exclude = "venv312" + +[tool.black] +line-length = 120 +skip-magic-trailing-comma = true + +[metadata] +long_description = { file = "README.md" } +long_description_content_type = "text/markdown" diff --git a/requirements.txt b/requirements.txt index 6558092..a3a57df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -pyzmq -psutil +bidict cloudpickle +graphlib-backport; python_version < '3.9' +psutil +pycapnp +pyzmq tblib -bidict -graphlib-backport; python_version < '3.9' \ No newline at end of file diff --git a/run_top.py b/run_top.py index 31635e4..75d50e2 100644 --- a/run_top.py +++ b/run_top.py @@ -1,4 +1,5 @@ from scaler.entry_points.top import main +from scaler.utility.debug import pdb_wrapped if __name__ == "__main__": - main() + pdb_wrapped(main)() diff --git a/scaler/about.py b/scaler/about.py index a8b3783..29654ee 100644 --- a/scaler/about.py +++ b/scaler/about.py @@ -1 +1 @@ -__version__ = "1.7.14" +__version__ = "1.8.0" diff --git a/scaler/client/agent/client_agent.py b/scaler/client/agent/client_agent.py index 56c8abf..6bf6507 100644 --- a/scaler/client/agent/client_agent.py +++ b/scaler/client/agent/client_agent.py @@ -18,7 +18,6 @@ ClientShutdownResponse, GraphTask, GraphTaskCancel, - MessageVariant, ObjectInstruction, ObjectRequest, ObjectResponse, @@ -26,6 +25,7 @@ TaskCancel, TaskResult, ) +from scaler.protocol.python.mixins import Message from scaler.utility.event_loop import create_async_loop_routine from scaler.utility.exceptions import ClientCancelledException, ClientQuitException, ClientShutdownException from scaler.utility.zmq_config import ZMQConfig @@ -113,7 +113,7 @@ def run(self): self.__initialize() self.__run_loop() - async def __on_receive_from_client(self, message: MessageVariant): + async def __on_receive_from_client(self, message: Message): if isinstance(message, ClientDisconnect): await self._disconnect_manager.on_client_disconnect(message) return @@ -144,7 +144,7 @@ async def __on_receive_from_client(self, message: MessageVariant): raise TypeError(f"Unknown {message=}") - async def __on_receive_from_scheduler(self, message: MessageVariant): + async def __on_receive_from_scheduler(self, message: Message): if isinstance(message, ClientShutdownResponse): await self._disconnect_manager.on_client_shutdown_response(message) return diff --git a/scaler/client/agent/disconnect_manager.py b/scaler/client/agent/disconnect_manager.py index 08d86ac..50df24b 100644 --- a/scaler/client/agent/disconnect_manager.py +++ b/scaler/client/agent/disconnect_manager.py @@ -2,7 +2,7 @@ from scaler.client.agent.mixins import DisconnectManager from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse, DisconnectType +from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse from scaler.utility.exceptions import ClientQuitException, ClientShutdownException @@ -18,7 +18,7 @@ def register(self, connector_internal: AsyncConnector, connector_external: Async async def on_client_disconnect(self, disconnect: ClientDisconnect): await self._connector_external.send(disconnect) - if disconnect.type == DisconnectType.Disconnect: + if disconnect.disconnect_type == ClientDisconnect.DisconnectType.Disconnect: raise ClientQuitException("client disconnecting") async def on_client_shutdown_response(self, response: ClientShutdownResponse): diff --git a/scaler/client/agent/future_manager.py b/scaler/client/agent/future_manager.py index 5f98574..154fcd5 100644 --- a/scaler/client/agent/future_manager.py +++ b/scaler/client/agent/future_manager.py @@ -1,12 +1,13 @@ import logging import threading -from concurrent.futures import InvalidStateError +from concurrent.futures import InvalidStateError, Future from typing import Dict, Tuple from scaler.client.agent.mixins import FutureManager from scaler.client.future import ScalerFuture from scaler.client.serializer.mixins import Serializer -from scaler.protocol.python.message import ObjectResponse, TaskResult, TaskStatus +from scaler.protocol.python.common import TaskStatus +from scaler.protocol.python.message import ObjectResponse, TaskResult from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError from scaler.utility.metadata.profile_result import retrieve_profiling_result_from_task_result from scaler.utility.object_utility import deserialize_failure @@ -20,7 +21,8 @@ def __init__(self, serializer: Serializer): self._task_id_to_future: Dict[bytes, ScalerFuture] = dict() self._object_id_to_future: Dict[bytes, Tuple[TaskStatus, ScalerFuture]] = dict() - def add_future(self, future: ScalerFuture): + def add_future(self, future: Future): + assert isinstance(future, ScalerFuture) with self._lock: future.set_running_or_notify_cancel() self._task_id_to_future[future.task_id] = future diff --git a/scaler/client/agent/heartbeat_manager.py b/scaler/client/agent/heartbeat_manager.py index c8e46bf..d100a2e 100644 --- a/scaler/client/agent/heartbeat_manager.py +++ b/scaler/client/agent/heartbeat_manager.py @@ -6,6 +6,7 @@ from scaler.client.agent.mixins import HeartbeatManager from scaler.io.async_connector import AsyncConnector from scaler.protocol.python.message import ClientHeartbeat, ClientHeartbeatEcho +from scaler.protocol.python.status import Resource from scaler.utility.mixins import Looper @@ -26,9 +27,12 @@ def register(self, connector_external: AsyncConnector): self._connector_external = connector_external async def send_heartbeat(self): - cpu = self._process.cpu_percent() / 100 - rss = self._process.memory_info().rss - await self._connector_external.send(ClientHeartbeat(cpu, rss, self._latency_us)) + await self._connector_external.send( + ClientHeartbeat.new_msg( + Resource.new_msg(int(self._process.cpu_percent() * 10), self._process.memory_info().rss), + self._latency_us, + ) + ) async def on_heartbeat_echo(self, heartbeat: ClientHeartbeatEcho): if not self._connected: diff --git a/scaler/client/agent/object_manager.py b/scaler/client/agent/object_manager.py index 83b9c56..92ff051 100644 --- a/scaler/client/agent/object_manager.py +++ b/scaler/client/agent/object_manager.py @@ -1,19 +1,14 @@ -from typing import Optional +from typing import Optional, Set from scaler.client.agent.mixins import ObjectManager from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import ( - ObjectContent, - ObjectInstruction, - ObjectInstructionType, - ObjectRequest, - ObjectRequestType, -) +from scaler.protocol.python.common import ObjectContent +from scaler.protocol.python.message import ObjectInstruction, ObjectRequest class ClientObjectManager(ObjectManager): def __init__(self, identity: bytes): - self._sent_object_ids = set() + self._sent_object_ids: Set[bytes] = set() self._identity = identity self._connector_internal: Optional[AsyncConnector] = None @@ -24,13 +19,13 @@ def register(self, connector_internal: AsyncConnector, connector_external: Async self._connector_external = connector_external async def on_object_instruction(self, instruction: ObjectInstruction): - if instruction.type == ObjectInstructionType.Create: + if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create: await self.__send_object_creation(instruction) - elif instruction.type == ObjectInstructionType.Delete: + elif instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete: await self.__delete_objects(instruction) async def on_object_request(self, object_request: ObjectRequest): - assert object_request.type == ObjectRequestType.Get + assert object_request.request_type == ObjectRequest.ObjectRequestType.Get await self._connector_external.send(object_request) def record_task_result(self, task_id: bytes, object_id: bytes): @@ -38,17 +33,25 @@ def record_task_result(self, task_id: bytes, object_id: bytes): async def clean_all_objects(self): await self._connector_external.send( - ObjectInstruction(ObjectInstructionType.Delete, self._identity, ObjectContent(tuple(self._sent_object_ids))) + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Delete, + self._identity, + ObjectContent.new_msg(tuple(self._sent_object_ids)), + ) ) self._sent_object_ids = set() async def __send_object_creation(self, instruction: ObjectInstruction): - assert instruction.type == ObjectInstructionType.Create + assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create - new_object_content = list( - zip( + new_object_ids = set(instruction.object_content.object_ids) - self._sent_object_ids + if not new_object_ids: + return + + new_object_content = ObjectContent.new_msg( + *zip( *filter( - lambda object_pack: object_pack[0] not in self._sent_object_ids, + lambda object_pack: object_pack[0] in new_object_ids, zip( instruction.object_content.object_ids, instruction.object_content.object_names, @@ -58,15 +61,15 @@ async def __send_object_creation(self, instruction: ObjectInstruction): ) ) - if not new_object_content: - return - - instruction.object_content = ObjectContent(*new_object_content) + self._sent_object_ids.update(set(new_object_content.object_ids)) - self._sent_object_ids.update(instruction.object_content.object_ids) - await self._connector_external.send(instruction) + await self._connector_external.send( + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Create, instruction.object_user, new_object_content + ) + ) async def __delete_objects(self, instruction: ObjectInstruction): - assert instruction.type == ObjectInstructionType.Delete + assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete self._sent_object_ids.difference_update(instruction.object_content.object_ids) await self._connector_external.send(instruction) diff --git a/scaler/client/agent/task_manager.py b/scaler/client/agent/task_manager.py index 1c0c021..8b54c77 100644 --- a/scaler/client/agent/task_manager.py +++ b/scaler/client/agent/task_manager.py @@ -3,7 +3,8 @@ from scaler.client.agent.future_manager import ClientFutureManager from scaler.client.agent.mixins import ObjectManager, TaskManager from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import GraphTask, GraphTaskCancel, Task, TaskCancel, TaskResult, TaskStatus +from scaler.protocol.python.common import TaskStatus +from scaler.protocol.python.message import GraphTask, GraphTaskCancel, Task, TaskCancel, TaskResult class ClientTaskManager(TaskManager): diff --git a/scaler/client/client.py b/scaler/client/client.py index 79df3c9..100a558 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -1,10 +1,10 @@ +import dataclasses import functools import logging import os import threading import uuid from collections import Counter -from concurrent.futures import Future from inspect import signature from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -20,15 +20,7 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.config import DEFAULT_CLIENT_TIMEOUT_SECONDS, DEFAULT_HEARTBEAT_INTERVAL_SECONDS from scaler.io.sync_connector import SyncConnector -from scaler.protocol.python.message import ( - Argument, - ArgumentType, - ClientDisconnect, - ClientShutdownResponse, - DisconnectType, - GraphTask, - Task, -) +from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse, GraphTask, Task from scaler.utility.exceptions import ClientQuitException from scaler.utility.graph.optimization import cull_graph from scaler.utility.graph.topological_sorter import TopologicalSorter @@ -38,6 +30,23 @@ from scaler.worker.agent.processor.processor import Processor +@dataclasses.dataclass +class _CallNode: + func: Callable + args: Tuple[str, ...] + + def __post_init__(self): + if not callable(self.func): + raise TypeError(f"the first item of the tuple must be function, get {self.func}") + + if not isinstance(self.args, tuple): + raise TypeError(f"arguments must be tuple, get {self.args}") + + for arg in self.args: + if not isinstance(arg, str): + raise TypeError(f"argument `{arg}` must be a string and the string has to be in the graph") + + class Client: def __init__( self, @@ -59,6 +68,16 @@ def __init__( :param heartbeat_interval_seconds: Frequency of heartbeat to scheduler in seconds :type heartbeat_interval_seconds: int """ + self.__initialize__(address, profiling, timeout_seconds, heartbeat_interval_seconds, serializer) + + def __initialize__( + self, + address: str, + profiling: bool, + timeout_seconds: int, + heartbeat_interval_seconds: int, + serializer: Serializer = DefaultSerializer(), + ): self._serializer = serializer self._profiling = profiling @@ -149,7 +168,8 @@ def fibonacci(client: Client, n: int): } def __setstate__(self, state: dict) -> None: - self.__init__( + # TODO: fix copy the serializer + self.__initialize__( address=state["address"], profiling=state["profiling"], timeout_seconds=state["timeout_seconds"], @@ -201,8 +221,8 @@ def map(self, fn: Callable, iterable: Iterable[Tuple[Any, ...]]) -> List[Any]: return results def get( - self, graph: Dict[str, Union[Any, Tuple[Union[Callable, Any], ...]]], keys: List[str], block: bool = True - ) -> Dict[str, Union[Any, Future]]: + self, graph: Dict[str, Union[Any, Tuple[Union[Callable, str], ...]]], keys: List[str], block: bool = True + ) -> Dict[str, Union[Any, ScalerFuture]]: """ .. code-block:: python :linenos: @@ -227,16 +247,26 @@ def get( graph = cull_graph(graph, keys) - node_name_to_argument, graph = self.__split_data_and_graph(graph) - self.__check_graph(node_name_to_argument, graph, keys) + node_name_to_argument, call_graph = self.__split_data_and_graph(graph) + self.__check_graph(node_name_to_argument, call_graph, keys) - graph, compute_futures, finished_futures = self.__construct_graph(node_name_to_argument, graph, keys, block) + graph_task, compute_futures, finished_futures = self.__construct_graph( + node_name_to_argument, call_graph, keys, block + ) self._object_buffer.commit_send_objects() - self._connector.send(graph) + self._connector.send(graph_task) self._future_manager.add_future( self._future_factory( - task=Task(graph.task_id, self._identity, b"", b"", []), is_delayed=not block, group_task_id=None + task=Task.new_msg( + task_id=graph_task.task_id, + source=self._identity, + metadata=b"", + func_object_id=b"", + function_args=[], + ), + is_delayed=not block, + group_task_id=None, ) ) for future in compute_futures.values(): @@ -294,12 +324,12 @@ def disconnect(self): self._future_manager.cancel_all_futures() - self._connector.send(ClientDisconnect(DisconnectType.Disconnect)) + self._connector.send(ClientDisconnect.new_msg(ClientDisconnect.DisconnectType.Disconnect)) self.__destroy() def __receive_shutdown_response(self): - message = None + message: Optional[ClientShutdownResponse] = None while not isinstance(message, ClientShutdownResponse): message = self._connector.receive() @@ -321,7 +351,7 @@ def shutdown(self): self._future_manager.cancel_all_futures() - self._connector.send(ClientDisconnect(DisconnectType.Shutdown)) + self._connector.send(ClientDisconnect.new_msg(ClientDisconnect.DisconnectType.Shutdown)) try: self.__receive_shutdown_response() finally: @@ -339,8 +369,14 @@ def __submit(self, function_object_id: bytes, args: Tuple[Any, ...], delayed: bo task_flags_bytes = self.__get_task_flags().serialize() - arguments = [Argument(ArgumentType.ObjectID, object_id) for object_id in object_ids] - task = Task(task_id, self._identity, task_flags_bytes, function_object_id, arguments) + arguments = [Task.Argument(Task.Argument.ArgumentType.ObjectID, object_id) for object_id in object_ids] + task = Task.new_msg( + task_id=task_id, + source=self._identity, + metadata=task_flags_bytes, + func_object_id=function_object_id, + function_args=arguments, + ) future = self._future_factory(task=task, is_delayed=delayed, group_task_id=None) self._future_manager.add_future(future) @@ -357,48 +393,45 @@ def __convert_kwargs_to_args(fn: Callable, args: Tuple[Any, ...], kwargs: Dict[s number_of_required = len([p for p in params if p.default is p.empty]) - args = list(args) + args_list = list(args) kwargs = kwargs.copy() kwargs.update({p.name: p.default for p in all_params if p.kind == p.KEYWORD_ONLY if p.default != p.empty}) - for p in params[len(args) : number_of_required]: + for p in params[len(args_list) : number_of_required]: try: - args.append(kwargs.pop(p.name)) + args_list.append(kwargs.pop(p.name)) except KeyError: - missing = tuple(p.name for p in params[len(args) : number_of_required]) + missing = tuple(p.name for p in params[len(args_list) : number_of_required]) raise TypeError(f"{fn} missing {len(missing)} arguments: {missing}") - for p in params[len(args) :]: - args.append(kwargs.pop(p.name, p.default)) + for p in params[len(args_list) :]: + args_list.append(kwargs.pop(p.name, p.default)) - return tuple(args) + return tuple(args_list) def __split_data_and_graph( - self, graph: Dict[str, Union[Any, Tuple[Union[Callable, Any], ...]]] - ) -> Tuple[Dict[str, Tuple[Argument, Any]], Dict[str, Tuple[Union[Callable, Any], ...]]]: - graph = graph.copy() - node_name_to_argument = {} + self, graph: Dict[str, Union[Any, Tuple[Union[Callable, str], ...]]] + ) -> Tuple[Dict[str, Tuple[Task.Argument, Any]], Dict[str, _CallNode]]: + call_graph = {} + node_name_to_argument: Dict[str, Tuple[Task.Argument, Union[Any, Tuple[Union[Callable, Any], ...]]]] = dict() for node_name, node in graph.items(): if isinstance(node, tuple) and len(node) > 0 and callable(node[0]): + call_graph[node_name] = _CallNode(func=node[0], args=node[1:]) # type: ignore[arg-type] continue if isinstance(node, ObjectReference): - node_name_to_argument[node_name] = (Argument(ArgumentType.ObjectID, node.object_id), None) + object_id = node.object_id else: - object_cache = self._object_buffer.buffer_send_object(node, name=node_name) - node_name_to_argument[node_name] = (Argument(ArgumentType.ObjectID, object_cache.object_id), node) + object_id = self._object_buffer.buffer_send_object(node, name=node_name).object_id - for node_name in node_name_to_argument.keys(): - graph.pop(node_name) + node_name_to_argument[node_name] = (Task.Argument(Task.Argument.ArgumentType.ObjectID, object_id), node) - return node_name_to_argument, graph + return node_name_to_argument, call_graph @staticmethod def __check_graph( - node_to_argument: Dict[str, Tuple[Argument, Any]], - graph: Dict[str, Union[Any, Tuple[Union[Callable, Any], ...]]], - keys: List[str], + node_to_argument: Dict[str, Tuple[Task.Argument, Any]], call_graph: Dict[str, _CallNode], keys: List[str] ): duplicate_keys = [key for key, count in dict(Counter(keys)).items() if count > 1] if duplicate_keys: @@ -406,76 +439,78 @@ def __check_graph( # sanity check graph for key in keys: - if key not in graph and key not in node_to_argument: + if key not in call_graph and key not in node_to_argument: raise KeyError(f"key {key} has to be in graph") - sorter = TopologicalSorter() - for node_name, node in graph.items(): - assert ( - isinstance(node, tuple) and len(node) > 0 and callable(node[0]) - ), "node has to be tuple and first item should be function" - - for arg in node[1:]: - if arg not in node_to_argument and arg not in graph: - raise KeyError(f"argument {arg} in node '{node_name}': {tuple(node)} is not defined in graph") + sorter: TopologicalSorter[str] = TopologicalSorter() + for node_name, node in call_graph.items(): + for arg in node.args: + if arg not in node_to_argument and arg not in call_graph: + raise KeyError(f"argument {arg} in node '{node_name}': {node} is not defined in graph") - sorter.add(node_name, *node[1:]) + sorter.add(node_name, *node.args) # check cyclic dependencies sorter.prepare() def __construct_graph( self, - node_name_to_arguments: Dict[str, Tuple[Argument, Any]], - graph: Dict[str, Tuple[Union[Callable, Any], ...]], + node_name_to_arguments: Dict[str, Tuple[Task.Argument, Any]], + call_graph: Dict[str, _CallNode], keys: List[str], block: bool, ) -> Tuple[GraphTask, Dict[str, ScalerFuture], Dict[str, ScalerFuture]]: graph_task_id = uuid.uuid1().bytes - node_name_to_task_id = {node_name: uuid.uuid1().bytes for node_name in graph.keys()} + node_name_to_task_id: Dict[str, bytes] = {node_name: uuid.uuid1().bytes for node_name in call_graph.keys()} task_flags_bytes = self.__get_task_flags().serialize() task_id_to_tasks = dict() - for node_name, node in graph.items(): + for node_name, node in call_graph.items(): task_id = node_name_to_task_id[node_name] + function_cache = self._object_buffer.buffer_send_function(node.func) - function, *args = node - function_cache = self._object_buffer.buffer_send_function(function) - - arguments = [] - for arg in args: - assert arg in graph or arg in node_name_to_arguments + arguments: List[Task.Argument] = [] + for arg in node.args: + assert arg in call_graph or arg in node_name_to_arguments - if arg in graph: - argument = Argument(ArgumentType.Task, node_name_to_task_id[arg]) - else: + if arg in call_graph: + arguments.append( + Task.Argument(type=Task.Argument.ArgumentType.Task, data=node_name_to_task_id[arg]) + ) + elif arg in node_name_to_arguments: argument, _ = node_name_to_arguments[arg] - - arguments.append(argument) - - task_id_to_tasks[task_id] = Task( - task_id, self._identity, task_flags_bytes, function_cache.object_id, arguments + arguments.append(argument) + else: + raise ValueError("Not possible") + + task_id_to_tasks[task_id] = Task.new_msg( + task_id=task_id, + source=self._identity, + metadata=task_flags_bytes, + func_object_id=function_cache.object_id, + function_args=arguments, ) - result_task_ids = [node_name_to_task_id[key] for key in keys if key in graph] - graph_task = GraphTask(graph_task_id, self._identity, result_task_ids, list(task_id_to_tasks.values())) + result_task_ids = [node_name_to_task_id[key] for key in keys if key in call_graph] + graph_task = GraphTask.new_msg(graph_task_id, self._identity, result_task_ids, list(task_id_to_tasks.values())) compute_futures = {} ready_futures = {} for key in keys: - if key in graph: - future = self._future_factory( + if key in call_graph: + compute_futures[key] = self._future_factory( task=task_id_to_tasks[node_name_to_task_id[key]], is_delayed=not block, group_task_id=graph_task_id ) - compute_futures[key] = future elif key in node_name_to_arguments: argument, data = node_name_to_arguments[key] future: ScalerFuture = self._future_factory( - task=Task(argument.data, self._identity, b"", b"", []), + task=Task.new_msg( + task_id=argument.data, source=self._identity, metadata=b"", func_object_id=b"", function_args=[] + ), is_delayed=False, group_task_id=graph_task_id, ) @@ -498,6 +533,14 @@ def __get_task_flags(self) -> TaskFlags: return TaskFlags(profiling=self._profiling, priority=task_priority) + def __assert_client_not_stopped(self): + if self._stop_event.is_set(): + raise ClientQuitException("client is already stopped.") + + def __destroy(self): + self._agent.join() + self._internal_context.destroy(linger=1) + @staticmethod def __get_parent_task_priority() -> Optional[int]: """If the client is running inside a Scaler processor, returns the priority of the associated task.""" @@ -511,11 +554,3 @@ def __get_parent_task_priority() -> Optional[int]: assert current_task is not None return retrieve_task_flags_from_task(current_task).priority - - def __assert_client_not_stopped(self): - if self._stop_event.is_set(): - raise ClientQuitException("client is already stopped.") - - def __destroy(self): - self._agent.join() - self._internal_context.destroy(linger=1) diff --git a/scaler/client/future.py b/scaler/client/future.py index 359a199..c5379ad 100644 --- a/scaler/client/future.py +++ b/scaler/client/future.py @@ -3,17 +3,17 @@ from typing import Any, Callable, Optional from scaler.io.sync_connector import SyncConnector -from scaler.protocol.python.message import ObjectRequest, ObjectRequestType, Task, TaskCancel -from scaler.utility.metadata.profile_result import ProfileResult +from scaler.protocol.python.message import ObjectRequest, Task, TaskCancel from scaler.utility.event_list import EventList +from scaler.utility.metadata.profile_result import ProfileResult class ScalerFuture(Future): def __init__(self, task: Task, is_delayed: bool, group_task_id: Optional[bytes], connector: SyncConnector): super().__init__() - self._waiters = EventList(self._waiters) - self._waiters.add_update_callback(self._on_waiters_updated) + self._waiters = EventList(self._waiters) # type: ignore[assignment] + self._waiters.add_update_callback(self._on_waiters_updated) # type: ignore[attr-defined] self._task_id: bytes = task.task_id self._is_delayed: bool = is_delayed @@ -31,14 +31,14 @@ def task_id(self): return self._task_id def profiling_info(self) -> ProfileResult: - with self._condition: + with self._condition: # type: ignore[attr-defined] if self._profiling_info is None: raise ValueError(f"didn't receive profiling info for {self} yet") return self._profiling_info def set_result_ready(self, object_id: Optional[bytes], profile_result: Optional[ProfileResult] = None) -> None: - with self._condition: + with self._condition: # type: ignore[attr-defined] if self.done(): raise InvalidStateError(f"invalid future state: {self._state}") @@ -55,7 +55,7 @@ def set_result_ready(self, object_id: Optional[bytes], profile_result: Optional[ self._result_ready_event.set() def set_exception(self, exception: Optional[BaseException], profile_result: Optional[ProfileResult] = None) -> None: - with self._condition: + with self._condition: # type: ignore[attr-defined] if profile_result is not None: self._profiling_info = profile_result @@ -66,7 +66,7 @@ def set_exception(self, exception: Optional[BaseException], profile_result: Opti def result(self, timeout=None): self._result_ready_event.wait(timeout) - with self._condition: + with self._condition: # type: ignore[attr-defined] # if it's delayed future, get the result when future.result() get called if self._is_delayed: self._request_result_object() @@ -75,7 +75,7 @@ def result(self, timeout=None): return super().result(timeout) def cancel(self) -> bool: - with self._condition: + with self._condition: # type: ignore[attr-defined] if self.cancelled(): return True @@ -83,20 +83,20 @@ def cancel(self) -> bool: return False if self._group_task_id is not None: - self._connector.send(TaskCancel(self._group_task_id)) + self._connector.send(TaskCancel.new_msg(self._group_task_id)) else: - self._connector.send(TaskCancel(self._task_id)) + self._connector.send(TaskCancel.new_msg(self._task_id)) self._state = "CANCELLED" self._result_ready_event.set() - self._condition.notify_all() + self._condition.notify_all() # type: ignore[attr-defined] - self._invoke_callbacks() + self._invoke_callbacks() # type: ignore[attr-defined] return True def add_done_callback(self, fn: Callable[[Future], Any]) -> None: - with self._condition: + with self._condition: # type: ignore[attr-defined] # if it's delayed future, get the result when a callback gets added if self._is_delayed: self._request_result_object() @@ -104,17 +104,17 @@ def add_done_callback(self, fn: Callable[[Future], Any]) -> None: return super().add_done_callback(fn) def _on_waiters_updated(self, waiters: EventList): - with self._condition: + with self._condition: # type: ignore[attr-defined] # if it's delayed future, get the result when waiter gets added if self._is_delayed and len(self._waiters) > 0: self._request_result_object() def _has_result_listeners(self) -> bool: - return len(self._done_callbacks) > 0 or len(self._waiters) > 0 + return len(self._done_callbacks) > 0 or len(self._waiters) > 0 # type: ignore[attr-defined] def _request_result_object(self): if self._result_request_sent or self._result_object_id is None or self.cancelled(): return - self._connector.send(ObjectRequest(ObjectRequestType.Get, (self._result_object_id,))) + self._connector.send(ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (self._result_object_id,))) self._result_request_sent = True diff --git a/scaler/client/object_buffer.py b/scaler/client/object_buffer.py index dbcec48..bbb1050 100644 --- a/scaler/client/object_buffer.py +++ b/scaler/client/object_buffer.py @@ -6,7 +6,8 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.sync_connector import SyncConnector -from scaler.protocol.python.message import ObjectContent, ObjectInstruction, ObjectInstructionType +from scaler.protocol.python.common import ObjectContent +from scaler.protocol.python.message import ObjectInstruction from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id @@ -56,7 +57,11 @@ def commit_send_objects(self): ] self._connector.send( - ObjectInstruction(ObjectInstructionType.Create, self._identity, ObjectContent(*zip(*objects_to_send))) + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Create, + self._identity, + ObjectContent.new_msg(*zip(*objects_to_send)), + ) ) self._pending_objects = list() @@ -66,8 +71,10 @@ def commit_delete_objects(self): return self._connector.send( - ObjectInstruction( - ObjectInstructionType.Delete, self._identity, ObjectContent(tuple(self._pending_delete_objects)) + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Delete, + self._identity, + ObjectContent.new_msg(tuple(self._pending_delete_objects)), ) ) @@ -89,5 +96,5 @@ def __construct_function(self, fn: Callable) -> ObjectCache: def __construct_object(self, obj: Any, name: Optional[str] = None) -> ObjectCache: object_payload = self._serializer.serialize(obj) object_id = generate_object_id(self._identity, object_payload) - name_bytes = name.encode() if name else f"".encode() + name_bytes = name.encode() if name else f"".encode() return ObjectCache(object_id, name_bytes, object_payload) diff --git a/scaler/client/object_reference.py b/scaler/client/object_reference.py index e4f7c79..ed95195 100644 --- a/scaler/client/object_reference.py +++ b/scaler/client/object_reference.py @@ -2,20 +2,20 @@ @dataclasses.dataclass -class ObjectReference(object): +class ObjectReference: name: bytes object_id: bytes size: int def __repr__(self): - return f"ScalerReference(name={self.name}, id={self.object_id}, size={self.size})" + return f"ScalerReference(name={self.name!r}, id={self.object_id!r}, size={self.size})" def __hash__(self): return hash(self.object_id) - def __eq__(self, other: "ObjectReference"): + def __eq__(self, other: object) -> bool: if not isinstance(other, ObjectReference): - return False + return NotImplemented return self.object_id == other.object_id diff --git a/scaler/cluster/cluster.py b/scaler/cluster/cluster.py index 1f3cc67..2303a17 100644 --- a/scaler/cluster/cluster.py +++ b/scaler/cluster/cluster.py @@ -9,7 +9,7 @@ from scaler.worker.worker import Worker -class Cluster(multiprocessing.get_context("spawn").Process): +class Cluster(multiprocessing.get_context("spawn").Process): # type: ignore[misc] def __init__( self, address: ZMQConfig, diff --git a/scaler/cluster/scheduler.py b/scaler/cluster/scheduler.py index 1926a62..347721a 100644 --- a/scaler/cluster/scheduler.py +++ b/scaler/cluster/scheduler.py @@ -1,6 +1,7 @@ import asyncio import multiprocessing -from typing import Optional, Tuple +from asyncio import AbstractEventLoop, Task +from typing import Optional, Tuple, Any from scaler.scheduler.config import SchedulerConfig from scaler.scheduler.scheduler import Scheduler, scheduler_main @@ -9,7 +10,7 @@ from scaler.utility.zmq_config import ZMQConfig -class SchedulerProcess(multiprocessing.get_context("spawn").Process): +class SchedulerProcess(multiprocessing.get_context("spawn").Process): # type: ignore[misc] def __init__( self, address: ZMQConfig, @@ -45,8 +46,8 @@ def __init__( self._logging_config_file = logging_config_file self._scheduler: Optional[Scheduler] = None - self._loop = None - self._task = None + self._loop: Optional[AbstractEventLoop] = None + self._task: Optional[Task[Any]] = None def run(self) -> None: # scheduler have its own single process diff --git a/scaler/entry_points/top.py b/scaler/entry_points/top.py index 9b01907..c8c598e 100644 --- a/scaler/entry_points/top.py +++ b/scaler/entry_points/top.py @@ -1,10 +1,11 @@ import argparse import curses import functools -from typing import List, Literal +from typing import List, Literal, Dict, Union from scaler.io.sync_subscriber import SyncSubscriber -from scaler.protocol.python.message import MessageVariant, StateScheduler +from scaler.protocol.python.message import StateScheduler +from scaler.protocol.python.mixins import Message from scaler.utility.formatter import ( format_bytes, format_integer, @@ -28,7 +29,7 @@ ord("l"): "lag", } -sort_by_state = {"sort_by": "cpu", "sort_by_previous": "cpu", "sort_reverse": True} +SORT_BY_STATE: Dict[str, Union[str, bool]] = {"sort_by": "cpu", "sort_by_previous": "cpu", "sort_reverse": True} def get_args(): @@ -61,7 +62,7 @@ def poke(screen, args): pass -def show_status(status: MessageVariant, screen): +def show_status(status: Message, screen): if not isinstance(status, StateScheduler): return @@ -72,7 +73,7 @@ def show_status(status: MessageVariant, screen): { "cpu": format_percentage(status.scheduler.cpu), "rss": format_bytes(status.scheduler.rss), - "rss_free": format_bytes(status.scheduler.rss_free), + "rss_free": format_bytes(status.rss_free), }, ) @@ -106,16 +107,16 @@ def show_status(status: MessageVariant, screen): "worker": worker.worker_id.decode(), "agt_cpu": worker.agent.cpu, "agt_rss": worker.agent.rss, - "cpu": worker.total_processors.cpu, - "rss": worker.total_processors.rss, - "os_rss_free": worker.total_processors.rss_free, + "cpu": sum(p.resource.cpu for p in worker.processor_statuses), + "rss": sum(p.resource.rss for p in worker.processor_statuses), + "os_rss_free": worker.rss_free, "free": worker.free, "sent": worker.sent, "queued": worker.queued, "suspended": worker.suspended, "lag": worker.lag_us, "last": worker.last_s, - "ITL": worker.ITL, + "ITL": worker.itl, } for worker in status.worker_manager.workers ], @@ -160,12 +161,14 @@ def format_integer_func(value): return table -def __generate_worker_manager_table(wm_data, worker_length: int): +def __generate_worker_manager_table(wm_data: List[Dict], worker_length: int) -> List[List[str]]: if not wm_data: headers = [["No workers"]] return headers - wm_data = sorted(wm_data, key=lambda item: item[sort_by_state["sort_by"]], reverse=sort_by_state["sort_reverse"]) + wm_data = sorted( + wm_data, key=lambda item: item[SORT_BY_STATE["sort_by"]], reverse=bool(SORT_BY_STATE["sort_reverse"]) + ) for row in wm_data: row["worker"] = __truncate(row["worker"], worker_length, how="left") @@ -179,7 +182,7 @@ def __generate_worker_manager_table(wm_data, worker_length: int): last = f"({format_seconds(last)}) " if last > 5 else "" row["lag"] = last + format_microseconds(row["lag"]) - worker_manager_table = [[f"[{v}]" if v == sort_by_state["sort_by"] else v for v in wm_data[0].keys()]] + worker_manager_table = [[f"[{v}]" if v == SORT_BY_STATE["sort_by"] else v for v in wm_data[0].keys()]] worker_manager_table.extend([list(worker.values()) for worker in wm_data]) return worker_manager_table @@ -264,10 +267,10 @@ def __change_option_state(option: int): if option not in SORT_BY_OPTIONS.keys(): return - sort_by_state["sort_by_previous"] = sort_by_state["sort_by"] - sort_by_state["sort_by"] = SORT_BY_OPTIONS[option] - if sort_by_state["sort_by"] != sort_by_state["sort_by_previous"]: - sort_by_state["sort_reverse"] = True + SORT_BY_STATE["sort_by_previous"] = SORT_BY_STATE["sort_by"] + SORT_BY_STATE["sort_by"] = SORT_BY_OPTIONS[option] + if SORT_BY_STATE["sort_by"] != SORT_BY_STATE["sort_by_previous"]: + SORT_BY_STATE["sort_reverse"] = True return - sort_by_state["sort_reverse"] = not sort_by_state["sort_reverse"] + SORT_BY_STATE["sort_reverse"] = not SORT_BY_STATE["sort_reverse"] diff --git a/scaler/io/async_binder.py b/scaler/io/async_binder.py index 8423985..68d5e1a 100644 --- a/scaler/io/async_binder.py +++ b/scaler/io/async_binder.py @@ -2,11 +2,12 @@ import os import uuid from collections import defaultdict -from typing import Awaitable, Callable, List, Literal, Optional +from typing import Awaitable, Callable, List, Optional, Dict import zmq.asyncio -from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant +from scaler.io.utility import deserialize, serialize +from scaler.protocol.python.mixins import Message from scaler.protocol.python.status import BinderStatus from scaler.utility.mixins import Looper, Reporter from scaler.utility.zmq_config import ZMQConfig @@ -25,14 +26,15 @@ def __init__(self, name: str, address: ZMQConfig, io_threads: int, identity: Opt self.__set_socket_options() self._socket.bind(self._address.to_address()) - self._callback: Optional[Callable[[bytes, MessageVariant], Awaitable[None]]] = None + self._callback: Optional[Callable[[bytes, Message], Awaitable[None]]] = None - self._statistics = {"received": defaultdict(lambda: 0), "sent": defaultdict(lambda: 0)} + self._received: Dict[str, int] = defaultdict(lambda: 0) + self._sent: Dict[str, int] = defaultdict(lambda: 0) def destroy(self): self._context.destroy(linger=0) - def register(self, callback: Callable[[bytes, MessageVariant], Awaitable[None]]): + def register(self, callback: Callable[[bytes, Message], Awaitable[None]]): self._callback = callback async def routine(self): @@ -40,23 +42,21 @@ async def routine(self): if not self.__is_valid_message(frames): return - source, message_type_bytes, payload = frames[0], frames[1], frames[2:] - message_type = MessageType(message_type_bytes) - self.__count_one("received", message_type) + source, payload = frames + message: Optional[Message] = deserialize(payload) + if message is None: + logging.error(f"received unknown message from {source!r}: {payload!r}") + return - message = PROTOCOL[message_type].deserialize(payload) + self.__count_received(message.__class__.__name__) await self._callback(source, message) - async def send(self, to: bytes, message: MessageVariant): - message_type = PROTOCOL.inverse[type(message)] - self.__count_one("sent", message_type) - await self._socket.send_multipart([to, message_type.value, *message.serialize()], copy=False) + async def send(self, to: bytes, message: Message): + self.__count_sent(message.__class__.__name__) + await self._socket.send_multipart([to, serialize(message)], copy=False) def get_status(self) -> BinderStatus: - return BinderStatus( - received={k: v for k, v in self._statistics["received"].items()}, - sent={k: v for k, v in self._statistics["sent"].items()}, - ) + return BinderStatus.new_msg(received=self._received, sent=self._sent) def __set_socket_options(self): self._socket.setsockopt(zmq.IDENTITY, self._identity) @@ -64,18 +64,17 @@ def __set_socket_options(self): self._socket.setsockopt(zmq.RCVHWM, 0) def __is_valid_message(self, frames: List[bytes]) -> bool: - if len(frames) < 3: + if len(frames) < 2: logging.error(f"{self.__get_prefix()} received unexpected frames {frames}") return False - if frames[1] not in {member.value for member in MessageType}: - logging.error(f"{self.__get_prefix()} received unexpected message type: {frames[0]}: {frames}") - return False - return True - def __count_one(self, count_type: Literal["sent", "received"], message_type: MessageType): - self._statistics[count_type][message_type.name] += 1 + def __count_received(self, message_type: str): + self._received[message_type] += 1 + + def __count_sent(self, message_type: str): + self._sent[message_type] += 1 def __get_prefix(self): return f"{self.__class__.__name__}[{self._identity.decode()}]:" diff --git a/scaler/io/async_connector.py b/scaler/io/async_connector.py index faba381..86be509 100644 --- a/scaler/io/async_connector.py +++ b/scaler/io/async_connector.py @@ -5,7 +5,8 @@ import zmq.asyncio -from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant +from scaler.io.utility import deserialize, serialize +from scaler.protocol.python.mixins import Message from scaler.utility.zmq_config import ZMQConfig @@ -17,7 +18,7 @@ def __init__( socket_type: int, address: ZMQConfig, bind_or_connect: Literal["bind", "connect"], - callback: Optional[Callable[[MessageType, MessageVariant], Awaitable[None]]], + callback: Optional[Callable[[Message], Awaitable[None]]], identity: Optional[bytes], ): self._address = address @@ -41,7 +42,7 @@ def __init__( else: raise TypeError("bind_or_connect has to be 'bind' or 'connect'") - self._callback: Optional[Callable[[MessageVariant], Awaitable[None]]] = callback + self._callback: Optional[Callable[[Message], Awaitable[None]]] = callback def __del__(self): self.destroy() @@ -68,41 +69,35 @@ async def routine(self): if self._callback is None: return - message = await self.receive() + message: Optional[Message] = await self.receive() if message is None: return await self._callback(message) - async def receive(self) -> Optional[MessageVariant]: + async def receive(self) -> Optional[Message]: if self._context.closed: return None if self._socket.closed: return None - frames = await self._socket.recv_multipart() - if not self.__is_valid_message(frames): + payload = await self._socket.recv() + result: Optional[Message] = deserialize(payload) + if result is None: + logging.error(f"received unknown message: {payload!r}") return None - message_type_bytes, *payload = frames - message_type = MessageType(message_type_bytes) - message = PROTOCOL[message_type].deserialize(payload) - return message + return result - async def send(self, data: MessageVariant): - message_type = PROTOCOL.inverse[type(data)] - await self._socket.send_multipart([message_type.value, *data.serialize()], copy=False) + async def send(self, message: Message): + await self._socket.send(serialize(message), copy=False) def __is_valid_message(self, frames: List[bytes]) -> bool: - if len(frames) < 2: + if len(frames) > 1: logging.error(f"{self.__get_prefix()} received unexpected frames {frames}") return False - if frames[0] not in {member.value for member in MessageType}: - logging.error(f"{self.__get_prefix()} received unexpected message type: {frames[0]}: {frames}") - return False - return True def __get_prefix(self): diff --git a/scaler/io/config.py b/scaler/io/config.py index 631f449..5c55fd7 100644 --- a/scaler/io/config.py +++ b/scaler/io/config.py @@ -12,6 +12,9 @@ # number of seconds for profiling PROFILING_INTERVAL_SECONDS = 1 +# message size limitation, max can be 2**64 +MESSAGE_SIZE_LIMIT = 2**64 - 1 + # ========================== # SCHEDULER SPECIFIC OPTIONS diff --git a/scaler/io/sync_connector.py b/scaler/io/sync_connector.py index 2c82c03..a45b6a5 100644 --- a/scaler/io/sync_connector.py +++ b/scaler/io/sync_connector.py @@ -3,11 +3,12 @@ import socket import threading import uuid -from typing import List, Optional +from typing import Optional import zmq -from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant +from scaler.io.utility import deserialize, serialize +from scaler.protocol.python.mixins import Message from scaler.utility.zmq_config import ZMQConfig @@ -44,31 +45,23 @@ def address(self) -> ZMQConfig: def identity(self) -> bytes: return self._identity - def send(self, message: MessageVariant): - message_type = PROTOCOL.inverse[type(message)] - + def send(self, message: Message): with self._lock: - self._socket.send_multipart([message_type.value, *message.serialize()]) + self._socket.send(serialize(message)) - def receive(self) -> Optional[MessageVariant]: + def receive(self) -> Optional[Message]: with self._lock: - frames = self._socket.recv_multipart() - - return self.__compose_message(frames) + payload = self._socket.recv() - def __compose_message(self, frames: List[bytes]) -> Optional[MessageVariant]: - if len(frames) < 2: - logging.error(f"{self.__get_prefix()} received unexpected frames {frames}") - return None + return self.__compose_message(payload) - if frames[0] not in {member.value for member in MessageType}: - logging.error(f"{self.__get_prefix()} received unexpected message type: {frames[0]}: {frames}") + def __compose_message(self, payload: bytes) -> Optional[Message]: + result: Optional[Message] = deserialize(payload) + if result is None: + logging.error(f"{self.__get_prefix()}: received unknown message: {payload!r}") return None - message_type_bytes, *payload = frames - message_type = MessageType(message_type_bytes) - message = PROTOCOL[message_type].deserialize(payload) - return message + return result def __get_prefix(self): return f"{self.__class__.__name__}[{self._identity.decode()}]:" diff --git a/scaler/io/sync_subscriber.py b/scaler/io/sync_subscriber.py index 00010a6..5423f2c 100644 --- a/scaler/io/sync_subscriber.py +++ b/scaler/io/sync_subscriber.py @@ -1,10 +1,11 @@ import logging import threading -from typing import Callable, List, Optional +from typing import Callable, Optional import zmq -from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant +from scaler.io.utility import deserialize +from scaler.protocol.python.mixins import Message from scaler.utility.zmq_config import ZMQConfig @@ -12,7 +13,7 @@ class SyncSubscriber(threading.Thread): def __init__( self, address: ZMQConfig, - callback: Callable[[MessageVariant], None], + callback: Callable[[Message], None], topic: bytes, exit_callback: Optional[Callable[[], None]] = None, stop_event: threading.Event = threading.Event(), @@ -54,7 +55,7 @@ def run(self) -> None: def __initialize(self): self._context = zmq.Context.instance() - self._socket: zmq.Socket = self._context.socket(zmq.SUB) + self._socket = self._context.socket(zmq.SUB) self._socket.setsockopt(zmq.RCVHWM, 0) if self._timeout_seconds == -1: @@ -68,23 +69,14 @@ def __initialize(self): def __routine_polling(self): try: - frames = self._socket.recv_multipart() - self.__routine_receive(frames) + self.__routine_receive(self._socket.recv()) except zmq.Again: raise TimeoutError(f"Cannot connect to {self._address.to_address()} in {self._timeout_seconds} seconds") - def __routine_receive(self, frames: List[bytes]): - logging.info(frames) - if len(frames) < 2: - logging.error(f"received unexpected frames {frames}") - return + def __routine_receive(self, payload: bytes): + result: Optional[Message] = deserialize(payload) + if result is None: + logging.error(f"received unknown message: {payload!r}") + return None - if frames[0] not in {member.value for member in MessageType}: - logging.error(f"received unexpected message type: {frames[0]}") - return - - message_type_bytes, *payload = frames - message_type = MessageType(message_type_bytes) - message = PROTOCOL[message_type].deserialize(payload) - - self._callback(message) + self._callback(result) diff --git a/scaler/io/utility.py b/scaler/io/utility.py new file mode 100644 index 0000000..6465671 --- /dev/null +++ b/scaler/io/utility.py @@ -0,0 +1,22 @@ +import logging +from typing import Optional + +from scaler.io.config import MESSAGE_SIZE_LIMIT +from scaler.protocol.capnp._python import _message # noqa +from scaler.protocol.python.message import PROTOCOL +from scaler.protocol.python.mixins import Message + + +def deserialize(data: bytes) -> Optional[Message]: + with _message.Message.from_bytes(data, traversal_limit_in_words=MESSAGE_SIZE_LIMIT) as payload: + if not hasattr(payload, payload.which()): + logging.error(f"unknown message type: {payload.which()}") + return None + + message = getattr(payload, payload.which()) + return PROTOCOL[payload.which()](message) + + +def serialize(message: Message) -> bytes: + payload = _message.Message(**{PROTOCOL.inverse[type(message)]: message.get_message()}) + return payload.to_bytes() diff --git a/scaler/protocol/capnp/__init__.py b/scaler/protocol/capnp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scaler/protocol/capnp/_python.py b/scaler/protocol/capnp/_python.py new file mode 100644 index 0000000..93d725f --- /dev/null +++ b/scaler/protocol/capnp/_python.py @@ -0,0 +1,4 @@ +import capnp # noqa +import scaler.protocol.capnp.common_capnp as _common # noqa +import scaler.protocol.capnp.message_capnp as _message # noqa +import scaler.protocol.capnp.status_capnp as _status # noqa diff --git a/scaler/protocol/capnp/common.capnp b/scaler/protocol/capnp/common.capnp new file mode 100644 index 0000000..3f08c37 --- /dev/null +++ b/scaler/protocol/capnp/common.capnp @@ -0,0 +1,22 @@ +@0xf57f79ac88fab620; + +enum TaskStatus { + # task is accepted by scheduler, but will have below status + success @0; # if submit and task is done and get result + failed @1; # if submit and task is failed on worker + canceled @2; # if submit and task is canceled + notFound @3; # if submit and task is not found in scheduler + workerDied @4; # if submit and worker died (only happened when scheduler keep_task=False) + noWorker @5; # if submit and scheduler is full (not implemented yet) + + # below are only used for monitoring channel, not sent to client + inactive @6; # task is scheduled but not allocate to worker + running @7; # task is running in worker + canceling @8; # task is canceling (can be in Inactive or Running state) +} + +struct ObjectContent { + objectIds @0 :List(Data); + objectNames @1 :List(Data); + objectBytes @2 :List(Data); +} diff --git a/scaler/protocol/capnp/message.capnp b/scaler/protocol/capnp/message.capnp new file mode 100644 index 0000000..c554aa2 --- /dev/null +++ b/scaler/protocol/capnp/message.capnp @@ -0,0 +1,208 @@ +@0xaf44f44ea94a4675; + +using CommonType = import "common.capnp"; +using Status = import "status.capnp"; + +struct Task { + taskId @0 :Data; + source @1 :Data; + metadata @2 :Data; + funcObjectId @3 :Data; + functionArgs @4 :List(Argument); + + struct Argument { + type @0 :ArgumentType; + data @1 :Data; + + enum ArgumentType { + task @0; + objectID @1; + } + } +} + +struct TaskCancel { + struct TaskCancelFlags { + force @0 :Bool; + retrieveTaskObject @1 :Bool; + } + + taskId @0 :Data; + flags @1 :TaskCancelFlags; +} + +struct TaskResult { + taskId @0 :Data; + status @1 :CommonType.TaskStatus; + metadata @2 :Data; + results @3 :List(Data); +} + +struct GraphTask { + taskId @0 :Data; + source @1 :Data; + targets @2 :List(Data); + graph @3 :List(Task); +} + +struct GraphTaskCancel { + taskId @0 :Data; +} + +struct ClientHeartbeat { + resource @0 :Status.Resource; + latencyUS @1 :UInt32; +} + +struct ClientHeartbeatEcho { +} + +struct WorkerHeartbeat { + agent @0 :Status.Resource; + rssFree @1 :UInt64; + queuedTasks @2 :UInt32; + latencyUS @3 :UInt32; + taskLock @4 :Bool; + processors @5 :List(Status.ProcessorStatus); +} + +struct WorkerHeartbeatEcho { +} + +struct ObjectInstruction { + instructionType @0 :ObjectInstructionType; + objectUser @1 :Data; + objectContent @2 :CommonType.ObjectContent; + + enum ObjectInstructionType { + create @0; + delete @1; + } +} + +struct ObjectRequest { + requestType @0 :ObjectRequestType; + objectIds @1 :List(Data); + + enum ObjectRequestType { + get @0; + } +} + +struct ObjectResponse { + responseType @0 :ObjectResponseType; + objectContent @1 :CommonType.ObjectContent; + + enum ObjectResponseType { + content @0; + objectNotExist @1; + } +} + +struct DisconnectRequest { + worker @0 :Data; +} + +struct DisconnectResponse { + worker @0 :Data; +} + +struct ClientDisconnect { + disconnectType @0 :DisconnectType; + + enum DisconnectType { + disconnect @0; + shutdown @1; + } +} + +struct ClientShutdownResponse { + accepted @0 :Bool; +} + +struct StateClient { +} + +struct StateObject { +} + +struct StateBalanceAdvice { + workerId @0 :Data; + taskIds @1 :List(Data); +} + +struct StateScheduler { + binder @0 :Status.BinderStatus; + scheduler @1 :Status.Resource; + rssFree @2 :UInt64; + clientManager @3 :Status.ClientManagerStatus; + objectManager @4 :Status.ObjectManagerStatus; + taskManager @5 :Status.TaskManagerStatus; + workerManager @6 :Status.WorkerManagerStatus; +} + +struct StateWorker { + workerId @0 :Data; + message @1 :Data; +} + +struct StateTask { + taskId @0 :Data; + functionName @1 :Data; + status @2 :CommonType.TaskStatus; + worker @3 :Data; + metadata @4 :Data; +} + +struct StateGraphTask { + enum NodeTaskType { + normal @0; + target @1; + } + + graphTaskId @0 :Data; + taskId @1 :Data; + nodeTaskType @2 :NodeTaskType; + parentTaskIds @3 :List(Data); +} + +struct ProcessorInitialized { +} + + +struct Message { + union { + task @0 :Task; + taskCancel @1 :TaskCancel; + taskResult @2 :TaskResult; + + graphTask @3 :GraphTask; + graphTaskCancel @4 :GraphTaskCancel; + + objectInstruction @5 :ObjectInstruction; + objectRequest @6 :ObjectRequest; + objectResponse @7 :ObjectResponse; + + clientHeartbeat @8 :ClientHeartbeat; + clientHeartbeatEcho @9 :ClientHeartbeatEcho; + + workerHeartbeat @10 :WorkerHeartbeat; + workerHeartbeatEcho @11 :WorkerHeartbeatEcho; + + disconnectRequest @12 :DisconnectRequest; + disconnectResponse @13 :DisconnectResponse; + + stateClient @14 :StateClient; + stateObject @15 :StateObject; + stateBalanceAdvice @16 :StateBalanceAdvice; + stateScheduler @17 :StateScheduler; + stateWorker @18 :StateWorker; + stateTask @19 :StateTask; + stateGraphTask @20 :StateGraphTask; + + clientDisconnect @21 :ClientDisconnect; + clientShutdownResponse @22 :ClientShutdownResponse; + + processorInitialized @23 :ProcessorInitialized; + } +} diff --git a/scaler/protocol/capnp/status.capnp b/scaler/protocol/capnp/status.capnp new file mode 100644 index 0000000..1131bc4 --- /dev/null +++ b/scaler/protocol/capnp/status.capnp @@ -0,0 +1,65 @@ +@0xa4dfa1212ad2d0f0; + +struct Resource { + cpu @0 :UInt16; # 99.2% will be represented as 992 as integer + rss @1 :UInt64; # 32bit is capped to 4GB, so use 64bit to represent +} + +struct ObjectManagerStatus { + numberOfObjects @0 :UInt32; + objectMemory @1 :UInt64; +} + +struct ClientManagerStatus { + clientToNumOfTask @0 :List(Pair); + + struct Pair { + client @0 :Data; + numTask @1 :UInt32; + } +} + +struct TaskManagerStatus { + unassigned @0 :UInt32; + running @1 :UInt32; + success @2 :UInt32; + failed @3 :UInt32; + canceled @4 :UInt32; + notFound @5 :UInt32; +} + +struct ProcessorStatus { + pid @0 :UInt32; + initialized @1 :Bool; + hasTask @2 :Bool; + suspended @3 :Bool; + resource @4 :Resource; +} + +struct WorkerStatus { + workerId @0 :Data; + agent @1 :Resource; + rssFree @2 :UInt64; + free @3 :UInt32; + sent @4 :UInt32; + queued @5 :UInt32; + suspended @6: UInt8; + lagUS @7 :UInt64; + lastS @8 :UInt8; + itl @9 :Text; + processorStatuses @10 :List(ProcessorStatus); +} + +struct WorkerManagerStatus { + workers @0 :List(WorkerStatus); +} + +struct BinderStatus { + received @0 :List(Pair); + sent @1 :List(Pair); + + struct Pair { + client @0 :Text; + number @1 :UInt32; + } +} diff --git a/scaler/protocol/introduction.md b/scaler/protocol/introduction.md index 97fbb2f..f612fd4 100644 --- a/scaler/protocol/introduction.md +++ b/scaler/protocol/introduction.md @@ -1,6 +1,7 @@ # Roles The communication protocol include 3 roles: client, scheduler and worker: + - client is upstream of scheduler, scheduler is upstream of worker - worker is downstream of scheduler, scheduler is downstream of client @@ -29,26 +30,11 @@ each client to scheduler and each worker to scheduler only maintains 1 TCP conne # Message format -Each message is a sequence of bytes, and composed by frames, each frame is a sequence of bytes and has a fixed -length header and a variable length body. - -see below, each frame have a fixed length header, and a variable length body, the header is 8 bytes, so each frame -cannot exceed 2^64 bytes, which is 16 exabytes, which is enough for most of the use cases. - -```plaintext -| Frame 0 | Frame 1 | Frame 2 | Frame 3 -+---+----------+---+-----+-----+---+-----+----+--------------------- -| 7 | "Client1"| 2 | "O" | "I" | 1 | "C" | 36 | ... -+---+----------+---+-----+-----+---+-----+----+--------------------- - | Message | | Object | | | - | Identity | |Instruction| | | - | | - Instruction - Type -``` - +Scaler is using capnp library to serialize/deserialize and use zmq to communicate between client and scheduler and +worker # Message Type Category + In general, there are 2 categories of the message types: object and task object normally has an object id associated with actual object data, object data is immutable bytes, serialized by @@ -60,10 +46,9 @@ function and series of arguments, but task message doesn't contain the actual fu contains object ids, workers are responsible to fetch the function/argument data from scheduler and deserialize and execute the function call. - ## Object Channel -Scheduler is the center of the object storage, client and worker are identical and can push +Scheduler is the center of the object storage, client and worker are identical and can push ```plaintext ObjectInstruction @@ -84,19 +69,20 @@ Scheduler is the center of the object storage, client and worker are identical a ObjectRequest +--------------+ ``` + ObjectInstruction = b"OI" client can send object instruction to scheduler, scheduler can send object instruction to worker it has 2 subtypes: create b"C", delete b"D" when subtype is create, it has to include: + - list of object id (type bytes) - list of object names (type bytes) - list of object bytes (type bytes) -All above 3 lists, the number of items need match + All above 3 lists, the number of items need match ObjectRequest = b"OR" ObjectResponse = b"OA" - ## Task Channel ```plaintext diff --git a/scaler/protocol/python/common.py b/scaler/protocol/python/common.py new file mode 100644 index 0000000..1e94e9c --- /dev/null +++ b/scaler/protocol/python/common.py @@ -0,0 +1,56 @@ +import dataclasses +import enum +from typing import Tuple + +from scaler.protocol.capnp._python import _common # noqa +from scaler.protocol.python.mixins import Message + + +class TaskStatus(enum.Enum): + # task is accepted by scheduler, but will have below status + Success = _common.TaskStatus.success # if submit and task is done and get result + Failed = _common.TaskStatus.failed # if submit and task is failed on worker + Canceled = _common.TaskStatus.canceled # if submit and task is canceled + NotFound = _common.TaskStatus.notFound # if submit and task is not found in scheduler + WorkerDied = ( + _common.TaskStatus.workerDied + ) # if submit and worker died (only happened when scheduler keep_task=False) + NoWorker = _common.TaskStatus.noWorker # if submit and scheduler is full (not implemented yet) + + # below are only used for monitoring channel, not sent to client + Inactive = _common.TaskStatus.inactive # task is scheduled but not allocate to worker + Running = _common.TaskStatus.running # task is running in worker + Canceling = _common.TaskStatus.canceling # task is canceling (can be in Inactive or Running state) + + +@dataclasses.dataclass +class ObjectContent(Message): + def __init__(self, msg): + self._msg = msg + + @property + def object_ids(self) -> Tuple[bytes, ...]: + return tuple(self._msg.objectIds) + + @property + def object_names(self) -> Tuple[bytes, ...]: + return tuple(self._msg.objectNames) + + @property + def object_bytes(self) -> Tuple[bytes, ...]: + return tuple(self._msg.objectBytes) + + @staticmethod + def new_msg( + object_ids: Tuple[bytes, ...], + object_names: Tuple[bytes, ...] = tuple(), + object_bytes: Tuple[bytes, ...] = tuple(), + ) -> "ObjectContent": + return ObjectContent( + _common.ObjectContent( + objectIds=list(object_ids), objectNames=list(object_names), objectBytes=tuple(object_bytes) + ) + ) + + def get_message(self): + return self._msg diff --git a/scaler/protocol/python/message.py b/scaler/protocol/python/message.py index bc3f907..43d55b9 100644 --- a/scaler/protocol/python/message.py +++ b/scaler/protocol/python/message.py @@ -1,653 +1,644 @@ import dataclasses import enum import os -import pickle -import struct -from typing import List, Set, Tuple, TypeVar +from typing import List, Set, Tuple, Optional, Type import bidict -from scaler.protocol.python.mixins import _Message +from scaler.protocol.capnp._python import _message # noqa +from scaler.protocol.python.common import TaskStatus, ObjectContent +from scaler.protocol.python.mixins import Message from scaler.protocol.python.status import ( BinderStatus, ClientManagerStatus, ObjectManagerStatus, Resource, + ProcessorStatus, TaskManagerStatus, WorkerManagerStatus, ) -class MessageType(enum.Enum): - Task = b"TK" - TaskCancel = b"TC" - TaskResult = b"TR" +class Task(Message): + @dataclasses.dataclass + class Argument: + type: "ArgumentType" + data: bytes - GraphTask = b"GT" - GraphTaskCancel = b"GC" + def __repr__(self): + return f"Argument(type={self.type}, data={self.data.hex()})" - ObjectInstruction = b"OI" - ObjectRequest = b"OR" - ObjectResponse = b"OA" + class ArgumentType(enum.Enum): + Task = _message.Task.Argument.ArgumentType.task + ObjectID = _message.Task.Argument.ArgumentType.objectID - ClientHeartbeat = b"CB" - ClientHeartbeatEcho = b"CE" + def __init__(self, msg): + super().__init__(msg) - WorkerHeartbeat = b"HB" - WorkerHeartbeatEcho = b"HE" - - DisconnectRequest = b"DR" - DisconnectResponse = b"DP" + def __repr__(self): + return ( + f"Task(task_id={self.task_id.hex()}, source={self.source.hex()}, metadata={self.metadata.hex()}," + f"func_object_id={self.func_object_id.hex()}, function_args={self.function_args})" + ) - StateClient = b"SC" - StateObject = b"SF" - StateBalanceAdvice = b"SA" - StateScheduler = b"SS" - StateWorker = b"SW" - StateTask = b"ST" - StateGraphTask = b"SG" + @property + def task_id(self) -> bytes: + return self._msg.taskId - ClientDisconnect = b"CS" - ClientShutdownResponse = b"CR" + @property + def source(self) -> bytes: + return self._msg.source - ProcessorInitialized = b"PI" + @property + def metadata(self) -> bytes: + return self._msg.metadata - @staticmethod - def allowed_values(): - return {member.value for member in MessageType} + @property + def func_object_id(self) -> bytes: + return self._msg.funcObjectId + @property + def function_args(self) -> List[Argument]: + return [ + Task.Argument(type=Task.Argument.ArgumentType(arg.type.raw), data=arg.data) + for arg in self._msg.functionArgs + ] -class TaskEchoStatus(enum.Enum): - # task echo is the response of task submit to scheduler - SubmitOK = b"O" # if submit ok and task get accepted by scheduler - SubmitDuplicated = b"D" # if submit and find task in scheduler - CancelOK = b"C" # if cancel and success - CancelTaskNotFound = b"N" # if cancel and cannot find task in scheduler + @staticmethod + def new_msg( + task_id: bytes, source: bytes, metadata: bytes, func_object_id: bytes, function_args: List[Argument] + ) -> "Task": + return Task( + _message.Task( + taskId=task_id, + source=source, + metadata=metadata, + funcObjectId=func_object_id, + functionArgs=[_message.Task.Argument(type=arg.type.value, data=arg.data) for arg in function_args], + ) + ) -class TaskStatus(enum.Enum): - # task is accepted by scheduler, but will have below status - Success = b"S" # if submit and task is done and get result - Failed = b"F" # if submit and task is failed on worker - Canceled = b"C" # if submit and task is canceled - NotFound = b"N" # if submit and task is not found in scheduler - WorkerDied = b"K" # if submit and worker died (only happened when scheduler keep_task=False) - NoWorker = b"W" # if submit and scheduler is full (not implemented yet) +class TaskCancel(Message): + def __init__(self, msg): + super().__init__(msg) - # below are only used for monitoring channel, not sent to client - Inactive = b"I" # task is scheduled but not allocate to worker - Running = b"R" # task is running in worker - Canceling = b"X" # task is canceling (can be in Inactive or Running state) + @dataclasses.dataclass + class TaskCancelFlags: + force: bool + retrieve_task_object: bool + @property + def task_id(self) -> bytes: + return self._msg.taskId -class NodeTaskType(enum.Enum): - Normal = b"N" - Target = b"T" + @property + def flags(self) -> TaskCancelFlags: + return TaskCancel.TaskCancelFlags( + force=self._msg.flags.force, retrieve_task_object=self._msg.flags.retrieveTaskObject + ) + @staticmethod + def new_msg(task_id: bytes, flags: Optional[TaskCancelFlags] = None) -> "TaskCancel": + if flags is None: + flags = TaskCancel.TaskCancelFlags(force=False, retrieve_task_object=False) + + return TaskCancel( + _message.TaskCancel( + taskId=task_id, + flags=_message.TaskCancel.TaskCancelFlags( + force=flags.force, retrieveTaskObject=flags.retrieve_task_object + ), + ) + ) -class ObjectInstructionType(enum.Enum): - Create = b"C" - Delete = b"D" +class TaskResult(Message): + def __init__(self, msg): + super().__init__(msg) -class ObjectRequestType(enum.Enum): - Get = b"A" + @property + def task_id(self) -> bytes: + return self._msg.taskId + @property + def status(self) -> TaskStatus: + return TaskStatus(self._msg.status.raw) -class ObjectResponseType(enum.Enum): - Content = b"C" - ObjectNotExist = b"N" + @property + def metadata(self) -> bytes: + return self._msg.metadata + @property + def results(self) -> List[bytes]: + return self._msg.results -class ArgumentType(enum.Enum): - Task = b"T" - ObjectID = b"R" + @staticmethod + def new_msg( + task_id: bytes, status: TaskStatus, metadata: Optional[bytes] = None, results: Optional[List[bytes]] = None + ) -> "TaskResult": + if metadata is None: + metadata = bytes() + if results is None: + results = list() -class DisconnectType(enum.Enum): - Disconnect = b"D" - Shutdown = b"S" + return TaskResult(_message.TaskResult(taskId=task_id, status=status.value, metadata=metadata, results=results)) -@dataclasses.dataclass -class Argument: - type: ArgumentType - data: bytes +class GraphTask(Message): + def __init__(self, msg): + super().__init__(msg) def __repr__(self): - return f"Argument(type={self.type}, data={self.data.hex()})" + return ( + f"GraphTask({os.linesep}" + f" task_id={self.task_id.hex()},{os.linesep}" + f" targets=[{os.linesep}" + f" {[target.hex() + ',' + os.linesep for target in self.targets]}" + f" ]\n" + f" graph={self.graph}\n" + f")" + ) - def serialize(self) -> Tuple[bytes, ...]: - return self.type.value, self.data + @property + def task_id(self) -> bytes: + return self._msg.taskId - @staticmethod - def deserialize(data: List[bytes]): - return Argument(ArgumentType(data[0]), data[1]) + @property + def source(self) -> bytes: + return self._msg.source + @property + def targets(self) -> List[bytes]: + return self._msg.targets -@dataclasses.dataclass -class ObjectContent: - object_ids: Tuple[bytes, ...] - object_names: Tuple[bytes, ...] = dataclasses.field(default_factory=tuple) - object_bytes: Tuple[bytes, ...] = dataclasses.field(default_factory=tuple) - - def serialize(self) -> Tuple[bytes, ...]: - payload = ( - struct.pack("III", len(self.object_ids), len(self.object_names), len(self.object_bytes)), - *self.object_ids, - *self.object_names, - *self.object_bytes, - ) - return payload + @property + def graph(self) -> List[Task]: + return [Task(task) for task in self._msg.graph] @staticmethod - def deserialize(data: List[bytes]) -> "ObjectContent": - num_of_object_ids, num_of_object_names, num_of_object_bytes = struct.unpack("III", data[0]) - - data = data[1:] - object_ids = data[:num_of_object_ids] - - data = data[num_of_object_ids:] - object_names = data[:num_of_object_names] + def new_msg(task_id: bytes, source: bytes, targets: List[bytes], graph: List[Task]) -> "GraphTask": + return GraphTask( + _message.GraphTask( + taskId=task_id, source=source, targets=targets, graph=[task.get_message() for task in graph] + ) + ) - data = data[num_of_object_names:] - object_bytes = data[:num_of_object_bytes] - return ObjectContent(tuple(object_ids), tuple(object_names), tuple(object_bytes)) +class GraphTaskCancel(Message): + def __init__(self, msg): + super().__init__(msg) + @property + def task_id(self) -> bytes: + return self._msg.taskId -MessageVariant = TypeVar("MessageVariant", bound=_Message) + @staticmethod + def new_msg(task_id: bytes) -> "GraphTaskCancel": + return GraphTaskCancel(_message.GraphTaskCancel(taskId=task_id)) + def get_message(self): + return self._msg -@dataclasses.dataclass -class Task(_Message): - task_id: bytes - source: bytes - metadata: bytes - func_object_id: bytes - function_args: List[Argument] - def __repr__(self): - return ( - f"Task(task_id={self.task_id.hex()}, source={self.source.hex()}, metadata={self.metadata.hex()}," - f"func_object_id={self.func_object_id.hex()}, function_args={self.function_args})" - ) +class ClientHeartbeat(Message): + def __init__(self, msg): + super().__init__(msg) - def get_required_object_ids(self) -> Set[bytes]: - return {self.func_object_id} | {arg.data for arg in self.function_args if arg.type == ArgumentType.ObjectID} + @property + def resource(self) -> Resource: + return Resource(self._msg.resource) - def serialize(self) -> Tuple[bytes, ...]: - return ( - self.task_id, - self.source, - self.metadata, - self.func_object_id, - *[d for arg in self.function_args for d in arg.serialize()], - ) + @property + def latency_us(self) -> int: + return self._msg.latencyUS @staticmethod - def deserialize(data: List[bytes]): - return Task( - data[0], data[1], data[2], data[3], [Argument.deserialize(data[i : i + 2]) for i in range(4, len(data), 2)] - ) - - -@dataclasses.dataclass -class TaskCancelFlags: - force: bool = dataclasses.field(default=False) - retrieve_task_object: bool = dataclasses.field(default=False) + def new_msg(resource: Resource, latency_us: int) -> "ClientHeartbeat": + return ClientHeartbeat(_message.ClientHeartbeat(resource=resource.get_message(), latencyUS=latency_us)) - FORMAT = "??" - def serialize(self) -> bytes: - return struct.pack(TaskCancelFlags.FORMAT, self.force, self.retrieve_task_object) +class ClientHeartbeatEcho(Message): + def __init__(self, msg): + super().__init__(msg) @staticmethod - def deserialize(data: bytes) -> "TaskCancelFlags": - return TaskCancelFlags(*struct.unpack(TaskCancelFlags.FORMAT, data)) + def new_msg() -> "ClientHeartbeatEcho": + return ClientHeartbeatEcho(_message.ClientHeartbeatEcho()) -@dataclasses.dataclass -class TaskCancel(_Message): - task_id: bytes - flags: TaskCancelFlags = dataclasses.field(default_factory=TaskCancelFlags) +class WorkerHeartbeat(Message): + def __init__(self, msg): + super().__init__(msg) - def serialize(self) -> Tuple[bytes, ...]: - return (self.task_id, self.flags.serialize()) + @property + def agent(self) -> Resource: + return Resource(self._msg.agent) - @staticmethod - def deserialize(data: List[bytes]): - return TaskCancel(data[0], TaskCancelFlags.deserialize(data[1])) # type: ignore + @property + def rss_free(self) -> int: + return self._msg.rssFree + @property + def queued_tasks(self) -> int: + return self._msg.queuedTasks -@dataclasses.dataclass -class TaskResult(_Message): - task_id: bytes - status: TaskStatus - metadata: bytes = dataclasses.field(default=b"") - results: List[bytes] = dataclasses.field(default_factory=list) + @property + def latency_us(self) -> int: + return self._msg.latencyUS - def serialize(self) -> Tuple[bytes, ...]: - return self.task_id, self.status.value, self.metadata, *self.results + @property + def task_lock(self) -> bool: + return self._msg.taskLock - @staticmethod - def deserialize(data: List[bytes]): - return TaskResult(data[0], TaskStatus(data[1]), data[2], data[3:]) - - -@dataclasses.dataclass -class GraphTask(_Message): - task_id: bytes - source: bytes - targets: List[bytes] - graph: List[Task] + @property + def processors(self) -> List[ProcessorStatus]: + return [ProcessorStatus(p) for p in self._msg.processors] - def __repr__(self): - return ( - f"GraphTask({os.linesep}" - f" task_id={self.task_id.hex()},{os.linesep}" - f" targets=[{os.linesep}" - f" {[target.hex() + ',' + os.linesep for target in self.targets]}" - f" ]\n" - f" graph={self.graph}\n" - f")" + @staticmethod + def new_msg( + agent: Resource, + rss_free: int, + queued_tasks: int, + latency_us: int, + task_lock: bool, + processors: List[ProcessorStatus], + ) -> "WorkerHeartbeat": + return WorkerHeartbeat( + _message.WorkerHeartbeat( + agent=agent.get_message(), + rssFree=rss_free, + queuedTasks=queued_tasks, + latencyUS=latency_us, + taskLock=task_lock, + processors=[p.get_message() for p in processors], + ) ) - def serialize(self) -> Tuple[bytes, ...]: - graph_bytes = [] - for task in self.graph: - frames = task.serialize() - graph_bytes.append(struct.pack("I", len(frames))) - graph_bytes.extend(frames) - return self.task_id, self.source, struct.pack("I", len(self.targets)), *self.targets, *graph_bytes +class WorkerHeartbeatEcho(Message): + def __init__(self, msg): + super().__init__(msg) @staticmethod - def deserialize(data: List[bytes]): - index = 0 - task_id = data[index] + def new_msg() -> "WorkerHeartbeatEcho": + return WorkerHeartbeatEcho(_message.WorkerHeartbeatEcho()) - index += 1 - client_id = data[index] - index += 1 - number_of_targets = struct.unpack("I", data[index])[0] +class ObjectInstruction(Message): + class ObjectInstructionType(enum.Enum): + Create = _message.ObjectInstruction.ObjectInstructionType.create + Delete = _message.ObjectInstruction.ObjectInstructionType.delete - index += 1 - targets = data[index : index + number_of_targets] + def __init__(self, msg): + super().__init__(msg) - index += number_of_targets - graph = [] - while index < len(data): - number_of_frames = struct.unpack("I", data[index])[0] - index += 1 - graph.append(Task.deserialize(data[index : index + number_of_frames])) - index += number_of_frames + @property + def instruction_type(self) -> ObjectInstructionType: + return ObjectInstruction.ObjectInstructionType(self._msg.instructionType.raw) - return GraphTask(task_id, client_id, targets, graph) + @property + def object_user(self) -> bytes: + return self._msg.objectUser - -@dataclasses.dataclass -class GraphTaskCancel(_Message): - task_id: bytes - - def serialize(self) -> Tuple[bytes, ...]: - return (self.task_id,) + @property + def object_content(self) -> ObjectContent: + return ObjectContent(self._msg.objectContent) @staticmethod - def deserialize(data: List[bytes]): - return GraphTaskCancel(data[0]) - - -@dataclasses.dataclass -class ClientHeartbeat(_Message): - client_cpu: float - client_rss: int - latency_us: int - - FORMAT = "HQI" - - def serialize(self) -> Tuple[bytes, ...]: - return (struct.pack(ClientHeartbeat.FORMAT, int(self.client_cpu * 1000), self.client_rss, self.latency_us),) + def new_msg( + instruction_type: ObjectInstructionType, object_user: bytes, object_content: ObjectContent + ) -> "ObjectInstruction": + return ObjectInstruction( + _message.ObjectInstruction( + instructionType=instruction_type.value, + objectUser=object_user, + objectContent=object_content.get_message(), + ) + ) - @staticmethod - def deserialize(data: List[bytes]): - client_cpu, client_rss, latency_us = struct.unpack(ClientHeartbeat.FORMAT, data[0]) - return ClientHeartbeat(float(client_cpu / 1000), client_rss, latency_us) +class ObjectRequest(Message): + class ObjectRequestType(enum.Enum): + Get = _message.ObjectRequest.ObjectRequestType.get -@dataclasses.dataclass -class ClientHeartbeatEcho(_Message): - def serialize(self) -> Tuple[bytes, ...]: - return (b"",) + def __init__(self, msg): + super().__init__(msg) - @staticmethod - def deserialize(data: List[bytes]): - return ClientHeartbeatEcho() + def __repr__(self): + return ( + f"ObjectRequest(type={self.request_type}, " + f"object_ids={tuple(object_id.hex() for object_id in self.object_ids)})" + ) + @property + def request_type(self) -> ObjectRequestType: + return ObjectRequest.ObjectRequestType(self._msg.requestType.raw) -@dataclasses.dataclass -class ProcessorHeartbeat: - pid: int - initialized: bool - has_task: bool - suspended: bool - cpu: float - rss: int - - FORMAT = "I???HQ" - - def serialize(self) -> bytes: - return struct.pack( - ProcessorHeartbeat.FORMAT, - self.pid, - self.initialized, - self.has_task, - self.suspended, - int(self.cpu * 1000), - self.rss, - ) + @property + def object_ids(self) -> Tuple[bytes]: + return tuple(self._msg.objectIds) @staticmethod - def deserialize(data: bytes) -> "ProcessorHeartbeat": - pid, initialized, has_task, suspended, cpu, rss = struct.unpack(ProcessorHeartbeat.FORMAT, data) - return ProcessorHeartbeat(pid, initialized, has_task, suspended, float(cpu / 1000), rss) + def new_msg(request_type: ObjectRequestType, object_ids: Tuple[bytes, ...]) -> "ObjectRequest": + return ObjectRequest(_message.ObjectRequest(requestType=request_type.value, objectIds=list(object_ids))) -@dataclasses.dataclass -class WorkerHeartbeat(_Message): - agent_cpu: float - agent_rss: int - rss_free: int - queued_tasks: int - latency_us: int - task_lock: bool +class ObjectResponse(Message): + class ObjectResponseType(enum.Enum): + Content = _message.ObjectResponse.ObjectResponseType.content + ObjectNotExist = _message.ObjectResponse.ObjectResponseType.objectNotExist - processors: List[ProcessorHeartbeat] + def __init__(self, msg): + super().__init__(msg) - FORMAT = "HQQHI?" # processor heartbeats come right after the main fields + @property + def response_type(self) -> ObjectResponseType: + return ObjectResponse.ObjectResponseType(self._msg.responseType.raw) - def serialize(self) -> Tuple[bytes, ...]: - return ( - struct.pack( - WorkerHeartbeat.FORMAT, - int(self.agent_cpu * 1000), - self.agent_rss, - self.rss_free, - self.queued_tasks, - self.latency_us, - self.task_lock, - ), - *(p.serialize() for p in self.processors) - ) + @property + def object_content(self) -> ObjectContent: + return ObjectContent(self._msg.objectContent) @staticmethod - def deserialize(data: List[bytes]): - ( - agent_cpu, - agent_rss, - rss_free, - queued_tasks, - latency_us, - task_lock, - ) = struct.unpack(WorkerHeartbeat.FORMAT, data[0]) - processors = [ProcessorHeartbeat.deserialize(d) for d in data[1:]] - - return WorkerHeartbeat( - float(agent_cpu / 1000), - agent_rss, - rss_free, - queued_tasks, - latency_us, - task_lock, - processors, + def new_msg(response_type: ObjectResponseType, object_content: ObjectContent) -> "ObjectResponse": + return ObjectResponse( + _message.ObjectResponse(responseType=response_type.value, objectContent=object_content.get_message()) ) -@dataclasses.dataclass -class WorkerHeartbeatEcho(_Message): - def serialize(self) -> Tuple[bytes, ...]: - return (b"",) +class DisconnectRequest(Message): + def __init__(self, msg): + super().__init__(msg) + + @property + def worker(self) -> bytes: + return self._msg.worker @staticmethod - def deserialize(data: List[bytes]): - return WorkerHeartbeatEcho() + def new_msg(worker: bytes) -> "DisconnectRequest": + return DisconnectRequest(_message.DisconnectRequest(worker=worker)) @dataclasses.dataclass -class ObjectInstruction(_Message): - type: ObjectInstructionType - object_user: bytes - object_content: ObjectContent +class DisconnectResponse(Message): + def __init__(self, msg): + super().__init__(msg) - def serialize(self) -> Tuple[bytes, ...]: - return self.type.value, self.object_user, *self.object_content.serialize() + @property + def worker(self) -> bytes: + return self._msg.worker @staticmethod - def deserialize(data: List[bytes]) -> "ObjectInstruction": - return ObjectInstruction(ObjectInstructionType(data[0]), data[1], ObjectContent.deserialize(data[2:])) + def new_msg(worker: bytes) -> "DisconnectResponse": + return DisconnectResponse(_message.DisconnectResponse(worker=worker)) -@dataclasses.dataclass -class ObjectRequest(_Message): - type: ObjectRequestType - object_ids: Tuple[bytes, ...] +class ClientDisconnect(Message): + class DisconnectType(enum.Enum): + Disconnect = _message.ClientDisconnect.DisconnectType.disconnect + Shutdown = _message.ClientDisconnect.DisconnectType.shutdown - def __repr__(self): - return f"ObjectRequest(type={self.type}, object_ids={tuple(object_id.hex() for object_id in self.object_ids)})" + def __init__(self, msg): + super().__init__(msg) - def serialize(self) -> Tuple[bytes, ...]: - return self.type.value, *self.object_ids + @property + def disconnect_type(self) -> DisconnectType: + return ClientDisconnect.DisconnectType(self._msg.disconnectType.raw) @staticmethod - def deserialize(data: List[bytes]): - return ObjectRequest(ObjectRequestType(data[0]), tuple(data[1:])) + def new_msg(disconnect_type: DisconnectType) -> "ClientDisconnect": + return ClientDisconnect(_message.ClientDisconnect(disconnectType=disconnect_type.value)) -@dataclasses.dataclass -class ObjectResponse(_Message): - type: ObjectResponseType - object_content: ObjectContent +class ClientShutdownResponse(Message): + def __init__(self, msg): + super().__init__(msg) - def serialize(self) -> Tuple[bytes, ...]: - return self.type.value, *self.object_content.serialize() + @property + def accepted(self) -> bool: + return self._msg.accepted @staticmethod - def deserialize(data: List[bytes]): - request_type = ObjectResponseType(data[0]) - return ObjectResponse(request_type, ObjectContent.deserialize(data[1:])) + def new_msg(accepted: bool) -> "ClientShutdownResponse": + return ClientShutdownResponse(_message.ClientShutdownResponse(accepted=accepted)) -@dataclasses.dataclass -class DisconnectRequest(_Message): - worker: bytes - - def serialize(self) -> Tuple[bytes, ...]: - return (self.worker,) +class StateClient(Message): + def __init__(self, msg): + super().__init__(msg) @staticmethod - def deserialize(data: List[bytes]): - return DisconnectRequest(data[0]) + def new_msg() -> "StateClient": + return StateClient(_message.StateClient()) -@dataclasses.dataclass -class DisconnectResponse(_Message): - worker: bytes - - def serialize(self) -> Tuple[bytes, ...]: - return (self.worker,) +class StateObject(Message): + def __init__(self, msg): + super().__init__(msg) @staticmethod - def deserialize(data: List[bytes]): - return DisconnectResponse(data[0]) + def new_msg() -> "StateObject": + return StateObject(_message.StateObject()) -@dataclasses.dataclass -class ClientDisconnect(_Message): - type: DisconnectType +class StateBalanceAdvice(Message): + def __init__(self, msg): + super().__init__(msg) - def serialize(self) -> Tuple[bytes, ...]: - return (self.type.value,) + @property + def worker_id(self) -> bytes: + return self._msg.workerId + + @property + def task_ids(self) -> List[bytes]: + return self._msg.taskIds @staticmethod - def deserialize(data: List[bytes]): - return ClientDisconnect(DisconnectType(data[0])) + def new_msg(worker_id: bytes, task_ids: List[bytes]) -> "StateBalanceAdvice": + return StateBalanceAdvice(_message.StateBalanceAdvice(workerId=worker_id, taskIds=task_ids)) -@dataclasses.dataclass -class ClientShutdownResponse(_Message): - accepted: bool +class StateScheduler(Message): + def __init__(self, msg): + super().__init__(msg) - def serialize(self) -> Tuple[bytes, ...]: - return (struct.pack("?", self.accepted),) + @property + def binder(self) -> BinderStatus: + return BinderStatus(self._msg.binder) - @staticmethod - def deserialize(data: List[bytes]): - return ClientShutdownResponse(struct.unpack("?", data[0])[0]) + @property + def scheduler(self) -> Resource: + return Resource(self._msg.scheduler) + @property + def rss_free(self) -> int: + return self._msg.rssFree -@dataclasses.dataclass -class StateClient(_Message): - # TODO: implement this - def serialize(self) -> Tuple[bytes, ...]: - return (b"",) + @property + def client_manager(self) -> ClientManagerStatus: + return ClientManagerStatus(self._msg.clientManager) - @staticmethod - def deserialize(data: List[bytes]): - return StateClient() + @property + def object_manager(self) -> ObjectManagerStatus: + return ObjectManagerStatus(self._msg.objectManager) + @property + def task_manager(self) -> TaskManagerStatus: + return TaskManagerStatus(self._msg.taskManager) -@dataclasses.dataclass -class StateObject(_Message): - # TODO: implement this - def serialize(self) -> Tuple[bytes, ...]: - return (b"",) + @property + def worker_manager(self) -> WorkerManagerStatus: + return WorkerManagerStatus(self._msg.workerManager) @staticmethod - def deserialize(data: List[bytes]): - return StateObject() + def new_msg( + binder: BinderStatus, + scheduler: Resource, + rss_free: int, + client_manager: ClientManagerStatus, + object_manager: ObjectManagerStatus, + task_manager: TaskManagerStatus, + worker_manager: WorkerManagerStatus, + ) -> "StateScheduler": + return StateScheduler( + _message.StateScheduler( + binder=binder.get_message(), + scheduler=scheduler.get_message(), + rssFree=rss_free, + clientManager=client_manager.get_message(), + objectManager=object_manager.get_message(), + taskManager=task_manager.get_message(), + workerManager=worker_manager.get_message(), + ) + ) -@dataclasses.dataclass -class StateBalanceAdvice(_Message): - worker_id: bytes - task_ids: List[bytes] +class StateWorker(Message): + def __init__(self, msg): + super().__init__(msg) + + @property + def worker_id(self) -> bytes: + return self._msg.workerId - def serialize(self) -> Tuple[bytes, ...]: - return self.worker_id, *self.task_ids + @property + def message(self) -> bytes: + return self._msg.message @staticmethod - def deserialize(data: List[bytes]): - return StateBalanceAdvice(data[0], data[1:]) + def new_msg(worker_id: bytes, message: bytes) -> "StateWorker": + return StateWorker(_message.StateWorker(workerId=worker_id, message=message)) -@dataclasses.dataclass -class StateScheduler(_Message): - binder: BinderStatus - scheduler: Resource - client_manager: ClientManagerStatus - object_manager: ObjectManagerStatus - task_manager: TaskManagerStatus - worker_manager: WorkerManagerStatus - - def serialize(self) -> Tuple[bytes, ...]: - return ( - pickle.dumps( - ( - self.binder, - self.scheduler, - self.client_manager, - self.object_manager, - self.task_manager, - self.worker_manager, - ) - ), - ) +class StateTask(Message): + def __init__(self, msg): + super().__init__(msg) - @staticmethod - def deserialize(data: List[bytes]): - return StateScheduler(*pickle.loads(data[0])) + @property + def task_id(self) -> bytes: + return self._msg.taskId + @property + def function_name(self) -> bytes: + return self._msg.functionName -@dataclasses.dataclass -class StateWorker(_Message): - worker_id: bytes - message: bytes + @property + def status(self) -> TaskStatus: + return TaskStatus(self._msg.status.raw) - def serialize(self) -> Tuple[bytes, ...]: - return self.worker_id, self.message + @property + def worker(self) -> bytes: + return self._msg.worker + + @property + def metadata(self) -> bytes: + return self._msg.metadata @staticmethod - def deserialize(data: List[bytes]): - return StateWorker(data[0], data[1]) + def new_msg( + task_id: bytes, function_name: bytes, status: TaskStatus, worker: bytes, metadata: bytes = b"" + ) -> "StateTask": + return StateTask( + _message.StateTask( + taskId=task_id, functionName=function_name, status=status.value, worker=worker, metadata=metadata + ) + ) -@dataclasses.dataclass -class StateTask(_Message): - task_id: bytes - function_name: bytes - status: TaskStatus - worker: bytes - metadata: bytes = dataclasses.field(default=b"") +class StateGraphTask(Message): + class NodeTaskType(enum.Enum): + Normal = _message.StateGraphTask.NodeTaskType.normal + Target = _message.StateGraphTask.NodeTaskType.target - def serialize(self) -> Tuple[bytes, ...]: - return self.task_id, self.function_name, self.status.value, self.worker, self.metadata + def __init__(self, msg): + super().__init__(msg) - @staticmethod - def deserialize(data: List[bytes]): - return StateTask(data[0], data[1], TaskStatus(data[2]), data[3], data[4]) + @property + def graph_task_id(self) -> bytes: + return self._msg.graphTaskId + @property + def task_id(self) -> bytes: + return self._msg.taskId -@dataclasses.dataclass -class StateGraphTask(_Message): - graph_task_id: bytes - task_id: bytes - node_task_type: NodeTaskType - parent_task_ids: Set[bytes] + @property + def node_task_type(self) -> NodeTaskType: + return StateGraphTask.NodeTaskType(self._msg.nodeTaskType.raw) - def serialize(self) -> Tuple[bytes, ...]: - return self.graph_task_id, self.task_id, self.node_task_type.value, *self.parent_task_ids + @property + def parent_task_ids(self) -> Set[bytes]: + return set(self._msg.parentTaskIds) @staticmethod - def deserialize(data: List[bytes]): - return StateGraphTask(data[0], data[1], NodeTaskType(data[2]), set(data[3:])) + def new_msg( + graph_task_id: bytes, task_id: bytes, node_task_type: NodeTaskType, parent_task_ids: Set[bytes] + ) -> "StateGraphTask": + return StateGraphTask( + _message.StateGraphTask( + graphTaskId=graph_task_id, + taskId=task_id, + nodeTaskType=node_task_type.value, + parentTaskIds=list(parent_task_ids), + ) + ) -@dataclasses.dataclass -class ProcessorInitialized(_Message): - def serialize(self) -> Tuple[bytes, ...]: - return (b"",) +class ProcessorInitialized(Message): + def __init__(self, msg): + super().__init__(msg) @staticmethod - def deserialize(data: List[bytes]): - return ProcessorInitialized() + def new_msg() -> "ProcessorInitialized": + return ProcessorInitialized(_message.ProcessorInitialized()) -PROTOCOL = bidict.bidict( +PROTOCOL: bidict.bidict[str, Type[Message]] = bidict.bidict( { - MessageType.ClientHeartbeat: ClientHeartbeat, - MessageType.ClientHeartbeatEcho: ClientHeartbeatEcho, - MessageType.WorkerHeartbeat: WorkerHeartbeat, - MessageType.WorkerHeartbeatEcho: WorkerHeartbeatEcho, - MessageType.Task: Task, - MessageType.TaskCancel: TaskCancel, - MessageType.TaskResult: TaskResult, - MessageType.GraphTask: GraphTask, - MessageType.GraphTaskCancel: GraphTaskCancel, - MessageType.ObjectInstruction: ObjectInstruction, - MessageType.ObjectRequest: ObjectRequest, - MessageType.ObjectResponse: ObjectResponse, - MessageType.DisconnectRequest: DisconnectRequest, - MessageType.DisconnectResponse: DisconnectResponse, - MessageType.StateClient: StateClient, - MessageType.StateObject: StateObject, - MessageType.StateBalanceAdvice: StateBalanceAdvice, - MessageType.StateScheduler: StateScheduler, - MessageType.StateWorker: StateWorker, - MessageType.StateTask: StateTask, - MessageType.StateGraphTask: StateGraphTask, - MessageType.ClientDisconnect: ClientDisconnect, - MessageType.ClientShutdownResponse: ClientShutdownResponse, - MessageType.ProcessorInitialized: ProcessorInitialized, + "task": Task, + "taskCancel": TaskCancel, + "taskResult": TaskResult, + "graphTask": GraphTask, + "graphTaskCancel": GraphTaskCancel, + "objectInstruction": ObjectInstruction, + "objectRequest": ObjectRequest, + "objectResponse": ObjectResponse, + "clientHeartbeat": ClientHeartbeat, + "clientHeartbeatEcho": ClientHeartbeatEcho, + "workerHeartbeat": WorkerHeartbeat, + "workerHeartbeatEcho": WorkerHeartbeatEcho, + "disconnectRequest": DisconnectRequest, + "disconnectResponse": DisconnectResponse, + "stateClient": StateClient, + "stateObject": StateObject, + "stateBalanceAdvice": StateBalanceAdvice, + "stateScheduler": StateScheduler, + "stateWorker": StateWorker, + "stateTask": StateTask, + "stateGraphTask": StateGraphTask, + "clientDisconnect": ClientDisconnect, + "clientShutdownResponse": ClientShutdownResponse, + "processorInitialized": ProcessorInitialized, } ) diff --git a/scaler/protocol/python/mixins.py b/scaler/protocol/python/mixins.py index 765873d..52c44e6 100644 --- a/scaler/protocol/python/mixins.py +++ b/scaler/protocol/python/mixins.py @@ -1,13 +1,13 @@ import abc -from typing import List, Tuple +from typing import TypeVar -class _Message(metaclass=abc.ABCMeta): - @abc.abstractmethod - def serialize(self) -> Tuple[bytes, ...]: - raise NotImplementedError() +class Message(metaclass=abc.ABCMeta): + def __init__(self, msg): + self._msg = msg - @staticmethod - @abc.abstractmethod - def deserialize(data: List[bytes]): - raise NotImplementedError() + def get_message(self): + return self._msg + + +MessageType = TypeVar("MessageType", bound=Message) diff --git a/scaler/protocol/python/status.py b/scaler/protocol/python/status.py index 1f3cb08..5dfc702 100644 --- a/scaler/protocol/python/status.py +++ b/scaler/protocol/python/status.py @@ -1,65 +1,276 @@ -import dataclasses from typing import Dict, List +from scaler.protocol.capnp._python import _status # noqa +from scaler.protocol.python.mixins import Message -@dataclasses.dataclass -class Resource: - cpu: float - rss: int - rss_free: int +class Resource(Message): + def __init__(self, msg): + self._msg = msg -@dataclasses.dataclass -class ObjectManagerStatus: - number_of_objects: int - object_memory: int + @property + def cpu(self) -> int: + return self._msg.cpu + @property + def rss(self) -> int: + return self._msg.rss -@dataclasses.dataclass -class ClientManagerStatus: - client_to_num_of_tasks: Dict[bytes, int] + @staticmethod + def new_msg(cpu: int, rss: int) -> "Resource": # type: ignore[override] + return Resource(_status.Resource(cpu=cpu, rss=rss)) + def get_message(self): + return self._msg -@dataclasses.dataclass -class TaskManagerStatus: - unassigned: int - running: int - success: int - failed: int - canceled: int - not_found: int +class ObjectManagerStatus(Message): + def __init__(self, msg): + self._msg = msg -@dataclasses.dataclass -class ProcessorStatus: - pid: int - initialized: bool - has_task: bool - suspended: bool - resource: Resource + @property + def number_of_objects(self) -> int: + return self._msg.numberOfObjects + @property + def object_memory(self) -> int: + return self._msg.objectMemory -@dataclasses.dataclass -class WorkerStatus: - worker_id: bytes - agent: Resource - total_processors: Resource - free: int - sent: int - queued: int - suspended: int - lag_us: int - last_s: int - ITL: str - processor_statuses: List[ProcessorStatus] + @staticmethod + def new_msg(number_of_objects: int, object_memory: int) -> "ObjectManagerStatus": # type: ignore[override] + return ObjectManagerStatus( + _status.ObjectManagerStatus(numberOfObjects=number_of_objects, objectMemory=object_memory) + ) + def get_message(self): + return self._msg -@dataclasses.dataclass -class WorkerManagerStatus: - workers: List[WorkerStatus] +class ClientManagerStatus(Message): + def __init__(self, msg): + self._msg = msg -@dataclasses.dataclass -class BinderStatus: - received: Dict[str, int] - sent: Dict[str, int] + @property + def client_to_num_of_tasks(self) -> Dict[bytes, int]: + return {p.client: p.numTask for p in self._msg.clientToNumOfTask} + + @staticmethod + def new_msg(client_to_num_of_tasks: Dict[bytes, int]) -> "ClientManagerStatus": # type: ignore[override] + return ClientManagerStatus( + _status.ClientManagerStatus( + clientToNumOfTask=[ + _status.ClientManagerStatus.Pair(client=p[0], numTask=p[1]) for p in client_to_num_of_tasks.items() + ] + ) + ) + + def get_message(self): + return self._msg + + +class TaskManagerStatus(Message): + def __init__(self, msg): + self._msg = msg + + @property + def unassigned(self) -> int: + return self._msg.unassigned + + @property + def running(self) -> int: + return self._msg.running + + @property + def success(self) -> int: + return self._msg.success + + @property + def failed(self) -> int: + return self._msg.failed + + @property + def canceled(self) -> int: + return self._msg.canceled + + @property + def not_found(self) -> int: + return self._msg.notFound + + @staticmethod + def new_msg( # type: ignore[override] + unassigned: int, running: int, success: int, failed: int, canceled: int, not_found: int + ) -> "TaskManagerStatus": + return TaskManagerStatus( + _status.TaskManagerStatus( + unassigned=unassigned, + running=running, + success=success, + failed=failed, + canceled=canceled, + notFound=not_found, + ) + ) + + def get_message(self): + return self._msg + + +class ProcessorStatus(Message): + def __init__(self, msg): + self._msg = msg + + @property + def pid(self) -> int: + return self._msg.pid + + @property + def initialized(self) -> int: + return self._msg.initialized + + @property + def has_task(self) -> bool: + return self._msg.hasTask + + @property + def suspended(self) -> bool: + return self._msg.suspended + + @property + def resource(self) -> Resource: + return Resource(self._msg.resource) + + @staticmethod + def new_msg( + pid: int, initialized: int, has_task: bool, suspended: bool, resource: Resource # type: ignore[override] + ) -> "ProcessorStatus": + return ProcessorStatus( + _status.ProcessorStatus( + pid=pid, initialized=initialized, hasTask=has_task, suspended=suspended, resource=resource.get_message() + ) + ) + + def get_message(self): + return self._msg + + +class WorkerStatus(Message): + def __init__(self, msg): + self._msg = msg + + @property + def worker_id(self) -> bytes: + return self._msg.workerId + + @property + def agent(self) -> Resource: + return Resource(self._msg.agent) + + @property + def rss_free(self) -> int: + return self._msg.rssFree + + @property + def free(self) -> int: + return self._msg.free + + @property + def sent(self) -> int: + return self._msg.sent + + @property + def queued(self) -> int: + return self._msg.queued + + @property + def suspended(self) -> bool: + return self._msg.suspended + + @property + def lag_us(self) -> int: + return self._msg.lagUS + + @property + def last_s(self) -> int: + return self._msg.lastS + + @property + def itl(self) -> str: + return self._msg.itl + + @property + def processor_statuses(self) -> List[ProcessorStatus]: + return [ProcessorStatus(ps) for ps in self._msg.processorStatuses] + + @staticmethod + def new_msg( # type: ignore[override] + worker_id: bytes, + agent: Resource, + rss_free: int, + free: int, + sent: int, + queued: int, + suspended: int, + lag_us: int, + last_s: int, + itl: str, + processor_statuses: List[ProcessorStatus], + ) -> "WorkerStatus": + return WorkerStatus( + _status.WorkerStatus( + workerId=worker_id, + agent=agent.get_message(), + rssFree=rss_free, + free=free, + sent=sent, + queued=queued, + suspended=suspended, + lagUS=lag_us, + lastS=last_s, + itl=itl, + processorStatuses=[ps.get_message() for ps in processor_statuses], + ) + ) + + def get_message(self): + return self._msg + + +class WorkerManagerStatus(Message): + def __init__(self, msg): + self._msg = msg + + @property + def workers(self) -> List[WorkerStatus]: + return [WorkerStatus(ws) for ws in self._msg.workers] + + @staticmethod + def new_msg(workers: List[WorkerStatus]) -> "WorkerManagerStatus": # type: ignore[override] + return WorkerManagerStatus(_status.WorkerManagerStatus(workers=[ws.get_message() for ws in workers])) + + def get_message(self): + return self._msg + + +class BinderStatus(Message): + def __init__(self, msg): + self._msg = msg + + @property + def received(self) -> Dict[str, int]: + return {p.client: p.number for p in self._msg.received} + + @property + def sent(self) -> Dict[str, int]: + return {p.client: p.number for p in self._msg.sent} + + @staticmethod + def new_msg(received: Dict[str, int], sent: Dict[str, int]) -> "BinderStatus": # type: ignore[override] + return BinderStatus( + _status.BinderStatus( + received=[_status.BinderStatus.Pair(client=p[0], number=p[1]) for p in received.items()], + sent=[_status.BinderStatus.Pair(client=p[0], number=p[1]) for p in sent.items()], + ) + ) + + def get_message(self): + return self._msg diff --git a/scaler/scheduler/allocators/mixins.py b/scaler/scheduler/allocators/mixins.py index 3bbbd14..34b4077 100644 --- a/scaler/scheduler/allocators/mixins.py +++ b/scaler/scheduler/allocators/mixins.py @@ -1,5 +1,5 @@ import abc -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set class TaskAllocator(metaclass=abc.ABCMeta): @@ -14,7 +14,7 @@ def remove_worker(self, worker: bytes) -> List[bytes]: raise NotImplementedError() @abc.abstractmethod - def get_worker_ids(self) -> List[bytes]: + def get_worker_ids(self) -> Set[bytes]: """get all worker ids as list""" raise NotImplementedError() @@ -24,9 +24,9 @@ def get_worker_by_task_id(self, task_id: bytes) -> bytes: raise NotImplementedError() @abc.abstractmethod - def balance(self) -> Dict[bytes, int]: - """balance worker, it should return the number of tasks for over burdened worker, represented as worker - identity to number of tasks dictionary""" + def balance(self) -> Dict[bytes, List[bytes]]: + """balance worker, it should return list of task ids for over burdened worker, represented as worker + identity to list of task ids dictionary""" raise NotImplementedError() @abc.abstractmethod diff --git a/scaler/scheduler/client_manager.py b/scaler/scheduler/client_manager.py index ed6de20..2d2f1e1 100644 --- a/scaler/scheduler/client_manager.py +++ b/scaler/scheduler/client_manager.py @@ -9,7 +9,6 @@ ClientHeartbeat, ClientHeartbeatEcho, ClientShutdownResponse, - DisconnectType, TaskCancel, ) from scaler.protocol.python.status import ClientManagerStatus @@ -63,14 +62,14 @@ def on_task_finish(self, task_id: bytes) -> bytes: return self._client_to_task_ids.remove_value(task_id) async def on_heartbeat(self, client: bytes, info: ClientHeartbeat): - await self._binder.send(client, ClientHeartbeatEcho()) + await self._binder.send(client, ClientHeartbeatEcho.new_msg()) if client not in self._client_last_seen: - logging.info(f"client {client} connected") + logging.info(f"client {client!r} connected") self._client_last_seen[client] = (time.time(), info) async def on_client_disconnect(self, client: bytes, request: ClientDisconnect): - if request.type == DisconnectType.Disconnect: + if request.disconnect_type == ClientDisconnect.DisconnectType.Disconnect: await self.__on_client_disconnect(client) return @@ -78,23 +77,23 @@ async def on_client_disconnect(self, client: bytes, request: ClientDisconnect): logging.warning("cannot shutdown clusters as scheduler is running in protected mode") accepted = False else: - logging.info(f"shutdown scheduler and all clusters as received signal from {client=}") + logging.info(f"shutdown scheduler and all clusters as received signal from {client=!r}") accepted = True - await self._binder.send(client, ClientShutdownResponse(accepted=accepted)) + await self._binder.send(client, ClientShutdownResponse.new_msg(accepted=accepted)) if self._protected: return await self._worker_manager.on_client_shutdown(client) - raise ClientShutdownException(f"received client shutdown from {client}, quiting") + raise ClientShutdownException(f"received client shutdown from {client!r}, quiting") async def routine(self): await self.__routine_cleanup_clients() def get_status(self) -> ClientManagerStatus: - return ClientManagerStatus( + return ClientManagerStatus.new_msg( {client.decode(): len(task_ids) for client, task_ids in self._client_to_task_ids.items()} ) @@ -110,7 +109,7 @@ async def __routine_cleanup_clients(self): await self.__on_client_disconnect(client) async def __on_client_disconnect(self, client_id: bytes): - logging.info(f"client {client_id} disconnected") + logging.info(f"client {client_id!r} disconnected") if client_id in self._client_last_seen: self._client_last_seen.pop(client_id) @@ -123,4 +122,4 @@ async def __cancel_tasks(self, client: bytes): tasks = self._client_to_task_ids.get_values(client).copy() for task in tasks: - await self._task_manager.on_task_cancel(client, TaskCancel(task)) + await self._task_manager.on_task_cancel(client, TaskCancel.new_msg(task)) diff --git a/scaler/scheduler/graph_manager.py b/scaler/scheduler/graph_manager.py index 1ef96ba..392c829 100644 --- a/scaler/scheduler/graph_manager.py +++ b/scaler/scheduler/graph_manager.py @@ -6,18 +6,8 @@ from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import ( - Argument, - ArgumentType, - GraphTask, - GraphTaskCancel, - NodeTaskType, - StateGraphTask, - Task, - TaskCancel, - TaskResult, - TaskStatus, -) +from scaler.protocol.python.common import TaskStatus +from scaler.protocol.python.message import GraphTask, GraphTaskCancel, StateGraphTask, Task, TaskCancel, TaskResult from scaler.scheduler.mixins import ClientManager, GraphTaskManager, ObjectManager, TaskManager from scaler.utility.graph.topological_sorter import TopologicalSorter from scaler.utility.many_to_many_dict import ManyToManyDict @@ -107,7 +97,7 @@ async def on_graph_task(self, client: bytes, graph_task: GraphTask): async def on_graph_task_cancel(self, client: bytes, graph_task_cancel: GraphTaskCancel): if graph_task_cancel.task_id not in self._graph_task_id_to_graph: - await self._binder.send(client, TaskResult(graph_task_cancel.task_id, TaskStatus.NotFound)) + await self._binder.send(client, TaskResult.new_msg(graph_task_cancel.task_id, TaskStatus.NotFound)) return graph_task_id = self._task_id_to_graph_task_id[graph_task_cancel.task_id] @@ -115,7 +105,7 @@ async def on_graph_task_cancel(self, client: bytes, graph_task_cancel: GraphTask if graph_info.status == _GraphState.Canceling: return - await self.__cancel_one_graph(graph_task_id, TaskResult(graph_task_cancel.task_id, TaskStatus.Canceled)) + await self.__cancel_one_graph(graph_task_id, TaskResult.new_msg(graph_task_cancel.task_id, TaskStatus.Canceled)) async def on_graph_sub_task_done(self, result: TaskResult): graph_task_id = self._task_id_to_graph_task_id[result.task_id] @@ -148,22 +138,26 @@ async def __add_new_graph(self, client: bytes, graph_task: GraphTask): self._client_manager.on_task_begin(client, graph_task.task_id) tasks = dict() - depended_task_id_to_task_id = ManyToManyDict() + depended_task_id_to_task_id: ManyToManyDict[bytes, bytes] = ManyToManyDict() for task in graph_task.graph: self._task_id_to_graph_task_id[task.task_id] = graph_task.task_id tasks[task.task_id] = _TaskInfo(_NodeTaskState.Inactive, task) - required_task_ids = {arg.data for arg in task.function_args if arg.type == ArgumentType.Task} + required_task_ids = {arg.data for arg in task.function_args if arg.type == Task.Argument.ArgumentType.Task} for required_task_id in required_task_ids: depended_task_id_to_task_id.add(required_task_id, task.task_id) graph[task.task_id] = required_task_ids await self._binder_monitor.send( - StateGraphTask( + StateGraphTask.new_msg( graph_task.task_id, task.task_id, - NodeTaskType.Target if task.task_id in graph_task.targets else NodeTaskType.Normal, + ( + StateGraphTask.NodeTaskType.Target + if task.task_id in graph_task.targets + else StateGraphTask.NodeTaskType.Normal + ), required_task_ids, ) ) @@ -179,7 +173,7 @@ async def __add_new_graph(self, client: bytes, graph_task: GraphTask): async def __check_one_graph(self, graph_task_id: bytes): graph_info = self._graph_task_id_to_graph[graph_task_id] if not graph_info.sorter.is_active(): - await self.__finish_one_graph(graph_task_id, TaskResult(graph_task_id, TaskStatus.Success)) + await self.__finish_one_graph(graph_task_id, TaskResult.new_msg(graph_task_id, TaskStatus.Success)) return ready_task_ids = graph_info.sorter.get_ready() @@ -191,12 +185,12 @@ async def __check_one_graph(self, graph_task_id: bytes): task_info.state = _NodeTaskState.Running graph_info.running_task_ids.add(task_id) - task = Task( - task_info.task.task_id, - task_info.task.source, - task_info.task.metadata, - task_info.task.func_object_id, - [self.__get_argument(graph_task_id, arg) for arg in task_info.task.function_args], + task = Task.new_msg( + task_id=task_info.task.task_id, + source=task_info.task.source, + metadata=task_info.task.metadata, + func_object_id=task_info.task.func_object_id, + function_args=[self.__get_argument(graph_task_id, arg) for arg in task_info.task.function_args], ) await self._task_manager.on_task_new(graph_info.client, task) @@ -240,7 +234,7 @@ async def __cancel_one_graph(self, graph_task_id: bytes, result: TaskResult): await self.__clean_all_inactive_nodes(graph_task_id, result) await self.__finish_one_graph( - graph_task_id, TaskResult(result.task_id, result.status, result.metadata, result.results) + graph_task_id, TaskResult.new_msg(result.task_id, result.status, result.metadata, result.results) ) async def __clean_all_running_nodes(self, graph_task_id: bytes, result: TaskResult): @@ -261,8 +255,10 @@ async def __clean_all_running_nodes(self, graph_task_id: bytes, result: TaskResu ) new_result_object_ids.append(new_result_object_id) - await self._task_manager.on_task_cancel(graph_info.client, TaskCancel(task_id)) - await self.__mark_node_done(TaskResult(task_id, result.status, result.metadata, new_result_object_ids)) + await self._task_manager.on_task_cancel(graph_info.client, TaskCancel.new_msg(task_id)) + await self.__mark_node_done( + TaskResult.new_msg(task_id, result.status, result.metadata, new_result_object_ids) + ) async def __clean_all_inactive_nodes(self, graph_task_id: bytes, result: TaskResult): graph_info = self._graph_task_id_to_graph[graph_task_id] @@ -280,12 +276,14 @@ async def __clean_all_inactive_nodes(self, graph_task_id: bytes, result: TaskRes ) new_result_object_ids.append(new_result_object_id) - await self.__mark_node_done(TaskResult(task_id, result.status, result.metadata, new_result_object_ids)) + await self.__mark_node_done( + TaskResult.new_msg(task_id, result.status, result.metadata, new_result_object_ids) + ) async def __finish_one_graph(self, graph_task_id: bytes, result: TaskResult): self._client_manager.on_task_finish(graph_task_id) info = self._graph_task_id_to_graph.pop(graph_task_id) - await self._binder.send(info.client, TaskResult(graph_task_id, result.status, results=result.results)) + await self._binder.send(info.client, TaskResult.new_msg(graph_task_id, result.status, results=result.results)) def __is_graph_finished(self, graph_task_id: bytes): graph_info = self._graph_task_id_to_graph[graph_task_id] @@ -299,11 +297,11 @@ def __get_target_results_ids(self, graph_task_id: bytes) -> List[bytes]: for result_object_id in graph_info.tasks[task_id].result_object_ids ] - def __get_argument(self, graph_task_id: bytes, argument: Argument) -> Argument: - if argument.type == ArgumentType.ObjectID: + def __get_argument(self, graph_task_id: bytes, argument: Task.Argument) -> Task.Argument: + if argument.type == Task.Argument.ArgumentType.ObjectID: return argument - assert argument.type == ArgumentType.Task + assert argument.type == Task.Argument.ArgumentType.Task argument_task_id = argument.data graph_info = self._graph_task_id_to_graph[graph_task_id] @@ -311,13 +309,13 @@ def __get_argument(self, graph_task_id: bytes, argument: Argument) -> Argument: assert len(task_info.result_object_ids) == 1 - return Argument(ArgumentType.ObjectID, task_info.result_object_ids[0]) + return Task.Argument(Task.Argument.ArgumentType.ObjectID, task_info.result_object_ids[0]) def __clean_intermediate_result(self, graph_task_id: bytes, task_id: bytes): graph_info = self._graph_task_id_to_graph[graph_task_id] task_info = graph_info.tasks[task_id] - for argument in filter(lambda arg: arg.type == ArgumentType.Task, task_info.task.function_args): + for argument in filter(lambda arg: arg.type == Task.Argument.ArgumentType.Task, task_info.task.function_args): argument_task_id = argument.data graph_info.depended_task_id_to_task_id.remove(argument_task_id, task_id) if graph_info.depended_task_id_to_task_id.has_left_key(argument_task_id): diff --git a/scaler/scheduler/mixins.py b/scaler/scheduler/mixins.py index d1928a1..bb6af88 100644 --- a/scaler/scheduler/mixins.py +++ b/scaler/scheduler/mixins.py @@ -8,21 +8,22 @@ GraphTask, GraphTaskCancel, ObjectRequest, - ObjectResponse, Task, TaskCancel, TaskResult, WorkerHeartbeat, + ObjectInstruction, ) +from scaler.utility.mixins import Reporter -class ObjectManager(metaclass=abc.ABCMeta): +class ObjectManager(Reporter): @abc.abstractmethod - async def on_object_instruction(self, source: bytes, request: ObjectResponse): + async def on_object_instruction(self, source: bytes, request: ObjectInstruction): raise NotImplementedError() @abc.abstractmethod - async def on_object_request(self, source: bytes, response: ObjectRequest): + async def on_object_request(self, source: bytes, request: ObjectRequest): raise NotImplementedError() @abc.abstractmethod @@ -50,7 +51,7 @@ def get_object_content(self, object_id: bytes) -> bytes: raise NotImplementedError() -class ClientManager(metaclass=abc.ABCMeta): +class ClientManager(Reporter): @abc.abstractmethod def get_client_task_ids(self, client: bytes) -> Set[bytes]: raise NotImplementedError() @@ -79,7 +80,7 @@ async def on_client_disconnect(self, client: bytes, request: ClientDisconnect): raise NotImplementedError() -class GraphTaskManager(metaclass=abc.ABCMeta): +class GraphTaskManager(Reporter): @abc.abstractmethod async def on_graph_task(self, client: bytes, graph_task: GraphTask): raise NotImplementedError() @@ -97,23 +98,7 @@ def is_graph_sub_task(self, task_id: bytes): raise NotImplementedError() -class TaskReadyManager(metaclass=abc.ABCMeta): - @abc.abstractmethod - async def on_task_new(self, client: bytes, task: Task): - raise NotImplementedError() - - @abc.abstractmethod - async def on_task_cancel(self, task_cancel: TaskCancel): - raise NotImplementedError() - - -class TaskResultReadyManager(metaclass=abc.ABCMeta): - @abc.abstractmethod - async def on_task_done(self, task_result: TaskResult): - raise NotImplementedError() - - -class TaskManager(metaclass=abc.ABCMeta): +class TaskManager(Reporter): @abc.abstractmethod async def on_task_new(self, client: bytes, task: Task): raise NotImplementedError() @@ -131,7 +116,7 @@ async def on_task_reroute(self, task_id: bytes): raise NotImplementedError() -class WorkerManager(metaclass=abc.ABCMeta): +class WorkerManager(Reporter): @abc.abstractmethod async def assign_task_to_worker(self, task: Task) -> bool: raise NotImplementedError() diff --git a/scaler/scheduler/object_manager.py b/scaler/scheduler/object_manager.py index 9253928..4953e22 100644 --- a/scaler/scheduler/object_manager.py +++ b/scaler/scheduler/object_manager.py @@ -5,15 +5,8 @@ from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import ( - ObjectContent, - ObjectInstruction, - ObjectInstructionType, - ObjectRequest, - ObjectRequestType, - ObjectResponse, - ObjectResponseType, -) +from scaler.protocol.python.common import ObjectContent +from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse from scaler.protocol.python.status import ObjectManagerStatus from scaler.scheduler.mixins import ClientManager, ObjectManager, WorkerManager from scaler.scheduler.object_usage.object_tracker import ObjectTracker, ObjectUsage @@ -34,7 +27,7 @@ def get_object_key(self) -> bytes: class VanillaObjectManager(ObjectManager, Looper, Reporter): def __init__(self): - self._object_storage: ObjectTracker[_ObjectCreation, bytes] = ObjectTracker( + self._object_storage: ObjectTracker[bytes, _ObjectCreation] = ObjectTracker( "object_usage", self.__finished_object_storage ) @@ -58,31 +51,31 @@ def register( self._worker_manager = worker_manager async def on_object_instruction(self, source: bytes, instruction: ObjectInstruction): - if instruction.type == ObjectInstructionType.Create: + if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create: self.__on_object_create(source, instruction) return - if instruction.type == ObjectInstructionType.Delete: + if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete: self.on_del_objects(instruction.object_user, set(instruction.object_content.object_ids)) return logging.error( - f"received unknown object response type instruction_type={instruction.type} from " + f"received unknown object response type instruction_type={instruction.instruction_type} from " f"source={instruction.object_user}" ) async def on_object_request(self, source: bytes, request: ObjectRequest): - if request.type == ObjectRequestType.Get: + if request.request_type == ObjectRequest.ObjectRequestType.Get: await self.__process_get_request(source, request) return - logging.error(f"received unknown object request type {request=} from {source=}") + logging.error(f"received unknown object request type {request=} from {source=!r}") def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes): creation = _ObjectCreation(object_id, object_user, object_name, object_bytes) logging.debug( f"add object cache " - f"object_name={creation.object_name}, " + f"object_name={creation.object_name!r}, " f"object_id={creation.object_id.hex()}, " f"size={format_bytes(len(creation.object_bytes))}" ) @@ -116,10 +109,8 @@ def get_object_content(self, object_id: bytes) -> bytes: return self._object_storage.get_object(object_id).object_bytes def get_status(self) -> ObjectManagerStatus: - return ObjectManagerStatus( - # self._pending_get_requests.object_count(), - self._object_storage.object_count(), - sum(len(v.object_bytes) for _, v in self._object_storage.items()), + return ObjectManagerStatus.new_msg( + self._object_storage.object_count(), sum(len(v.object_bytes) for _, v in self._object_storage.items()) ) async def __process_get_request(self, source: bytes, request: ObjectRequest): @@ -136,12 +127,16 @@ async def __routine_send_objects_deletions(self): for worker in self._worker_manager.get_worker_ids(): await self._binder.send( worker, - ObjectInstruction(ObjectInstructionType.Delete, worker, ObjectContent(tuple(deleted_object_ids))), + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Delete, + worker, + ObjectContent.new_msg(tuple(deleted_object_ids)), + ), ) def __on_object_create(self, source: bytes, instruction: ObjectInstruction): if not self._client_manager.has_client_id(instruction.object_user): - logging.error(f"received object creation from {source} for unknown client {instruction.object_user}") + logging.error(f"received object creation from {source!r} for unknown client {instruction.object_user!r}") return for object_id, object_name, object_content in zip( @@ -154,7 +149,7 @@ def __on_object_create(self, source: bytes, instruction: ObjectInstruction): def __finished_object_storage(self, creation: _ObjectCreation): logging.debug( f"del object cache " - f"object_name={creation.object_name}, " + f"object_name={creation.object_name!r}, " f"object_id={creation.object_id.hex()}, " f"size={format_bytes(len(creation.object_bytes))}" ) @@ -173,7 +168,7 @@ def __construct_response(self, request: ObjectRequest) -> ObjectResponse: object_names.append(object_info.object_name) object_bytes.append(object_info.object_bytes) - return ObjectResponse( - ObjectResponseType.Content, - ObjectContent(tuple(request.object_ids), tuple(object_names), tuple(object_bytes)), + return ObjectResponse.new_msg( + ObjectResponse.ObjectResponseType.Content, + ObjectContent.new_msg(tuple(request.object_ids), tuple(object_names), tuple(object_bytes)), ) diff --git a/scaler/scheduler/object_usage/object_tracker.py b/scaler/scheduler/object_usage/object_tracker.py index d6d495e..a61cdd2 100644 --- a/scaler/scheduler/object_usage/object_tracker.py +++ b/scaler/scheduler/object_usage/object_tracker.py @@ -6,23 +6,22 @@ ObjectKeyType = TypeVar("ObjectKeyType") -class ObjectUsage(metaclass=abc.ABCMeta): +class ObjectUsage(Generic[ObjectKeyType], metaclass=abc.ABCMeta): @abc.abstractmethod def get_object_key(self) -> ObjectKeyType: raise NotImplementedError() ObjectType = TypeVar("ObjectType", bound=ObjectUsage) -BlockType = TypeVar("BlockType") -class ObjectTracker(Generic[ObjectType, BlockType]): +class ObjectTracker(Generic[ObjectKeyType, ObjectType]): def __init__(self, prefix: str, callback: Callable[[ObjectType], None]): self._prefix = prefix self._callback = callback - self._current_blocks: Set[BlockType] = set() - self._object_key_to_block: ManyToManyDict[ObjectKeyType, BlockType] = ManyToManyDict() + self._current_blocks: Set[ObjectKeyType] = set() + self._object_key_to_block: ManyToManyDict[ObjectKeyType, ObjectKeyType] = ManyToManyDict() self._object_key_to_object: Dict[ObjectKeyType, ObjectType] = dict() def object_count(self): @@ -43,7 +42,9 @@ def get_object(self, key: ObjectKeyType) -> ObjectType: def add_object(self, obj: ObjectType): self._object_key_to_object[obj.get_object_key()] = obj - def get_object_block_pairs(self, blocks: Set[BlockType]) -> Generator[Tuple[ObjectKeyType, BlockType], None, None]: + def get_object_block_pairs( + self, blocks: Set[ObjectKeyType] + ) -> Generator[Tuple[ObjectKeyType, ObjectKeyType], None, None]: for block in blocks: if not self._object_key_to_block.has_right_key(block): continue @@ -51,16 +52,16 @@ def get_object_block_pairs(self, blocks: Set[BlockType]) -> Generator[Tuple[Obje for object_key in self._object_key_to_block.get_left_items(block): yield object_key, block - def add_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[BlockType]): + def add_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[ObjectKeyType]): if object_key not in self._object_key_to_object: - raise KeyError(f"cannot find key={object_key.hex()} in ObjectTracker") + raise KeyError(f"cannot find key={object_key} in ObjectTracker") for block in blocks: self._object_key_to_block.add(object_key, block) self._current_blocks.update(blocks) - def remove_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[BlockType]): + def remove_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[ObjectKeyType]): ready_objects = [] for block in blocks: obj = self.__remove_block_for_object(object_key, block) @@ -72,16 +73,16 @@ def remove_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[Bl for obj in ready_objects: self._callback(obj) - def add_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: BlockType): + def add_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: ObjectKeyType): for object_key in object_keys: if object_key not in self._object_key_to_object: - raise KeyError(f"cannot find key={object_key.hex()} in ObjectTracker") + raise KeyError(f"cannot find key={object_key} in ObjectTracker") self._object_key_to_block.add(object_key, block) self._current_blocks.add(block) - def remove_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: BlockType): + def remove_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: ObjectKeyType): ready_objects = [] for object_key in object_keys: obj = self.__remove_block_for_object(object_key, block) @@ -93,7 +94,7 @@ def remove_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: B for obj in ready_objects: self._callback(obj) - def remove_blocks(self, blocks: Set[BlockType]): + def remove_blocks(self, blocks: Set[ObjectKeyType]): ready_objects = [] for block in blocks: if not self._object_key_to_block.has_right_key(block): @@ -110,7 +111,7 @@ def remove_blocks(self, blocks: Set[BlockType]): for obj in ready_objects: self._callback(obj) - def __remove_block_for_object(self, object_key: ObjectKeyType, block: BlockType) -> Optional[ObjectType]: + def __remove_block_for_object(self, object_key: ObjectKeyType, block: ObjectKeyType) -> Optional[ObjectType]: if block not in self._current_blocks: return None diff --git a/scaler/scheduler/scheduler.py b/scaler/scheduler/scheduler.py index 066f97f..e0fe274 100644 --- a/scaler/scheduler/scheduler.py +++ b/scaler/scheduler/scheduler.py @@ -13,7 +13,6 @@ DisconnectRequest, GraphTask, GraphTaskCancel, - MessageVariant, ObjectInstruction, ObjectRequest, Task, @@ -21,6 +20,7 @@ TaskResult, WorkerHeartbeat, ) +from scaler.protocol.python.mixins import Message from scaler.scheduler.client_manager import VanillaClientManager from scaler.scheduler.config import SchedulerConfig from scaler.scheduler.graph_manager import VanillaGraphTaskManager @@ -92,7 +92,7 @@ def __init__(self, config: SchedulerConfig): self._binder, self._client_manager, self._object_manager, self._task_manager, self._worker_manager ) - async def on_receive_message(self, source: bytes, message: MessageVariant): + async def on_receive_message(self, source: bytes, message: Message): # ===================================================================================== # receive from upstream if isinstance(message, ClientHeartbeat): diff --git a/scaler/scheduler/status_reporter.py b/scaler/scheduler/status_reporter.py index 601383d..47910b9 100644 --- a/scaler/scheduler/status_reporter.py +++ b/scaler/scheduler/status_reporter.py @@ -7,7 +7,7 @@ from scaler.protocol.python.message import StateScheduler from scaler.protocol.python.status import Resource from scaler.scheduler.mixins import ClientManager, ObjectManager, TaskManager, WorkerManager -from scaler.utility.mixins import Looper, Reporter +from scaler.utility.mixins import Looper class StatusReporter(Looper): @@ -16,10 +16,10 @@ def __init__(self, binder: AsyncConnector): self._process = psutil.Process() self._binder: Optional[AsyncBinder] = None - self._client_manager: Optional[Reporter] = None - self._object_manager: Optional[Reporter] = None - self._task_manager: Optional[Reporter] = None - self._worker_manager: Optional[Reporter] = None + self._client_manager: Optional[ClientManager] = None + self._object_manager: Optional[ObjectManager] = None + self._task_manager: Optional[TaskManager] = None + self._worker_manager: Optional[WorkerManager] = None def register_managers( self, @@ -37,13 +37,10 @@ def register_managers( async def routine(self): await self._monitor_binder.send( - StateScheduler( + StateScheduler.new_msg( binder=self._binder.get_status(), - scheduler=Resource( - self._process.cpu_percent() / 100, - self._process.memory_info().rss, - psutil.virtual_memory().available, - ), + scheduler=Resource.new_msg(int(self._process.cpu_percent() * 10), self._process.memory_info().rss), + rss_free=psutil.virtual_memory().available, client_manager=self._client_manager.get_status(), object_manager=self._object_manager.get_status(), task_manager=self._task_manager.get_status(), diff --git a/scaler/scheduler/task_manager.py b/scaler/scheduler/task_manager.py index 6998382..30f1f1b 100644 --- a/scaler/scheduler/task_manager.py +++ b/scaler/scheduler/task_manager.py @@ -3,7 +3,8 @@ from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import StateTask, Task, TaskCancel, TaskResult, TaskStatus +from scaler.protocol.python.common import TaskStatus +from scaler.protocol.python.message import StateTask, Task, TaskCancel, TaskResult from scaler.protocol.python.status import TaskManagerStatus from scaler.scheduler.graph_manager import VanillaGraphTaskManager from scaler.scheduler.mixins import ClientManager, ObjectManager, TaskManager, WorkerManager @@ -24,7 +25,7 @@ def __init__(self, max_number_of_tasks_waiting: int): self._task_id_to_task: Dict[bytes, Task] = dict() - self._unassigned: AsyncIndexedQueue[bytes] = AsyncIndexedQueue() + self._unassigned: AsyncIndexedQueue = AsyncIndexedQueue() self._running: Set[bytes] = set() self._success_count: int = 0 @@ -66,7 +67,7 @@ async def routine(self): ) def get_status(self) -> TaskManagerStatus: - return TaskManagerStatus( + return TaskManagerStatus.new_msg( unassigned=self._unassigned.qsize(), running=len(self._running), success=self._success_count, @@ -80,7 +81,7 @@ async def on_task_new(self, client: bytes, task: Task): 0 <= self._max_number_of_tasks_waiting <= self._unassigned.qsize() and not self._worker_manager.has_available_worker() ): - await self._binder.send(client, TaskResult(task.task_id, TaskStatus.NoWorker)) + await self._binder.send(client, TaskResult.new_msg(task.task_id, TaskStatus.NoWorker)) return self._client_manager.on_task_begin(client, task.task_id) @@ -96,11 +97,11 @@ async def on_task_new(self, client: bytes, task: Task): async def on_task_cancel(self, client: bytes, task_cancel: TaskCancel): if task_cancel.task_id not in self._task_id_to_task: logging.warning(f"cannot cancel, task not found: task_id={task_cancel.task_id.hex()}") - await self.on_task_done(TaskResult(task_cancel.task_id, TaskStatus.NotFound)) + await self.on_task_done(TaskResult.new_msg(task_cancel.task_id, TaskStatus.NotFound)) return if task_cancel.task_id in self._unassigned: - await self.on_task_done(TaskResult(task_cancel.task_id, TaskStatus.Canceled)) + await self.on_task_done(TaskResult.new_msg(task_cancel.task_id, TaskStatus.Canceled)) return await self._worker_manager.on_task_cancel(task_cancel) @@ -162,4 +163,4 @@ async def __send_monitor( self, task_id: bytes, function_name: bytes, status: TaskStatus, metadata: Optional[bytes] = b"" ): worker = self._worker_manager.get_worker_by_task_id(task_id) - await self._binder_monitor.send(StateTask(task_id, function_name, status, worker, metadata)) + await self._binder_monitor.send(StateTask.new_msg(task_id, function_name, status, worker, metadata)) diff --git a/scaler/scheduler/worker_manager.py b/scaler/scheduler/worker_manager.py index 3581a85..751bc4e 100644 --- a/scaler/scheduler/worker_manager.py +++ b/scaler/scheduler/worker_manager.py @@ -4,18 +4,16 @@ from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector +from scaler.protocol.python.common import TaskStatus from scaler.protocol.python.message import ( ClientDisconnect, DisconnectRequest, DisconnectResponse, - DisconnectType, StateBalanceAdvice, StateWorker, Task, TaskCancel, - TaskCancelFlags, TaskResult, - TaskStatus, WorkerHeartbeat, WorkerHeartbeatEcho, ) @@ -44,7 +42,7 @@ def __init__( self._worker_alive_since: Dict[bytes, Tuple[float, WorkerHeartbeat]] = dict() self._allocator = QueuedAllocator(per_worker_queue_size) - self._last_balance_advice = None + self._last_balance_advice: Dict[bytes, List[bytes]] = dict() self._load_balance_advice_same_count = 0 def register(self, binder: AsyncBinder, binder_monitor: AsyncConnector, task_manager: TaskManager): @@ -67,7 +65,7 @@ async def on_task_cancel(self, task_cancel: TaskCancel): logging.error(f"cannot find task_id={task_cancel.task_id.hex()} in task workers") return - await self._binder.send(worker, TaskCancel(task_cancel.task_id)) + await self._binder.send(worker, TaskCancel.new_msg(task_cancel.task_id)) async def on_task_result(self, task_result: TaskResult): worker = self._allocator.remove_task(task_result.task_id) @@ -93,11 +91,11 @@ async def on_task_result(self, task_result: TaskResult): async def on_heartbeat(self, worker: bytes, info: WorkerHeartbeat): if await self._allocator.add_worker(worker): - logging.info(f"worker {worker} connected") - await self._binder_monitor.send(StateWorker(worker, b"connected")) + logging.info(f"worker {worker!r} connected") + await self._binder_monitor.send(StateWorker.new_msg(worker, b"connected")) self._worker_alive_since[worker] = (time.time(), info) - await self._binder.send(worker, WorkerHeartbeatEcho()) + await self._binder.send(worker, WorkerHeartbeatEcho.new_msg()) async def on_client_shutdown(self, client: bytes): for worker in self._allocator.get_worker_ids(): @@ -105,7 +103,7 @@ async def on_client_shutdown(self, client: bytes): async def on_disconnect(self, source: bytes, request: DisconnectRequest): await self.__disconnect_worker(request.worker) - await self._binder.send(source, DisconnectResponse(request.worker)) + await self._binder.send(source, DisconnectResponse.new_msg(request.worker)) async def routine(self): await self.__balance_request() @@ -113,49 +111,46 @@ async def routine(self): def get_status(self) -> WorkerManagerStatus: worker_to_task_numbers = self._allocator.statistics() - return WorkerManagerStatus( + return WorkerManagerStatus.new_msg( [ self.__worker_status_from_heartbeat(worker, worker_to_task_numbers[worker], last, info) for worker, (last, info) in self._worker_alive_since.items() ] ) + @staticmethod def __worker_status_from_heartbeat( - self, worker: bytes, worker_task_numbers: Dict, last: float, info: WorkerHeartbeat + worker: bytes, worker_task_numbers: Dict, last: float, info: WorkerHeartbeat ) -> WorkerStatus: current_processor = next((p for p in info.processors if not p.suspended), None) - n_suspended = sum(1 for p in info.processors if p.suspended) + suspended = len([p for p in info.processors if p.suspended]) if current_processor: - ITL = f"{int(current_processor.initialized)}{int(current_processor.has_task)}{int(info.task_lock)}" + debug_info = f"{int(current_processor.initialized)}{int(current_processor.has_task)}{int(info.task_lock)}" else: - ITL = f"00{int(info.task_lock)}" - - processor_statuses = [ - ProcessorStatus( - p.pid, - p.initialized, - p.has_task, - p.suspended, - Resource(p.cpu, p.rss, info.rss_free), - ) - for p in info.processors - ] + debug_info = f"00{int(info.task_lock)}" - return WorkerStatus( + return WorkerStatus.new_msg( worker_id=worker, - agent=Resource(info.agent_cpu, info.agent_rss, info.rss_free), - total_processors=Resource( - sum(p.cpu for p in info.processors), sum(p.rss for p in info.processors), info.rss_free - ), + agent=info.agent, + rss_free=info.rss_free, free=worker_task_numbers["free"], sent=worker_task_numbers["sent"], queued=info.queued_tasks, - suspended=n_suspended, + suspended=suspended, lag_us=info.latency_us, last_s=int(time.time() - last), - ITL=ITL, - processor_statuses=processor_statuses, + itl=debug_info, + processor_statuses=[ + ProcessorStatus.new_msg( + pid=p.pid, + initialized=p.initialized, + has_task=p.has_task, + suspended=p.suspended, + resource=Resource.new_msg(p.resource.cpu, p.resource.rss), + ) + for p in info.processors + ], ) def has_available_worker(self) -> bool: @@ -187,16 +182,17 @@ async def __do_balance(self, current_advice: Dict[bytes, List[bytes]]): if not current_advice: return - logging.info(f"balance: {current_advice}") - for worker, tasks in current_advice.items(): - await self._binder_monitor.send(StateBalanceAdvice(worker, tasks)) + worker_to_num_tasks = {worker: len(task_ids) for worker, task_ids in current_advice.items()} + logging.info(f"balancing task: {worker_to_num_tasks}") + for worker, task_ids in current_advice.items(): + await self._binder_monitor.send(StateBalanceAdvice.new_msg(worker, task_ids)) - task_cancel_flags = TaskCancelFlags(force=True) + task_cancel_flags = TaskCancel.TaskCancelFlags(force=True, retrieve_task_object=False) self._last_balance_advice = current_advice - for worker, tasks in current_advice.items(): - for task in tasks: - await self._binder.send(worker, TaskCancel(task, flags=task_cancel_flags)) + for worker, task_ids in current_advice.items(): + for task_id in task_ids: + await self._binder.send(worker, TaskCancel.new_msg(task_id=task_id, flags=task_cancel_flags)) async def __clean_workers(self): now = time.time() @@ -217,8 +213,8 @@ async def __disconnect_worker(self, worker: bytes): if worker not in self._worker_alive_since: return - logging.info(f"worker {worker} disconnected") - await self._binder_monitor.send(StateWorker(worker, b"disconnected")) + logging.info(f"worker {worker!r} disconnected") + await self._binder_monitor.send(StateWorker.new_msg(worker, b"disconnected")) self._worker_alive_since.pop(worker) task_ids = self._allocator.remove_worker(worker) @@ -229,5 +225,5 @@ async def __disconnect_worker(self, worker: bytes): await self.__reroute_tasks(task_ids) async def __shutdown_worker(self, worker: bytes): - await self._binder.send(worker, ClientDisconnect(DisconnectType.Shutdown)) + await self._binder.send(worker, ClientDisconnect.new_msg(ClientDisconnect.DisconnectType.Shutdown)) await self.__disconnect_worker(worker) diff --git a/scaler/ui/live_display.py b/scaler/ui/live_display.py index 9882da3..d4b7f9e 100644 --- a/scaler/ui/live_display.py +++ b/scaler/ui/live_display.py @@ -40,9 +40,9 @@ def delete_section(self): @dataclasses.dataclass class WorkerRow: worker: str = dataclasses.field(default="") - agt_cpu: int = dataclasses.field(default=0) + agt_cpu: float = dataclasses.field(default=0) agt_rss: int = dataclasses.field(default=0) - cpu: int = dataclasses.field(default=0) + cpu: float = dataclasses.field(default=0) rss: int = dataclasses.field(default=0) rss_free: int = dataclasses.field(default=0) free: int = dataclasses.field(default=0) @@ -50,24 +50,24 @@ class WorkerRow: queued: int = dataclasses.field(default=0) suspended: int = dataclasses.field(default=0) lag: str = dataclasses.field(default="") - ITL: str = dataclasses.field(default="") + itl: str = dataclasses.field(default="") last_seen: str = dataclasses.field(default="") handlers: List[Element] = dataclasses.field(default_factory=list) def populate(self, data: WorkerStatus): self.worker = data.worker_id.decode() - self.agt_cpu = int(data.agent.cpu * 100) + self.agt_cpu = data.agent.cpu / 10 self.agt_rss = int(data.agent.rss / 1e6) - self.cpu = int(data.total_processors.cpu * 100) - self.rss = int(data.total_processors.rss / 1e6) - self.rss_free = int(data.total_processors.rss_free / 1e6) + self.cpu = sum(p.resource.cpu for p in data.processor_statuses) / 10 + self.rss = int(sum(p.resource.rss for p in data.processor_statuses) / 1e6) + self.rss_free = int(data.rss_free / 1e6) self.free = data.free self.sent = data.sent self.queued = data.queued self.suspended = data.suspended self.lag = format_microseconds(data.lag_us) - self.ITL = data.ITL + self.itl = data.itl self.last_seen = format_seconds(data.last_s) def draw_row(self): @@ -102,7 +102,8 @@ def draw_section(self): for worker_row in self.workers.values(): worker_row.draw_row() - def __draw_titles(self): + @staticmethod + def __draw_titles(): ui.label("Worker") ui.label("Agt CPU %") ui.label("Agt RSS (in MB)") diff --git a/scaler/ui/task_graph.py b/scaler/ui/task_graph.py index 349cad3..e08aaf7 100644 --- a/scaler/ui/task_graph.py +++ b/scaler/ui/task_graph.py @@ -5,7 +5,8 @@ from nicegui import ui -from scaler.protocol.python.message import StateTask, TaskStatus +from scaler.protocol.python.common import TaskStatus +from scaler.protocol.python.message import StateTask from scaler.ui.live_display import WorkersSection from scaler.ui.setting_page import Settings from scaler.ui.utility import format_timediff, format_worker_name, get_bounds, make_tick_text, make_ticks @@ -51,10 +52,10 @@ def __init__(self): self._task_id_to_worker: Dict[bytes, str] = {} self._seen_workers = set() - self._lost_workers_queue = SimpleQueue() + self._lost_workers_queue: SimpleQueue[Tuple[datetime.datetime, str]] = SimpleQueue() self._data_update_lock = Lock() - self._busy_workers: List[str] = [] + self._busy_workers: Set[str] = set() self._busy_workers_update_time: datetime.datetime = datetime.datetime.now() def setup_task_stream(self, settings: Settings): @@ -224,14 +225,15 @@ def handle_task_state(self, state: StateTask): if not (worker := state.worker): return - worker = worker.decode() - self._worker_last_update[worker] = now - if worker not in self._seen_workers: - self.__handle_new_worker(worker, now) + worker_string = worker.decode() + self._worker_last_update[worker_string] = now + + if worker_string not in self._seen_workers: + self.__handle_new_worker(worker_string, now) if task_status in {TaskStatus.Running}: - self.__handle_running_task(state, worker, now) + self.__handle_running_task(state, worker_string, now) def __add_lost_worker(self, worker: str, now: datetime.datetime): self._lost_workers_queue.put((now, worker)) @@ -273,7 +275,7 @@ def __split_workers_by_status(self, now: datetime.datetime) -> List[Tuple[str, f def update_data(self, workers_section: WorkersSection): now = datetime.datetime.now() worker_names = sorted(workers_section.workers.keys()) - itls = {w: workers_section.workers[w].ITL for w in worker_names} + itls = {w: workers_section.workers[w].itl for w in worker_names} busy_workers = {w for w in worker_names if len(itls[w]) == 3 and itls[w][1] == "1" and itls[w][2] == "1"} for worker in worker_names: self._worker_last_update[worker] = now diff --git a/scaler/ui/task_log.py b/scaler/ui/task_log.py index 033ac8d..fb444f4 100644 --- a/scaler/ui/task_log.py +++ b/scaler/ui/task_log.py @@ -5,7 +5,8 @@ from nicegui import ui -from scaler.protocol.python.message import StateTask, TaskStatus +from scaler.protocol.python.common import TaskStatus +from scaler.protocol.python.message import StateTask from scaler.utility.formatter import format_bytes from scaler.utility.metadata.profile_result import ProfileResult diff --git a/scaler/ui/webui.py b/scaler/ui/webui.py index f7fc310..e60a136 100644 --- a/scaler/ui/webui.py +++ b/scaler/ui/webui.py @@ -5,7 +5,8 @@ from nicegui import ui from scaler.io.sync_subscriber import SyncSubscriber -from scaler.protocol.python.message import MessageVariant, StateScheduler, StateTask +from scaler.protocol.python.message import StateScheduler, StateTask +from scaler.protocol.python.mixins import Message from scaler.ui.live_display import SchedulerSection, WorkersSection from scaler.ui.memory_window import MemoryChart from scaler.ui.setting_page import Settings @@ -80,7 +81,7 @@ def start_webui(address: str, host: str, port: int): ui_thread.start() -def __show_status(status: MessageVariant, tables: Sections): +def __show_status(status: Message, tables: Sections): if isinstance(status, StateScheduler): __update_scheduler_state(status, tables) return @@ -95,7 +96,7 @@ def __show_status(status: MessageVariant, tables: Sections): def __update_scheduler_state(data: StateScheduler, tables: Sections): tables.scheduler_section.cpu = format_percentage(data.scheduler.cpu) tables.scheduler_section.rss = format_bytes(data.scheduler.rss) - tables.scheduler_section.rss_free = format_bytes(data.scheduler.rss_free) + tables.scheduler_section.rss_free = format_bytes(data.rss_free) previous_workers = set(tables.workers_section.workers.keys()) current_workers = set(worker_data.worker_id.decode() for worker_data in data.worker_manager.workers) diff --git a/scaler/ui/worker_processors.py b/scaler/ui/worker_processors.py index 6c78f79..0986578 100644 --- a/scaler/ui/worker_processors.py +++ b/scaler/ui/worker_processors.py @@ -31,7 +31,7 @@ def update_data(self, data: List[WorkerStatus]): processor_table = self.workers.get(worker.worker_id) if processor_table is None: - processor_table = WorkerProcessorTable(worker_name, worker.processor_statuses) + processor_table = WorkerProcessorTable(worker_name, worker.rss_free, worker.processor_statuses) processor_table.draw_table() self.workers[worker.worker_id] = processor_table elif processor_table.processor_statuses != worker.processor_statuses: @@ -46,6 +46,7 @@ def update_data(self, data: List[WorkerStatus]): @dataclasses.dataclass class WorkerProcessorTable: worker_name: str + rss_free: int processor_statuses: List[ProcessorStatus] handler: Optional[Element] = dataclasses.field(default=None) @@ -61,7 +62,7 @@ def draw_table(self): with ui.grid(columns=6).classes("w-full"): self.draw_titles() for processor in sorted(self.processor_statuses, key=lambda x: x.pid): - self.draw_row(processor) + self.draw_row(processor, self.rss_free) @staticmethod def draw_titles(): @@ -73,10 +74,10 @@ def draw_titles(): ui.label("Suspended") @staticmethod - def draw_row(processor_status: ProcessorStatus): - cpu = int(processor_status.resource.cpu * 100) + def draw_row(processor_status: ProcessorStatus, rss_free: int): + cpu = processor_status.resource.cpu rss = int(processor_status.resource.rss / 1e6) - rss_free = int(processor_status.resource.rss_free / 1e6) + rss_free = int(rss_free / 1e6) ui.label(processor_status.pid) ui.knob(value=cpu, track_color="grey-2", show_value=True, min=0, max=100) diff --git a/scaler/utility/event_list.py b/scaler/utility/event_list.py index 85bec24..49e80c6 100644 --- a/scaler/utility/event_list.py +++ b/scaler/utility/event_list.py @@ -1,4 +1,3 @@ - import collections from typing import Callable @@ -21,6 +20,10 @@ def __delitem__(self, i): super().__delitem__(i) self._list_updated() + def __add__(self, other): + super().__add__(other) + self._list_updated() + def __iadd__(self, other): super().__iadd__(other) self._list_updated() diff --git a/scaler/utility/event_loop.py b/scaler/utility/event_loop.py index 3ab35f6..cdd704e 100644 --- a/scaler/utility/event_loop.py +++ b/scaler/utility/event_loop.py @@ -17,8 +17,8 @@ def register_event_loop(event_loop_type: str): if event_loop_type not in EventLoopType.allowed_types(): raise TypeError(f"allowed event loop types are: {EventLoopType.allowed_types()}") - event_loop_type = EventLoopType[event_loop_type] - if event_loop_type == EventLoopType.uvloop: + event_loop_type_enum = EventLoopType[event_loop_type] + if event_loop_type_enum == EventLoopType.uvloop: try: import uvloop # noqa except ImportError: @@ -26,12 +26,14 @@ def register_event_loop(event_loop_type: str): uvloop.install() - logging.info(f"use event loop: {event_loop_type.value}") + assert event_loop_type in EventLoopType.allowed_types() + + logging.info(f"use event loop: {event_loop_type}") def create_async_loop_routine(routine: Callable[[], Awaitable], seconds: int): async def loop(): - logging.info(f"{routine.__self__.__class__.__name__}: started") + logging.info(f"{routine.__self__.__class__.__name__}: started") # type: ignore[attr-defined] try: while True: await routine() @@ -41,6 +43,6 @@ async def loop(): except KeyboardInterrupt: pass - logging.info(f"{routine.__self__.__class__.__name__}: exited") + logging.info(f"{routine.__self__.__class__.__name__}: exited") # type: ignore[attr-defined] return loop() diff --git a/scaler/utility/formatter.py b/scaler/utility/formatter.py index 3fb7cbd..e02b4c6 100644 --- a/scaler/utility/formatter.py +++ b/scaler/utility/formatter.py @@ -3,29 +3,31 @@ def format_bytes(number) -> str: - for unit in ["b", "k", "m", "g", "t"]: + for unit in ["B", "K", "M", "G", "T"]: if number >= STORAGE_SIZE_MODULUS: number /= STORAGE_SIZE_MODULUS continue - if unit in {"b", "k"}: + if unit in {"B", "K"}: return f"{int(number)}{unit}" return f"{number:.1f}{unit}" + raise ValueError("This should not happen") + def format_integer(number): return f"{number:,}" -def format_percentage(number: float): - return f"{number:.1%}" +def format_percentage(number: int): + return f"{(number/1000):.1%}" def format_microseconds(number: int): for unit in ["us", "ms", "s"]: if number >= TIME_MODULUS: - number /= TIME_MODULUS + number = int(number / TIME_MODULUS) continue if unit == "us": diff --git a/scaler/utility/graph/topological_sorter.py b/scaler/utility/graph/topological_sorter.py index c6c6c47..f85fbef 100644 --- a/scaler/utility/graph/topological_sorter.py +++ b/scaler/utility/graph/topological_sorter.py @@ -6,6 +6,6 @@ logging.info("using GraphBLAS for calculate graph") except ImportError as e: assert isinstance(e, Exception) - from graphlib import TopologicalSorter + from graphlib import TopologicalSorter # type: ignore[assignment, no-redef] assert isinstance(TopologicalSorter, object) diff --git a/scaler/utility/graph/topological_sorter_graphblas.py b/scaler/utility/graph/topological_sorter_graphblas.py index 2429d32..717608f 100644 --- a/scaler/utility/graph/topological_sorter_graphblas.py +++ b/scaler/utility/graph/topological_sorter_graphblas.py @@ -1,31 +1,33 @@ import collections import graphlib import itertools -from typing import Dict, Hashable, Iterable, List, Optional, Tuple +from typing import Hashable, Iterable, List, Optional, Tuple, TypeVar, Generic, Mapping from bidict import bidict try: import graphblas as gb - import numpy as np + import numpy as np # noqa except ImportError: raise ImportError("Please use 'pip install python-graphblas' to have graph blas support") +GraphKeyType = TypeVar("GraphKeyType", bound=Hashable) -class TopologicalSorter: + +class TopologicalSorter(Generic[GraphKeyType]): """ Implements graphlib's TopologicalSorter, but the graph handling is backed by GraphBLAS Reference: https://github.com/python/cpython/blob/4a3ea1fdd890e5e2ec26540dc3c958a52fba6556/Lib/graphlib.py """ - def __init__(self, graph: Optional[Dict[Hashable, Iterable[Hashable]]] = None): + def __init__(self, graph: Optional[Mapping[GraphKeyType, Iterable[GraphKeyType]]] = None): # the layout of the matrix is (in-vertex, out-vertex) self._matrix = gb.Matrix(gb.dtypes.BOOL) - self._key_to_id = bidict() + self._key_to_id: bidict[GraphKeyType, int] = bidict() - self._graph_matrix_mask = None - self._visited_vertices_mask = None - self._ready_nodes = None + self._graph_matrix_mask: Optional[np.ndarray] = None + self._visited_vertices_mask: Optional[np.ndarray] = None + self._ready_nodes: Optional[List[GraphKeyType]] = None self._n_done = 0 self._n_visited = 0 @@ -33,10 +35,10 @@ def __init__(self, graph: Optional[Dict[Hashable, Iterable[Hashable]]] = None): if graph is not None: self.merge_graph(graph) - def add(self, node: Hashable, *predecessors: Hashable) -> None: + def add(self, node: GraphKeyType, *predecessors: GraphKeyType) -> None: self.merge_graph({node: predecessors}) - def merge_graph(self, graph: Dict[Hashable, Iterable[Hashable]]) -> None: + def merge_graph(self, graph: Mapping[GraphKeyType, Iterable[GraphKeyType]]) -> None: if self._ready_nodes is not None: raise ValueError("nodes cannot be added after a call to prepare()") @@ -86,7 +88,7 @@ def prepare(self) -> None: if self._has_cycle(): raise graphlib.CycleError("cycle detected") - def get_ready(self) -> Tuple[Hashable, ...]: + def get_ready(self) -> Tuple[GraphKeyType, ...]: if self._ready_nodes is None: raise ValueError("prepare() must be called first") @@ -102,7 +104,7 @@ def is_active(self) -> bool: def __bool__(self) -> bool: return self.is_active() - def done(self, *nodes: Hashable) -> None: + def done(self, *nodes: GraphKeyType) -> None: if self._ready_nodes is None: raise ValueError("prepare() must be called first") @@ -127,7 +129,7 @@ def done(self, *nodes: Hashable) -> None: self._ready_nodes.extend(new_ready_nodes) self._n_visited += len(new_ready_nodes) - def static_order(self) -> Iterable[Hashable]: + def static_order(self) -> Iterable[GraphKeyType]: self.prepare() while self.is_active(): node_group = self.get_ready() @@ -150,7 +152,7 @@ def _has_cycle(self) -> bool: return True return False - def _get_zero_degree_keys(self) -> List[Hashable]: + def _get_zero_degree_keys(self) -> List[GraphKeyType]: ids = self._get_mask_diff(self._visited_vertices_mask, self._get_zero_degree_mask(self._get_masked_matrix())) return [self._key_to_id.inverse[_id] for _id in ids] @@ -165,7 +167,7 @@ def _get_masked_matrix(self) -> gb.Matrix: def _get_zero_degree_mask(cls, masked_matrix: gb.Matrix) -> np.ndarray: degrees = masked_matrix.reduce_rowwise(gb.monoid.lor) indices, _ = degrees.to_coo(indices=True, values=False, sort=False) - return np.logical_not(np.in1d(np.arange(masked_matrix.nrows), indices)) + return np.logical_not(np.in1d(np.arange(masked_matrix.nrows), indices)) # type: ignore[attr-defined] @staticmethod def _get_mask_diff(old_mask: np.ndarray, new_mask: np.ndarray) -> List[int]: diff --git a/scaler/utility/logging/scoped_logger.py b/scaler/utility/logging/scoped_logger.py index 8beaa3d..6f6da30 100644 --- a/scaler/utility/logging/scoped_logger.py +++ b/scaler/utility/logging/scoped_logger.py @@ -1,6 +1,7 @@ import datetime import logging import time +from typing import Optional class ScopedLogger: @@ -18,7 +19,7 @@ class TimedLogger: def __init__(self, message: str, logging_level=logging.INFO): self.message = message self.logging_level = logging_level - self.timer = None + self.timer: Optional[int] = None def begin(self): self.timer = time.perf_counter_ns() diff --git a/scaler/utility/many_to_many_dict.py b/scaler/utility/many_to_many_dict.py index cca3acc..f162629 100644 --- a/scaler/utility/many_to_many_dict.py +++ b/scaler/utility/many_to_many_dict.py @@ -9,8 +9,8 @@ class ManyToManyDict(Generic[LeftKeyT, RightKeyT]): def __init__(self): - self._left_key_to_right_key_set: _KeyValueDictSet[LeftKeyT, Set[RightKeyT]] = _KeyValueDictSet() - self._right_key_to_left_key_set: _KeyValueDictSet[RightKeyT, Set[LeftKeyT]] = _KeyValueDictSet() + self._left_key_to_right_key_set: _KeyValueDictSet[LeftKeyT, RightKeyT] = _KeyValueDictSet() + self._right_key_to_left_key_set: _KeyValueDictSet[RightKeyT, LeftKeyT] = _KeyValueDictSet() def left_keys(self): return self._left_key_to_right_key_set.keys() diff --git a/scaler/utility/queues/async_priority_queue.py b/scaler/utility/queues/async_priority_queue.py index 98943fe..f02bf57 100644 --- a/scaler/utility/queues/async_priority_queue.py +++ b/scaler/utility/queues/async_priority_queue.py @@ -38,7 +38,7 @@ def remove(self, data): item = self._locator.pop(data) i = self._queue.index(item) # O(n) item[0] = self.__to_lowest_priority(item[0]) - heapq._siftdown(self._queue, 0, i) # noqa + heapq._siftdown(self._queue, 0, i) # type: ignore[attr-defined] assert heapq.heappop(self._queue) == item def decrease_priority(self, data): @@ -47,7 +47,7 @@ def decrease_priority(self, data): item = self._locator[data] i = self._queue.index(item) # O(n) item[0] = self.__to_lower_priority(item[0]) - heapq._siftdown(self._queue, 0, i) # noqa + heapq._siftdown(self._queue, 0, i) # type: ignore[attr-defined] def max_priority(self): item = heapq.heappop(self._queue) diff --git a/scaler/utility/zmq_config.py b/scaler/utility/zmq_config.py index b4198b0..5e5b00f 100644 --- a/scaler/utility/zmq_config.py +++ b/scaler/utility/zmq_config.py @@ -54,15 +54,17 @@ def from_string(string: str) -> "ZMQConfig": if socket_type not in ZMQType.allowed_types(): raise ValueError(f"supported ZMQ types are: {ZMQType.allowed_types()}") - socket_type = ZMQType(socket_type) - if socket_type in {ZMQType.inproc, ZMQType.ipc}: + socket_type_enum = ZMQType(socket_type) + if socket_type_enum in {ZMQType.inproc, ZMQType.ipc}: host = host_port - port = None - else: + port_int = None + elif socket_type_enum == ZMQType.tcp: host, port = host_port.split(":") try: - port = int(port) + port_int = int(port) except ValueError: raise ValueError(f"cannot convert '{port}' to port number") + else: + raise ValueError(f"Unsupported ZMQ type: {socket_type}") - return ZMQConfig(ZMQType(socket_type), host, port) + return ZMQConfig(socket_type_enum, host, port_int) diff --git a/scaler/worker/agent/heartbeat_manager.py b/scaler/worker/agent/heartbeat_manager.py index dd8c063..04823f3 100644 --- a/scaler/worker/agent/heartbeat_manager.py +++ b/scaler/worker/agent/heartbeat_manager.py @@ -4,7 +4,8 @@ import psutil from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import ProcessorHeartbeat, WorkerHeartbeat, WorkerHeartbeatEcho +from scaler.protocol.python.message import WorkerHeartbeat, WorkerHeartbeatEcho, Resource +from scaler.protocol.python.status import ProcessorStatus from scaler.utility.mixins import Looper from scaler.worker.agent.mixins import HeartbeatManager, ProcessorManager, TaskManager, TimeoutManager from scaler.worker.agent.processor_holder import ProcessorHolder @@ -59,14 +60,13 @@ async def routine(self): await self._processor_manager.on_failing_task(self._worker_process.status()) processors = self._processor_manager.processors() - n_suspended_processors = self._processor_manager.n_suspended_processors() + num_suspended_processors = self._processor_manager.num_suspended_processors() await self._connector_external.send( - WorkerHeartbeat( - self._agent_process.cpu_percent() / 100, - self._agent_process.memory_info().rss, + WorkerHeartbeat.new_msg( + Resource.new_msg(int(self._agent_process.cpu_percent() * 10), self._agent_process.memory_info().rss), psutil.virtual_memory().available, - self._worker_task_manager.get_queued_size() - n_suspended_processors, + self._worker_task_manager.get_queued_size() - num_suspended_processors, self._latency_us, self._processor_manager.task_lock(), [self.__get_processor_status_from_holder(processor) for processor in processors], @@ -75,13 +75,12 @@ async def routine(self): self._start_timestamp_ns = time.time_ns() @staticmethod - def __get_processor_status_from_holder(processor: ProcessorHolder) -> ProcessorHeartbeat: + def __get_processor_status_from_holder(processor: ProcessorHolder) -> ProcessorStatus: process = processor.process() - return ProcessorHeartbeat( + return ProcessorStatus.new_msg( processor.pid(), processor.initialized(), processor.task() is not None, processor.suspended(), - process.cpu_percent() / 100, - process.memory_info().rss, + Resource.new_msg(int(process.cpu_percent() * 10), process.memory_info().rss), ) diff --git a/scaler/worker/agent/mixins.py b/scaler/worker/agent/mixins.py index fe81d15..6ad3f53 100644 --- a/scaler/worker/agent/mixins.py +++ b/scaler/worker/agent/mixins.py @@ -2,7 +2,6 @@ from typing import Dict, List, Optional, Set from scaler.protocol.python.message import ( - ObjectContent, ObjectInstruction, ObjectRequest, ObjectResponse, @@ -67,7 +66,7 @@ async def on_task(self, task: Task) -> bool: raise NotImplementedError() @abc.abstractmethod - async def on_cancel_task(self, task_id: bytes) -> bool: + def on_cancel_task(self, task_id: bytes) -> Optional[Task]: raise NotImplementedError() @abc.abstractmethod @@ -107,7 +106,7 @@ def processors(self) -> List[ProcessorHolder]: raise NotImplementedError() @abc.abstractmethod - def n_suspended_processors(self) -> int: + def num_suspended_processors(self) -> int: raise NotImplementedError() @abc.abstractmethod @@ -139,7 +138,7 @@ def on_object_request(self, processor_id: bytes, object_request: ObjectRequest) raise NotImplementedError() @abc.abstractmethod - def on_object_response(self, object_response: ObjectContent) -> Set[bytes]: + def on_object_response(self, object_response: ObjectResponse) -> Set[bytes]: raise NotImplementedError() @abc.abstractmethod diff --git a/scaler/worker/agent/object_tracker.py b/scaler/worker/agent/object_tracker.py index d721ef6..13a28c8 100644 --- a/scaler/worker/agent/object_tracker.py +++ b/scaler/worker/agent/object_tracker.py @@ -1,15 +1,10 @@ +from collections import defaultdict from typing import Dict, Set, Tuple -from scaler.protocol.python.message import ( - ObjectContent, - ObjectInstruction, - ObjectInstructionType, - ObjectRequest, - ObjectResponse, - ObjectResponseType, -) -from scaler.worker.agent.mixins import ObjectTracker +from scaler.protocol.python.common import ObjectContent +from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse from scaler.utility.many_to_many_dict import ManyToManyDict +from scaler.worker.agent.mixins import ObjectTracker class VanillaObjectTracker(ObjectTracker): @@ -23,8 +18,8 @@ def on_object_request(self, processor_id: bytes, object_request: ObjectRequest) def on_object_response(self, object_response: ObjectResponse) -> Set[bytes]: """Returns a list of processor ids that requested this object content.""" - if object_response.type != ObjectResponseType.Content: - raise TypeError(f"invalid object response type received: {object_response.type}.") + if object_response.response_type != ObjectResponse.ObjectResponseType.Content: + raise TypeError(f"invalid object response type received: {object_response.response_type}.") object_ids = object_response.object_content.object_ids @@ -45,26 +40,23 @@ def on_object_instruction(self, object_instruction: ObjectInstruction) -> Dict[b forwarded to processors. """ - if object_instruction.type != ObjectInstructionType.Delete: - raise TypeError(f"invalid object instruction type received: {object_instruction.type}.") + if object_instruction.instruction_type != ObjectInstruction.ObjectInstructionType.Delete: + raise TypeError(f"invalid object instruction type received: {object_instruction.instruction_type}.") - per_processor_object_ids = {} + per_processor_object_ids: Dict[bytes, Set[bytes]] = defaultdict(set) for object_id in object_instruction.object_content.object_ids: if not self._object_id_to_processors_ids.has_left_key(object_id): continue processor_ids = self._object_id_to_processors_ids.remove_left_key(object_id) for processor_id in processor_ids: - if processor_id not in per_processor_object_ids: - per_processor_object_ids[processor_id] = [] - - per_processor_object_ids[processor_id].append(object_id) + per_processor_object_ids[processor_id].add(object_id) return { - processor_id: ObjectInstruction( - type=ObjectInstructionType.Delete, + processor_id: ObjectInstruction.new_msg( + instruction_type=ObjectInstruction.ObjectInstructionType.Delete, object_user=object_instruction.object_user, - object_content=ObjectContent(object_ids=object_ids), + object_content=ObjectContent.new_msg(object_ids=tuple(object_ids)), ) for processor_id, object_ids in per_processor_object_ids.items() } diff --git a/scaler/worker/agent/processor/object_cache.py b/scaler/worker/agent/processor/object_cache.py index b3d0b0a..ee8e42e 100644 --- a/scaler/worker/agent/processor/object_cache.py +++ b/scaler/worker/agent/processor/object_cache.py @@ -12,7 +12,8 @@ from scaler.client.serializer.mixins import Serializer from scaler.io.config import CLEANUP_INTERVAL_SECONDS -from scaler.protocol.python.message import ObjectContent, Task +from scaler.protocol.python.common import ObjectContent +from scaler.protocol.python.message import Task from scaler.utility.exceptions import DeserializeObjectError from scaler.utility.object_utility import generate_serializer_object_id, is_object_id_serializer diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py index 1de1251..a35152f 100644 --- a/scaler/worker/agent/processor/processor.py +++ b/scaler/worker/agent/processor/processor.py @@ -12,21 +12,16 @@ from scaler.io.config import DUMMY_CLIENT from scaler.io.sync_connector import SyncConnector +from scaler.protocol.python.common import TaskStatus, ObjectContent from scaler.protocol.python.message import ( - ArgumentType, - MessageVariant, - ObjectContent, ObjectInstruction, - ObjectInstructionType, ObjectRequest, - ObjectRequestType, ObjectResponse, - ObjectResponseType, ProcessorInitialized, Task, TaskResult, - TaskStatus, ) +from scaler.protocol.python.mixins import Message from scaler.utility.exceptions import MissingObjects from scaler.utility.logging.utility import setup_logger from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id, serialize_failure @@ -56,7 +51,7 @@ def __init__( self._logging_paths = logging_paths self._logging_level = logging_level - self._client_to_decorator = {} + # self._client_to_decorator = {} self._object_cache: Optional[ObjectCache] = None @@ -103,9 +98,12 @@ def __interrupt(self, *args): def __run_forever(self): try: - self._connector.send(ProcessorInitialized()) + self._connector.send(ProcessorInitialized.new_msg()) while True: message = self._connector.receive() + if message is None: + continue + self.__on_connector_receive(message) except zmq.error.ZMQError as e: @@ -121,7 +119,7 @@ def __run_forever(self): self._object_cache.join() - def __on_connector_receive(self, message: MessageVariant): + def __on_connector_receive(self, message: Message): if isinstance(message, ObjectInstruction): self.__on_receive_object_instruction(message) return @@ -137,7 +135,7 @@ def __on_connector_receive(self, message: MessageVariant): logging.error(f"unknown {message=}") def __on_receive_object_instruction(self, instruction: ObjectInstruction): - if instruction.type == ObjectInstructionType.Delete: + if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete: for object_id in instruction.object_content.object_ids: self._object_cache.del_object(object_id) return @@ -145,7 +143,7 @@ def __on_receive_object_instruction(self, instruction: ObjectInstruction): logging.error(f"worker received unknown object instruction type {instruction=}") def __on_receive_object_response(self, response: ObjectResponse): - if response.type == ObjectResponseType.Content: + if response.response_type == ObjectResponse.ObjectResponseType.Content: self.__on_receive_object_content(response.object_content) return @@ -179,7 +177,7 @@ def __on_received_task(self, task: Task): self.__process_task(task) return - self._connector.send(ObjectRequest(ObjectRequestType.Get, unknown_object_ids)) + self._connector.send(ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, unknown_object_ids)) def __get_not_ready_object_ids(self, task: Task) -> Tuple[bytes, ...]: required_object_ids = self.__get_required_object_ids_for_task(task) @@ -189,7 +187,9 @@ def __get_not_ready_object_ids(self, task: Task) -> Tuple[bytes, ...]: def __get_required_object_ids_for_task(task: Task) -> List[bytes]: serializer_id = generate_serializer_object_id(task.source) object_ids = [serializer_id, task.func_object_id] - object_ids.extend([argument.data for argument in task.function_args if argument.type == ArgumentType.ObjectID]) + object_ids.extend( + [argument.data for argument in task.function_args if argument.type == Task.Argument.ArgumentType.ObjectID] + ) return object_ids def __process_task(self, task: Task): @@ -246,13 +246,15 @@ def __send_result(self, source: bytes, task_id: bytes, status: TaskStatus, resul # clients result_object_id = generate_object_id(source, uuid.uuid4().bytes) self._connector.send( - ObjectInstruction( - ObjectInstructionType.Create, + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Create, source, - ObjectContent((result_object_id,), (f"".encode(),), (result_bytes,)), + ObjectContent.new_msg( + (result_object_id,), (f"".encode(),), (result_bytes,) + ), ) ) - self._connector.send(TaskResult(task_id, status, results=[result_object_id])) + self._connector.send(TaskResult.new_msg(task_id, status, metadata=b"", results=[result_object_id])) @staticmethod def __set_current_processor(context: Optional["Processor"]) -> Token: diff --git a/scaler/worker/agent/processor_holder.py b/scaler/worker/agent/processor_holder.py index 59fe1fb..8cd29cd 100644 --- a/scaler/worker/agent/processor_holder.py +++ b/scaler/worker/agent/processor_holder.py @@ -91,7 +91,7 @@ def kill(self): # TODO: some processors fail to interrupt because of a blocking 0mq call. Ideally we should interrupt # these blocking calls instead of sending a SIGKILL signal. - logging.warn(f"Processor[{self.pid()}] does not terminate in time, send SIGKILL.") + logging.warning(f"Processor[{self.pid()}] does not terminate in time, send SIGKILL.") self.__send_signal(signal.SIGKILL) self._processor.join() diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index 0e49925..53c5c80 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -10,18 +10,16 @@ # from scaler.utility.logging.utility import setup_logger from scaler.io.async_binder import AsyncBinder from scaler.io.async_connector import AsyncConnector +from scaler.protocol.python.common import TaskStatus, ObjectContent from scaler.protocol.python.message import ( - MessageVariant, - ObjectContent, ObjectInstruction, - ObjectInstructionType, ObjectRequest, ObjectResponse, ProcessorInitialized, Task, TaskResult, - TaskStatus, ) +from scaler.protocol.python.mixins import Message from scaler.utility.exceptions import ProcessorDiedError from scaler.utility.metadata.profile_result import ProfileResult from scaler.utility.mixins import Looper @@ -97,11 +95,11 @@ async def on_object_instruction(self, instruction: ObjectInstruction): for processor_id, instruction in processor_instructions.items(): await self._binder_internal.send(processor_id, instruction) - async def on_object_response(self, request: ObjectResponse): - processors_ids = self._object_tracker.on_object_response(request) + async def on_object_response(self, response: ObjectResponse): + processors_ids = self._object_tracker.on_object_response(response) for process_id in processors_ids: - await self._binder_internal.send(process_id, request) + await self._binder_internal.send(process_id, response) async def acquire_task_active_lock(self): await self._task_active_lock.acquire() @@ -151,15 +149,15 @@ async def on_failing_task(self, process_status: str): result_object_id = generate_object_id(source, uuid.uuid4().bytes) await self._connector_external.send( - ObjectInstruction( - ObjectInstructionType.Create, + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Create, source, - ObjectContent((result_object_id,), (b"",), (result_object_bytes,)), + ObjectContent.new_msg((result_object_id,), (b"",), (result_object_bytes,)), ) ) await self._task_manager.on_task_result( - TaskResult(task_id, TaskStatus.Failed, profile_result.serialize(), [result_object_id]) + TaskResult.new_msg(task_id, TaskStatus.Failed, profile_result.serialize(), [result_object_id]) ) self.restart_current_processor(f"process died {process_status=}") @@ -235,7 +233,7 @@ def current_task_id(self) -> bytes: def processors(self) -> List[ProcessorHolder]: return list(self._holders_by_processor_id.values()) - def n_suspended_processors(self) -> int: + def num_suspended_processors(self) -> int: return len(self._suspended_holders_by_task_id) def task_lock(self) -> bool: @@ -293,7 +291,7 @@ def __end_task(self, processor_holder: ProcessorHolder) -> ProfileResult: return profile_result - async def __on_receive_internal(self, processor_id: bytes, message: MessageVariant): + async def __on_receive_internal(self, processor_id: bytes, message: Message): if isinstance(message, ProcessorInitialized): await self.__on_internal_processor_initialized(processor_id) return @@ -355,9 +353,14 @@ async def __on_internal_task_result(self, processor_id: bytes, task_result: Task else: return - task_result.metadata = profile_result.serialize() - - await self._task_manager.on_task_result(task_result) + await self._task_manager.on_task_result( + TaskResult.new_msg( + task_id=task_id, + status=task_result.status, + metadata=profile_result.serialize(), + results=task_result.results, + ) + ) def __processor_ready_to_process_object(self, processor_id: bytes) -> bool: holder = self._holders_by_processor_id.get(processor_id) diff --git a/scaler/worker/agent/profiling_manager.py b/scaler/worker/agent/profiling_manager.py index f15fbaa..06e4939 100644 --- a/scaler/worker/agent/profiling_manager.py +++ b/scaler/worker/agent/profiling_manager.py @@ -1,5 +1,5 @@ import dataclasses - +import time from typing import Dict, Optional import psutil @@ -15,7 +15,7 @@ class _ProcessProfiler: current_task_id: Optional[bytes] = None - start_time: Optional[int] = None + start_time: Optional[float] = None init_memory_rss: Optional[int] = None peak_memory_rss: Optional[int] = None @@ -57,7 +57,7 @@ def on_task_end(self, pid: int, task_id: bytes) -> ProfileResult: raise ValueError(f"process {pid=} is not registered.") if task_id != process_profiler.current_task_id: - raise ValueError(f"task {task_id=} is not the current task task_id={process_profiler.current_task_id}.") + raise ValueError(f"task {task_id=!r} is not the current task task_id={process_profiler.current_task_id!r}.") assert process_profiler.start_time is not None assert process_profiler.init_memory_rss is not None @@ -83,8 +83,7 @@ async def routine(self): @staticmethod def __process_cpu_time(process: psutil.Process) -> float: - cpu_times = process.cpu_times() - return cpu_times.user + cpu_times.system + return time.monotonic() @staticmethod def __process_memory_rss(process: psutil.Process) -> int: diff --git a/scaler/worker/agent/task_manager.py b/scaler/worker/agent/task_manager.py index 32a2a58..325ac80 100644 --- a/scaler/worker/agent/task_manager.py +++ b/scaler/worker/agent/task_manager.py @@ -1,7 +1,8 @@ from typing import Dict, Optional, Tuple from scaler.io.async_connector import AsyncConnector -from scaler.protocol.python.message import Task, TaskCancel, TaskResult, TaskStatus +from scaler.protocol.python.common import TaskStatus +from scaler.protocol.python.message import Task, TaskCancel, TaskResult from scaler.utility.metadata.task_flags import retrieve_task_flags_from_task from scaler.utility.mixins import Looper from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue @@ -41,9 +42,12 @@ async def on_cancel_task(self, task_cancel: TaskCancel): task = self._queued_task_id_to_task.pop(task_cancel.task_id) self._queued_task_ids.remove(task_cancel.task_id) - result = TaskResult(task_cancel.task_id, TaskStatus.Canceled) if task_cancel.flags.retrieve_task_object: - result.results = list(task.serialize()) + result = TaskResult.new_msg( + task_cancel.task_id, TaskStatus.Canceled, b"", [task.get_message().to_bytes()] + ) + else: + result = TaskResult.new_msg(task_cancel.task_id, TaskStatus.Canceled) await self._connector_external.send(result) return @@ -51,16 +55,17 @@ async def on_cancel_task(self, task_cancel: TaskCancel): if not task_cancel.flags.force: return - cancelled_running_task = self._processor_manager.on_cancel_task(task_cancel.task_id) - if cancelled_running_task is not None: - result = TaskResult(task_cancel.task_id, TaskStatus.Canceled) - if task_cancel.flags.retrieve_task_object: - result.results = list(cancelled_running_task.serialize()) - - await self._connector_external.send(result) + canceled_running_task = self._processor_manager.on_cancel_task(task_cancel.task_id) + if canceled_running_task is not None: + payload = [canceled_running_task.get_message().to_bytes()] if task_cancel.flags.retrieve_task_object else [] + await self._connector_external.send( + TaskResult.new_msg( + task_id=task_cancel.task_id, status=TaskStatus.Canceled, metadata=b"", results=payload + ) + ) return - await self._connector_external.send(TaskResult(task_cancel.task_id, TaskStatus.NotFound)) + await self._connector_external.send(TaskResult.new_msg(task_cancel.task_id, TaskStatus.NotFound)) async def on_task_result(self, result: TaskResult): if result.task_id in self._queued_task_id_to_task: diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index c0dee7d..4de8d2b 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -11,14 +11,13 @@ from scaler.protocol.python.message import ( ClientDisconnect, DisconnectRequest, - DisconnectType, - MessageVariant, ObjectInstruction, ObjectResponse, Task, TaskCancel, WorkerHeartbeatEcho, ) +from scaler.protocol.python.mixins import Message from scaler.utility.event_loop import create_async_loop_routine, register_event_loop from scaler.utility.exceptions import ClientShutdownException from scaler.utility.logging.utility import setup_logger @@ -125,7 +124,7 @@ def __initialize(self): self.__register_signal() self._task = self._loop.create_task(self.__get_loops()) - async def __on_receive_external(self, message: MessageVariant): + async def __on_receive_external(self, message: Message): if isinstance(message, WorkerHeartbeatEcho): await self._heartbeat_manager.on_heartbeat_echo(message) return @@ -147,7 +146,7 @@ async def __on_receive_external(self, message: MessageVariant): return if isinstance(message, ClientDisconnect): - if message.type == DisconnectType.Shutdown: + if message.disconnect_type == ClientDisconnect.DisconnectType.Shutdown: raise ClientShutdownException("received client shutdown, quitting") logging.error(f"Worker received invalid ClientDisconnect type, ignoring {message=}") return @@ -169,11 +168,11 @@ async def __get_loops(self): except (ClientShutdownException, TimeoutError) as e: logging.info(f"Worker[{self.pid}]: {str(e)}") - await self._connector_external.send(DisconnectRequest(self._connector_external.identity)) + await self._connector_external.send(DisconnectRequest.new_msg(self._connector_external.identity)) self._connector_external.destroy() - self._processor_manager.destroy("quitted") - logging.info(f"Worker[{self.pid}]: quitted") + self._processor_manager.destroy("quited") + logging.info(f"Worker[{self.pid}]: quited") def __run_forever(self): self._loop.run_until_complete(self._task) diff --git a/setup.py b/setup.py deleted file mode 100644 index 28af916..0000000 --- a/setup.py +++ /dev/null @@ -1,32 +0,0 @@ -from setuptools import find_packages, setup - -from scaler.about import __version__ - -with open("requirements.txt", "rt") as f: - requirements = [i.strip() for i in f.readlines()] - -setup( - name="scaler", - version=__version__, - packages=find_packages(exclude=("tests",)), - install_requires=requirements, - extras_require={ - "graphblas": ["python-graphblas", "numpy"], - "uvloop": ["uvloop"], - "gui": ["nicegui[plotly]"], - "all": ["python-graphblas", "numpy", "uvloop", "nicegui[plotly]"], - }, - url="", - license="", - author="Citi", - author_email="opensource@citi.com", - description="Scaler Distributed Framework", - entry_points={ - "console_scripts": [ - "scaler_scheduler=scaler.entry_points.scheduler:main", - "scaler_cluster=scaler.entry_points.cluster:main", - "scaler_top=scaler.entry_points.top:main", - "scaler_ui=scaler.entry_points.webui:main", - ] - }, -) diff --git a/tests/test_graph.py b/tests/test_graph.py index b5a5e07..3d22b41 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -148,7 +148,7 @@ def func(a): def test_cull_graph(self): graph = { - "a": (lambda *_: None), + "a": (lambda *_: None,), "b": (lambda *_: None, "a"), "c": (lambda *_: None, "a"), "d": (lambda *_: None, "b"), @@ -156,8 +156,8 @@ def test_cull_graph(self): "f": (lambda *_: None, "c"), } - def filter_keys(graph, keys): - return {key: value for key, value in graph.items() if key in keys} + def filter_keys(_graph, keys): + return {key: value for key, value in _graph.items() if key in keys} self.assertEqual(cull_graph(graph, ["d"]), filter_keys(graph, ["a", "b", "d"])) self.assertEqual(cull_graph(graph, ["e"]), filter_keys(graph, ["a", "b", "c", "e"])) diff --git a/tests/test_object_usage.py b/tests/test_object_usage.py index 6878348..ca05b5d 100644 --- a/tests/test_object_usage.py +++ b/tests/test_object_usage.py @@ -22,7 +22,7 @@ class TestObjectUsage(unittest.TestCase): def test_object_usage(self): setup_logger() - object_usage = ObjectTracker("sample", sample_ready) + object_usage: ObjectTracker[str, Sample] = ObjectTracker("sample", sample_ready) object_usage.add_object(Sample("a", "value1")) object_usage.add_object(Sample("b", "value2")) diff --git a/tests/test_worker_object_tracker.py b/tests/test_worker_object_tracker.py index d97967b..72e6509 100644 --- a/tests/test_worker_object_tracker.py +++ b/tests/test_worker_object_tracker.py @@ -1,14 +1,7 @@ import unittest -from scaler.protocol.python.message import ( - ObjectContent, - ObjectInstruction, - ObjectInstructionType, - ObjectRequest, - ObjectRequestType, - ObjectResponse, - ObjectResponseType, -) +from scaler.protocol.python.common import ObjectContent +from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse from scaler.worker.agent.object_tracker import VanillaObjectTracker @@ -16,33 +9,45 @@ class TestWorkerObjectTracker(unittest.TestCase): def test_object_tracker(self) -> None: tracker = VanillaObjectTracker() - tracker.on_object_request(b"processor_1", ObjectRequest(ObjectRequestType.Get, (b"object_1", b"object_2"))) + tracker.on_object_request( + b"processor_1", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_1", b"object_2")) + ) - tracker.on_object_request(b"processor_2", ObjectRequest(ObjectRequestType.Get, (b"object_1", b"object_2"))) - tracker.on_object_request(b"processor_2", ObjectRequest(ObjectRequestType.Get, (b"object_3",))) + tracker.on_object_request( + b"processor_2", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_1", b"object_2")) + ) + tracker.on_object_request( + b"processor_2", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_3",)) + ) - tracker.on_object_request(b"processor_3", ObjectRequest(ObjectRequestType.Get, (b"object_4", b"object_5"))) + tracker.on_object_request( + b"processor_3", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_4", b"object_5")) + ) response_1 = tracker.on_object_response( - ObjectResponse(ObjectResponseType.Content, ObjectContent((b"object_1", b"object_2"))) + ObjectResponse.new_msg( + ObjectResponse.ObjectResponseType.Content, ObjectContent.new_msg((b"object_1", b"object_2")) + ) ) self.assertSetEqual(set(response_1), {b"processor_1", b"processor_2"}) response_2 = tracker.on_object_response( - ObjectResponse(ObjectResponseType.Content, ObjectContent((b"object_unknown",))) + ObjectResponse.new_msg( + ObjectResponse.ObjectResponseType.Content, ObjectContent.new_msg((b"object_unknown",)) + ) ) self.assertSetEqual(response_2, set()) response_3 = tracker.on_object_response( - ObjectResponse(ObjectResponseType.Content, ObjectContent((b"object_3",))) + ObjectResponse.new_msg(ObjectResponse.ObjectResponseType.Content, ObjectContent.new_msg((b"object_3",))) ) self.assertSetEqual(response_3, {b"processor_2"}) object_instructions = tracker.on_object_instruction( - ObjectInstruction( - ObjectInstructionType.Delete, + ObjectInstruction.new_msg( + ObjectInstruction.ObjectInstructionType.Delete, b"client", - ObjectContent( + ObjectContent.new_msg( (b"object_1", b"object_2", b"object_3"), (b"name_1", b"name_2", b"name_3"), (b"content_1", b"content_2", b"content_3"),