diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml
new file mode 100644
index 0000000..18847c6
--- /dev/null
+++ b/.github/workflows/linter.yml
@@ -0,0 +1,39 @@
+# This workflow will install Python dependencies, run tests and lint with a single version of Python
+# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
+
+name: Python Linter And Unittest
+
+on:
+ push:
+ branches: [ "main" ]
+ pull_request:
+ branches: [ "main" ]
+
+permissions:
+ contents: read
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v3
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install flake8 pyproject-flake8 mypy
+ pip install -r requirements.txt
+ - name: Lint with flake8
+ run: |
+ pflake8 .
+ - name: Lint with MyPy
+ run: |
+ mypy .
+ - name: Run python unittest
+ run: |
+ python -m unittest discover -v tests
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..4e5153a
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,20 @@
+__pychace__/
+*.py[cod]
+*$py.class
+
+build/
+dist/
+sdist/
+wheels/
+eggs/
+.eggs/
+.idea/
+*.egg-info/
+*.egg
+.mypy_cache/
+
+venv*/
+.vscode/
+
+dask-worker-space/*
+.pre-commit-config.yaml
diff --git a/Makefile b/Makefile
index 21b1dec..937e42b 100644
--- a/Makefile
+++ b/Makefile
@@ -12,7 +12,7 @@ _doc:
rm -fr docsvenv build; mkdir build
python3.8 -m venv docsvenv
. docsvenv/bin/activate; \
- pip install -r docs/requirements_docs.txt; \
- pip install -r requirements.txt; \
- cd docs; make clean && make html
+ pip install -r docs/requirements_docs.txt; \
+ pip install -r requirements.txt; \
+ cd docs; make clean && make html
zip -r build/scaler_docs.zip docs/build/html/*
diff --git a/README.md b/README.md
index 6069d2f..5b24d26 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
-
Citi/scaler
+Citi/scaler
Efficient, lightweight and reliable distributed computation engine.
@@ -22,6 +22,9 @@
with a stable and language agnostic protocol for client and worker communications.
```python
+import math
+from scaler import Client
+
with Client(address="tcp://127.0.0.1:2345") as client:
# Submits 100 tasks
futures = [
@@ -43,20 +46,22 @@ messaging errors, among others.
- Distributed computing on **multiple cores and multiple servers**
- **Python** reference implementation, with **language agnostic messaging protocol** built on top of
-[ZeroMQ](https://zeromq.org)
+ [Cap'n Proto](https://capnproto.org/) and [ZeroMQ](https://zeromq.org)
- **Graph** scheduling, which supports [Dask](https://www.dask.org)-like graph computing, optionally you
-can use [GraphBLAS](https://graphblas.org)
-- **Automated load balancing**. When workers got full of tasks, these will be scheduled to idle workers
+ can use [GraphBLAS](https://graphblas.org) for massive graph tasks
+- **Automated load balancing**. automatically balance busy workers' loads to idle workers, keep every worker as busy as
+ possible
- **Automated recovery** from faulting workers or clients
- Supports for **nested tasks**. Tasks can themselves submit new tasks
- `top`-like **monitoring tools**
+- GUI monitoring tool
Scaler's scheduler can be run on PyPy, which will provide a performance boost
## Installation
```bash
-$ pip instal scaler
+$ pip install scaler
# or with graphblas and uvloop support
$ pip install scaler[graphblas,uvloop]
@@ -77,7 +82,6 @@ A local scheduler and a local set of workers can be conveniently spawn using `Sc
```python
from scaler import SchedulerClusterCombo
-
cluster = SchedulerClusterCombo(address="tcp://127.0.0.1:2345", n_workers=4)
...
@@ -154,9 +158,11 @@ from scaler import Client
def inc(i):
return i + 1
+
def add(a, b):
return a + b
+
def minus(a, b):
return a - b
@@ -169,7 +175,7 @@ graph = {
"e": (minus, "d", "c") # e = d - c = 4 - 3 = 1
}
-with Client(address="tcp://127.0.0.1:2345") as client
+with Client(address="tcp://127.0.0.1:2345") as client:
result = client.get(graph, keys=["e"])
print(result) # {"e": 1}
```
@@ -179,18 +185,21 @@ with Client(address="tcp://127.0.0.1:2345") as client
Scaler allows tasks to submit new tasks while being executed. Scaler also supports recursive task calls.
```python
-def fibonacci(client: Client, n: int):
+from scaler import Client
+
+
+def fibonacci(clnt: Client, n: int):
if n == 0:
return 0
elif n == 1:
return 1
else:
- a = client.submit(fibonacci, client, n - 1)
- b = client.submit(fibonacci, client, n - 2)
+ a = clnt.submit(fibonacci, clnt, n - 1)
+ b = clnt.submit(fibonacci, clnt, n - 2)
return a.result() + b.result()
-with Client(address="tcp://127.0.0.1:2345") as client
+with Client(address="tcp://127.0.0.1:2345") as client:
result = client.submit(fibonacci, client, 8).result()
print(result) # 21
```
@@ -256,11 +265,11 @@ W|Linux|15943|a7fe8b5e+ 0.0% 30.7m 0.0% 28.3m 1000 0 0 |
- function_id_to_tasks section shows task count for each function used
- worker section shows worker details, you can use shortcuts to sort by columns, the char * on column header show which
column is sorted right now
- - agt_cpu/agt_rss means cpu/memory usage of worker agent
- - cpu/rss means cpu/memory usage of worker
- - free means number of free task slots for this worker
- - sent means how many tasks scheduler sent to the worker
- - queued means how many tasks worker received and queued
+ - agt_cpu/agt_rss means cpu/memory usage of worker agent
+ - cpu/rss means cpu/memory usage of worker
+ - free means number of free task slots for this worker
+ - sent means how many tasks scheduler sent to the worker
+ - queued means how many tasks worker received and queued
### From the web UI
@@ -274,7 +283,8 @@ This will open a web server on port `8081`.
## Contributing
-Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly appreciated**.
+Your contributions are at the core of making this a true open source project. Any contributions you make are **greatly
+appreciated**.
We welcome you to:
@@ -297,4 +307,5 @@ This project is distributed under the [Apache-2.0 License](https://www.apache.or
## Contact
-If you have a query or require support with this project, [raise an issue](https://github.com/Citi/scaler/issues). Otherwise, reach out to [opensource@citi.com](mailto:opensource@citi.com).
\ No newline at end of file
+If you have a query or require support with this project, [raise an issue](https://github.com/Citi/scaler/issues).
+Otherwise, reach out to [opensource@citi.com](mailto:opensource@citi.com).
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..bd17dff
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,67 @@
+[build-system]
+requires = ["setuptools", "setuptools-scm", "mypy", "black", "flake8", "pyproject-flake8"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "scaler"
+description = "Scaler Distribution Framework"
+requires-python = ">=3.8"
+readme = { file = "README.md", content-type = "text/markdown" }
+license = { text = "Apache 2.0" }
+authors = [{ name = "Citi", email = "opensource@citi.com" }]
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Intended Audience :: Developers",
+ "Operating System :: OS Independent",
+ "Topic :: System :: Distributed Computing",
+]
+dynamic = ["dependencies", "version"]
+
+[project.urls]
+Home = "https://github.com/Citi/scaler"
+
+[project.scripts]
+scaler_scheduler = "scaler.entry_points.scheduler:main"
+scaler_cluster = "scaler.entry_points.cluster:main"
+scaler_top = "scaler.entry_points.top:main"
+scaler_ui = "scaler.entry_points.webui:main"
+
+[project.optional-dependencies]
+uvloop = ["uvloop"]
+graphblas = ["python-graphblas", "numpy"]
+gui = ["nicegui[plotly]"]
+all = ["python-graphblas", "numpy", "uvloop", "nicegui[plotly]"]
+
+[tool.setuptools]
+packages = ["scaler"]
+include-package-data = true
+
+[tool.setuptools.dynamic]
+dependencies = { file = "requirements.txt" }
+version = { attr = "scaler.about.__version__" }
+
+[tool.mypy]
+no_strict_optional = true
+check_untyped_defs = true
+ignore_missing_imports = true
+exclude = [
+ "^docs.*$",
+ "^benchmark.*$",
+ "^venv.*$"
+]
+
+[tool.flake8]
+count = true
+statistics = true
+max-line-length = 120
+extend-ignore = ["E203"]
+exclude = "venv312"
+
+[tool.black]
+line-length = 120
+skip-magic-trailing-comma = true
+
+[metadata]
+long_description = { file = "README.md" }
+long_description_content_type = "text/markdown"
diff --git a/requirements.txt b/requirements.txt
index 6558092..a3a57df 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
-pyzmq
-psutil
+bidict
cloudpickle
+graphlib-backport; python_version < '3.9'
+psutil
+pycapnp
+pyzmq
tblib
-bidict
-graphlib-backport; python_version < '3.9'
\ No newline at end of file
diff --git a/run_top.py b/run_top.py
index 31635e4..75d50e2 100644
--- a/run_top.py
+++ b/run_top.py
@@ -1,4 +1,5 @@
from scaler.entry_points.top import main
+from scaler.utility.debug import pdb_wrapped
if __name__ == "__main__":
- main()
+ pdb_wrapped(main)()
diff --git a/scaler/about.py b/scaler/about.py
index a8b3783..29654ee 100644
--- a/scaler/about.py
+++ b/scaler/about.py
@@ -1 +1 @@
-__version__ = "1.7.14"
+__version__ = "1.8.0"
diff --git a/scaler/client/agent/client_agent.py b/scaler/client/agent/client_agent.py
index 56c8abf..6bf6507 100644
--- a/scaler/client/agent/client_agent.py
+++ b/scaler/client/agent/client_agent.py
@@ -18,7 +18,6 @@
ClientShutdownResponse,
GraphTask,
GraphTaskCancel,
- MessageVariant,
ObjectInstruction,
ObjectRequest,
ObjectResponse,
@@ -26,6 +25,7 @@
TaskCancel,
TaskResult,
)
+from scaler.protocol.python.mixins import Message
from scaler.utility.event_loop import create_async_loop_routine
from scaler.utility.exceptions import ClientCancelledException, ClientQuitException, ClientShutdownException
from scaler.utility.zmq_config import ZMQConfig
@@ -113,7 +113,7 @@ def run(self):
self.__initialize()
self.__run_loop()
- async def __on_receive_from_client(self, message: MessageVariant):
+ async def __on_receive_from_client(self, message: Message):
if isinstance(message, ClientDisconnect):
await self._disconnect_manager.on_client_disconnect(message)
return
@@ -144,7 +144,7 @@ async def __on_receive_from_client(self, message: MessageVariant):
raise TypeError(f"Unknown {message=}")
- async def __on_receive_from_scheduler(self, message: MessageVariant):
+ async def __on_receive_from_scheduler(self, message: Message):
if isinstance(message, ClientShutdownResponse):
await self._disconnect_manager.on_client_shutdown_response(message)
return
diff --git a/scaler/client/agent/disconnect_manager.py b/scaler/client/agent/disconnect_manager.py
index 08d86ac..50df24b 100644
--- a/scaler/client/agent/disconnect_manager.py
+++ b/scaler/client/agent/disconnect_manager.py
@@ -2,7 +2,7 @@
from scaler.client.agent.mixins import DisconnectManager
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse, DisconnectType
+from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse
from scaler.utility.exceptions import ClientQuitException, ClientShutdownException
@@ -18,7 +18,7 @@ def register(self, connector_internal: AsyncConnector, connector_external: Async
async def on_client_disconnect(self, disconnect: ClientDisconnect):
await self._connector_external.send(disconnect)
- if disconnect.type == DisconnectType.Disconnect:
+ if disconnect.disconnect_type == ClientDisconnect.DisconnectType.Disconnect:
raise ClientQuitException("client disconnecting")
async def on_client_shutdown_response(self, response: ClientShutdownResponse):
diff --git a/scaler/client/agent/future_manager.py b/scaler/client/agent/future_manager.py
index 5f98574..154fcd5 100644
--- a/scaler/client/agent/future_manager.py
+++ b/scaler/client/agent/future_manager.py
@@ -1,12 +1,13 @@
import logging
import threading
-from concurrent.futures import InvalidStateError
+from concurrent.futures import InvalidStateError, Future
from typing import Dict, Tuple
from scaler.client.agent.mixins import FutureManager
from scaler.client.future import ScalerFuture
from scaler.client.serializer.mixins import Serializer
-from scaler.protocol.python.message import ObjectResponse, TaskResult, TaskStatus
+from scaler.protocol.python.common import TaskStatus
+from scaler.protocol.python.message import ObjectResponse, TaskResult
from scaler.utility.exceptions import DisconnectedError, NoWorkerError, TaskNotFoundError, WorkerDiedError
from scaler.utility.metadata.profile_result import retrieve_profiling_result_from_task_result
from scaler.utility.object_utility import deserialize_failure
@@ -20,7 +21,8 @@ def __init__(self, serializer: Serializer):
self._task_id_to_future: Dict[bytes, ScalerFuture] = dict()
self._object_id_to_future: Dict[bytes, Tuple[TaskStatus, ScalerFuture]] = dict()
- def add_future(self, future: ScalerFuture):
+ def add_future(self, future: Future):
+ assert isinstance(future, ScalerFuture)
with self._lock:
future.set_running_or_notify_cancel()
self._task_id_to_future[future.task_id] = future
diff --git a/scaler/client/agent/heartbeat_manager.py b/scaler/client/agent/heartbeat_manager.py
index c8e46bf..d100a2e 100644
--- a/scaler/client/agent/heartbeat_manager.py
+++ b/scaler/client/agent/heartbeat_manager.py
@@ -6,6 +6,7 @@
from scaler.client.agent.mixins import HeartbeatManager
from scaler.io.async_connector import AsyncConnector
from scaler.protocol.python.message import ClientHeartbeat, ClientHeartbeatEcho
+from scaler.protocol.python.status import Resource
from scaler.utility.mixins import Looper
@@ -26,9 +27,12 @@ def register(self, connector_external: AsyncConnector):
self._connector_external = connector_external
async def send_heartbeat(self):
- cpu = self._process.cpu_percent() / 100
- rss = self._process.memory_info().rss
- await self._connector_external.send(ClientHeartbeat(cpu, rss, self._latency_us))
+ await self._connector_external.send(
+ ClientHeartbeat.new_msg(
+ Resource.new_msg(int(self._process.cpu_percent() * 10), self._process.memory_info().rss),
+ self._latency_us,
+ )
+ )
async def on_heartbeat_echo(self, heartbeat: ClientHeartbeatEcho):
if not self._connected:
diff --git a/scaler/client/agent/object_manager.py b/scaler/client/agent/object_manager.py
index 83b9c56..92ff051 100644
--- a/scaler/client/agent/object_manager.py
+++ b/scaler/client/agent/object_manager.py
@@ -1,19 +1,14 @@
-from typing import Optional
+from typing import Optional, Set
from scaler.client.agent.mixins import ObjectManager
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import (
- ObjectContent,
- ObjectInstruction,
- ObjectInstructionType,
- ObjectRequest,
- ObjectRequestType,
-)
+from scaler.protocol.python.common import ObjectContent
+from scaler.protocol.python.message import ObjectInstruction, ObjectRequest
class ClientObjectManager(ObjectManager):
def __init__(self, identity: bytes):
- self._sent_object_ids = set()
+ self._sent_object_ids: Set[bytes] = set()
self._identity = identity
self._connector_internal: Optional[AsyncConnector] = None
@@ -24,13 +19,13 @@ def register(self, connector_internal: AsyncConnector, connector_external: Async
self._connector_external = connector_external
async def on_object_instruction(self, instruction: ObjectInstruction):
- if instruction.type == ObjectInstructionType.Create:
+ if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create:
await self.__send_object_creation(instruction)
- elif instruction.type == ObjectInstructionType.Delete:
+ elif instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete:
await self.__delete_objects(instruction)
async def on_object_request(self, object_request: ObjectRequest):
- assert object_request.type == ObjectRequestType.Get
+ assert object_request.request_type == ObjectRequest.ObjectRequestType.Get
await self._connector_external.send(object_request)
def record_task_result(self, task_id: bytes, object_id: bytes):
@@ -38,17 +33,25 @@ def record_task_result(self, task_id: bytes, object_id: bytes):
async def clean_all_objects(self):
await self._connector_external.send(
- ObjectInstruction(ObjectInstructionType.Delete, self._identity, ObjectContent(tuple(self._sent_object_ids)))
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Delete,
+ self._identity,
+ ObjectContent.new_msg(tuple(self._sent_object_ids)),
+ )
)
self._sent_object_ids = set()
async def __send_object_creation(self, instruction: ObjectInstruction):
- assert instruction.type == ObjectInstructionType.Create
+ assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create
- new_object_content = list(
- zip(
+ new_object_ids = set(instruction.object_content.object_ids) - self._sent_object_ids
+ if not new_object_ids:
+ return
+
+ new_object_content = ObjectContent.new_msg(
+ *zip(
*filter(
- lambda object_pack: object_pack[0] not in self._sent_object_ids,
+ lambda object_pack: object_pack[0] in new_object_ids,
zip(
instruction.object_content.object_ids,
instruction.object_content.object_names,
@@ -58,15 +61,15 @@ async def __send_object_creation(self, instruction: ObjectInstruction):
)
)
- if not new_object_content:
- return
-
- instruction.object_content = ObjectContent(*new_object_content)
+ self._sent_object_ids.update(set(new_object_content.object_ids))
- self._sent_object_ids.update(instruction.object_content.object_ids)
- await self._connector_external.send(instruction)
+ await self._connector_external.send(
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Create, instruction.object_user, new_object_content
+ )
+ )
async def __delete_objects(self, instruction: ObjectInstruction):
- assert instruction.type == ObjectInstructionType.Delete
+ assert instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete
self._sent_object_ids.difference_update(instruction.object_content.object_ids)
await self._connector_external.send(instruction)
diff --git a/scaler/client/agent/task_manager.py b/scaler/client/agent/task_manager.py
index 1c0c021..8b54c77 100644
--- a/scaler/client/agent/task_manager.py
+++ b/scaler/client/agent/task_manager.py
@@ -3,7 +3,8 @@
from scaler.client.agent.future_manager import ClientFutureManager
from scaler.client.agent.mixins import ObjectManager, TaskManager
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import GraphTask, GraphTaskCancel, Task, TaskCancel, TaskResult, TaskStatus
+from scaler.protocol.python.common import TaskStatus
+from scaler.protocol.python.message import GraphTask, GraphTaskCancel, Task, TaskCancel, TaskResult
class ClientTaskManager(TaskManager):
diff --git a/scaler/client/client.py b/scaler/client/client.py
index 79df3c9..100a558 100644
--- a/scaler/client/client.py
+++ b/scaler/client/client.py
@@ -1,10 +1,10 @@
+import dataclasses
import functools
import logging
import os
import threading
import uuid
from collections import Counter
-from concurrent.futures import Future
from inspect import signature
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
@@ -20,15 +20,7 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.config import DEFAULT_CLIENT_TIMEOUT_SECONDS, DEFAULT_HEARTBEAT_INTERVAL_SECONDS
from scaler.io.sync_connector import SyncConnector
-from scaler.protocol.python.message import (
- Argument,
- ArgumentType,
- ClientDisconnect,
- ClientShutdownResponse,
- DisconnectType,
- GraphTask,
- Task,
-)
+from scaler.protocol.python.message import ClientDisconnect, ClientShutdownResponse, GraphTask, Task
from scaler.utility.exceptions import ClientQuitException
from scaler.utility.graph.optimization import cull_graph
from scaler.utility.graph.topological_sorter import TopologicalSorter
@@ -38,6 +30,23 @@
from scaler.worker.agent.processor.processor import Processor
+@dataclasses.dataclass
+class _CallNode:
+ func: Callable
+ args: Tuple[str, ...]
+
+ def __post_init__(self):
+ if not callable(self.func):
+ raise TypeError(f"the first item of the tuple must be function, get {self.func}")
+
+ if not isinstance(self.args, tuple):
+ raise TypeError(f"arguments must be tuple, get {self.args}")
+
+ for arg in self.args:
+ if not isinstance(arg, str):
+ raise TypeError(f"argument `{arg}` must be a string and the string has to be in the graph")
+
+
class Client:
def __init__(
self,
@@ -59,6 +68,16 @@ def __init__(
:param heartbeat_interval_seconds: Frequency of heartbeat to scheduler in seconds
:type heartbeat_interval_seconds: int
"""
+ self.__initialize__(address, profiling, timeout_seconds, heartbeat_interval_seconds, serializer)
+
+ def __initialize__(
+ self,
+ address: str,
+ profiling: bool,
+ timeout_seconds: int,
+ heartbeat_interval_seconds: int,
+ serializer: Serializer = DefaultSerializer(),
+ ):
self._serializer = serializer
self._profiling = profiling
@@ -149,7 +168,8 @@ def fibonacci(client: Client, n: int):
}
def __setstate__(self, state: dict) -> None:
- self.__init__(
+ # TODO: fix copy the serializer
+ self.__initialize__(
address=state["address"],
profiling=state["profiling"],
timeout_seconds=state["timeout_seconds"],
@@ -201,8 +221,8 @@ def map(self, fn: Callable, iterable: Iterable[Tuple[Any, ...]]) -> List[Any]:
return results
def get(
- self, graph: Dict[str, Union[Any, Tuple[Union[Callable, Any], ...]]], keys: List[str], block: bool = True
- ) -> Dict[str, Union[Any, Future]]:
+ self, graph: Dict[str, Union[Any, Tuple[Union[Callable, str], ...]]], keys: List[str], block: bool = True
+ ) -> Dict[str, Union[Any, ScalerFuture]]:
"""
.. code-block:: python
:linenos:
@@ -227,16 +247,26 @@ def get(
graph = cull_graph(graph, keys)
- node_name_to_argument, graph = self.__split_data_and_graph(graph)
- self.__check_graph(node_name_to_argument, graph, keys)
+ node_name_to_argument, call_graph = self.__split_data_and_graph(graph)
+ self.__check_graph(node_name_to_argument, call_graph, keys)
- graph, compute_futures, finished_futures = self.__construct_graph(node_name_to_argument, graph, keys, block)
+ graph_task, compute_futures, finished_futures = self.__construct_graph(
+ node_name_to_argument, call_graph, keys, block
+ )
self._object_buffer.commit_send_objects()
- self._connector.send(graph)
+ self._connector.send(graph_task)
self._future_manager.add_future(
self._future_factory(
- task=Task(graph.task_id, self._identity, b"", b"", []), is_delayed=not block, group_task_id=None
+ task=Task.new_msg(
+ task_id=graph_task.task_id,
+ source=self._identity,
+ metadata=b"",
+ func_object_id=b"",
+ function_args=[],
+ ),
+ is_delayed=not block,
+ group_task_id=None,
)
)
for future in compute_futures.values():
@@ -294,12 +324,12 @@ def disconnect(self):
self._future_manager.cancel_all_futures()
- self._connector.send(ClientDisconnect(DisconnectType.Disconnect))
+ self._connector.send(ClientDisconnect.new_msg(ClientDisconnect.DisconnectType.Disconnect))
self.__destroy()
def __receive_shutdown_response(self):
- message = None
+ message: Optional[ClientShutdownResponse] = None
while not isinstance(message, ClientShutdownResponse):
message = self._connector.receive()
@@ -321,7 +351,7 @@ def shutdown(self):
self._future_manager.cancel_all_futures()
- self._connector.send(ClientDisconnect(DisconnectType.Shutdown))
+ self._connector.send(ClientDisconnect.new_msg(ClientDisconnect.DisconnectType.Shutdown))
try:
self.__receive_shutdown_response()
finally:
@@ -339,8 +369,14 @@ def __submit(self, function_object_id: bytes, args: Tuple[Any, ...], delayed: bo
task_flags_bytes = self.__get_task_flags().serialize()
- arguments = [Argument(ArgumentType.ObjectID, object_id) for object_id in object_ids]
- task = Task(task_id, self._identity, task_flags_bytes, function_object_id, arguments)
+ arguments = [Task.Argument(Task.Argument.ArgumentType.ObjectID, object_id) for object_id in object_ids]
+ task = Task.new_msg(
+ task_id=task_id,
+ source=self._identity,
+ metadata=task_flags_bytes,
+ func_object_id=function_object_id,
+ function_args=arguments,
+ )
future = self._future_factory(task=task, is_delayed=delayed, group_task_id=None)
self._future_manager.add_future(future)
@@ -357,48 +393,45 @@ def __convert_kwargs_to_args(fn: Callable, args: Tuple[Any, ...], kwargs: Dict[s
number_of_required = len([p for p in params if p.default is p.empty])
- args = list(args)
+ args_list = list(args)
kwargs = kwargs.copy()
kwargs.update({p.name: p.default for p in all_params if p.kind == p.KEYWORD_ONLY if p.default != p.empty})
- for p in params[len(args) : number_of_required]:
+ for p in params[len(args_list) : number_of_required]:
try:
- args.append(kwargs.pop(p.name))
+ args_list.append(kwargs.pop(p.name))
except KeyError:
- missing = tuple(p.name for p in params[len(args) : number_of_required])
+ missing = tuple(p.name for p in params[len(args_list) : number_of_required])
raise TypeError(f"{fn} missing {len(missing)} arguments: {missing}")
- for p in params[len(args) :]:
- args.append(kwargs.pop(p.name, p.default))
+ for p in params[len(args_list) :]:
+ args_list.append(kwargs.pop(p.name, p.default))
- return tuple(args)
+ return tuple(args_list)
def __split_data_and_graph(
- self, graph: Dict[str, Union[Any, Tuple[Union[Callable, Any], ...]]]
- ) -> Tuple[Dict[str, Tuple[Argument, Any]], Dict[str, Tuple[Union[Callable, Any], ...]]]:
- graph = graph.copy()
- node_name_to_argument = {}
+ self, graph: Dict[str, Union[Any, Tuple[Union[Callable, str], ...]]]
+ ) -> Tuple[Dict[str, Tuple[Task.Argument, Any]], Dict[str, _CallNode]]:
+ call_graph = {}
+ node_name_to_argument: Dict[str, Tuple[Task.Argument, Union[Any, Tuple[Union[Callable, Any], ...]]]] = dict()
for node_name, node in graph.items():
if isinstance(node, tuple) and len(node) > 0 and callable(node[0]):
+ call_graph[node_name] = _CallNode(func=node[0], args=node[1:]) # type: ignore[arg-type]
continue
if isinstance(node, ObjectReference):
- node_name_to_argument[node_name] = (Argument(ArgumentType.ObjectID, node.object_id), None)
+ object_id = node.object_id
else:
- object_cache = self._object_buffer.buffer_send_object(node, name=node_name)
- node_name_to_argument[node_name] = (Argument(ArgumentType.ObjectID, object_cache.object_id), node)
+ object_id = self._object_buffer.buffer_send_object(node, name=node_name).object_id
- for node_name in node_name_to_argument.keys():
- graph.pop(node_name)
+ node_name_to_argument[node_name] = (Task.Argument(Task.Argument.ArgumentType.ObjectID, object_id), node)
- return node_name_to_argument, graph
+ return node_name_to_argument, call_graph
@staticmethod
def __check_graph(
- node_to_argument: Dict[str, Tuple[Argument, Any]],
- graph: Dict[str, Union[Any, Tuple[Union[Callable, Any], ...]]],
- keys: List[str],
+ node_to_argument: Dict[str, Tuple[Task.Argument, Any]], call_graph: Dict[str, _CallNode], keys: List[str]
):
duplicate_keys = [key for key, count in dict(Counter(keys)).items() if count > 1]
if duplicate_keys:
@@ -406,76 +439,78 @@ def __check_graph(
# sanity check graph
for key in keys:
- if key not in graph and key not in node_to_argument:
+ if key not in call_graph and key not in node_to_argument:
raise KeyError(f"key {key} has to be in graph")
- sorter = TopologicalSorter()
- for node_name, node in graph.items():
- assert (
- isinstance(node, tuple) and len(node) > 0 and callable(node[0])
- ), "node has to be tuple and first item should be function"
-
- for arg in node[1:]:
- if arg not in node_to_argument and arg not in graph:
- raise KeyError(f"argument {arg} in node '{node_name}': {tuple(node)} is not defined in graph")
+ sorter: TopologicalSorter[str] = TopologicalSorter()
+ for node_name, node in call_graph.items():
+ for arg in node.args:
+ if arg not in node_to_argument and arg not in call_graph:
+ raise KeyError(f"argument {arg} in node '{node_name}': {node} is not defined in graph")
- sorter.add(node_name, *node[1:])
+ sorter.add(node_name, *node.args)
# check cyclic dependencies
sorter.prepare()
def __construct_graph(
self,
- node_name_to_arguments: Dict[str, Tuple[Argument, Any]],
- graph: Dict[str, Tuple[Union[Callable, Any], ...]],
+ node_name_to_arguments: Dict[str, Tuple[Task.Argument, Any]],
+ call_graph: Dict[str, _CallNode],
keys: List[str],
block: bool,
) -> Tuple[GraphTask, Dict[str, ScalerFuture], Dict[str, ScalerFuture]]:
graph_task_id = uuid.uuid1().bytes
- node_name_to_task_id = {node_name: uuid.uuid1().bytes for node_name in graph.keys()}
+ node_name_to_task_id: Dict[str, bytes] = {node_name: uuid.uuid1().bytes for node_name in call_graph.keys()}
task_flags_bytes = self.__get_task_flags().serialize()
task_id_to_tasks = dict()
- for node_name, node in graph.items():
+ for node_name, node in call_graph.items():
task_id = node_name_to_task_id[node_name]
+ function_cache = self._object_buffer.buffer_send_function(node.func)
- function, *args = node
- function_cache = self._object_buffer.buffer_send_function(function)
-
- arguments = []
- for arg in args:
- assert arg in graph or arg in node_name_to_arguments
+ arguments: List[Task.Argument] = []
+ for arg in node.args:
+ assert arg in call_graph or arg in node_name_to_arguments
- if arg in graph:
- argument = Argument(ArgumentType.Task, node_name_to_task_id[arg])
- else:
+ if arg in call_graph:
+ arguments.append(
+ Task.Argument(type=Task.Argument.ArgumentType.Task, data=node_name_to_task_id[arg])
+ )
+ elif arg in node_name_to_arguments:
argument, _ = node_name_to_arguments[arg]
-
- arguments.append(argument)
-
- task_id_to_tasks[task_id] = Task(
- task_id, self._identity, task_flags_bytes, function_cache.object_id, arguments
+ arguments.append(argument)
+ else:
+ raise ValueError("Not possible")
+
+ task_id_to_tasks[task_id] = Task.new_msg(
+ task_id=task_id,
+ source=self._identity,
+ metadata=task_flags_bytes,
+ func_object_id=function_cache.object_id,
+ function_args=arguments,
)
- result_task_ids = [node_name_to_task_id[key] for key in keys if key in graph]
- graph_task = GraphTask(graph_task_id, self._identity, result_task_ids, list(task_id_to_tasks.values()))
+ result_task_ids = [node_name_to_task_id[key] for key in keys if key in call_graph]
+ graph_task = GraphTask.new_msg(graph_task_id, self._identity, result_task_ids, list(task_id_to_tasks.values()))
compute_futures = {}
ready_futures = {}
for key in keys:
- if key in graph:
- future = self._future_factory(
+ if key in call_graph:
+ compute_futures[key] = self._future_factory(
task=task_id_to_tasks[node_name_to_task_id[key]], is_delayed=not block, group_task_id=graph_task_id
)
- compute_futures[key] = future
elif key in node_name_to_arguments:
argument, data = node_name_to_arguments[key]
future: ScalerFuture = self._future_factory(
- task=Task(argument.data, self._identity, b"", b"", []),
+ task=Task.new_msg(
+ task_id=argument.data, source=self._identity, metadata=b"", func_object_id=b"", function_args=[]
+ ),
is_delayed=False,
group_task_id=graph_task_id,
)
@@ -498,6 +533,14 @@ def __get_task_flags(self) -> TaskFlags:
return TaskFlags(profiling=self._profiling, priority=task_priority)
+ def __assert_client_not_stopped(self):
+ if self._stop_event.is_set():
+ raise ClientQuitException("client is already stopped.")
+
+ def __destroy(self):
+ self._agent.join()
+ self._internal_context.destroy(linger=1)
+
@staticmethod
def __get_parent_task_priority() -> Optional[int]:
"""If the client is running inside a Scaler processor, returns the priority of the associated task."""
@@ -511,11 +554,3 @@ def __get_parent_task_priority() -> Optional[int]:
assert current_task is not None
return retrieve_task_flags_from_task(current_task).priority
-
- def __assert_client_not_stopped(self):
- if self._stop_event.is_set():
- raise ClientQuitException("client is already stopped.")
-
- def __destroy(self):
- self._agent.join()
- self._internal_context.destroy(linger=1)
diff --git a/scaler/client/future.py b/scaler/client/future.py
index 359a199..c5379ad 100644
--- a/scaler/client/future.py
+++ b/scaler/client/future.py
@@ -3,17 +3,17 @@
from typing import Any, Callable, Optional
from scaler.io.sync_connector import SyncConnector
-from scaler.protocol.python.message import ObjectRequest, ObjectRequestType, Task, TaskCancel
-from scaler.utility.metadata.profile_result import ProfileResult
+from scaler.protocol.python.message import ObjectRequest, Task, TaskCancel
from scaler.utility.event_list import EventList
+from scaler.utility.metadata.profile_result import ProfileResult
class ScalerFuture(Future):
def __init__(self, task: Task, is_delayed: bool, group_task_id: Optional[bytes], connector: SyncConnector):
super().__init__()
- self._waiters = EventList(self._waiters)
- self._waiters.add_update_callback(self._on_waiters_updated)
+ self._waiters = EventList(self._waiters) # type: ignore[assignment]
+ self._waiters.add_update_callback(self._on_waiters_updated) # type: ignore[attr-defined]
self._task_id: bytes = task.task_id
self._is_delayed: bool = is_delayed
@@ -31,14 +31,14 @@ def task_id(self):
return self._task_id
def profiling_info(self) -> ProfileResult:
- with self._condition:
+ with self._condition: # type: ignore[attr-defined]
if self._profiling_info is None:
raise ValueError(f"didn't receive profiling info for {self} yet")
return self._profiling_info
def set_result_ready(self, object_id: Optional[bytes], profile_result: Optional[ProfileResult] = None) -> None:
- with self._condition:
+ with self._condition: # type: ignore[attr-defined]
if self.done():
raise InvalidStateError(f"invalid future state: {self._state}")
@@ -55,7 +55,7 @@ def set_result_ready(self, object_id: Optional[bytes], profile_result: Optional[
self._result_ready_event.set()
def set_exception(self, exception: Optional[BaseException], profile_result: Optional[ProfileResult] = None) -> None:
- with self._condition:
+ with self._condition: # type: ignore[attr-defined]
if profile_result is not None:
self._profiling_info = profile_result
@@ -66,7 +66,7 @@ def set_exception(self, exception: Optional[BaseException], profile_result: Opti
def result(self, timeout=None):
self._result_ready_event.wait(timeout)
- with self._condition:
+ with self._condition: # type: ignore[attr-defined]
# if it's delayed future, get the result when future.result() get called
if self._is_delayed:
self._request_result_object()
@@ -75,7 +75,7 @@ def result(self, timeout=None):
return super().result(timeout)
def cancel(self) -> bool:
- with self._condition:
+ with self._condition: # type: ignore[attr-defined]
if self.cancelled():
return True
@@ -83,20 +83,20 @@ def cancel(self) -> bool:
return False
if self._group_task_id is not None:
- self._connector.send(TaskCancel(self._group_task_id))
+ self._connector.send(TaskCancel.new_msg(self._group_task_id))
else:
- self._connector.send(TaskCancel(self._task_id))
+ self._connector.send(TaskCancel.new_msg(self._task_id))
self._state = "CANCELLED"
self._result_ready_event.set()
- self._condition.notify_all()
+ self._condition.notify_all() # type: ignore[attr-defined]
- self._invoke_callbacks()
+ self._invoke_callbacks() # type: ignore[attr-defined]
return True
def add_done_callback(self, fn: Callable[[Future], Any]) -> None:
- with self._condition:
+ with self._condition: # type: ignore[attr-defined]
# if it's delayed future, get the result when a callback gets added
if self._is_delayed:
self._request_result_object()
@@ -104,17 +104,17 @@ def add_done_callback(self, fn: Callable[[Future], Any]) -> None:
return super().add_done_callback(fn)
def _on_waiters_updated(self, waiters: EventList):
- with self._condition:
+ with self._condition: # type: ignore[attr-defined]
# if it's delayed future, get the result when waiter gets added
if self._is_delayed and len(self._waiters) > 0:
self._request_result_object()
def _has_result_listeners(self) -> bool:
- return len(self._done_callbacks) > 0 or len(self._waiters) > 0
+ return len(self._done_callbacks) > 0 or len(self._waiters) > 0 # type: ignore[attr-defined]
def _request_result_object(self):
if self._result_request_sent or self._result_object_id is None or self.cancelled():
return
- self._connector.send(ObjectRequest(ObjectRequestType.Get, (self._result_object_id,)))
+ self._connector.send(ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (self._result_object_id,)))
self._result_request_sent = True
diff --git a/scaler/client/object_buffer.py b/scaler/client/object_buffer.py
index dbcec48..bbb1050 100644
--- a/scaler/client/object_buffer.py
+++ b/scaler/client/object_buffer.py
@@ -6,7 +6,8 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.sync_connector import SyncConnector
-from scaler.protocol.python.message import ObjectContent, ObjectInstruction, ObjectInstructionType
+from scaler.protocol.python.common import ObjectContent
+from scaler.protocol.python.message import ObjectInstruction
from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id
@@ -56,7 +57,11 @@ def commit_send_objects(self):
]
self._connector.send(
- ObjectInstruction(ObjectInstructionType.Create, self._identity, ObjectContent(*zip(*objects_to_send)))
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Create,
+ self._identity,
+ ObjectContent.new_msg(*zip(*objects_to_send)),
+ )
)
self._pending_objects = list()
@@ -66,8 +71,10 @@ def commit_delete_objects(self):
return
self._connector.send(
- ObjectInstruction(
- ObjectInstructionType.Delete, self._identity, ObjectContent(tuple(self._pending_delete_objects))
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Delete,
+ self._identity,
+ ObjectContent.new_msg(tuple(self._pending_delete_objects)),
)
)
@@ -89,5 +96,5 @@ def __construct_function(self, fn: Callable) -> ObjectCache:
def __construct_object(self, obj: Any, name: Optional[str] = None) -> ObjectCache:
object_payload = self._serializer.serialize(obj)
object_id = generate_object_id(self._identity, object_payload)
- name_bytes = name.encode() if name else f"".encode()
+ name_bytes = name.encode() if name else f"".encode()
return ObjectCache(object_id, name_bytes, object_payload)
diff --git a/scaler/client/object_reference.py b/scaler/client/object_reference.py
index e4f7c79..ed95195 100644
--- a/scaler/client/object_reference.py
+++ b/scaler/client/object_reference.py
@@ -2,20 +2,20 @@
@dataclasses.dataclass
-class ObjectReference(object):
+class ObjectReference:
name: bytes
object_id: bytes
size: int
def __repr__(self):
- return f"ScalerReference(name={self.name}, id={self.object_id}, size={self.size})"
+ return f"ScalerReference(name={self.name!r}, id={self.object_id!r}, size={self.size})"
def __hash__(self):
return hash(self.object_id)
- def __eq__(self, other: "ObjectReference"):
+ def __eq__(self, other: object) -> bool:
if not isinstance(other, ObjectReference):
- return False
+ return NotImplemented
return self.object_id == other.object_id
diff --git a/scaler/cluster/cluster.py b/scaler/cluster/cluster.py
index 1f3cc67..2303a17 100644
--- a/scaler/cluster/cluster.py
+++ b/scaler/cluster/cluster.py
@@ -9,7 +9,7 @@
from scaler.worker.worker import Worker
-class Cluster(multiprocessing.get_context("spawn").Process):
+class Cluster(multiprocessing.get_context("spawn").Process): # type: ignore[misc]
def __init__(
self,
address: ZMQConfig,
diff --git a/scaler/cluster/scheduler.py b/scaler/cluster/scheduler.py
index 1926a62..347721a 100644
--- a/scaler/cluster/scheduler.py
+++ b/scaler/cluster/scheduler.py
@@ -1,6 +1,7 @@
import asyncio
import multiprocessing
-from typing import Optional, Tuple
+from asyncio import AbstractEventLoop, Task
+from typing import Optional, Tuple, Any
from scaler.scheduler.config import SchedulerConfig
from scaler.scheduler.scheduler import Scheduler, scheduler_main
@@ -9,7 +10,7 @@
from scaler.utility.zmq_config import ZMQConfig
-class SchedulerProcess(multiprocessing.get_context("spawn").Process):
+class SchedulerProcess(multiprocessing.get_context("spawn").Process): # type: ignore[misc]
def __init__(
self,
address: ZMQConfig,
@@ -45,8 +46,8 @@ def __init__(
self._logging_config_file = logging_config_file
self._scheduler: Optional[Scheduler] = None
- self._loop = None
- self._task = None
+ self._loop: Optional[AbstractEventLoop] = None
+ self._task: Optional[Task[Any]] = None
def run(self) -> None:
# scheduler have its own single process
diff --git a/scaler/entry_points/top.py b/scaler/entry_points/top.py
index 9b01907..c8c598e 100644
--- a/scaler/entry_points/top.py
+++ b/scaler/entry_points/top.py
@@ -1,10 +1,11 @@
import argparse
import curses
import functools
-from typing import List, Literal
+from typing import List, Literal, Dict, Union
from scaler.io.sync_subscriber import SyncSubscriber
-from scaler.protocol.python.message import MessageVariant, StateScheduler
+from scaler.protocol.python.message import StateScheduler
+from scaler.protocol.python.mixins import Message
from scaler.utility.formatter import (
format_bytes,
format_integer,
@@ -28,7 +29,7 @@
ord("l"): "lag",
}
-sort_by_state = {"sort_by": "cpu", "sort_by_previous": "cpu", "sort_reverse": True}
+SORT_BY_STATE: Dict[str, Union[str, bool]] = {"sort_by": "cpu", "sort_by_previous": "cpu", "sort_reverse": True}
def get_args():
@@ -61,7 +62,7 @@ def poke(screen, args):
pass
-def show_status(status: MessageVariant, screen):
+def show_status(status: Message, screen):
if not isinstance(status, StateScheduler):
return
@@ -72,7 +73,7 @@ def show_status(status: MessageVariant, screen):
{
"cpu": format_percentage(status.scheduler.cpu),
"rss": format_bytes(status.scheduler.rss),
- "rss_free": format_bytes(status.scheduler.rss_free),
+ "rss_free": format_bytes(status.rss_free),
},
)
@@ -106,16 +107,16 @@ def show_status(status: MessageVariant, screen):
"worker": worker.worker_id.decode(),
"agt_cpu": worker.agent.cpu,
"agt_rss": worker.agent.rss,
- "cpu": worker.total_processors.cpu,
- "rss": worker.total_processors.rss,
- "os_rss_free": worker.total_processors.rss_free,
+ "cpu": sum(p.resource.cpu for p in worker.processor_statuses),
+ "rss": sum(p.resource.rss for p in worker.processor_statuses),
+ "os_rss_free": worker.rss_free,
"free": worker.free,
"sent": worker.sent,
"queued": worker.queued,
"suspended": worker.suspended,
"lag": worker.lag_us,
"last": worker.last_s,
- "ITL": worker.ITL,
+ "ITL": worker.itl,
}
for worker in status.worker_manager.workers
],
@@ -160,12 +161,14 @@ def format_integer_func(value):
return table
-def __generate_worker_manager_table(wm_data, worker_length: int):
+def __generate_worker_manager_table(wm_data: List[Dict], worker_length: int) -> List[List[str]]:
if not wm_data:
headers = [["No workers"]]
return headers
- wm_data = sorted(wm_data, key=lambda item: item[sort_by_state["sort_by"]], reverse=sort_by_state["sort_reverse"])
+ wm_data = sorted(
+ wm_data, key=lambda item: item[SORT_BY_STATE["sort_by"]], reverse=bool(SORT_BY_STATE["sort_reverse"])
+ )
for row in wm_data:
row["worker"] = __truncate(row["worker"], worker_length, how="left")
@@ -179,7 +182,7 @@ def __generate_worker_manager_table(wm_data, worker_length: int):
last = f"({format_seconds(last)}) " if last > 5 else ""
row["lag"] = last + format_microseconds(row["lag"])
- worker_manager_table = [[f"[{v}]" if v == sort_by_state["sort_by"] else v for v in wm_data[0].keys()]]
+ worker_manager_table = [[f"[{v}]" if v == SORT_BY_STATE["sort_by"] else v for v in wm_data[0].keys()]]
worker_manager_table.extend([list(worker.values()) for worker in wm_data])
return worker_manager_table
@@ -264,10 +267,10 @@ def __change_option_state(option: int):
if option not in SORT_BY_OPTIONS.keys():
return
- sort_by_state["sort_by_previous"] = sort_by_state["sort_by"]
- sort_by_state["sort_by"] = SORT_BY_OPTIONS[option]
- if sort_by_state["sort_by"] != sort_by_state["sort_by_previous"]:
- sort_by_state["sort_reverse"] = True
+ SORT_BY_STATE["sort_by_previous"] = SORT_BY_STATE["sort_by"]
+ SORT_BY_STATE["sort_by"] = SORT_BY_OPTIONS[option]
+ if SORT_BY_STATE["sort_by"] != SORT_BY_STATE["sort_by_previous"]:
+ SORT_BY_STATE["sort_reverse"] = True
return
- sort_by_state["sort_reverse"] = not sort_by_state["sort_reverse"]
+ SORT_BY_STATE["sort_reverse"] = not SORT_BY_STATE["sort_reverse"]
diff --git a/scaler/io/async_binder.py b/scaler/io/async_binder.py
index 8423985..68d5e1a 100644
--- a/scaler/io/async_binder.py
+++ b/scaler/io/async_binder.py
@@ -2,11 +2,12 @@
import os
import uuid
from collections import defaultdict
-from typing import Awaitable, Callable, List, Literal, Optional
+from typing import Awaitable, Callable, List, Optional, Dict
import zmq.asyncio
-from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant
+from scaler.io.utility import deserialize, serialize
+from scaler.protocol.python.mixins import Message
from scaler.protocol.python.status import BinderStatus
from scaler.utility.mixins import Looper, Reporter
from scaler.utility.zmq_config import ZMQConfig
@@ -25,14 +26,15 @@ def __init__(self, name: str, address: ZMQConfig, io_threads: int, identity: Opt
self.__set_socket_options()
self._socket.bind(self._address.to_address())
- self._callback: Optional[Callable[[bytes, MessageVariant], Awaitable[None]]] = None
+ self._callback: Optional[Callable[[bytes, Message], Awaitable[None]]] = None
- self._statistics = {"received": defaultdict(lambda: 0), "sent": defaultdict(lambda: 0)}
+ self._received: Dict[str, int] = defaultdict(lambda: 0)
+ self._sent: Dict[str, int] = defaultdict(lambda: 0)
def destroy(self):
self._context.destroy(linger=0)
- def register(self, callback: Callable[[bytes, MessageVariant], Awaitable[None]]):
+ def register(self, callback: Callable[[bytes, Message], Awaitable[None]]):
self._callback = callback
async def routine(self):
@@ -40,23 +42,21 @@ async def routine(self):
if not self.__is_valid_message(frames):
return
- source, message_type_bytes, payload = frames[0], frames[1], frames[2:]
- message_type = MessageType(message_type_bytes)
- self.__count_one("received", message_type)
+ source, payload = frames
+ message: Optional[Message] = deserialize(payload)
+ if message is None:
+ logging.error(f"received unknown message from {source!r}: {payload!r}")
+ return
- message = PROTOCOL[message_type].deserialize(payload)
+ self.__count_received(message.__class__.__name__)
await self._callback(source, message)
- async def send(self, to: bytes, message: MessageVariant):
- message_type = PROTOCOL.inverse[type(message)]
- self.__count_one("sent", message_type)
- await self._socket.send_multipart([to, message_type.value, *message.serialize()], copy=False)
+ async def send(self, to: bytes, message: Message):
+ self.__count_sent(message.__class__.__name__)
+ await self._socket.send_multipart([to, serialize(message)], copy=False)
def get_status(self) -> BinderStatus:
- return BinderStatus(
- received={k: v for k, v in self._statistics["received"].items()},
- sent={k: v for k, v in self._statistics["sent"].items()},
- )
+ return BinderStatus.new_msg(received=self._received, sent=self._sent)
def __set_socket_options(self):
self._socket.setsockopt(zmq.IDENTITY, self._identity)
@@ -64,18 +64,17 @@ def __set_socket_options(self):
self._socket.setsockopt(zmq.RCVHWM, 0)
def __is_valid_message(self, frames: List[bytes]) -> bool:
- if len(frames) < 3:
+ if len(frames) < 2:
logging.error(f"{self.__get_prefix()} received unexpected frames {frames}")
return False
- if frames[1] not in {member.value for member in MessageType}:
- logging.error(f"{self.__get_prefix()} received unexpected message type: {frames[0]}: {frames}")
- return False
-
return True
- def __count_one(self, count_type: Literal["sent", "received"], message_type: MessageType):
- self._statistics[count_type][message_type.name] += 1
+ def __count_received(self, message_type: str):
+ self._received[message_type] += 1
+
+ def __count_sent(self, message_type: str):
+ self._sent[message_type] += 1
def __get_prefix(self):
return f"{self.__class__.__name__}[{self._identity.decode()}]:"
diff --git a/scaler/io/async_connector.py b/scaler/io/async_connector.py
index faba381..86be509 100644
--- a/scaler/io/async_connector.py
+++ b/scaler/io/async_connector.py
@@ -5,7 +5,8 @@
import zmq.asyncio
-from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant
+from scaler.io.utility import deserialize, serialize
+from scaler.protocol.python.mixins import Message
from scaler.utility.zmq_config import ZMQConfig
@@ -17,7 +18,7 @@ def __init__(
socket_type: int,
address: ZMQConfig,
bind_or_connect: Literal["bind", "connect"],
- callback: Optional[Callable[[MessageType, MessageVariant], Awaitable[None]]],
+ callback: Optional[Callable[[Message], Awaitable[None]]],
identity: Optional[bytes],
):
self._address = address
@@ -41,7 +42,7 @@ def __init__(
else:
raise TypeError("bind_or_connect has to be 'bind' or 'connect'")
- self._callback: Optional[Callable[[MessageVariant], Awaitable[None]]] = callback
+ self._callback: Optional[Callable[[Message], Awaitable[None]]] = callback
def __del__(self):
self.destroy()
@@ -68,41 +69,35 @@ async def routine(self):
if self._callback is None:
return
- message = await self.receive()
+ message: Optional[Message] = await self.receive()
if message is None:
return
await self._callback(message)
- async def receive(self) -> Optional[MessageVariant]:
+ async def receive(self) -> Optional[Message]:
if self._context.closed:
return None
if self._socket.closed:
return None
- frames = await self._socket.recv_multipart()
- if not self.__is_valid_message(frames):
+ payload = await self._socket.recv()
+ result: Optional[Message] = deserialize(payload)
+ if result is None:
+ logging.error(f"received unknown message: {payload!r}")
return None
- message_type_bytes, *payload = frames
- message_type = MessageType(message_type_bytes)
- message = PROTOCOL[message_type].deserialize(payload)
- return message
+ return result
- async def send(self, data: MessageVariant):
- message_type = PROTOCOL.inverse[type(data)]
- await self._socket.send_multipart([message_type.value, *data.serialize()], copy=False)
+ async def send(self, message: Message):
+ await self._socket.send(serialize(message), copy=False)
def __is_valid_message(self, frames: List[bytes]) -> bool:
- if len(frames) < 2:
+ if len(frames) > 1:
logging.error(f"{self.__get_prefix()} received unexpected frames {frames}")
return False
- if frames[0] not in {member.value for member in MessageType}:
- logging.error(f"{self.__get_prefix()} received unexpected message type: {frames[0]}: {frames}")
- return False
-
return True
def __get_prefix(self):
diff --git a/scaler/io/config.py b/scaler/io/config.py
index 631f449..5c55fd7 100644
--- a/scaler/io/config.py
+++ b/scaler/io/config.py
@@ -12,6 +12,9 @@
# number of seconds for profiling
PROFILING_INTERVAL_SECONDS = 1
+# message size limitation, max can be 2**64
+MESSAGE_SIZE_LIMIT = 2**64 - 1
+
# ==========================
# SCHEDULER SPECIFIC OPTIONS
diff --git a/scaler/io/sync_connector.py b/scaler/io/sync_connector.py
index 2c82c03..a45b6a5 100644
--- a/scaler/io/sync_connector.py
+++ b/scaler/io/sync_connector.py
@@ -3,11 +3,12 @@
import socket
import threading
import uuid
-from typing import List, Optional
+from typing import Optional
import zmq
-from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant
+from scaler.io.utility import deserialize, serialize
+from scaler.protocol.python.mixins import Message
from scaler.utility.zmq_config import ZMQConfig
@@ -44,31 +45,23 @@ def address(self) -> ZMQConfig:
def identity(self) -> bytes:
return self._identity
- def send(self, message: MessageVariant):
- message_type = PROTOCOL.inverse[type(message)]
-
+ def send(self, message: Message):
with self._lock:
- self._socket.send_multipart([message_type.value, *message.serialize()])
+ self._socket.send(serialize(message))
- def receive(self) -> Optional[MessageVariant]:
+ def receive(self) -> Optional[Message]:
with self._lock:
- frames = self._socket.recv_multipart()
-
- return self.__compose_message(frames)
+ payload = self._socket.recv()
- def __compose_message(self, frames: List[bytes]) -> Optional[MessageVariant]:
- if len(frames) < 2:
- logging.error(f"{self.__get_prefix()} received unexpected frames {frames}")
- return None
+ return self.__compose_message(payload)
- if frames[0] not in {member.value for member in MessageType}:
- logging.error(f"{self.__get_prefix()} received unexpected message type: {frames[0]}: {frames}")
+ def __compose_message(self, payload: bytes) -> Optional[Message]:
+ result: Optional[Message] = deserialize(payload)
+ if result is None:
+ logging.error(f"{self.__get_prefix()}: received unknown message: {payload!r}")
return None
- message_type_bytes, *payload = frames
- message_type = MessageType(message_type_bytes)
- message = PROTOCOL[message_type].deserialize(payload)
- return message
+ return result
def __get_prefix(self):
return f"{self.__class__.__name__}[{self._identity.decode()}]:"
diff --git a/scaler/io/sync_subscriber.py b/scaler/io/sync_subscriber.py
index 00010a6..5423f2c 100644
--- a/scaler/io/sync_subscriber.py
+++ b/scaler/io/sync_subscriber.py
@@ -1,10 +1,11 @@
import logging
import threading
-from typing import Callable, List, Optional
+from typing import Callable, Optional
import zmq
-from scaler.protocol.python.message import PROTOCOL, MessageType, MessageVariant
+from scaler.io.utility import deserialize
+from scaler.protocol.python.mixins import Message
from scaler.utility.zmq_config import ZMQConfig
@@ -12,7 +13,7 @@ class SyncSubscriber(threading.Thread):
def __init__(
self,
address: ZMQConfig,
- callback: Callable[[MessageVariant], None],
+ callback: Callable[[Message], None],
topic: bytes,
exit_callback: Optional[Callable[[], None]] = None,
stop_event: threading.Event = threading.Event(),
@@ -54,7 +55,7 @@ def run(self) -> None:
def __initialize(self):
self._context = zmq.Context.instance()
- self._socket: zmq.Socket = self._context.socket(zmq.SUB)
+ self._socket = self._context.socket(zmq.SUB)
self._socket.setsockopt(zmq.RCVHWM, 0)
if self._timeout_seconds == -1:
@@ -68,23 +69,14 @@ def __initialize(self):
def __routine_polling(self):
try:
- frames = self._socket.recv_multipart()
- self.__routine_receive(frames)
+ self.__routine_receive(self._socket.recv())
except zmq.Again:
raise TimeoutError(f"Cannot connect to {self._address.to_address()} in {self._timeout_seconds} seconds")
- def __routine_receive(self, frames: List[bytes]):
- logging.info(frames)
- if len(frames) < 2:
- logging.error(f"received unexpected frames {frames}")
- return
+ def __routine_receive(self, payload: bytes):
+ result: Optional[Message] = deserialize(payload)
+ if result is None:
+ logging.error(f"received unknown message: {payload!r}")
+ return None
- if frames[0] not in {member.value for member in MessageType}:
- logging.error(f"received unexpected message type: {frames[0]}")
- return
-
- message_type_bytes, *payload = frames
- message_type = MessageType(message_type_bytes)
- message = PROTOCOL[message_type].deserialize(payload)
-
- self._callback(message)
+ self._callback(result)
diff --git a/scaler/io/utility.py b/scaler/io/utility.py
new file mode 100644
index 0000000..6465671
--- /dev/null
+++ b/scaler/io/utility.py
@@ -0,0 +1,22 @@
+import logging
+from typing import Optional
+
+from scaler.io.config import MESSAGE_SIZE_LIMIT
+from scaler.protocol.capnp._python import _message # noqa
+from scaler.protocol.python.message import PROTOCOL
+from scaler.protocol.python.mixins import Message
+
+
+def deserialize(data: bytes) -> Optional[Message]:
+ with _message.Message.from_bytes(data, traversal_limit_in_words=MESSAGE_SIZE_LIMIT) as payload:
+ if not hasattr(payload, payload.which()):
+ logging.error(f"unknown message type: {payload.which()}")
+ return None
+
+ message = getattr(payload, payload.which())
+ return PROTOCOL[payload.which()](message)
+
+
+def serialize(message: Message) -> bytes:
+ payload = _message.Message(**{PROTOCOL.inverse[type(message)]: message.get_message()})
+ return payload.to_bytes()
diff --git a/scaler/protocol/capnp/__init__.py b/scaler/protocol/capnp/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/scaler/protocol/capnp/_python.py b/scaler/protocol/capnp/_python.py
new file mode 100644
index 0000000..93d725f
--- /dev/null
+++ b/scaler/protocol/capnp/_python.py
@@ -0,0 +1,4 @@
+import capnp # noqa
+import scaler.protocol.capnp.common_capnp as _common # noqa
+import scaler.protocol.capnp.message_capnp as _message # noqa
+import scaler.protocol.capnp.status_capnp as _status # noqa
diff --git a/scaler/protocol/capnp/common.capnp b/scaler/protocol/capnp/common.capnp
new file mode 100644
index 0000000..3f08c37
--- /dev/null
+++ b/scaler/protocol/capnp/common.capnp
@@ -0,0 +1,22 @@
+@0xf57f79ac88fab620;
+
+enum TaskStatus {
+ # task is accepted by scheduler, but will have below status
+ success @0; # if submit and task is done and get result
+ failed @1; # if submit and task is failed on worker
+ canceled @2; # if submit and task is canceled
+ notFound @3; # if submit and task is not found in scheduler
+ workerDied @4; # if submit and worker died (only happened when scheduler keep_task=False)
+ noWorker @5; # if submit and scheduler is full (not implemented yet)
+
+ # below are only used for monitoring channel, not sent to client
+ inactive @6; # task is scheduled but not allocate to worker
+ running @7; # task is running in worker
+ canceling @8; # task is canceling (can be in Inactive or Running state)
+}
+
+struct ObjectContent {
+ objectIds @0 :List(Data);
+ objectNames @1 :List(Data);
+ objectBytes @2 :List(Data);
+}
diff --git a/scaler/protocol/capnp/message.capnp b/scaler/protocol/capnp/message.capnp
new file mode 100644
index 0000000..c554aa2
--- /dev/null
+++ b/scaler/protocol/capnp/message.capnp
@@ -0,0 +1,208 @@
+@0xaf44f44ea94a4675;
+
+using CommonType = import "common.capnp";
+using Status = import "status.capnp";
+
+struct Task {
+ taskId @0 :Data;
+ source @1 :Data;
+ metadata @2 :Data;
+ funcObjectId @3 :Data;
+ functionArgs @4 :List(Argument);
+
+ struct Argument {
+ type @0 :ArgumentType;
+ data @1 :Data;
+
+ enum ArgumentType {
+ task @0;
+ objectID @1;
+ }
+ }
+}
+
+struct TaskCancel {
+ struct TaskCancelFlags {
+ force @0 :Bool;
+ retrieveTaskObject @1 :Bool;
+ }
+
+ taskId @0 :Data;
+ flags @1 :TaskCancelFlags;
+}
+
+struct TaskResult {
+ taskId @0 :Data;
+ status @1 :CommonType.TaskStatus;
+ metadata @2 :Data;
+ results @3 :List(Data);
+}
+
+struct GraphTask {
+ taskId @0 :Data;
+ source @1 :Data;
+ targets @2 :List(Data);
+ graph @3 :List(Task);
+}
+
+struct GraphTaskCancel {
+ taskId @0 :Data;
+}
+
+struct ClientHeartbeat {
+ resource @0 :Status.Resource;
+ latencyUS @1 :UInt32;
+}
+
+struct ClientHeartbeatEcho {
+}
+
+struct WorkerHeartbeat {
+ agent @0 :Status.Resource;
+ rssFree @1 :UInt64;
+ queuedTasks @2 :UInt32;
+ latencyUS @3 :UInt32;
+ taskLock @4 :Bool;
+ processors @5 :List(Status.ProcessorStatus);
+}
+
+struct WorkerHeartbeatEcho {
+}
+
+struct ObjectInstruction {
+ instructionType @0 :ObjectInstructionType;
+ objectUser @1 :Data;
+ objectContent @2 :CommonType.ObjectContent;
+
+ enum ObjectInstructionType {
+ create @0;
+ delete @1;
+ }
+}
+
+struct ObjectRequest {
+ requestType @0 :ObjectRequestType;
+ objectIds @1 :List(Data);
+
+ enum ObjectRequestType {
+ get @0;
+ }
+}
+
+struct ObjectResponse {
+ responseType @0 :ObjectResponseType;
+ objectContent @1 :CommonType.ObjectContent;
+
+ enum ObjectResponseType {
+ content @0;
+ objectNotExist @1;
+ }
+}
+
+struct DisconnectRequest {
+ worker @0 :Data;
+}
+
+struct DisconnectResponse {
+ worker @0 :Data;
+}
+
+struct ClientDisconnect {
+ disconnectType @0 :DisconnectType;
+
+ enum DisconnectType {
+ disconnect @0;
+ shutdown @1;
+ }
+}
+
+struct ClientShutdownResponse {
+ accepted @0 :Bool;
+}
+
+struct StateClient {
+}
+
+struct StateObject {
+}
+
+struct StateBalanceAdvice {
+ workerId @0 :Data;
+ taskIds @1 :List(Data);
+}
+
+struct StateScheduler {
+ binder @0 :Status.BinderStatus;
+ scheduler @1 :Status.Resource;
+ rssFree @2 :UInt64;
+ clientManager @3 :Status.ClientManagerStatus;
+ objectManager @4 :Status.ObjectManagerStatus;
+ taskManager @5 :Status.TaskManagerStatus;
+ workerManager @6 :Status.WorkerManagerStatus;
+}
+
+struct StateWorker {
+ workerId @0 :Data;
+ message @1 :Data;
+}
+
+struct StateTask {
+ taskId @0 :Data;
+ functionName @1 :Data;
+ status @2 :CommonType.TaskStatus;
+ worker @3 :Data;
+ metadata @4 :Data;
+}
+
+struct StateGraphTask {
+ enum NodeTaskType {
+ normal @0;
+ target @1;
+ }
+
+ graphTaskId @0 :Data;
+ taskId @1 :Data;
+ nodeTaskType @2 :NodeTaskType;
+ parentTaskIds @3 :List(Data);
+}
+
+struct ProcessorInitialized {
+}
+
+
+struct Message {
+ union {
+ task @0 :Task;
+ taskCancel @1 :TaskCancel;
+ taskResult @2 :TaskResult;
+
+ graphTask @3 :GraphTask;
+ graphTaskCancel @4 :GraphTaskCancel;
+
+ objectInstruction @5 :ObjectInstruction;
+ objectRequest @6 :ObjectRequest;
+ objectResponse @7 :ObjectResponse;
+
+ clientHeartbeat @8 :ClientHeartbeat;
+ clientHeartbeatEcho @9 :ClientHeartbeatEcho;
+
+ workerHeartbeat @10 :WorkerHeartbeat;
+ workerHeartbeatEcho @11 :WorkerHeartbeatEcho;
+
+ disconnectRequest @12 :DisconnectRequest;
+ disconnectResponse @13 :DisconnectResponse;
+
+ stateClient @14 :StateClient;
+ stateObject @15 :StateObject;
+ stateBalanceAdvice @16 :StateBalanceAdvice;
+ stateScheduler @17 :StateScheduler;
+ stateWorker @18 :StateWorker;
+ stateTask @19 :StateTask;
+ stateGraphTask @20 :StateGraphTask;
+
+ clientDisconnect @21 :ClientDisconnect;
+ clientShutdownResponse @22 :ClientShutdownResponse;
+
+ processorInitialized @23 :ProcessorInitialized;
+ }
+}
diff --git a/scaler/protocol/capnp/status.capnp b/scaler/protocol/capnp/status.capnp
new file mode 100644
index 0000000..1131bc4
--- /dev/null
+++ b/scaler/protocol/capnp/status.capnp
@@ -0,0 +1,65 @@
+@0xa4dfa1212ad2d0f0;
+
+struct Resource {
+ cpu @0 :UInt16; # 99.2% will be represented as 992 as integer
+ rss @1 :UInt64; # 32bit is capped to 4GB, so use 64bit to represent
+}
+
+struct ObjectManagerStatus {
+ numberOfObjects @0 :UInt32;
+ objectMemory @1 :UInt64;
+}
+
+struct ClientManagerStatus {
+ clientToNumOfTask @0 :List(Pair);
+
+ struct Pair {
+ client @0 :Data;
+ numTask @1 :UInt32;
+ }
+}
+
+struct TaskManagerStatus {
+ unassigned @0 :UInt32;
+ running @1 :UInt32;
+ success @2 :UInt32;
+ failed @3 :UInt32;
+ canceled @4 :UInt32;
+ notFound @5 :UInt32;
+}
+
+struct ProcessorStatus {
+ pid @0 :UInt32;
+ initialized @1 :Bool;
+ hasTask @2 :Bool;
+ suspended @3 :Bool;
+ resource @4 :Resource;
+}
+
+struct WorkerStatus {
+ workerId @0 :Data;
+ agent @1 :Resource;
+ rssFree @2 :UInt64;
+ free @3 :UInt32;
+ sent @4 :UInt32;
+ queued @5 :UInt32;
+ suspended @6: UInt8;
+ lagUS @7 :UInt64;
+ lastS @8 :UInt8;
+ itl @9 :Text;
+ processorStatuses @10 :List(ProcessorStatus);
+}
+
+struct WorkerManagerStatus {
+ workers @0 :List(WorkerStatus);
+}
+
+struct BinderStatus {
+ received @0 :List(Pair);
+ sent @1 :List(Pair);
+
+ struct Pair {
+ client @0 :Text;
+ number @1 :UInt32;
+ }
+}
diff --git a/scaler/protocol/introduction.md b/scaler/protocol/introduction.md
index 97fbb2f..f612fd4 100644
--- a/scaler/protocol/introduction.md
+++ b/scaler/protocol/introduction.md
@@ -1,6 +1,7 @@
# Roles
The communication protocol include 3 roles: client, scheduler and worker:
+
- client is upstream of scheduler, scheduler is upstream of worker
- worker is downstream of scheduler, scheduler is downstream of client
@@ -29,26 +30,11 @@ each client to scheduler and each worker to scheduler only maintains 1 TCP conne
# Message format
-Each message is a sequence of bytes, and composed by frames, each frame is a sequence of bytes and has a fixed
-length header and a variable length body.
-
-see below, each frame have a fixed length header, and a variable length body, the header is 8 bytes, so each frame
-cannot exceed 2^64 bytes, which is 16 exabytes, which is enough for most of the use cases.
-
-```plaintext
-| Frame 0 | Frame 1 | Frame 2 | Frame 3
-+---+----------+---+-----+-----+---+-----+----+---------------------
-| 7 | "Client1"| 2 | "O" | "I" | 1 | "C" | 36 | ...
-+---+----------+---+-----+-----+---+-----+----+---------------------
- | Message | | Object | | |
- | Identity | |Instruction| | |
- | |
- Instruction
- Type
-```
-
+Scaler is using capnp library to serialize/deserialize and use zmq to communicate between client and scheduler and
+worker
# Message Type Category
+
In general, there are 2 categories of the message types: object and task
object normally has an object id associated with actual object data, object data is immutable bytes, serialized by
@@ -60,10 +46,9 @@ function and series of arguments, but task message doesn't contain the actual fu
contains object ids, workers are responsible to fetch the function/argument data from scheduler and deserialize and
execute the function call.
-
## Object Channel
-Scheduler is the center of the object storage, client and worker are identical and can push
+Scheduler is the center of the object storage, client and worker are identical and can push
```plaintext
ObjectInstruction
@@ -84,19 +69,20 @@ Scheduler is the center of the object storage, client and worker are identical a
ObjectRequest +--------------+
```
+
ObjectInstruction = b"OI"
client can send object instruction to scheduler, scheduler can send object instruction to worker
it has 2 subtypes: create b"C", delete b"D"
when subtype is create, it has to include:
+
- list of object id (type bytes)
- list of object names (type bytes)
- list of object bytes (type bytes)
-All above 3 lists, the number of items need match
+ All above 3 lists, the number of items need match
ObjectRequest = b"OR"
ObjectResponse = b"OA"
-
## Task Channel
```plaintext
diff --git a/scaler/protocol/python/common.py b/scaler/protocol/python/common.py
new file mode 100644
index 0000000..1e94e9c
--- /dev/null
+++ b/scaler/protocol/python/common.py
@@ -0,0 +1,56 @@
+import dataclasses
+import enum
+from typing import Tuple
+
+from scaler.protocol.capnp._python import _common # noqa
+from scaler.protocol.python.mixins import Message
+
+
+class TaskStatus(enum.Enum):
+ # task is accepted by scheduler, but will have below status
+ Success = _common.TaskStatus.success # if submit and task is done and get result
+ Failed = _common.TaskStatus.failed # if submit and task is failed on worker
+ Canceled = _common.TaskStatus.canceled # if submit and task is canceled
+ NotFound = _common.TaskStatus.notFound # if submit and task is not found in scheduler
+ WorkerDied = (
+ _common.TaskStatus.workerDied
+ ) # if submit and worker died (only happened when scheduler keep_task=False)
+ NoWorker = _common.TaskStatus.noWorker # if submit and scheduler is full (not implemented yet)
+
+ # below are only used for monitoring channel, not sent to client
+ Inactive = _common.TaskStatus.inactive # task is scheduled but not allocate to worker
+ Running = _common.TaskStatus.running # task is running in worker
+ Canceling = _common.TaskStatus.canceling # task is canceling (can be in Inactive or Running state)
+
+
+@dataclasses.dataclass
+class ObjectContent(Message):
+ def __init__(self, msg):
+ self._msg = msg
+
+ @property
+ def object_ids(self) -> Tuple[bytes, ...]:
+ return tuple(self._msg.objectIds)
+
+ @property
+ def object_names(self) -> Tuple[bytes, ...]:
+ return tuple(self._msg.objectNames)
+
+ @property
+ def object_bytes(self) -> Tuple[bytes, ...]:
+ return tuple(self._msg.objectBytes)
+
+ @staticmethod
+ def new_msg(
+ object_ids: Tuple[bytes, ...],
+ object_names: Tuple[bytes, ...] = tuple(),
+ object_bytes: Tuple[bytes, ...] = tuple(),
+ ) -> "ObjectContent":
+ return ObjectContent(
+ _common.ObjectContent(
+ objectIds=list(object_ids), objectNames=list(object_names), objectBytes=tuple(object_bytes)
+ )
+ )
+
+ def get_message(self):
+ return self._msg
diff --git a/scaler/protocol/python/message.py b/scaler/protocol/python/message.py
index bc3f907..43d55b9 100644
--- a/scaler/protocol/python/message.py
+++ b/scaler/protocol/python/message.py
@@ -1,653 +1,644 @@
import dataclasses
import enum
import os
-import pickle
-import struct
-from typing import List, Set, Tuple, TypeVar
+from typing import List, Set, Tuple, Optional, Type
import bidict
-from scaler.protocol.python.mixins import _Message
+from scaler.protocol.capnp._python import _message # noqa
+from scaler.protocol.python.common import TaskStatus, ObjectContent
+from scaler.protocol.python.mixins import Message
from scaler.protocol.python.status import (
BinderStatus,
ClientManagerStatus,
ObjectManagerStatus,
Resource,
+ ProcessorStatus,
TaskManagerStatus,
WorkerManagerStatus,
)
-class MessageType(enum.Enum):
- Task = b"TK"
- TaskCancel = b"TC"
- TaskResult = b"TR"
+class Task(Message):
+ @dataclasses.dataclass
+ class Argument:
+ type: "ArgumentType"
+ data: bytes
- GraphTask = b"GT"
- GraphTaskCancel = b"GC"
+ def __repr__(self):
+ return f"Argument(type={self.type}, data={self.data.hex()})"
- ObjectInstruction = b"OI"
- ObjectRequest = b"OR"
- ObjectResponse = b"OA"
+ class ArgumentType(enum.Enum):
+ Task = _message.Task.Argument.ArgumentType.task
+ ObjectID = _message.Task.Argument.ArgumentType.objectID
- ClientHeartbeat = b"CB"
- ClientHeartbeatEcho = b"CE"
+ def __init__(self, msg):
+ super().__init__(msg)
- WorkerHeartbeat = b"HB"
- WorkerHeartbeatEcho = b"HE"
-
- DisconnectRequest = b"DR"
- DisconnectResponse = b"DP"
+ def __repr__(self):
+ return (
+ f"Task(task_id={self.task_id.hex()}, source={self.source.hex()}, metadata={self.metadata.hex()},"
+ f"func_object_id={self.func_object_id.hex()}, function_args={self.function_args})"
+ )
- StateClient = b"SC"
- StateObject = b"SF"
- StateBalanceAdvice = b"SA"
- StateScheduler = b"SS"
- StateWorker = b"SW"
- StateTask = b"ST"
- StateGraphTask = b"SG"
+ @property
+ def task_id(self) -> bytes:
+ return self._msg.taskId
- ClientDisconnect = b"CS"
- ClientShutdownResponse = b"CR"
+ @property
+ def source(self) -> bytes:
+ return self._msg.source
- ProcessorInitialized = b"PI"
+ @property
+ def metadata(self) -> bytes:
+ return self._msg.metadata
- @staticmethod
- def allowed_values():
- return {member.value for member in MessageType}
+ @property
+ def func_object_id(self) -> bytes:
+ return self._msg.funcObjectId
+ @property
+ def function_args(self) -> List[Argument]:
+ return [
+ Task.Argument(type=Task.Argument.ArgumentType(arg.type.raw), data=arg.data)
+ for arg in self._msg.functionArgs
+ ]
-class TaskEchoStatus(enum.Enum):
- # task echo is the response of task submit to scheduler
- SubmitOK = b"O" # if submit ok and task get accepted by scheduler
- SubmitDuplicated = b"D" # if submit and find task in scheduler
- CancelOK = b"C" # if cancel and success
- CancelTaskNotFound = b"N" # if cancel and cannot find task in scheduler
+ @staticmethod
+ def new_msg(
+ task_id: bytes, source: bytes, metadata: bytes, func_object_id: bytes, function_args: List[Argument]
+ ) -> "Task":
+ return Task(
+ _message.Task(
+ taskId=task_id,
+ source=source,
+ metadata=metadata,
+ funcObjectId=func_object_id,
+ functionArgs=[_message.Task.Argument(type=arg.type.value, data=arg.data) for arg in function_args],
+ )
+ )
-class TaskStatus(enum.Enum):
- # task is accepted by scheduler, but will have below status
- Success = b"S" # if submit and task is done and get result
- Failed = b"F" # if submit and task is failed on worker
- Canceled = b"C" # if submit and task is canceled
- NotFound = b"N" # if submit and task is not found in scheduler
- WorkerDied = b"K" # if submit and worker died (only happened when scheduler keep_task=False)
- NoWorker = b"W" # if submit and scheduler is full (not implemented yet)
+class TaskCancel(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- # below are only used for monitoring channel, not sent to client
- Inactive = b"I" # task is scheduled but not allocate to worker
- Running = b"R" # task is running in worker
- Canceling = b"X" # task is canceling (can be in Inactive or Running state)
+ @dataclasses.dataclass
+ class TaskCancelFlags:
+ force: bool
+ retrieve_task_object: bool
+ @property
+ def task_id(self) -> bytes:
+ return self._msg.taskId
-class NodeTaskType(enum.Enum):
- Normal = b"N"
- Target = b"T"
+ @property
+ def flags(self) -> TaskCancelFlags:
+ return TaskCancel.TaskCancelFlags(
+ force=self._msg.flags.force, retrieve_task_object=self._msg.flags.retrieveTaskObject
+ )
+ @staticmethod
+ def new_msg(task_id: bytes, flags: Optional[TaskCancelFlags] = None) -> "TaskCancel":
+ if flags is None:
+ flags = TaskCancel.TaskCancelFlags(force=False, retrieve_task_object=False)
+
+ return TaskCancel(
+ _message.TaskCancel(
+ taskId=task_id,
+ flags=_message.TaskCancel.TaskCancelFlags(
+ force=flags.force, retrieveTaskObject=flags.retrieve_task_object
+ ),
+ )
+ )
-class ObjectInstructionType(enum.Enum):
- Create = b"C"
- Delete = b"D"
+class TaskResult(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
-class ObjectRequestType(enum.Enum):
- Get = b"A"
+ @property
+ def task_id(self) -> bytes:
+ return self._msg.taskId
+ @property
+ def status(self) -> TaskStatus:
+ return TaskStatus(self._msg.status.raw)
-class ObjectResponseType(enum.Enum):
- Content = b"C"
- ObjectNotExist = b"N"
+ @property
+ def metadata(self) -> bytes:
+ return self._msg.metadata
+ @property
+ def results(self) -> List[bytes]:
+ return self._msg.results
-class ArgumentType(enum.Enum):
- Task = b"T"
- ObjectID = b"R"
+ @staticmethod
+ def new_msg(
+ task_id: bytes, status: TaskStatus, metadata: Optional[bytes] = None, results: Optional[List[bytes]] = None
+ ) -> "TaskResult":
+ if metadata is None:
+ metadata = bytes()
+ if results is None:
+ results = list()
-class DisconnectType(enum.Enum):
- Disconnect = b"D"
- Shutdown = b"S"
+ return TaskResult(_message.TaskResult(taskId=task_id, status=status.value, metadata=metadata, results=results))
-@dataclasses.dataclass
-class Argument:
- type: ArgumentType
- data: bytes
+class GraphTask(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
def __repr__(self):
- return f"Argument(type={self.type}, data={self.data.hex()})"
+ return (
+ f"GraphTask({os.linesep}"
+ f" task_id={self.task_id.hex()},{os.linesep}"
+ f" targets=[{os.linesep}"
+ f" {[target.hex() + ',' + os.linesep for target in self.targets]}"
+ f" ]\n"
+ f" graph={self.graph}\n"
+ f")"
+ )
- def serialize(self) -> Tuple[bytes, ...]:
- return self.type.value, self.data
+ @property
+ def task_id(self) -> bytes:
+ return self._msg.taskId
- @staticmethod
- def deserialize(data: List[bytes]):
- return Argument(ArgumentType(data[0]), data[1])
+ @property
+ def source(self) -> bytes:
+ return self._msg.source
+ @property
+ def targets(self) -> List[bytes]:
+ return self._msg.targets
-@dataclasses.dataclass
-class ObjectContent:
- object_ids: Tuple[bytes, ...]
- object_names: Tuple[bytes, ...] = dataclasses.field(default_factory=tuple)
- object_bytes: Tuple[bytes, ...] = dataclasses.field(default_factory=tuple)
-
- def serialize(self) -> Tuple[bytes, ...]:
- payload = (
- struct.pack("III", len(self.object_ids), len(self.object_names), len(self.object_bytes)),
- *self.object_ids,
- *self.object_names,
- *self.object_bytes,
- )
- return payload
+ @property
+ def graph(self) -> List[Task]:
+ return [Task(task) for task in self._msg.graph]
@staticmethod
- def deserialize(data: List[bytes]) -> "ObjectContent":
- num_of_object_ids, num_of_object_names, num_of_object_bytes = struct.unpack("III", data[0])
-
- data = data[1:]
- object_ids = data[:num_of_object_ids]
-
- data = data[num_of_object_ids:]
- object_names = data[:num_of_object_names]
+ def new_msg(task_id: bytes, source: bytes, targets: List[bytes], graph: List[Task]) -> "GraphTask":
+ return GraphTask(
+ _message.GraphTask(
+ taskId=task_id, source=source, targets=targets, graph=[task.get_message() for task in graph]
+ )
+ )
- data = data[num_of_object_names:]
- object_bytes = data[:num_of_object_bytes]
- return ObjectContent(tuple(object_ids), tuple(object_names), tuple(object_bytes))
+class GraphTaskCancel(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
+ @property
+ def task_id(self) -> bytes:
+ return self._msg.taskId
-MessageVariant = TypeVar("MessageVariant", bound=_Message)
+ @staticmethod
+ def new_msg(task_id: bytes) -> "GraphTaskCancel":
+ return GraphTaskCancel(_message.GraphTaskCancel(taskId=task_id))
+ def get_message(self):
+ return self._msg
-@dataclasses.dataclass
-class Task(_Message):
- task_id: bytes
- source: bytes
- metadata: bytes
- func_object_id: bytes
- function_args: List[Argument]
- def __repr__(self):
- return (
- f"Task(task_id={self.task_id.hex()}, source={self.source.hex()}, metadata={self.metadata.hex()},"
- f"func_object_id={self.func_object_id.hex()}, function_args={self.function_args})"
- )
+class ClientHeartbeat(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- def get_required_object_ids(self) -> Set[bytes]:
- return {self.func_object_id} | {arg.data for arg in self.function_args if arg.type == ArgumentType.ObjectID}
+ @property
+ def resource(self) -> Resource:
+ return Resource(self._msg.resource)
- def serialize(self) -> Tuple[bytes, ...]:
- return (
- self.task_id,
- self.source,
- self.metadata,
- self.func_object_id,
- *[d for arg in self.function_args for d in arg.serialize()],
- )
+ @property
+ def latency_us(self) -> int:
+ return self._msg.latencyUS
@staticmethod
- def deserialize(data: List[bytes]):
- return Task(
- data[0], data[1], data[2], data[3], [Argument.deserialize(data[i : i + 2]) for i in range(4, len(data), 2)]
- )
-
-
-@dataclasses.dataclass
-class TaskCancelFlags:
- force: bool = dataclasses.field(default=False)
- retrieve_task_object: bool = dataclasses.field(default=False)
+ def new_msg(resource: Resource, latency_us: int) -> "ClientHeartbeat":
+ return ClientHeartbeat(_message.ClientHeartbeat(resource=resource.get_message(), latencyUS=latency_us))
- FORMAT = "??"
- def serialize(self) -> bytes:
- return struct.pack(TaskCancelFlags.FORMAT, self.force, self.retrieve_task_object)
+class ClientHeartbeatEcho(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
@staticmethod
- def deserialize(data: bytes) -> "TaskCancelFlags":
- return TaskCancelFlags(*struct.unpack(TaskCancelFlags.FORMAT, data))
+ def new_msg() -> "ClientHeartbeatEcho":
+ return ClientHeartbeatEcho(_message.ClientHeartbeatEcho())
-@dataclasses.dataclass
-class TaskCancel(_Message):
- task_id: bytes
- flags: TaskCancelFlags = dataclasses.field(default_factory=TaskCancelFlags)
+class WorkerHeartbeat(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- def serialize(self) -> Tuple[bytes, ...]:
- return (self.task_id, self.flags.serialize())
+ @property
+ def agent(self) -> Resource:
+ return Resource(self._msg.agent)
- @staticmethod
- def deserialize(data: List[bytes]):
- return TaskCancel(data[0], TaskCancelFlags.deserialize(data[1])) # type: ignore
+ @property
+ def rss_free(self) -> int:
+ return self._msg.rssFree
+ @property
+ def queued_tasks(self) -> int:
+ return self._msg.queuedTasks
-@dataclasses.dataclass
-class TaskResult(_Message):
- task_id: bytes
- status: TaskStatus
- metadata: bytes = dataclasses.field(default=b"")
- results: List[bytes] = dataclasses.field(default_factory=list)
+ @property
+ def latency_us(self) -> int:
+ return self._msg.latencyUS
- def serialize(self) -> Tuple[bytes, ...]:
- return self.task_id, self.status.value, self.metadata, *self.results
+ @property
+ def task_lock(self) -> bool:
+ return self._msg.taskLock
- @staticmethod
- def deserialize(data: List[bytes]):
- return TaskResult(data[0], TaskStatus(data[1]), data[2], data[3:])
-
-
-@dataclasses.dataclass
-class GraphTask(_Message):
- task_id: bytes
- source: bytes
- targets: List[bytes]
- graph: List[Task]
+ @property
+ def processors(self) -> List[ProcessorStatus]:
+ return [ProcessorStatus(p) for p in self._msg.processors]
- def __repr__(self):
- return (
- f"GraphTask({os.linesep}"
- f" task_id={self.task_id.hex()},{os.linesep}"
- f" targets=[{os.linesep}"
- f" {[target.hex() + ',' + os.linesep for target in self.targets]}"
- f" ]\n"
- f" graph={self.graph}\n"
- f")"
+ @staticmethod
+ def new_msg(
+ agent: Resource,
+ rss_free: int,
+ queued_tasks: int,
+ latency_us: int,
+ task_lock: bool,
+ processors: List[ProcessorStatus],
+ ) -> "WorkerHeartbeat":
+ return WorkerHeartbeat(
+ _message.WorkerHeartbeat(
+ agent=agent.get_message(),
+ rssFree=rss_free,
+ queuedTasks=queued_tasks,
+ latencyUS=latency_us,
+ taskLock=task_lock,
+ processors=[p.get_message() for p in processors],
+ )
)
- def serialize(self) -> Tuple[bytes, ...]:
- graph_bytes = []
- for task in self.graph:
- frames = task.serialize()
- graph_bytes.append(struct.pack("I", len(frames)))
- graph_bytes.extend(frames)
- return self.task_id, self.source, struct.pack("I", len(self.targets)), *self.targets, *graph_bytes
+class WorkerHeartbeatEcho(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
@staticmethod
- def deserialize(data: List[bytes]):
- index = 0
- task_id = data[index]
+ def new_msg() -> "WorkerHeartbeatEcho":
+ return WorkerHeartbeatEcho(_message.WorkerHeartbeatEcho())
- index += 1
- client_id = data[index]
- index += 1
- number_of_targets = struct.unpack("I", data[index])[0]
+class ObjectInstruction(Message):
+ class ObjectInstructionType(enum.Enum):
+ Create = _message.ObjectInstruction.ObjectInstructionType.create
+ Delete = _message.ObjectInstruction.ObjectInstructionType.delete
- index += 1
- targets = data[index : index + number_of_targets]
+ def __init__(self, msg):
+ super().__init__(msg)
- index += number_of_targets
- graph = []
- while index < len(data):
- number_of_frames = struct.unpack("I", data[index])[0]
- index += 1
- graph.append(Task.deserialize(data[index : index + number_of_frames]))
- index += number_of_frames
+ @property
+ def instruction_type(self) -> ObjectInstructionType:
+ return ObjectInstruction.ObjectInstructionType(self._msg.instructionType.raw)
- return GraphTask(task_id, client_id, targets, graph)
+ @property
+ def object_user(self) -> bytes:
+ return self._msg.objectUser
-
-@dataclasses.dataclass
-class GraphTaskCancel(_Message):
- task_id: bytes
-
- def serialize(self) -> Tuple[bytes, ...]:
- return (self.task_id,)
+ @property
+ def object_content(self) -> ObjectContent:
+ return ObjectContent(self._msg.objectContent)
@staticmethod
- def deserialize(data: List[bytes]):
- return GraphTaskCancel(data[0])
-
-
-@dataclasses.dataclass
-class ClientHeartbeat(_Message):
- client_cpu: float
- client_rss: int
- latency_us: int
-
- FORMAT = "HQI"
-
- def serialize(self) -> Tuple[bytes, ...]:
- return (struct.pack(ClientHeartbeat.FORMAT, int(self.client_cpu * 1000), self.client_rss, self.latency_us),)
+ def new_msg(
+ instruction_type: ObjectInstructionType, object_user: bytes, object_content: ObjectContent
+ ) -> "ObjectInstruction":
+ return ObjectInstruction(
+ _message.ObjectInstruction(
+ instructionType=instruction_type.value,
+ objectUser=object_user,
+ objectContent=object_content.get_message(),
+ )
+ )
- @staticmethod
- def deserialize(data: List[bytes]):
- client_cpu, client_rss, latency_us = struct.unpack(ClientHeartbeat.FORMAT, data[0])
- return ClientHeartbeat(float(client_cpu / 1000), client_rss, latency_us)
+class ObjectRequest(Message):
+ class ObjectRequestType(enum.Enum):
+ Get = _message.ObjectRequest.ObjectRequestType.get
-@dataclasses.dataclass
-class ClientHeartbeatEcho(_Message):
- def serialize(self) -> Tuple[bytes, ...]:
- return (b"",)
+ def __init__(self, msg):
+ super().__init__(msg)
- @staticmethod
- def deserialize(data: List[bytes]):
- return ClientHeartbeatEcho()
+ def __repr__(self):
+ return (
+ f"ObjectRequest(type={self.request_type}, "
+ f"object_ids={tuple(object_id.hex() for object_id in self.object_ids)})"
+ )
+ @property
+ def request_type(self) -> ObjectRequestType:
+ return ObjectRequest.ObjectRequestType(self._msg.requestType.raw)
-@dataclasses.dataclass
-class ProcessorHeartbeat:
- pid: int
- initialized: bool
- has_task: bool
- suspended: bool
- cpu: float
- rss: int
-
- FORMAT = "I???HQ"
-
- def serialize(self) -> bytes:
- return struct.pack(
- ProcessorHeartbeat.FORMAT,
- self.pid,
- self.initialized,
- self.has_task,
- self.suspended,
- int(self.cpu * 1000),
- self.rss,
- )
+ @property
+ def object_ids(self) -> Tuple[bytes]:
+ return tuple(self._msg.objectIds)
@staticmethod
- def deserialize(data: bytes) -> "ProcessorHeartbeat":
- pid, initialized, has_task, suspended, cpu, rss = struct.unpack(ProcessorHeartbeat.FORMAT, data)
- return ProcessorHeartbeat(pid, initialized, has_task, suspended, float(cpu / 1000), rss)
+ def new_msg(request_type: ObjectRequestType, object_ids: Tuple[bytes, ...]) -> "ObjectRequest":
+ return ObjectRequest(_message.ObjectRequest(requestType=request_type.value, objectIds=list(object_ids)))
-@dataclasses.dataclass
-class WorkerHeartbeat(_Message):
- agent_cpu: float
- agent_rss: int
- rss_free: int
- queued_tasks: int
- latency_us: int
- task_lock: bool
+class ObjectResponse(Message):
+ class ObjectResponseType(enum.Enum):
+ Content = _message.ObjectResponse.ObjectResponseType.content
+ ObjectNotExist = _message.ObjectResponse.ObjectResponseType.objectNotExist
- processors: List[ProcessorHeartbeat]
+ def __init__(self, msg):
+ super().__init__(msg)
- FORMAT = "HQQHI?" # processor heartbeats come right after the main fields
+ @property
+ def response_type(self) -> ObjectResponseType:
+ return ObjectResponse.ObjectResponseType(self._msg.responseType.raw)
- def serialize(self) -> Tuple[bytes, ...]:
- return (
- struct.pack(
- WorkerHeartbeat.FORMAT,
- int(self.agent_cpu * 1000),
- self.agent_rss,
- self.rss_free,
- self.queued_tasks,
- self.latency_us,
- self.task_lock,
- ),
- *(p.serialize() for p in self.processors)
- )
+ @property
+ def object_content(self) -> ObjectContent:
+ return ObjectContent(self._msg.objectContent)
@staticmethod
- def deserialize(data: List[bytes]):
- (
- agent_cpu,
- agent_rss,
- rss_free,
- queued_tasks,
- latency_us,
- task_lock,
- ) = struct.unpack(WorkerHeartbeat.FORMAT, data[0])
- processors = [ProcessorHeartbeat.deserialize(d) for d in data[1:]]
-
- return WorkerHeartbeat(
- float(agent_cpu / 1000),
- agent_rss,
- rss_free,
- queued_tasks,
- latency_us,
- task_lock,
- processors,
+ def new_msg(response_type: ObjectResponseType, object_content: ObjectContent) -> "ObjectResponse":
+ return ObjectResponse(
+ _message.ObjectResponse(responseType=response_type.value, objectContent=object_content.get_message())
)
-@dataclasses.dataclass
-class WorkerHeartbeatEcho(_Message):
- def serialize(self) -> Tuple[bytes, ...]:
- return (b"",)
+class DisconnectRequest(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
+
+ @property
+ def worker(self) -> bytes:
+ return self._msg.worker
@staticmethod
- def deserialize(data: List[bytes]):
- return WorkerHeartbeatEcho()
+ def new_msg(worker: bytes) -> "DisconnectRequest":
+ return DisconnectRequest(_message.DisconnectRequest(worker=worker))
@dataclasses.dataclass
-class ObjectInstruction(_Message):
- type: ObjectInstructionType
- object_user: bytes
- object_content: ObjectContent
+class DisconnectResponse(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- def serialize(self) -> Tuple[bytes, ...]:
- return self.type.value, self.object_user, *self.object_content.serialize()
+ @property
+ def worker(self) -> bytes:
+ return self._msg.worker
@staticmethod
- def deserialize(data: List[bytes]) -> "ObjectInstruction":
- return ObjectInstruction(ObjectInstructionType(data[0]), data[1], ObjectContent.deserialize(data[2:]))
+ def new_msg(worker: bytes) -> "DisconnectResponse":
+ return DisconnectResponse(_message.DisconnectResponse(worker=worker))
-@dataclasses.dataclass
-class ObjectRequest(_Message):
- type: ObjectRequestType
- object_ids: Tuple[bytes, ...]
+class ClientDisconnect(Message):
+ class DisconnectType(enum.Enum):
+ Disconnect = _message.ClientDisconnect.DisconnectType.disconnect
+ Shutdown = _message.ClientDisconnect.DisconnectType.shutdown
- def __repr__(self):
- return f"ObjectRequest(type={self.type}, object_ids={tuple(object_id.hex() for object_id in self.object_ids)})"
+ def __init__(self, msg):
+ super().__init__(msg)
- def serialize(self) -> Tuple[bytes, ...]:
- return self.type.value, *self.object_ids
+ @property
+ def disconnect_type(self) -> DisconnectType:
+ return ClientDisconnect.DisconnectType(self._msg.disconnectType.raw)
@staticmethod
- def deserialize(data: List[bytes]):
- return ObjectRequest(ObjectRequestType(data[0]), tuple(data[1:]))
+ def new_msg(disconnect_type: DisconnectType) -> "ClientDisconnect":
+ return ClientDisconnect(_message.ClientDisconnect(disconnectType=disconnect_type.value))
-@dataclasses.dataclass
-class ObjectResponse(_Message):
- type: ObjectResponseType
- object_content: ObjectContent
+class ClientShutdownResponse(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- def serialize(self) -> Tuple[bytes, ...]:
- return self.type.value, *self.object_content.serialize()
+ @property
+ def accepted(self) -> bool:
+ return self._msg.accepted
@staticmethod
- def deserialize(data: List[bytes]):
- request_type = ObjectResponseType(data[0])
- return ObjectResponse(request_type, ObjectContent.deserialize(data[1:]))
+ def new_msg(accepted: bool) -> "ClientShutdownResponse":
+ return ClientShutdownResponse(_message.ClientShutdownResponse(accepted=accepted))
-@dataclasses.dataclass
-class DisconnectRequest(_Message):
- worker: bytes
-
- def serialize(self) -> Tuple[bytes, ...]:
- return (self.worker,)
+class StateClient(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
@staticmethod
- def deserialize(data: List[bytes]):
- return DisconnectRequest(data[0])
+ def new_msg() -> "StateClient":
+ return StateClient(_message.StateClient())
-@dataclasses.dataclass
-class DisconnectResponse(_Message):
- worker: bytes
-
- def serialize(self) -> Tuple[bytes, ...]:
- return (self.worker,)
+class StateObject(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
@staticmethod
- def deserialize(data: List[bytes]):
- return DisconnectResponse(data[0])
+ def new_msg() -> "StateObject":
+ return StateObject(_message.StateObject())
-@dataclasses.dataclass
-class ClientDisconnect(_Message):
- type: DisconnectType
+class StateBalanceAdvice(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- def serialize(self) -> Tuple[bytes, ...]:
- return (self.type.value,)
+ @property
+ def worker_id(self) -> bytes:
+ return self._msg.workerId
+
+ @property
+ def task_ids(self) -> List[bytes]:
+ return self._msg.taskIds
@staticmethod
- def deserialize(data: List[bytes]):
- return ClientDisconnect(DisconnectType(data[0]))
+ def new_msg(worker_id: bytes, task_ids: List[bytes]) -> "StateBalanceAdvice":
+ return StateBalanceAdvice(_message.StateBalanceAdvice(workerId=worker_id, taskIds=task_ids))
-@dataclasses.dataclass
-class ClientShutdownResponse(_Message):
- accepted: bool
+class StateScheduler(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- def serialize(self) -> Tuple[bytes, ...]:
- return (struct.pack("?", self.accepted),)
+ @property
+ def binder(self) -> BinderStatus:
+ return BinderStatus(self._msg.binder)
- @staticmethod
- def deserialize(data: List[bytes]):
- return ClientShutdownResponse(struct.unpack("?", data[0])[0])
+ @property
+ def scheduler(self) -> Resource:
+ return Resource(self._msg.scheduler)
+ @property
+ def rss_free(self) -> int:
+ return self._msg.rssFree
-@dataclasses.dataclass
-class StateClient(_Message):
- # TODO: implement this
- def serialize(self) -> Tuple[bytes, ...]:
- return (b"",)
+ @property
+ def client_manager(self) -> ClientManagerStatus:
+ return ClientManagerStatus(self._msg.clientManager)
- @staticmethod
- def deserialize(data: List[bytes]):
- return StateClient()
+ @property
+ def object_manager(self) -> ObjectManagerStatus:
+ return ObjectManagerStatus(self._msg.objectManager)
+ @property
+ def task_manager(self) -> TaskManagerStatus:
+ return TaskManagerStatus(self._msg.taskManager)
-@dataclasses.dataclass
-class StateObject(_Message):
- # TODO: implement this
- def serialize(self) -> Tuple[bytes, ...]:
- return (b"",)
+ @property
+ def worker_manager(self) -> WorkerManagerStatus:
+ return WorkerManagerStatus(self._msg.workerManager)
@staticmethod
- def deserialize(data: List[bytes]):
- return StateObject()
+ def new_msg(
+ binder: BinderStatus,
+ scheduler: Resource,
+ rss_free: int,
+ client_manager: ClientManagerStatus,
+ object_manager: ObjectManagerStatus,
+ task_manager: TaskManagerStatus,
+ worker_manager: WorkerManagerStatus,
+ ) -> "StateScheduler":
+ return StateScheduler(
+ _message.StateScheduler(
+ binder=binder.get_message(),
+ scheduler=scheduler.get_message(),
+ rssFree=rss_free,
+ clientManager=client_manager.get_message(),
+ objectManager=object_manager.get_message(),
+ taskManager=task_manager.get_message(),
+ workerManager=worker_manager.get_message(),
+ )
+ )
-@dataclasses.dataclass
-class StateBalanceAdvice(_Message):
- worker_id: bytes
- task_ids: List[bytes]
+class StateWorker(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
+
+ @property
+ def worker_id(self) -> bytes:
+ return self._msg.workerId
- def serialize(self) -> Tuple[bytes, ...]:
- return self.worker_id, *self.task_ids
+ @property
+ def message(self) -> bytes:
+ return self._msg.message
@staticmethod
- def deserialize(data: List[bytes]):
- return StateBalanceAdvice(data[0], data[1:])
+ def new_msg(worker_id: bytes, message: bytes) -> "StateWorker":
+ return StateWorker(_message.StateWorker(workerId=worker_id, message=message))
-@dataclasses.dataclass
-class StateScheduler(_Message):
- binder: BinderStatus
- scheduler: Resource
- client_manager: ClientManagerStatus
- object_manager: ObjectManagerStatus
- task_manager: TaskManagerStatus
- worker_manager: WorkerManagerStatus
-
- def serialize(self) -> Tuple[bytes, ...]:
- return (
- pickle.dumps(
- (
- self.binder,
- self.scheduler,
- self.client_manager,
- self.object_manager,
- self.task_manager,
- self.worker_manager,
- )
- ),
- )
+class StateTask(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
- @staticmethod
- def deserialize(data: List[bytes]):
- return StateScheduler(*pickle.loads(data[0]))
+ @property
+ def task_id(self) -> bytes:
+ return self._msg.taskId
+ @property
+ def function_name(self) -> bytes:
+ return self._msg.functionName
-@dataclasses.dataclass
-class StateWorker(_Message):
- worker_id: bytes
- message: bytes
+ @property
+ def status(self) -> TaskStatus:
+ return TaskStatus(self._msg.status.raw)
- def serialize(self) -> Tuple[bytes, ...]:
- return self.worker_id, self.message
+ @property
+ def worker(self) -> bytes:
+ return self._msg.worker
+
+ @property
+ def metadata(self) -> bytes:
+ return self._msg.metadata
@staticmethod
- def deserialize(data: List[bytes]):
- return StateWorker(data[0], data[1])
+ def new_msg(
+ task_id: bytes, function_name: bytes, status: TaskStatus, worker: bytes, metadata: bytes = b""
+ ) -> "StateTask":
+ return StateTask(
+ _message.StateTask(
+ taskId=task_id, functionName=function_name, status=status.value, worker=worker, metadata=metadata
+ )
+ )
-@dataclasses.dataclass
-class StateTask(_Message):
- task_id: bytes
- function_name: bytes
- status: TaskStatus
- worker: bytes
- metadata: bytes = dataclasses.field(default=b"")
+class StateGraphTask(Message):
+ class NodeTaskType(enum.Enum):
+ Normal = _message.StateGraphTask.NodeTaskType.normal
+ Target = _message.StateGraphTask.NodeTaskType.target
- def serialize(self) -> Tuple[bytes, ...]:
- return self.task_id, self.function_name, self.status.value, self.worker, self.metadata
+ def __init__(self, msg):
+ super().__init__(msg)
- @staticmethod
- def deserialize(data: List[bytes]):
- return StateTask(data[0], data[1], TaskStatus(data[2]), data[3], data[4])
+ @property
+ def graph_task_id(self) -> bytes:
+ return self._msg.graphTaskId
+ @property
+ def task_id(self) -> bytes:
+ return self._msg.taskId
-@dataclasses.dataclass
-class StateGraphTask(_Message):
- graph_task_id: bytes
- task_id: bytes
- node_task_type: NodeTaskType
- parent_task_ids: Set[bytes]
+ @property
+ def node_task_type(self) -> NodeTaskType:
+ return StateGraphTask.NodeTaskType(self._msg.nodeTaskType.raw)
- def serialize(self) -> Tuple[bytes, ...]:
- return self.graph_task_id, self.task_id, self.node_task_type.value, *self.parent_task_ids
+ @property
+ def parent_task_ids(self) -> Set[bytes]:
+ return set(self._msg.parentTaskIds)
@staticmethod
- def deserialize(data: List[bytes]):
- return StateGraphTask(data[0], data[1], NodeTaskType(data[2]), set(data[3:]))
+ def new_msg(
+ graph_task_id: bytes, task_id: bytes, node_task_type: NodeTaskType, parent_task_ids: Set[bytes]
+ ) -> "StateGraphTask":
+ return StateGraphTask(
+ _message.StateGraphTask(
+ graphTaskId=graph_task_id,
+ taskId=task_id,
+ nodeTaskType=node_task_type.value,
+ parentTaskIds=list(parent_task_ids),
+ )
+ )
-@dataclasses.dataclass
-class ProcessorInitialized(_Message):
- def serialize(self) -> Tuple[bytes, ...]:
- return (b"",)
+class ProcessorInitialized(Message):
+ def __init__(self, msg):
+ super().__init__(msg)
@staticmethod
- def deserialize(data: List[bytes]):
- return ProcessorInitialized()
+ def new_msg() -> "ProcessorInitialized":
+ return ProcessorInitialized(_message.ProcessorInitialized())
-PROTOCOL = bidict.bidict(
+PROTOCOL: bidict.bidict[str, Type[Message]] = bidict.bidict(
{
- MessageType.ClientHeartbeat: ClientHeartbeat,
- MessageType.ClientHeartbeatEcho: ClientHeartbeatEcho,
- MessageType.WorkerHeartbeat: WorkerHeartbeat,
- MessageType.WorkerHeartbeatEcho: WorkerHeartbeatEcho,
- MessageType.Task: Task,
- MessageType.TaskCancel: TaskCancel,
- MessageType.TaskResult: TaskResult,
- MessageType.GraphTask: GraphTask,
- MessageType.GraphTaskCancel: GraphTaskCancel,
- MessageType.ObjectInstruction: ObjectInstruction,
- MessageType.ObjectRequest: ObjectRequest,
- MessageType.ObjectResponse: ObjectResponse,
- MessageType.DisconnectRequest: DisconnectRequest,
- MessageType.DisconnectResponse: DisconnectResponse,
- MessageType.StateClient: StateClient,
- MessageType.StateObject: StateObject,
- MessageType.StateBalanceAdvice: StateBalanceAdvice,
- MessageType.StateScheduler: StateScheduler,
- MessageType.StateWorker: StateWorker,
- MessageType.StateTask: StateTask,
- MessageType.StateGraphTask: StateGraphTask,
- MessageType.ClientDisconnect: ClientDisconnect,
- MessageType.ClientShutdownResponse: ClientShutdownResponse,
- MessageType.ProcessorInitialized: ProcessorInitialized,
+ "task": Task,
+ "taskCancel": TaskCancel,
+ "taskResult": TaskResult,
+ "graphTask": GraphTask,
+ "graphTaskCancel": GraphTaskCancel,
+ "objectInstruction": ObjectInstruction,
+ "objectRequest": ObjectRequest,
+ "objectResponse": ObjectResponse,
+ "clientHeartbeat": ClientHeartbeat,
+ "clientHeartbeatEcho": ClientHeartbeatEcho,
+ "workerHeartbeat": WorkerHeartbeat,
+ "workerHeartbeatEcho": WorkerHeartbeatEcho,
+ "disconnectRequest": DisconnectRequest,
+ "disconnectResponse": DisconnectResponse,
+ "stateClient": StateClient,
+ "stateObject": StateObject,
+ "stateBalanceAdvice": StateBalanceAdvice,
+ "stateScheduler": StateScheduler,
+ "stateWorker": StateWorker,
+ "stateTask": StateTask,
+ "stateGraphTask": StateGraphTask,
+ "clientDisconnect": ClientDisconnect,
+ "clientShutdownResponse": ClientShutdownResponse,
+ "processorInitialized": ProcessorInitialized,
}
)
diff --git a/scaler/protocol/python/mixins.py b/scaler/protocol/python/mixins.py
index 765873d..52c44e6 100644
--- a/scaler/protocol/python/mixins.py
+++ b/scaler/protocol/python/mixins.py
@@ -1,13 +1,13 @@
import abc
-from typing import List, Tuple
+from typing import TypeVar
-class _Message(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def serialize(self) -> Tuple[bytes, ...]:
- raise NotImplementedError()
+class Message(metaclass=abc.ABCMeta):
+ def __init__(self, msg):
+ self._msg = msg
- @staticmethod
- @abc.abstractmethod
- def deserialize(data: List[bytes]):
- raise NotImplementedError()
+ def get_message(self):
+ return self._msg
+
+
+MessageType = TypeVar("MessageType", bound=Message)
diff --git a/scaler/protocol/python/status.py b/scaler/protocol/python/status.py
index 1f3cb08..5dfc702 100644
--- a/scaler/protocol/python/status.py
+++ b/scaler/protocol/python/status.py
@@ -1,65 +1,276 @@
-import dataclasses
from typing import Dict, List
+from scaler.protocol.capnp._python import _status # noqa
+from scaler.protocol.python.mixins import Message
-@dataclasses.dataclass
-class Resource:
- cpu: float
- rss: int
- rss_free: int
+class Resource(Message):
+ def __init__(self, msg):
+ self._msg = msg
-@dataclasses.dataclass
-class ObjectManagerStatus:
- number_of_objects: int
- object_memory: int
+ @property
+ def cpu(self) -> int:
+ return self._msg.cpu
+ @property
+ def rss(self) -> int:
+ return self._msg.rss
-@dataclasses.dataclass
-class ClientManagerStatus:
- client_to_num_of_tasks: Dict[bytes, int]
+ @staticmethod
+ def new_msg(cpu: int, rss: int) -> "Resource": # type: ignore[override]
+ return Resource(_status.Resource(cpu=cpu, rss=rss))
+ def get_message(self):
+ return self._msg
-@dataclasses.dataclass
-class TaskManagerStatus:
- unassigned: int
- running: int
- success: int
- failed: int
- canceled: int
- not_found: int
+class ObjectManagerStatus(Message):
+ def __init__(self, msg):
+ self._msg = msg
-@dataclasses.dataclass
-class ProcessorStatus:
- pid: int
- initialized: bool
- has_task: bool
- suspended: bool
- resource: Resource
+ @property
+ def number_of_objects(self) -> int:
+ return self._msg.numberOfObjects
+ @property
+ def object_memory(self) -> int:
+ return self._msg.objectMemory
-@dataclasses.dataclass
-class WorkerStatus:
- worker_id: bytes
- agent: Resource
- total_processors: Resource
- free: int
- sent: int
- queued: int
- suspended: int
- lag_us: int
- last_s: int
- ITL: str
- processor_statuses: List[ProcessorStatus]
+ @staticmethod
+ def new_msg(number_of_objects: int, object_memory: int) -> "ObjectManagerStatus": # type: ignore[override]
+ return ObjectManagerStatus(
+ _status.ObjectManagerStatus(numberOfObjects=number_of_objects, objectMemory=object_memory)
+ )
+ def get_message(self):
+ return self._msg
-@dataclasses.dataclass
-class WorkerManagerStatus:
- workers: List[WorkerStatus]
+class ClientManagerStatus(Message):
+ def __init__(self, msg):
+ self._msg = msg
-@dataclasses.dataclass
-class BinderStatus:
- received: Dict[str, int]
- sent: Dict[str, int]
+ @property
+ def client_to_num_of_tasks(self) -> Dict[bytes, int]:
+ return {p.client: p.numTask for p in self._msg.clientToNumOfTask}
+
+ @staticmethod
+ def new_msg(client_to_num_of_tasks: Dict[bytes, int]) -> "ClientManagerStatus": # type: ignore[override]
+ return ClientManagerStatus(
+ _status.ClientManagerStatus(
+ clientToNumOfTask=[
+ _status.ClientManagerStatus.Pair(client=p[0], numTask=p[1]) for p in client_to_num_of_tasks.items()
+ ]
+ )
+ )
+
+ def get_message(self):
+ return self._msg
+
+
+class TaskManagerStatus(Message):
+ def __init__(self, msg):
+ self._msg = msg
+
+ @property
+ def unassigned(self) -> int:
+ return self._msg.unassigned
+
+ @property
+ def running(self) -> int:
+ return self._msg.running
+
+ @property
+ def success(self) -> int:
+ return self._msg.success
+
+ @property
+ def failed(self) -> int:
+ return self._msg.failed
+
+ @property
+ def canceled(self) -> int:
+ return self._msg.canceled
+
+ @property
+ def not_found(self) -> int:
+ return self._msg.notFound
+
+ @staticmethod
+ def new_msg( # type: ignore[override]
+ unassigned: int, running: int, success: int, failed: int, canceled: int, not_found: int
+ ) -> "TaskManagerStatus":
+ return TaskManagerStatus(
+ _status.TaskManagerStatus(
+ unassigned=unassigned,
+ running=running,
+ success=success,
+ failed=failed,
+ canceled=canceled,
+ notFound=not_found,
+ )
+ )
+
+ def get_message(self):
+ return self._msg
+
+
+class ProcessorStatus(Message):
+ def __init__(self, msg):
+ self._msg = msg
+
+ @property
+ def pid(self) -> int:
+ return self._msg.pid
+
+ @property
+ def initialized(self) -> int:
+ return self._msg.initialized
+
+ @property
+ def has_task(self) -> bool:
+ return self._msg.hasTask
+
+ @property
+ def suspended(self) -> bool:
+ return self._msg.suspended
+
+ @property
+ def resource(self) -> Resource:
+ return Resource(self._msg.resource)
+
+ @staticmethod
+ def new_msg(
+ pid: int, initialized: int, has_task: bool, suspended: bool, resource: Resource # type: ignore[override]
+ ) -> "ProcessorStatus":
+ return ProcessorStatus(
+ _status.ProcessorStatus(
+ pid=pid, initialized=initialized, hasTask=has_task, suspended=suspended, resource=resource.get_message()
+ )
+ )
+
+ def get_message(self):
+ return self._msg
+
+
+class WorkerStatus(Message):
+ def __init__(self, msg):
+ self._msg = msg
+
+ @property
+ def worker_id(self) -> bytes:
+ return self._msg.workerId
+
+ @property
+ def agent(self) -> Resource:
+ return Resource(self._msg.agent)
+
+ @property
+ def rss_free(self) -> int:
+ return self._msg.rssFree
+
+ @property
+ def free(self) -> int:
+ return self._msg.free
+
+ @property
+ def sent(self) -> int:
+ return self._msg.sent
+
+ @property
+ def queued(self) -> int:
+ return self._msg.queued
+
+ @property
+ def suspended(self) -> bool:
+ return self._msg.suspended
+
+ @property
+ def lag_us(self) -> int:
+ return self._msg.lagUS
+
+ @property
+ def last_s(self) -> int:
+ return self._msg.lastS
+
+ @property
+ def itl(self) -> str:
+ return self._msg.itl
+
+ @property
+ def processor_statuses(self) -> List[ProcessorStatus]:
+ return [ProcessorStatus(ps) for ps in self._msg.processorStatuses]
+
+ @staticmethod
+ def new_msg( # type: ignore[override]
+ worker_id: bytes,
+ agent: Resource,
+ rss_free: int,
+ free: int,
+ sent: int,
+ queued: int,
+ suspended: int,
+ lag_us: int,
+ last_s: int,
+ itl: str,
+ processor_statuses: List[ProcessorStatus],
+ ) -> "WorkerStatus":
+ return WorkerStatus(
+ _status.WorkerStatus(
+ workerId=worker_id,
+ agent=agent.get_message(),
+ rssFree=rss_free,
+ free=free,
+ sent=sent,
+ queued=queued,
+ suspended=suspended,
+ lagUS=lag_us,
+ lastS=last_s,
+ itl=itl,
+ processorStatuses=[ps.get_message() for ps in processor_statuses],
+ )
+ )
+
+ def get_message(self):
+ return self._msg
+
+
+class WorkerManagerStatus(Message):
+ def __init__(self, msg):
+ self._msg = msg
+
+ @property
+ def workers(self) -> List[WorkerStatus]:
+ return [WorkerStatus(ws) for ws in self._msg.workers]
+
+ @staticmethod
+ def new_msg(workers: List[WorkerStatus]) -> "WorkerManagerStatus": # type: ignore[override]
+ return WorkerManagerStatus(_status.WorkerManagerStatus(workers=[ws.get_message() for ws in workers]))
+
+ def get_message(self):
+ return self._msg
+
+
+class BinderStatus(Message):
+ def __init__(self, msg):
+ self._msg = msg
+
+ @property
+ def received(self) -> Dict[str, int]:
+ return {p.client: p.number for p in self._msg.received}
+
+ @property
+ def sent(self) -> Dict[str, int]:
+ return {p.client: p.number for p in self._msg.sent}
+
+ @staticmethod
+ def new_msg(received: Dict[str, int], sent: Dict[str, int]) -> "BinderStatus": # type: ignore[override]
+ return BinderStatus(
+ _status.BinderStatus(
+ received=[_status.BinderStatus.Pair(client=p[0], number=p[1]) for p in received.items()],
+ sent=[_status.BinderStatus.Pair(client=p[0], number=p[1]) for p in sent.items()],
+ )
+ )
+
+ def get_message(self):
+ return self._msg
diff --git a/scaler/scheduler/allocators/mixins.py b/scaler/scheduler/allocators/mixins.py
index 3bbbd14..34b4077 100644
--- a/scaler/scheduler/allocators/mixins.py
+++ b/scaler/scheduler/allocators/mixins.py
@@ -1,5 +1,5 @@
import abc
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Set
class TaskAllocator(metaclass=abc.ABCMeta):
@@ -14,7 +14,7 @@ def remove_worker(self, worker: bytes) -> List[bytes]:
raise NotImplementedError()
@abc.abstractmethod
- def get_worker_ids(self) -> List[bytes]:
+ def get_worker_ids(self) -> Set[bytes]:
"""get all worker ids as list"""
raise NotImplementedError()
@@ -24,9 +24,9 @@ def get_worker_by_task_id(self, task_id: bytes) -> bytes:
raise NotImplementedError()
@abc.abstractmethod
- def balance(self) -> Dict[bytes, int]:
- """balance worker, it should return the number of tasks for over burdened worker, represented as worker
- identity to number of tasks dictionary"""
+ def balance(self) -> Dict[bytes, List[bytes]]:
+ """balance worker, it should return list of task ids for over burdened worker, represented as worker
+ identity to list of task ids dictionary"""
raise NotImplementedError()
@abc.abstractmethod
diff --git a/scaler/scheduler/client_manager.py b/scaler/scheduler/client_manager.py
index ed6de20..2d2f1e1 100644
--- a/scaler/scheduler/client_manager.py
+++ b/scaler/scheduler/client_manager.py
@@ -9,7 +9,6 @@
ClientHeartbeat,
ClientHeartbeatEcho,
ClientShutdownResponse,
- DisconnectType,
TaskCancel,
)
from scaler.protocol.python.status import ClientManagerStatus
@@ -63,14 +62,14 @@ def on_task_finish(self, task_id: bytes) -> bytes:
return self._client_to_task_ids.remove_value(task_id)
async def on_heartbeat(self, client: bytes, info: ClientHeartbeat):
- await self._binder.send(client, ClientHeartbeatEcho())
+ await self._binder.send(client, ClientHeartbeatEcho.new_msg())
if client not in self._client_last_seen:
- logging.info(f"client {client} connected")
+ logging.info(f"client {client!r} connected")
self._client_last_seen[client] = (time.time(), info)
async def on_client_disconnect(self, client: bytes, request: ClientDisconnect):
- if request.type == DisconnectType.Disconnect:
+ if request.disconnect_type == ClientDisconnect.DisconnectType.Disconnect:
await self.__on_client_disconnect(client)
return
@@ -78,23 +77,23 @@ async def on_client_disconnect(self, client: bytes, request: ClientDisconnect):
logging.warning("cannot shutdown clusters as scheduler is running in protected mode")
accepted = False
else:
- logging.info(f"shutdown scheduler and all clusters as received signal from {client=}")
+ logging.info(f"shutdown scheduler and all clusters as received signal from {client=!r}")
accepted = True
- await self._binder.send(client, ClientShutdownResponse(accepted=accepted))
+ await self._binder.send(client, ClientShutdownResponse.new_msg(accepted=accepted))
if self._protected:
return
await self._worker_manager.on_client_shutdown(client)
- raise ClientShutdownException(f"received client shutdown from {client}, quiting")
+ raise ClientShutdownException(f"received client shutdown from {client!r}, quiting")
async def routine(self):
await self.__routine_cleanup_clients()
def get_status(self) -> ClientManagerStatus:
- return ClientManagerStatus(
+ return ClientManagerStatus.new_msg(
{client.decode(): len(task_ids) for client, task_ids in self._client_to_task_ids.items()}
)
@@ -110,7 +109,7 @@ async def __routine_cleanup_clients(self):
await self.__on_client_disconnect(client)
async def __on_client_disconnect(self, client_id: bytes):
- logging.info(f"client {client_id} disconnected")
+ logging.info(f"client {client_id!r} disconnected")
if client_id in self._client_last_seen:
self._client_last_seen.pop(client_id)
@@ -123,4 +122,4 @@ async def __cancel_tasks(self, client: bytes):
tasks = self._client_to_task_ids.get_values(client).copy()
for task in tasks:
- await self._task_manager.on_task_cancel(client, TaskCancel(task))
+ await self._task_manager.on_task_cancel(client, TaskCancel.new_msg(task))
diff --git a/scaler/scheduler/graph_manager.py b/scaler/scheduler/graph_manager.py
index 1ef96ba..392c829 100644
--- a/scaler/scheduler/graph_manager.py
+++ b/scaler/scheduler/graph_manager.py
@@ -6,18 +6,8 @@
from scaler.io.async_binder import AsyncBinder
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import (
- Argument,
- ArgumentType,
- GraphTask,
- GraphTaskCancel,
- NodeTaskType,
- StateGraphTask,
- Task,
- TaskCancel,
- TaskResult,
- TaskStatus,
-)
+from scaler.protocol.python.common import TaskStatus
+from scaler.protocol.python.message import GraphTask, GraphTaskCancel, StateGraphTask, Task, TaskCancel, TaskResult
from scaler.scheduler.mixins import ClientManager, GraphTaskManager, ObjectManager, TaskManager
from scaler.utility.graph.topological_sorter import TopologicalSorter
from scaler.utility.many_to_many_dict import ManyToManyDict
@@ -107,7 +97,7 @@ async def on_graph_task(self, client: bytes, graph_task: GraphTask):
async def on_graph_task_cancel(self, client: bytes, graph_task_cancel: GraphTaskCancel):
if graph_task_cancel.task_id not in self._graph_task_id_to_graph:
- await self._binder.send(client, TaskResult(graph_task_cancel.task_id, TaskStatus.NotFound))
+ await self._binder.send(client, TaskResult.new_msg(graph_task_cancel.task_id, TaskStatus.NotFound))
return
graph_task_id = self._task_id_to_graph_task_id[graph_task_cancel.task_id]
@@ -115,7 +105,7 @@ async def on_graph_task_cancel(self, client: bytes, graph_task_cancel: GraphTask
if graph_info.status == _GraphState.Canceling:
return
- await self.__cancel_one_graph(graph_task_id, TaskResult(graph_task_cancel.task_id, TaskStatus.Canceled))
+ await self.__cancel_one_graph(graph_task_id, TaskResult.new_msg(graph_task_cancel.task_id, TaskStatus.Canceled))
async def on_graph_sub_task_done(self, result: TaskResult):
graph_task_id = self._task_id_to_graph_task_id[result.task_id]
@@ -148,22 +138,26 @@ async def __add_new_graph(self, client: bytes, graph_task: GraphTask):
self._client_manager.on_task_begin(client, graph_task.task_id)
tasks = dict()
- depended_task_id_to_task_id = ManyToManyDict()
+ depended_task_id_to_task_id: ManyToManyDict[bytes, bytes] = ManyToManyDict()
for task in graph_task.graph:
self._task_id_to_graph_task_id[task.task_id] = graph_task.task_id
tasks[task.task_id] = _TaskInfo(_NodeTaskState.Inactive, task)
- required_task_ids = {arg.data for arg in task.function_args if arg.type == ArgumentType.Task}
+ required_task_ids = {arg.data for arg in task.function_args if arg.type == Task.Argument.ArgumentType.Task}
for required_task_id in required_task_ids:
depended_task_id_to_task_id.add(required_task_id, task.task_id)
graph[task.task_id] = required_task_ids
await self._binder_monitor.send(
- StateGraphTask(
+ StateGraphTask.new_msg(
graph_task.task_id,
task.task_id,
- NodeTaskType.Target if task.task_id in graph_task.targets else NodeTaskType.Normal,
+ (
+ StateGraphTask.NodeTaskType.Target
+ if task.task_id in graph_task.targets
+ else StateGraphTask.NodeTaskType.Normal
+ ),
required_task_ids,
)
)
@@ -179,7 +173,7 @@ async def __add_new_graph(self, client: bytes, graph_task: GraphTask):
async def __check_one_graph(self, graph_task_id: bytes):
graph_info = self._graph_task_id_to_graph[graph_task_id]
if not graph_info.sorter.is_active():
- await self.__finish_one_graph(graph_task_id, TaskResult(graph_task_id, TaskStatus.Success))
+ await self.__finish_one_graph(graph_task_id, TaskResult.new_msg(graph_task_id, TaskStatus.Success))
return
ready_task_ids = graph_info.sorter.get_ready()
@@ -191,12 +185,12 @@ async def __check_one_graph(self, graph_task_id: bytes):
task_info.state = _NodeTaskState.Running
graph_info.running_task_ids.add(task_id)
- task = Task(
- task_info.task.task_id,
- task_info.task.source,
- task_info.task.metadata,
- task_info.task.func_object_id,
- [self.__get_argument(graph_task_id, arg) for arg in task_info.task.function_args],
+ task = Task.new_msg(
+ task_id=task_info.task.task_id,
+ source=task_info.task.source,
+ metadata=task_info.task.metadata,
+ func_object_id=task_info.task.func_object_id,
+ function_args=[self.__get_argument(graph_task_id, arg) for arg in task_info.task.function_args],
)
await self._task_manager.on_task_new(graph_info.client, task)
@@ -240,7 +234,7 @@ async def __cancel_one_graph(self, graph_task_id: bytes, result: TaskResult):
await self.__clean_all_inactive_nodes(graph_task_id, result)
await self.__finish_one_graph(
- graph_task_id, TaskResult(result.task_id, result.status, result.metadata, result.results)
+ graph_task_id, TaskResult.new_msg(result.task_id, result.status, result.metadata, result.results)
)
async def __clean_all_running_nodes(self, graph_task_id: bytes, result: TaskResult):
@@ -261,8 +255,10 @@ async def __clean_all_running_nodes(self, graph_task_id: bytes, result: TaskResu
)
new_result_object_ids.append(new_result_object_id)
- await self._task_manager.on_task_cancel(graph_info.client, TaskCancel(task_id))
- await self.__mark_node_done(TaskResult(task_id, result.status, result.metadata, new_result_object_ids))
+ await self._task_manager.on_task_cancel(graph_info.client, TaskCancel.new_msg(task_id))
+ await self.__mark_node_done(
+ TaskResult.new_msg(task_id, result.status, result.metadata, new_result_object_ids)
+ )
async def __clean_all_inactive_nodes(self, graph_task_id: bytes, result: TaskResult):
graph_info = self._graph_task_id_to_graph[graph_task_id]
@@ -280,12 +276,14 @@ async def __clean_all_inactive_nodes(self, graph_task_id: bytes, result: TaskRes
)
new_result_object_ids.append(new_result_object_id)
- await self.__mark_node_done(TaskResult(task_id, result.status, result.metadata, new_result_object_ids))
+ await self.__mark_node_done(
+ TaskResult.new_msg(task_id, result.status, result.metadata, new_result_object_ids)
+ )
async def __finish_one_graph(self, graph_task_id: bytes, result: TaskResult):
self._client_manager.on_task_finish(graph_task_id)
info = self._graph_task_id_to_graph.pop(graph_task_id)
- await self._binder.send(info.client, TaskResult(graph_task_id, result.status, results=result.results))
+ await self._binder.send(info.client, TaskResult.new_msg(graph_task_id, result.status, results=result.results))
def __is_graph_finished(self, graph_task_id: bytes):
graph_info = self._graph_task_id_to_graph[graph_task_id]
@@ -299,11 +297,11 @@ def __get_target_results_ids(self, graph_task_id: bytes) -> List[bytes]:
for result_object_id in graph_info.tasks[task_id].result_object_ids
]
- def __get_argument(self, graph_task_id: bytes, argument: Argument) -> Argument:
- if argument.type == ArgumentType.ObjectID:
+ def __get_argument(self, graph_task_id: bytes, argument: Task.Argument) -> Task.Argument:
+ if argument.type == Task.Argument.ArgumentType.ObjectID:
return argument
- assert argument.type == ArgumentType.Task
+ assert argument.type == Task.Argument.ArgumentType.Task
argument_task_id = argument.data
graph_info = self._graph_task_id_to_graph[graph_task_id]
@@ -311,13 +309,13 @@ def __get_argument(self, graph_task_id: bytes, argument: Argument) -> Argument:
assert len(task_info.result_object_ids) == 1
- return Argument(ArgumentType.ObjectID, task_info.result_object_ids[0])
+ return Task.Argument(Task.Argument.ArgumentType.ObjectID, task_info.result_object_ids[0])
def __clean_intermediate_result(self, graph_task_id: bytes, task_id: bytes):
graph_info = self._graph_task_id_to_graph[graph_task_id]
task_info = graph_info.tasks[task_id]
- for argument in filter(lambda arg: arg.type == ArgumentType.Task, task_info.task.function_args):
+ for argument in filter(lambda arg: arg.type == Task.Argument.ArgumentType.Task, task_info.task.function_args):
argument_task_id = argument.data
graph_info.depended_task_id_to_task_id.remove(argument_task_id, task_id)
if graph_info.depended_task_id_to_task_id.has_left_key(argument_task_id):
diff --git a/scaler/scheduler/mixins.py b/scaler/scheduler/mixins.py
index d1928a1..bb6af88 100644
--- a/scaler/scheduler/mixins.py
+++ b/scaler/scheduler/mixins.py
@@ -8,21 +8,22 @@
GraphTask,
GraphTaskCancel,
ObjectRequest,
- ObjectResponse,
Task,
TaskCancel,
TaskResult,
WorkerHeartbeat,
+ ObjectInstruction,
)
+from scaler.utility.mixins import Reporter
-class ObjectManager(metaclass=abc.ABCMeta):
+class ObjectManager(Reporter):
@abc.abstractmethod
- async def on_object_instruction(self, source: bytes, request: ObjectResponse):
+ async def on_object_instruction(self, source: bytes, request: ObjectInstruction):
raise NotImplementedError()
@abc.abstractmethod
- async def on_object_request(self, source: bytes, response: ObjectRequest):
+ async def on_object_request(self, source: bytes, request: ObjectRequest):
raise NotImplementedError()
@abc.abstractmethod
@@ -50,7 +51,7 @@ def get_object_content(self, object_id: bytes) -> bytes:
raise NotImplementedError()
-class ClientManager(metaclass=abc.ABCMeta):
+class ClientManager(Reporter):
@abc.abstractmethod
def get_client_task_ids(self, client: bytes) -> Set[bytes]:
raise NotImplementedError()
@@ -79,7 +80,7 @@ async def on_client_disconnect(self, client: bytes, request: ClientDisconnect):
raise NotImplementedError()
-class GraphTaskManager(metaclass=abc.ABCMeta):
+class GraphTaskManager(Reporter):
@abc.abstractmethod
async def on_graph_task(self, client: bytes, graph_task: GraphTask):
raise NotImplementedError()
@@ -97,23 +98,7 @@ def is_graph_sub_task(self, task_id: bytes):
raise NotImplementedError()
-class TaskReadyManager(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- async def on_task_new(self, client: bytes, task: Task):
- raise NotImplementedError()
-
- @abc.abstractmethod
- async def on_task_cancel(self, task_cancel: TaskCancel):
- raise NotImplementedError()
-
-
-class TaskResultReadyManager(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- async def on_task_done(self, task_result: TaskResult):
- raise NotImplementedError()
-
-
-class TaskManager(metaclass=abc.ABCMeta):
+class TaskManager(Reporter):
@abc.abstractmethod
async def on_task_new(self, client: bytes, task: Task):
raise NotImplementedError()
@@ -131,7 +116,7 @@ async def on_task_reroute(self, task_id: bytes):
raise NotImplementedError()
-class WorkerManager(metaclass=abc.ABCMeta):
+class WorkerManager(Reporter):
@abc.abstractmethod
async def assign_task_to_worker(self, task: Task) -> bool:
raise NotImplementedError()
diff --git a/scaler/scheduler/object_manager.py b/scaler/scheduler/object_manager.py
index 9253928..4953e22 100644
--- a/scaler/scheduler/object_manager.py
+++ b/scaler/scheduler/object_manager.py
@@ -5,15 +5,8 @@
from scaler.io.async_binder import AsyncBinder
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import (
- ObjectContent,
- ObjectInstruction,
- ObjectInstructionType,
- ObjectRequest,
- ObjectRequestType,
- ObjectResponse,
- ObjectResponseType,
-)
+from scaler.protocol.python.common import ObjectContent
+from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse
from scaler.protocol.python.status import ObjectManagerStatus
from scaler.scheduler.mixins import ClientManager, ObjectManager, WorkerManager
from scaler.scheduler.object_usage.object_tracker import ObjectTracker, ObjectUsage
@@ -34,7 +27,7 @@ def get_object_key(self) -> bytes:
class VanillaObjectManager(ObjectManager, Looper, Reporter):
def __init__(self):
- self._object_storage: ObjectTracker[_ObjectCreation, bytes] = ObjectTracker(
+ self._object_storage: ObjectTracker[bytes, _ObjectCreation] = ObjectTracker(
"object_usage", self.__finished_object_storage
)
@@ -58,31 +51,31 @@ def register(
self._worker_manager = worker_manager
async def on_object_instruction(self, source: bytes, instruction: ObjectInstruction):
- if instruction.type == ObjectInstructionType.Create:
+ if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Create:
self.__on_object_create(source, instruction)
return
- if instruction.type == ObjectInstructionType.Delete:
+ if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete:
self.on_del_objects(instruction.object_user, set(instruction.object_content.object_ids))
return
logging.error(
- f"received unknown object response type instruction_type={instruction.type} from "
+ f"received unknown object response type instruction_type={instruction.instruction_type} from "
f"source={instruction.object_user}"
)
async def on_object_request(self, source: bytes, request: ObjectRequest):
- if request.type == ObjectRequestType.Get:
+ if request.request_type == ObjectRequest.ObjectRequestType.Get:
await self.__process_get_request(source, request)
return
- logging.error(f"received unknown object request type {request=} from {source=}")
+ logging.error(f"received unknown object request type {request=} from {source=!r}")
def on_add_object(self, object_user: bytes, object_id: bytes, object_name: bytes, object_bytes: bytes):
creation = _ObjectCreation(object_id, object_user, object_name, object_bytes)
logging.debug(
f"add object cache "
- f"object_name={creation.object_name}, "
+ f"object_name={creation.object_name!r}, "
f"object_id={creation.object_id.hex()}, "
f"size={format_bytes(len(creation.object_bytes))}"
)
@@ -116,10 +109,8 @@ def get_object_content(self, object_id: bytes) -> bytes:
return self._object_storage.get_object(object_id).object_bytes
def get_status(self) -> ObjectManagerStatus:
- return ObjectManagerStatus(
- # self._pending_get_requests.object_count(),
- self._object_storage.object_count(),
- sum(len(v.object_bytes) for _, v in self._object_storage.items()),
+ return ObjectManagerStatus.new_msg(
+ self._object_storage.object_count(), sum(len(v.object_bytes) for _, v in self._object_storage.items())
)
async def __process_get_request(self, source: bytes, request: ObjectRequest):
@@ -136,12 +127,16 @@ async def __routine_send_objects_deletions(self):
for worker in self._worker_manager.get_worker_ids():
await self._binder.send(
worker,
- ObjectInstruction(ObjectInstructionType.Delete, worker, ObjectContent(tuple(deleted_object_ids))),
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Delete,
+ worker,
+ ObjectContent.new_msg(tuple(deleted_object_ids)),
+ ),
)
def __on_object_create(self, source: bytes, instruction: ObjectInstruction):
if not self._client_manager.has_client_id(instruction.object_user):
- logging.error(f"received object creation from {source} for unknown client {instruction.object_user}")
+ logging.error(f"received object creation from {source!r} for unknown client {instruction.object_user!r}")
return
for object_id, object_name, object_content in zip(
@@ -154,7 +149,7 @@ def __on_object_create(self, source: bytes, instruction: ObjectInstruction):
def __finished_object_storage(self, creation: _ObjectCreation):
logging.debug(
f"del object cache "
- f"object_name={creation.object_name}, "
+ f"object_name={creation.object_name!r}, "
f"object_id={creation.object_id.hex()}, "
f"size={format_bytes(len(creation.object_bytes))}"
)
@@ -173,7 +168,7 @@ def __construct_response(self, request: ObjectRequest) -> ObjectResponse:
object_names.append(object_info.object_name)
object_bytes.append(object_info.object_bytes)
- return ObjectResponse(
- ObjectResponseType.Content,
- ObjectContent(tuple(request.object_ids), tuple(object_names), tuple(object_bytes)),
+ return ObjectResponse.new_msg(
+ ObjectResponse.ObjectResponseType.Content,
+ ObjectContent.new_msg(tuple(request.object_ids), tuple(object_names), tuple(object_bytes)),
)
diff --git a/scaler/scheduler/object_usage/object_tracker.py b/scaler/scheduler/object_usage/object_tracker.py
index d6d495e..a61cdd2 100644
--- a/scaler/scheduler/object_usage/object_tracker.py
+++ b/scaler/scheduler/object_usage/object_tracker.py
@@ -6,23 +6,22 @@
ObjectKeyType = TypeVar("ObjectKeyType")
-class ObjectUsage(metaclass=abc.ABCMeta):
+class ObjectUsage(Generic[ObjectKeyType], metaclass=abc.ABCMeta):
@abc.abstractmethod
def get_object_key(self) -> ObjectKeyType:
raise NotImplementedError()
ObjectType = TypeVar("ObjectType", bound=ObjectUsage)
-BlockType = TypeVar("BlockType")
-class ObjectTracker(Generic[ObjectType, BlockType]):
+class ObjectTracker(Generic[ObjectKeyType, ObjectType]):
def __init__(self, prefix: str, callback: Callable[[ObjectType], None]):
self._prefix = prefix
self._callback = callback
- self._current_blocks: Set[BlockType] = set()
- self._object_key_to_block: ManyToManyDict[ObjectKeyType, BlockType] = ManyToManyDict()
+ self._current_blocks: Set[ObjectKeyType] = set()
+ self._object_key_to_block: ManyToManyDict[ObjectKeyType, ObjectKeyType] = ManyToManyDict()
self._object_key_to_object: Dict[ObjectKeyType, ObjectType] = dict()
def object_count(self):
@@ -43,7 +42,9 @@ def get_object(self, key: ObjectKeyType) -> ObjectType:
def add_object(self, obj: ObjectType):
self._object_key_to_object[obj.get_object_key()] = obj
- def get_object_block_pairs(self, blocks: Set[BlockType]) -> Generator[Tuple[ObjectKeyType, BlockType], None, None]:
+ def get_object_block_pairs(
+ self, blocks: Set[ObjectKeyType]
+ ) -> Generator[Tuple[ObjectKeyType, ObjectKeyType], None, None]:
for block in blocks:
if not self._object_key_to_block.has_right_key(block):
continue
@@ -51,16 +52,16 @@ def get_object_block_pairs(self, blocks: Set[BlockType]) -> Generator[Tuple[Obje
for object_key in self._object_key_to_block.get_left_items(block):
yield object_key, block
- def add_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[BlockType]):
+ def add_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[ObjectKeyType]):
if object_key not in self._object_key_to_object:
- raise KeyError(f"cannot find key={object_key.hex()} in ObjectTracker")
+ raise KeyError(f"cannot find key={object_key} in ObjectTracker")
for block in blocks:
self._object_key_to_block.add(object_key, block)
self._current_blocks.update(blocks)
- def remove_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[BlockType]):
+ def remove_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[ObjectKeyType]):
ready_objects = []
for block in blocks:
obj = self.__remove_block_for_object(object_key, block)
@@ -72,16 +73,16 @@ def remove_blocks_for_one_object(self, object_key: ObjectKeyType, blocks: Set[Bl
for obj in ready_objects:
self._callback(obj)
- def add_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: BlockType):
+ def add_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: ObjectKeyType):
for object_key in object_keys:
if object_key not in self._object_key_to_object:
- raise KeyError(f"cannot find key={object_key.hex()} in ObjectTracker")
+ raise KeyError(f"cannot find key={object_key} in ObjectTracker")
self._object_key_to_block.add(object_key, block)
self._current_blocks.add(block)
- def remove_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: BlockType):
+ def remove_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: ObjectKeyType):
ready_objects = []
for object_key in object_keys:
obj = self.__remove_block_for_object(object_key, block)
@@ -93,7 +94,7 @@ def remove_one_block_for_objects(self, object_keys: Set[ObjectKeyType], block: B
for obj in ready_objects:
self._callback(obj)
- def remove_blocks(self, blocks: Set[BlockType]):
+ def remove_blocks(self, blocks: Set[ObjectKeyType]):
ready_objects = []
for block in blocks:
if not self._object_key_to_block.has_right_key(block):
@@ -110,7 +111,7 @@ def remove_blocks(self, blocks: Set[BlockType]):
for obj in ready_objects:
self._callback(obj)
- def __remove_block_for_object(self, object_key: ObjectKeyType, block: BlockType) -> Optional[ObjectType]:
+ def __remove_block_for_object(self, object_key: ObjectKeyType, block: ObjectKeyType) -> Optional[ObjectType]:
if block not in self._current_blocks:
return None
diff --git a/scaler/scheduler/scheduler.py b/scaler/scheduler/scheduler.py
index 066f97f..e0fe274 100644
--- a/scaler/scheduler/scheduler.py
+++ b/scaler/scheduler/scheduler.py
@@ -13,7 +13,6 @@
DisconnectRequest,
GraphTask,
GraphTaskCancel,
- MessageVariant,
ObjectInstruction,
ObjectRequest,
Task,
@@ -21,6 +20,7 @@
TaskResult,
WorkerHeartbeat,
)
+from scaler.protocol.python.mixins import Message
from scaler.scheduler.client_manager import VanillaClientManager
from scaler.scheduler.config import SchedulerConfig
from scaler.scheduler.graph_manager import VanillaGraphTaskManager
@@ -92,7 +92,7 @@ def __init__(self, config: SchedulerConfig):
self._binder, self._client_manager, self._object_manager, self._task_manager, self._worker_manager
)
- async def on_receive_message(self, source: bytes, message: MessageVariant):
+ async def on_receive_message(self, source: bytes, message: Message):
# =====================================================================================
# receive from upstream
if isinstance(message, ClientHeartbeat):
diff --git a/scaler/scheduler/status_reporter.py b/scaler/scheduler/status_reporter.py
index 601383d..47910b9 100644
--- a/scaler/scheduler/status_reporter.py
+++ b/scaler/scheduler/status_reporter.py
@@ -7,7 +7,7 @@
from scaler.protocol.python.message import StateScheduler
from scaler.protocol.python.status import Resource
from scaler.scheduler.mixins import ClientManager, ObjectManager, TaskManager, WorkerManager
-from scaler.utility.mixins import Looper, Reporter
+from scaler.utility.mixins import Looper
class StatusReporter(Looper):
@@ -16,10 +16,10 @@ def __init__(self, binder: AsyncConnector):
self._process = psutil.Process()
self._binder: Optional[AsyncBinder] = None
- self._client_manager: Optional[Reporter] = None
- self._object_manager: Optional[Reporter] = None
- self._task_manager: Optional[Reporter] = None
- self._worker_manager: Optional[Reporter] = None
+ self._client_manager: Optional[ClientManager] = None
+ self._object_manager: Optional[ObjectManager] = None
+ self._task_manager: Optional[TaskManager] = None
+ self._worker_manager: Optional[WorkerManager] = None
def register_managers(
self,
@@ -37,13 +37,10 @@ def register_managers(
async def routine(self):
await self._monitor_binder.send(
- StateScheduler(
+ StateScheduler.new_msg(
binder=self._binder.get_status(),
- scheduler=Resource(
- self._process.cpu_percent() / 100,
- self._process.memory_info().rss,
- psutil.virtual_memory().available,
- ),
+ scheduler=Resource.new_msg(int(self._process.cpu_percent() * 10), self._process.memory_info().rss),
+ rss_free=psutil.virtual_memory().available,
client_manager=self._client_manager.get_status(),
object_manager=self._object_manager.get_status(),
task_manager=self._task_manager.get_status(),
diff --git a/scaler/scheduler/task_manager.py b/scaler/scheduler/task_manager.py
index 6998382..30f1f1b 100644
--- a/scaler/scheduler/task_manager.py
+++ b/scaler/scheduler/task_manager.py
@@ -3,7 +3,8 @@
from scaler.io.async_binder import AsyncBinder
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import StateTask, Task, TaskCancel, TaskResult, TaskStatus
+from scaler.protocol.python.common import TaskStatus
+from scaler.protocol.python.message import StateTask, Task, TaskCancel, TaskResult
from scaler.protocol.python.status import TaskManagerStatus
from scaler.scheduler.graph_manager import VanillaGraphTaskManager
from scaler.scheduler.mixins import ClientManager, ObjectManager, TaskManager, WorkerManager
@@ -24,7 +25,7 @@ def __init__(self, max_number_of_tasks_waiting: int):
self._task_id_to_task: Dict[bytes, Task] = dict()
- self._unassigned: AsyncIndexedQueue[bytes] = AsyncIndexedQueue()
+ self._unassigned: AsyncIndexedQueue = AsyncIndexedQueue()
self._running: Set[bytes] = set()
self._success_count: int = 0
@@ -66,7 +67,7 @@ async def routine(self):
)
def get_status(self) -> TaskManagerStatus:
- return TaskManagerStatus(
+ return TaskManagerStatus.new_msg(
unassigned=self._unassigned.qsize(),
running=len(self._running),
success=self._success_count,
@@ -80,7 +81,7 @@ async def on_task_new(self, client: bytes, task: Task):
0 <= self._max_number_of_tasks_waiting <= self._unassigned.qsize()
and not self._worker_manager.has_available_worker()
):
- await self._binder.send(client, TaskResult(task.task_id, TaskStatus.NoWorker))
+ await self._binder.send(client, TaskResult.new_msg(task.task_id, TaskStatus.NoWorker))
return
self._client_manager.on_task_begin(client, task.task_id)
@@ -96,11 +97,11 @@ async def on_task_new(self, client: bytes, task: Task):
async def on_task_cancel(self, client: bytes, task_cancel: TaskCancel):
if task_cancel.task_id not in self._task_id_to_task:
logging.warning(f"cannot cancel, task not found: task_id={task_cancel.task_id.hex()}")
- await self.on_task_done(TaskResult(task_cancel.task_id, TaskStatus.NotFound))
+ await self.on_task_done(TaskResult.new_msg(task_cancel.task_id, TaskStatus.NotFound))
return
if task_cancel.task_id in self._unassigned:
- await self.on_task_done(TaskResult(task_cancel.task_id, TaskStatus.Canceled))
+ await self.on_task_done(TaskResult.new_msg(task_cancel.task_id, TaskStatus.Canceled))
return
await self._worker_manager.on_task_cancel(task_cancel)
@@ -162,4 +163,4 @@ async def __send_monitor(
self, task_id: bytes, function_name: bytes, status: TaskStatus, metadata: Optional[bytes] = b""
):
worker = self._worker_manager.get_worker_by_task_id(task_id)
- await self._binder_monitor.send(StateTask(task_id, function_name, status, worker, metadata))
+ await self._binder_monitor.send(StateTask.new_msg(task_id, function_name, status, worker, metadata))
diff --git a/scaler/scheduler/worker_manager.py b/scaler/scheduler/worker_manager.py
index 3581a85..751bc4e 100644
--- a/scaler/scheduler/worker_manager.py
+++ b/scaler/scheduler/worker_manager.py
@@ -4,18 +4,16 @@
from scaler.io.async_binder import AsyncBinder
from scaler.io.async_connector import AsyncConnector
+from scaler.protocol.python.common import TaskStatus
from scaler.protocol.python.message import (
ClientDisconnect,
DisconnectRequest,
DisconnectResponse,
- DisconnectType,
StateBalanceAdvice,
StateWorker,
Task,
TaskCancel,
- TaskCancelFlags,
TaskResult,
- TaskStatus,
WorkerHeartbeat,
WorkerHeartbeatEcho,
)
@@ -44,7 +42,7 @@ def __init__(
self._worker_alive_since: Dict[bytes, Tuple[float, WorkerHeartbeat]] = dict()
self._allocator = QueuedAllocator(per_worker_queue_size)
- self._last_balance_advice = None
+ self._last_balance_advice: Dict[bytes, List[bytes]] = dict()
self._load_balance_advice_same_count = 0
def register(self, binder: AsyncBinder, binder_monitor: AsyncConnector, task_manager: TaskManager):
@@ -67,7 +65,7 @@ async def on_task_cancel(self, task_cancel: TaskCancel):
logging.error(f"cannot find task_id={task_cancel.task_id.hex()} in task workers")
return
- await self._binder.send(worker, TaskCancel(task_cancel.task_id))
+ await self._binder.send(worker, TaskCancel.new_msg(task_cancel.task_id))
async def on_task_result(self, task_result: TaskResult):
worker = self._allocator.remove_task(task_result.task_id)
@@ -93,11 +91,11 @@ async def on_task_result(self, task_result: TaskResult):
async def on_heartbeat(self, worker: bytes, info: WorkerHeartbeat):
if await self._allocator.add_worker(worker):
- logging.info(f"worker {worker} connected")
- await self._binder_monitor.send(StateWorker(worker, b"connected"))
+ logging.info(f"worker {worker!r} connected")
+ await self._binder_monitor.send(StateWorker.new_msg(worker, b"connected"))
self._worker_alive_since[worker] = (time.time(), info)
- await self._binder.send(worker, WorkerHeartbeatEcho())
+ await self._binder.send(worker, WorkerHeartbeatEcho.new_msg())
async def on_client_shutdown(self, client: bytes):
for worker in self._allocator.get_worker_ids():
@@ -105,7 +103,7 @@ async def on_client_shutdown(self, client: bytes):
async def on_disconnect(self, source: bytes, request: DisconnectRequest):
await self.__disconnect_worker(request.worker)
- await self._binder.send(source, DisconnectResponse(request.worker))
+ await self._binder.send(source, DisconnectResponse.new_msg(request.worker))
async def routine(self):
await self.__balance_request()
@@ -113,49 +111,46 @@ async def routine(self):
def get_status(self) -> WorkerManagerStatus:
worker_to_task_numbers = self._allocator.statistics()
- return WorkerManagerStatus(
+ return WorkerManagerStatus.new_msg(
[
self.__worker_status_from_heartbeat(worker, worker_to_task_numbers[worker], last, info)
for worker, (last, info) in self._worker_alive_since.items()
]
)
+ @staticmethod
def __worker_status_from_heartbeat(
- self, worker: bytes, worker_task_numbers: Dict, last: float, info: WorkerHeartbeat
+ worker: bytes, worker_task_numbers: Dict, last: float, info: WorkerHeartbeat
) -> WorkerStatus:
current_processor = next((p for p in info.processors if not p.suspended), None)
- n_suspended = sum(1 for p in info.processors if p.suspended)
+ suspended = len([p for p in info.processors if p.suspended])
if current_processor:
- ITL = f"{int(current_processor.initialized)}{int(current_processor.has_task)}{int(info.task_lock)}"
+ debug_info = f"{int(current_processor.initialized)}{int(current_processor.has_task)}{int(info.task_lock)}"
else:
- ITL = f"00{int(info.task_lock)}"
-
- processor_statuses = [
- ProcessorStatus(
- p.pid,
- p.initialized,
- p.has_task,
- p.suspended,
- Resource(p.cpu, p.rss, info.rss_free),
- )
- for p in info.processors
- ]
+ debug_info = f"00{int(info.task_lock)}"
- return WorkerStatus(
+ return WorkerStatus.new_msg(
worker_id=worker,
- agent=Resource(info.agent_cpu, info.agent_rss, info.rss_free),
- total_processors=Resource(
- sum(p.cpu for p in info.processors), sum(p.rss for p in info.processors), info.rss_free
- ),
+ agent=info.agent,
+ rss_free=info.rss_free,
free=worker_task_numbers["free"],
sent=worker_task_numbers["sent"],
queued=info.queued_tasks,
- suspended=n_suspended,
+ suspended=suspended,
lag_us=info.latency_us,
last_s=int(time.time() - last),
- ITL=ITL,
- processor_statuses=processor_statuses,
+ itl=debug_info,
+ processor_statuses=[
+ ProcessorStatus.new_msg(
+ pid=p.pid,
+ initialized=p.initialized,
+ has_task=p.has_task,
+ suspended=p.suspended,
+ resource=Resource.new_msg(p.resource.cpu, p.resource.rss),
+ )
+ for p in info.processors
+ ],
)
def has_available_worker(self) -> bool:
@@ -187,16 +182,17 @@ async def __do_balance(self, current_advice: Dict[bytes, List[bytes]]):
if not current_advice:
return
- logging.info(f"balance: {current_advice}")
- for worker, tasks in current_advice.items():
- await self._binder_monitor.send(StateBalanceAdvice(worker, tasks))
+ worker_to_num_tasks = {worker: len(task_ids) for worker, task_ids in current_advice.items()}
+ logging.info(f"balancing task: {worker_to_num_tasks}")
+ for worker, task_ids in current_advice.items():
+ await self._binder_monitor.send(StateBalanceAdvice.new_msg(worker, task_ids))
- task_cancel_flags = TaskCancelFlags(force=True)
+ task_cancel_flags = TaskCancel.TaskCancelFlags(force=True, retrieve_task_object=False)
self._last_balance_advice = current_advice
- for worker, tasks in current_advice.items():
- for task in tasks:
- await self._binder.send(worker, TaskCancel(task, flags=task_cancel_flags))
+ for worker, task_ids in current_advice.items():
+ for task_id in task_ids:
+ await self._binder.send(worker, TaskCancel.new_msg(task_id=task_id, flags=task_cancel_flags))
async def __clean_workers(self):
now = time.time()
@@ -217,8 +213,8 @@ async def __disconnect_worker(self, worker: bytes):
if worker not in self._worker_alive_since:
return
- logging.info(f"worker {worker} disconnected")
- await self._binder_monitor.send(StateWorker(worker, b"disconnected"))
+ logging.info(f"worker {worker!r} disconnected")
+ await self._binder_monitor.send(StateWorker.new_msg(worker, b"disconnected"))
self._worker_alive_since.pop(worker)
task_ids = self._allocator.remove_worker(worker)
@@ -229,5 +225,5 @@ async def __disconnect_worker(self, worker: bytes):
await self.__reroute_tasks(task_ids)
async def __shutdown_worker(self, worker: bytes):
- await self._binder.send(worker, ClientDisconnect(DisconnectType.Shutdown))
+ await self._binder.send(worker, ClientDisconnect.new_msg(ClientDisconnect.DisconnectType.Shutdown))
await self.__disconnect_worker(worker)
diff --git a/scaler/ui/live_display.py b/scaler/ui/live_display.py
index 9882da3..d4b7f9e 100644
--- a/scaler/ui/live_display.py
+++ b/scaler/ui/live_display.py
@@ -40,9 +40,9 @@ def delete_section(self):
@dataclasses.dataclass
class WorkerRow:
worker: str = dataclasses.field(default="")
- agt_cpu: int = dataclasses.field(default=0)
+ agt_cpu: float = dataclasses.field(default=0)
agt_rss: int = dataclasses.field(default=0)
- cpu: int = dataclasses.field(default=0)
+ cpu: float = dataclasses.field(default=0)
rss: int = dataclasses.field(default=0)
rss_free: int = dataclasses.field(default=0)
free: int = dataclasses.field(default=0)
@@ -50,24 +50,24 @@ class WorkerRow:
queued: int = dataclasses.field(default=0)
suspended: int = dataclasses.field(default=0)
lag: str = dataclasses.field(default="")
- ITL: str = dataclasses.field(default="")
+ itl: str = dataclasses.field(default="")
last_seen: str = dataclasses.field(default="")
handlers: List[Element] = dataclasses.field(default_factory=list)
def populate(self, data: WorkerStatus):
self.worker = data.worker_id.decode()
- self.agt_cpu = int(data.agent.cpu * 100)
+ self.agt_cpu = data.agent.cpu / 10
self.agt_rss = int(data.agent.rss / 1e6)
- self.cpu = int(data.total_processors.cpu * 100)
- self.rss = int(data.total_processors.rss / 1e6)
- self.rss_free = int(data.total_processors.rss_free / 1e6)
+ self.cpu = sum(p.resource.cpu for p in data.processor_statuses) / 10
+ self.rss = int(sum(p.resource.rss for p in data.processor_statuses) / 1e6)
+ self.rss_free = int(data.rss_free / 1e6)
self.free = data.free
self.sent = data.sent
self.queued = data.queued
self.suspended = data.suspended
self.lag = format_microseconds(data.lag_us)
- self.ITL = data.ITL
+ self.itl = data.itl
self.last_seen = format_seconds(data.last_s)
def draw_row(self):
@@ -102,7 +102,8 @@ def draw_section(self):
for worker_row in self.workers.values():
worker_row.draw_row()
- def __draw_titles(self):
+ @staticmethod
+ def __draw_titles():
ui.label("Worker")
ui.label("Agt CPU %")
ui.label("Agt RSS (in MB)")
diff --git a/scaler/ui/task_graph.py b/scaler/ui/task_graph.py
index 349cad3..e08aaf7 100644
--- a/scaler/ui/task_graph.py
+++ b/scaler/ui/task_graph.py
@@ -5,7 +5,8 @@
from nicegui import ui
-from scaler.protocol.python.message import StateTask, TaskStatus
+from scaler.protocol.python.common import TaskStatus
+from scaler.protocol.python.message import StateTask
from scaler.ui.live_display import WorkersSection
from scaler.ui.setting_page import Settings
from scaler.ui.utility import format_timediff, format_worker_name, get_bounds, make_tick_text, make_ticks
@@ -51,10 +52,10 @@ def __init__(self):
self._task_id_to_worker: Dict[bytes, str] = {}
self._seen_workers = set()
- self._lost_workers_queue = SimpleQueue()
+ self._lost_workers_queue: SimpleQueue[Tuple[datetime.datetime, str]] = SimpleQueue()
self._data_update_lock = Lock()
- self._busy_workers: List[str] = []
+ self._busy_workers: Set[str] = set()
self._busy_workers_update_time: datetime.datetime = datetime.datetime.now()
def setup_task_stream(self, settings: Settings):
@@ -224,14 +225,15 @@ def handle_task_state(self, state: StateTask):
if not (worker := state.worker):
return
- worker = worker.decode()
- self._worker_last_update[worker] = now
- if worker not in self._seen_workers:
- self.__handle_new_worker(worker, now)
+ worker_string = worker.decode()
+ self._worker_last_update[worker_string] = now
+
+ if worker_string not in self._seen_workers:
+ self.__handle_new_worker(worker_string, now)
if task_status in {TaskStatus.Running}:
- self.__handle_running_task(state, worker, now)
+ self.__handle_running_task(state, worker_string, now)
def __add_lost_worker(self, worker: str, now: datetime.datetime):
self._lost_workers_queue.put((now, worker))
@@ -273,7 +275,7 @@ def __split_workers_by_status(self, now: datetime.datetime) -> List[Tuple[str, f
def update_data(self, workers_section: WorkersSection):
now = datetime.datetime.now()
worker_names = sorted(workers_section.workers.keys())
- itls = {w: workers_section.workers[w].ITL for w in worker_names}
+ itls = {w: workers_section.workers[w].itl for w in worker_names}
busy_workers = {w for w in worker_names if len(itls[w]) == 3 and itls[w][1] == "1" and itls[w][2] == "1"}
for worker in worker_names:
self._worker_last_update[worker] = now
diff --git a/scaler/ui/task_log.py b/scaler/ui/task_log.py
index 033ac8d..fb444f4 100644
--- a/scaler/ui/task_log.py
+++ b/scaler/ui/task_log.py
@@ -5,7 +5,8 @@
from nicegui import ui
-from scaler.protocol.python.message import StateTask, TaskStatus
+from scaler.protocol.python.common import TaskStatus
+from scaler.protocol.python.message import StateTask
from scaler.utility.formatter import format_bytes
from scaler.utility.metadata.profile_result import ProfileResult
diff --git a/scaler/ui/webui.py b/scaler/ui/webui.py
index f7fc310..e60a136 100644
--- a/scaler/ui/webui.py
+++ b/scaler/ui/webui.py
@@ -5,7 +5,8 @@
from nicegui import ui
from scaler.io.sync_subscriber import SyncSubscriber
-from scaler.protocol.python.message import MessageVariant, StateScheduler, StateTask
+from scaler.protocol.python.message import StateScheduler, StateTask
+from scaler.protocol.python.mixins import Message
from scaler.ui.live_display import SchedulerSection, WorkersSection
from scaler.ui.memory_window import MemoryChart
from scaler.ui.setting_page import Settings
@@ -80,7 +81,7 @@ def start_webui(address: str, host: str, port: int):
ui_thread.start()
-def __show_status(status: MessageVariant, tables: Sections):
+def __show_status(status: Message, tables: Sections):
if isinstance(status, StateScheduler):
__update_scheduler_state(status, tables)
return
@@ -95,7 +96,7 @@ def __show_status(status: MessageVariant, tables: Sections):
def __update_scheduler_state(data: StateScheduler, tables: Sections):
tables.scheduler_section.cpu = format_percentage(data.scheduler.cpu)
tables.scheduler_section.rss = format_bytes(data.scheduler.rss)
- tables.scheduler_section.rss_free = format_bytes(data.scheduler.rss_free)
+ tables.scheduler_section.rss_free = format_bytes(data.rss_free)
previous_workers = set(tables.workers_section.workers.keys())
current_workers = set(worker_data.worker_id.decode() for worker_data in data.worker_manager.workers)
diff --git a/scaler/ui/worker_processors.py b/scaler/ui/worker_processors.py
index 6c78f79..0986578 100644
--- a/scaler/ui/worker_processors.py
+++ b/scaler/ui/worker_processors.py
@@ -31,7 +31,7 @@ def update_data(self, data: List[WorkerStatus]):
processor_table = self.workers.get(worker.worker_id)
if processor_table is None:
- processor_table = WorkerProcessorTable(worker_name, worker.processor_statuses)
+ processor_table = WorkerProcessorTable(worker_name, worker.rss_free, worker.processor_statuses)
processor_table.draw_table()
self.workers[worker.worker_id] = processor_table
elif processor_table.processor_statuses != worker.processor_statuses:
@@ -46,6 +46,7 @@ def update_data(self, data: List[WorkerStatus]):
@dataclasses.dataclass
class WorkerProcessorTable:
worker_name: str
+ rss_free: int
processor_statuses: List[ProcessorStatus]
handler: Optional[Element] = dataclasses.field(default=None)
@@ -61,7 +62,7 @@ def draw_table(self):
with ui.grid(columns=6).classes("w-full"):
self.draw_titles()
for processor in sorted(self.processor_statuses, key=lambda x: x.pid):
- self.draw_row(processor)
+ self.draw_row(processor, self.rss_free)
@staticmethod
def draw_titles():
@@ -73,10 +74,10 @@ def draw_titles():
ui.label("Suspended")
@staticmethod
- def draw_row(processor_status: ProcessorStatus):
- cpu = int(processor_status.resource.cpu * 100)
+ def draw_row(processor_status: ProcessorStatus, rss_free: int):
+ cpu = processor_status.resource.cpu
rss = int(processor_status.resource.rss / 1e6)
- rss_free = int(processor_status.resource.rss_free / 1e6)
+ rss_free = int(rss_free / 1e6)
ui.label(processor_status.pid)
ui.knob(value=cpu, track_color="grey-2", show_value=True, min=0, max=100)
diff --git a/scaler/utility/event_list.py b/scaler/utility/event_list.py
index 85bec24..49e80c6 100644
--- a/scaler/utility/event_list.py
+++ b/scaler/utility/event_list.py
@@ -1,4 +1,3 @@
-
import collections
from typing import Callable
@@ -21,6 +20,10 @@ def __delitem__(self, i):
super().__delitem__(i)
self._list_updated()
+ def __add__(self, other):
+ super().__add__(other)
+ self._list_updated()
+
def __iadd__(self, other):
super().__iadd__(other)
self._list_updated()
diff --git a/scaler/utility/event_loop.py b/scaler/utility/event_loop.py
index 3ab35f6..cdd704e 100644
--- a/scaler/utility/event_loop.py
+++ b/scaler/utility/event_loop.py
@@ -17,8 +17,8 @@ def register_event_loop(event_loop_type: str):
if event_loop_type not in EventLoopType.allowed_types():
raise TypeError(f"allowed event loop types are: {EventLoopType.allowed_types()}")
- event_loop_type = EventLoopType[event_loop_type]
- if event_loop_type == EventLoopType.uvloop:
+ event_loop_type_enum = EventLoopType[event_loop_type]
+ if event_loop_type_enum == EventLoopType.uvloop:
try:
import uvloop # noqa
except ImportError:
@@ -26,12 +26,14 @@ def register_event_loop(event_loop_type: str):
uvloop.install()
- logging.info(f"use event loop: {event_loop_type.value}")
+ assert event_loop_type in EventLoopType.allowed_types()
+
+ logging.info(f"use event loop: {event_loop_type}")
def create_async_loop_routine(routine: Callable[[], Awaitable], seconds: int):
async def loop():
- logging.info(f"{routine.__self__.__class__.__name__}: started")
+ logging.info(f"{routine.__self__.__class__.__name__}: started") # type: ignore[attr-defined]
try:
while True:
await routine()
@@ -41,6 +43,6 @@ async def loop():
except KeyboardInterrupt:
pass
- logging.info(f"{routine.__self__.__class__.__name__}: exited")
+ logging.info(f"{routine.__self__.__class__.__name__}: exited") # type: ignore[attr-defined]
return loop()
diff --git a/scaler/utility/formatter.py b/scaler/utility/formatter.py
index 3fb7cbd..e02b4c6 100644
--- a/scaler/utility/formatter.py
+++ b/scaler/utility/formatter.py
@@ -3,29 +3,31 @@
def format_bytes(number) -> str:
- for unit in ["b", "k", "m", "g", "t"]:
+ for unit in ["B", "K", "M", "G", "T"]:
if number >= STORAGE_SIZE_MODULUS:
number /= STORAGE_SIZE_MODULUS
continue
- if unit in {"b", "k"}:
+ if unit in {"B", "K"}:
return f"{int(number)}{unit}"
return f"{number:.1f}{unit}"
+ raise ValueError("This should not happen")
+
def format_integer(number):
return f"{number:,}"
-def format_percentage(number: float):
- return f"{number:.1%}"
+def format_percentage(number: int):
+ return f"{(number/1000):.1%}"
def format_microseconds(number: int):
for unit in ["us", "ms", "s"]:
if number >= TIME_MODULUS:
- number /= TIME_MODULUS
+ number = int(number / TIME_MODULUS)
continue
if unit == "us":
diff --git a/scaler/utility/graph/topological_sorter.py b/scaler/utility/graph/topological_sorter.py
index c6c6c47..f85fbef 100644
--- a/scaler/utility/graph/topological_sorter.py
+++ b/scaler/utility/graph/topological_sorter.py
@@ -6,6 +6,6 @@
logging.info("using GraphBLAS for calculate graph")
except ImportError as e:
assert isinstance(e, Exception)
- from graphlib import TopologicalSorter
+ from graphlib import TopologicalSorter # type: ignore[assignment, no-redef]
assert isinstance(TopologicalSorter, object)
diff --git a/scaler/utility/graph/topological_sorter_graphblas.py b/scaler/utility/graph/topological_sorter_graphblas.py
index 2429d32..717608f 100644
--- a/scaler/utility/graph/topological_sorter_graphblas.py
+++ b/scaler/utility/graph/topological_sorter_graphblas.py
@@ -1,31 +1,33 @@
import collections
import graphlib
import itertools
-from typing import Dict, Hashable, Iterable, List, Optional, Tuple
+from typing import Hashable, Iterable, List, Optional, Tuple, TypeVar, Generic, Mapping
from bidict import bidict
try:
import graphblas as gb
- import numpy as np
+ import numpy as np # noqa
except ImportError:
raise ImportError("Please use 'pip install python-graphblas' to have graph blas support")
+GraphKeyType = TypeVar("GraphKeyType", bound=Hashable)
-class TopologicalSorter:
+
+class TopologicalSorter(Generic[GraphKeyType]):
"""
Implements graphlib's TopologicalSorter, but the graph handling is backed by GraphBLAS
Reference: https://github.com/python/cpython/blob/4a3ea1fdd890e5e2ec26540dc3c958a52fba6556/Lib/graphlib.py
"""
- def __init__(self, graph: Optional[Dict[Hashable, Iterable[Hashable]]] = None):
+ def __init__(self, graph: Optional[Mapping[GraphKeyType, Iterable[GraphKeyType]]] = None):
# the layout of the matrix is (in-vertex, out-vertex)
self._matrix = gb.Matrix(gb.dtypes.BOOL)
- self._key_to_id = bidict()
+ self._key_to_id: bidict[GraphKeyType, int] = bidict()
- self._graph_matrix_mask = None
- self._visited_vertices_mask = None
- self._ready_nodes = None
+ self._graph_matrix_mask: Optional[np.ndarray] = None
+ self._visited_vertices_mask: Optional[np.ndarray] = None
+ self._ready_nodes: Optional[List[GraphKeyType]] = None
self._n_done = 0
self._n_visited = 0
@@ -33,10 +35,10 @@ def __init__(self, graph: Optional[Dict[Hashable, Iterable[Hashable]]] = None):
if graph is not None:
self.merge_graph(graph)
- def add(self, node: Hashable, *predecessors: Hashable) -> None:
+ def add(self, node: GraphKeyType, *predecessors: GraphKeyType) -> None:
self.merge_graph({node: predecessors})
- def merge_graph(self, graph: Dict[Hashable, Iterable[Hashable]]) -> None:
+ def merge_graph(self, graph: Mapping[GraphKeyType, Iterable[GraphKeyType]]) -> None:
if self._ready_nodes is not None:
raise ValueError("nodes cannot be added after a call to prepare()")
@@ -86,7 +88,7 @@ def prepare(self) -> None:
if self._has_cycle():
raise graphlib.CycleError("cycle detected")
- def get_ready(self) -> Tuple[Hashable, ...]:
+ def get_ready(self) -> Tuple[GraphKeyType, ...]:
if self._ready_nodes is None:
raise ValueError("prepare() must be called first")
@@ -102,7 +104,7 @@ def is_active(self) -> bool:
def __bool__(self) -> bool:
return self.is_active()
- def done(self, *nodes: Hashable) -> None:
+ def done(self, *nodes: GraphKeyType) -> None:
if self._ready_nodes is None:
raise ValueError("prepare() must be called first")
@@ -127,7 +129,7 @@ def done(self, *nodes: Hashable) -> None:
self._ready_nodes.extend(new_ready_nodes)
self._n_visited += len(new_ready_nodes)
- def static_order(self) -> Iterable[Hashable]:
+ def static_order(self) -> Iterable[GraphKeyType]:
self.prepare()
while self.is_active():
node_group = self.get_ready()
@@ -150,7 +152,7 @@ def _has_cycle(self) -> bool:
return True
return False
- def _get_zero_degree_keys(self) -> List[Hashable]:
+ def _get_zero_degree_keys(self) -> List[GraphKeyType]:
ids = self._get_mask_diff(self._visited_vertices_mask, self._get_zero_degree_mask(self._get_masked_matrix()))
return [self._key_to_id.inverse[_id] for _id in ids]
@@ -165,7 +167,7 @@ def _get_masked_matrix(self) -> gb.Matrix:
def _get_zero_degree_mask(cls, masked_matrix: gb.Matrix) -> np.ndarray:
degrees = masked_matrix.reduce_rowwise(gb.monoid.lor)
indices, _ = degrees.to_coo(indices=True, values=False, sort=False)
- return np.logical_not(np.in1d(np.arange(masked_matrix.nrows), indices))
+ return np.logical_not(np.in1d(np.arange(masked_matrix.nrows), indices)) # type: ignore[attr-defined]
@staticmethod
def _get_mask_diff(old_mask: np.ndarray, new_mask: np.ndarray) -> List[int]:
diff --git a/scaler/utility/logging/scoped_logger.py b/scaler/utility/logging/scoped_logger.py
index 8beaa3d..6f6da30 100644
--- a/scaler/utility/logging/scoped_logger.py
+++ b/scaler/utility/logging/scoped_logger.py
@@ -1,6 +1,7 @@
import datetime
import logging
import time
+from typing import Optional
class ScopedLogger:
@@ -18,7 +19,7 @@ class TimedLogger:
def __init__(self, message: str, logging_level=logging.INFO):
self.message = message
self.logging_level = logging_level
- self.timer = None
+ self.timer: Optional[int] = None
def begin(self):
self.timer = time.perf_counter_ns()
diff --git a/scaler/utility/many_to_many_dict.py b/scaler/utility/many_to_many_dict.py
index cca3acc..f162629 100644
--- a/scaler/utility/many_to_many_dict.py
+++ b/scaler/utility/many_to_many_dict.py
@@ -9,8 +9,8 @@
class ManyToManyDict(Generic[LeftKeyT, RightKeyT]):
def __init__(self):
- self._left_key_to_right_key_set: _KeyValueDictSet[LeftKeyT, Set[RightKeyT]] = _KeyValueDictSet()
- self._right_key_to_left_key_set: _KeyValueDictSet[RightKeyT, Set[LeftKeyT]] = _KeyValueDictSet()
+ self._left_key_to_right_key_set: _KeyValueDictSet[LeftKeyT, RightKeyT] = _KeyValueDictSet()
+ self._right_key_to_left_key_set: _KeyValueDictSet[RightKeyT, LeftKeyT] = _KeyValueDictSet()
def left_keys(self):
return self._left_key_to_right_key_set.keys()
diff --git a/scaler/utility/queues/async_priority_queue.py b/scaler/utility/queues/async_priority_queue.py
index 98943fe..f02bf57 100644
--- a/scaler/utility/queues/async_priority_queue.py
+++ b/scaler/utility/queues/async_priority_queue.py
@@ -38,7 +38,7 @@ def remove(self, data):
item = self._locator.pop(data)
i = self._queue.index(item) # O(n)
item[0] = self.__to_lowest_priority(item[0])
- heapq._siftdown(self._queue, 0, i) # noqa
+ heapq._siftdown(self._queue, 0, i) # type: ignore[attr-defined]
assert heapq.heappop(self._queue) == item
def decrease_priority(self, data):
@@ -47,7 +47,7 @@ def decrease_priority(self, data):
item = self._locator[data]
i = self._queue.index(item) # O(n)
item[0] = self.__to_lower_priority(item[0])
- heapq._siftdown(self._queue, 0, i) # noqa
+ heapq._siftdown(self._queue, 0, i) # type: ignore[attr-defined]
def max_priority(self):
item = heapq.heappop(self._queue)
diff --git a/scaler/utility/zmq_config.py b/scaler/utility/zmq_config.py
index b4198b0..5e5b00f 100644
--- a/scaler/utility/zmq_config.py
+++ b/scaler/utility/zmq_config.py
@@ -54,15 +54,17 @@ def from_string(string: str) -> "ZMQConfig":
if socket_type not in ZMQType.allowed_types():
raise ValueError(f"supported ZMQ types are: {ZMQType.allowed_types()}")
- socket_type = ZMQType(socket_type)
- if socket_type in {ZMQType.inproc, ZMQType.ipc}:
+ socket_type_enum = ZMQType(socket_type)
+ if socket_type_enum in {ZMQType.inproc, ZMQType.ipc}:
host = host_port
- port = None
- else:
+ port_int = None
+ elif socket_type_enum == ZMQType.tcp:
host, port = host_port.split(":")
try:
- port = int(port)
+ port_int = int(port)
except ValueError:
raise ValueError(f"cannot convert '{port}' to port number")
+ else:
+ raise ValueError(f"Unsupported ZMQ type: {socket_type}")
- return ZMQConfig(ZMQType(socket_type), host, port)
+ return ZMQConfig(socket_type_enum, host, port_int)
diff --git a/scaler/worker/agent/heartbeat_manager.py b/scaler/worker/agent/heartbeat_manager.py
index dd8c063..04823f3 100644
--- a/scaler/worker/agent/heartbeat_manager.py
+++ b/scaler/worker/agent/heartbeat_manager.py
@@ -4,7 +4,8 @@
import psutil
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import ProcessorHeartbeat, WorkerHeartbeat, WorkerHeartbeatEcho
+from scaler.protocol.python.message import WorkerHeartbeat, WorkerHeartbeatEcho, Resource
+from scaler.protocol.python.status import ProcessorStatus
from scaler.utility.mixins import Looper
from scaler.worker.agent.mixins import HeartbeatManager, ProcessorManager, TaskManager, TimeoutManager
from scaler.worker.agent.processor_holder import ProcessorHolder
@@ -59,14 +60,13 @@ async def routine(self):
await self._processor_manager.on_failing_task(self._worker_process.status())
processors = self._processor_manager.processors()
- n_suspended_processors = self._processor_manager.n_suspended_processors()
+ num_suspended_processors = self._processor_manager.num_suspended_processors()
await self._connector_external.send(
- WorkerHeartbeat(
- self._agent_process.cpu_percent() / 100,
- self._agent_process.memory_info().rss,
+ WorkerHeartbeat.new_msg(
+ Resource.new_msg(int(self._agent_process.cpu_percent() * 10), self._agent_process.memory_info().rss),
psutil.virtual_memory().available,
- self._worker_task_manager.get_queued_size() - n_suspended_processors,
+ self._worker_task_manager.get_queued_size() - num_suspended_processors,
self._latency_us,
self._processor_manager.task_lock(),
[self.__get_processor_status_from_holder(processor) for processor in processors],
@@ -75,13 +75,12 @@ async def routine(self):
self._start_timestamp_ns = time.time_ns()
@staticmethod
- def __get_processor_status_from_holder(processor: ProcessorHolder) -> ProcessorHeartbeat:
+ def __get_processor_status_from_holder(processor: ProcessorHolder) -> ProcessorStatus:
process = processor.process()
- return ProcessorHeartbeat(
+ return ProcessorStatus.new_msg(
processor.pid(),
processor.initialized(),
processor.task() is not None,
processor.suspended(),
- process.cpu_percent() / 100,
- process.memory_info().rss,
+ Resource.new_msg(int(process.cpu_percent() * 10), process.memory_info().rss),
)
diff --git a/scaler/worker/agent/mixins.py b/scaler/worker/agent/mixins.py
index fe81d15..6ad3f53 100644
--- a/scaler/worker/agent/mixins.py
+++ b/scaler/worker/agent/mixins.py
@@ -2,7 +2,6 @@
from typing import Dict, List, Optional, Set
from scaler.protocol.python.message import (
- ObjectContent,
ObjectInstruction,
ObjectRequest,
ObjectResponse,
@@ -67,7 +66,7 @@ async def on_task(self, task: Task) -> bool:
raise NotImplementedError()
@abc.abstractmethod
- async def on_cancel_task(self, task_id: bytes) -> bool:
+ def on_cancel_task(self, task_id: bytes) -> Optional[Task]:
raise NotImplementedError()
@abc.abstractmethod
@@ -107,7 +106,7 @@ def processors(self) -> List[ProcessorHolder]:
raise NotImplementedError()
@abc.abstractmethod
- def n_suspended_processors(self) -> int:
+ def num_suspended_processors(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
@@ -139,7 +138,7 @@ def on_object_request(self, processor_id: bytes, object_request: ObjectRequest)
raise NotImplementedError()
@abc.abstractmethod
- def on_object_response(self, object_response: ObjectContent) -> Set[bytes]:
+ def on_object_response(self, object_response: ObjectResponse) -> Set[bytes]:
raise NotImplementedError()
@abc.abstractmethod
diff --git a/scaler/worker/agent/object_tracker.py b/scaler/worker/agent/object_tracker.py
index d721ef6..13a28c8 100644
--- a/scaler/worker/agent/object_tracker.py
+++ b/scaler/worker/agent/object_tracker.py
@@ -1,15 +1,10 @@
+from collections import defaultdict
from typing import Dict, Set, Tuple
-from scaler.protocol.python.message import (
- ObjectContent,
- ObjectInstruction,
- ObjectInstructionType,
- ObjectRequest,
- ObjectResponse,
- ObjectResponseType,
-)
-from scaler.worker.agent.mixins import ObjectTracker
+from scaler.protocol.python.common import ObjectContent
+from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse
from scaler.utility.many_to_many_dict import ManyToManyDict
+from scaler.worker.agent.mixins import ObjectTracker
class VanillaObjectTracker(ObjectTracker):
@@ -23,8 +18,8 @@ def on_object_request(self, processor_id: bytes, object_request: ObjectRequest)
def on_object_response(self, object_response: ObjectResponse) -> Set[bytes]:
"""Returns a list of processor ids that requested this object content."""
- if object_response.type != ObjectResponseType.Content:
- raise TypeError(f"invalid object response type received: {object_response.type}.")
+ if object_response.response_type != ObjectResponse.ObjectResponseType.Content:
+ raise TypeError(f"invalid object response type received: {object_response.response_type}.")
object_ids = object_response.object_content.object_ids
@@ -45,26 +40,23 @@ def on_object_instruction(self, object_instruction: ObjectInstruction) -> Dict[b
forwarded to processors.
"""
- if object_instruction.type != ObjectInstructionType.Delete:
- raise TypeError(f"invalid object instruction type received: {object_instruction.type}.")
+ if object_instruction.instruction_type != ObjectInstruction.ObjectInstructionType.Delete:
+ raise TypeError(f"invalid object instruction type received: {object_instruction.instruction_type}.")
- per_processor_object_ids = {}
+ per_processor_object_ids: Dict[bytes, Set[bytes]] = defaultdict(set)
for object_id in object_instruction.object_content.object_ids:
if not self._object_id_to_processors_ids.has_left_key(object_id):
continue
processor_ids = self._object_id_to_processors_ids.remove_left_key(object_id)
for processor_id in processor_ids:
- if processor_id not in per_processor_object_ids:
- per_processor_object_ids[processor_id] = []
-
- per_processor_object_ids[processor_id].append(object_id)
+ per_processor_object_ids[processor_id].add(object_id)
return {
- processor_id: ObjectInstruction(
- type=ObjectInstructionType.Delete,
+ processor_id: ObjectInstruction.new_msg(
+ instruction_type=ObjectInstruction.ObjectInstructionType.Delete,
object_user=object_instruction.object_user,
- object_content=ObjectContent(object_ids=object_ids),
+ object_content=ObjectContent.new_msg(object_ids=tuple(object_ids)),
)
for processor_id, object_ids in per_processor_object_ids.items()
}
diff --git a/scaler/worker/agent/processor/object_cache.py b/scaler/worker/agent/processor/object_cache.py
index b3d0b0a..ee8e42e 100644
--- a/scaler/worker/agent/processor/object_cache.py
+++ b/scaler/worker/agent/processor/object_cache.py
@@ -12,7 +12,8 @@
from scaler.client.serializer.mixins import Serializer
from scaler.io.config import CLEANUP_INTERVAL_SECONDS
-from scaler.protocol.python.message import ObjectContent, Task
+from scaler.protocol.python.common import ObjectContent
+from scaler.protocol.python.message import Task
from scaler.utility.exceptions import DeserializeObjectError
from scaler.utility.object_utility import generate_serializer_object_id, is_object_id_serializer
diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py
index 1de1251..a35152f 100644
--- a/scaler/worker/agent/processor/processor.py
+++ b/scaler/worker/agent/processor/processor.py
@@ -12,21 +12,16 @@
from scaler.io.config import DUMMY_CLIENT
from scaler.io.sync_connector import SyncConnector
+from scaler.protocol.python.common import TaskStatus, ObjectContent
from scaler.protocol.python.message import (
- ArgumentType,
- MessageVariant,
- ObjectContent,
ObjectInstruction,
- ObjectInstructionType,
ObjectRequest,
- ObjectRequestType,
ObjectResponse,
- ObjectResponseType,
ProcessorInitialized,
Task,
TaskResult,
- TaskStatus,
)
+from scaler.protocol.python.mixins import Message
from scaler.utility.exceptions import MissingObjects
from scaler.utility.logging.utility import setup_logger
from scaler.utility.object_utility import generate_object_id, generate_serializer_object_id, serialize_failure
@@ -56,7 +51,7 @@ def __init__(
self._logging_paths = logging_paths
self._logging_level = logging_level
- self._client_to_decorator = {}
+ # self._client_to_decorator = {}
self._object_cache: Optional[ObjectCache] = None
@@ -103,9 +98,12 @@ def __interrupt(self, *args):
def __run_forever(self):
try:
- self._connector.send(ProcessorInitialized())
+ self._connector.send(ProcessorInitialized.new_msg())
while True:
message = self._connector.receive()
+ if message is None:
+ continue
+
self.__on_connector_receive(message)
except zmq.error.ZMQError as e:
@@ -121,7 +119,7 @@ def __run_forever(self):
self._object_cache.join()
- def __on_connector_receive(self, message: MessageVariant):
+ def __on_connector_receive(self, message: Message):
if isinstance(message, ObjectInstruction):
self.__on_receive_object_instruction(message)
return
@@ -137,7 +135,7 @@ def __on_connector_receive(self, message: MessageVariant):
logging.error(f"unknown {message=}")
def __on_receive_object_instruction(self, instruction: ObjectInstruction):
- if instruction.type == ObjectInstructionType.Delete:
+ if instruction.instruction_type == ObjectInstruction.ObjectInstructionType.Delete:
for object_id in instruction.object_content.object_ids:
self._object_cache.del_object(object_id)
return
@@ -145,7 +143,7 @@ def __on_receive_object_instruction(self, instruction: ObjectInstruction):
logging.error(f"worker received unknown object instruction type {instruction=}")
def __on_receive_object_response(self, response: ObjectResponse):
- if response.type == ObjectResponseType.Content:
+ if response.response_type == ObjectResponse.ObjectResponseType.Content:
self.__on_receive_object_content(response.object_content)
return
@@ -179,7 +177,7 @@ def __on_received_task(self, task: Task):
self.__process_task(task)
return
- self._connector.send(ObjectRequest(ObjectRequestType.Get, unknown_object_ids))
+ self._connector.send(ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, unknown_object_ids))
def __get_not_ready_object_ids(self, task: Task) -> Tuple[bytes, ...]:
required_object_ids = self.__get_required_object_ids_for_task(task)
@@ -189,7 +187,9 @@ def __get_not_ready_object_ids(self, task: Task) -> Tuple[bytes, ...]:
def __get_required_object_ids_for_task(task: Task) -> List[bytes]:
serializer_id = generate_serializer_object_id(task.source)
object_ids = [serializer_id, task.func_object_id]
- object_ids.extend([argument.data for argument in task.function_args if argument.type == ArgumentType.ObjectID])
+ object_ids.extend(
+ [argument.data for argument in task.function_args if argument.type == Task.Argument.ArgumentType.ObjectID]
+ )
return object_ids
def __process_task(self, task: Task):
@@ -246,13 +246,15 @@ def __send_result(self, source: bytes, task_id: bytes, status: TaskStatus, resul
# clients
result_object_id = generate_object_id(source, uuid.uuid4().bytes)
self._connector.send(
- ObjectInstruction(
- ObjectInstructionType.Create,
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Create,
source,
- ObjectContent((result_object_id,), (f"".encode(),), (result_bytes,)),
+ ObjectContent.new_msg(
+ (result_object_id,), (f"".encode(),), (result_bytes,)
+ ),
)
)
- self._connector.send(TaskResult(task_id, status, results=[result_object_id]))
+ self._connector.send(TaskResult.new_msg(task_id, status, metadata=b"", results=[result_object_id]))
@staticmethod
def __set_current_processor(context: Optional["Processor"]) -> Token:
diff --git a/scaler/worker/agent/processor_holder.py b/scaler/worker/agent/processor_holder.py
index 59fe1fb..8cd29cd 100644
--- a/scaler/worker/agent/processor_holder.py
+++ b/scaler/worker/agent/processor_holder.py
@@ -91,7 +91,7 @@ def kill(self):
# TODO: some processors fail to interrupt because of a blocking 0mq call. Ideally we should interrupt
# these blocking calls instead of sending a SIGKILL signal.
- logging.warn(f"Processor[{self.pid()}] does not terminate in time, send SIGKILL.")
+ logging.warning(f"Processor[{self.pid()}] does not terminate in time, send SIGKILL.")
self.__send_signal(signal.SIGKILL)
self._processor.join()
diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py
index 0e49925..53c5c80 100644
--- a/scaler/worker/agent/processor_manager.py
+++ b/scaler/worker/agent/processor_manager.py
@@ -10,18 +10,16 @@
# from scaler.utility.logging.utility import setup_logger
from scaler.io.async_binder import AsyncBinder
from scaler.io.async_connector import AsyncConnector
+from scaler.protocol.python.common import TaskStatus, ObjectContent
from scaler.protocol.python.message import (
- MessageVariant,
- ObjectContent,
ObjectInstruction,
- ObjectInstructionType,
ObjectRequest,
ObjectResponse,
ProcessorInitialized,
Task,
TaskResult,
- TaskStatus,
)
+from scaler.protocol.python.mixins import Message
from scaler.utility.exceptions import ProcessorDiedError
from scaler.utility.metadata.profile_result import ProfileResult
from scaler.utility.mixins import Looper
@@ -97,11 +95,11 @@ async def on_object_instruction(self, instruction: ObjectInstruction):
for processor_id, instruction in processor_instructions.items():
await self._binder_internal.send(processor_id, instruction)
- async def on_object_response(self, request: ObjectResponse):
- processors_ids = self._object_tracker.on_object_response(request)
+ async def on_object_response(self, response: ObjectResponse):
+ processors_ids = self._object_tracker.on_object_response(response)
for process_id in processors_ids:
- await self._binder_internal.send(process_id, request)
+ await self._binder_internal.send(process_id, response)
async def acquire_task_active_lock(self):
await self._task_active_lock.acquire()
@@ -151,15 +149,15 @@ async def on_failing_task(self, process_status: str):
result_object_id = generate_object_id(source, uuid.uuid4().bytes)
await self._connector_external.send(
- ObjectInstruction(
- ObjectInstructionType.Create,
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Create,
source,
- ObjectContent((result_object_id,), (b"",), (result_object_bytes,)),
+ ObjectContent.new_msg((result_object_id,), (b"",), (result_object_bytes,)),
)
)
await self._task_manager.on_task_result(
- TaskResult(task_id, TaskStatus.Failed, profile_result.serialize(), [result_object_id])
+ TaskResult.new_msg(task_id, TaskStatus.Failed, profile_result.serialize(), [result_object_id])
)
self.restart_current_processor(f"process died {process_status=}")
@@ -235,7 +233,7 @@ def current_task_id(self) -> bytes:
def processors(self) -> List[ProcessorHolder]:
return list(self._holders_by_processor_id.values())
- def n_suspended_processors(self) -> int:
+ def num_suspended_processors(self) -> int:
return len(self._suspended_holders_by_task_id)
def task_lock(self) -> bool:
@@ -293,7 +291,7 @@ def __end_task(self, processor_holder: ProcessorHolder) -> ProfileResult:
return profile_result
- async def __on_receive_internal(self, processor_id: bytes, message: MessageVariant):
+ async def __on_receive_internal(self, processor_id: bytes, message: Message):
if isinstance(message, ProcessorInitialized):
await self.__on_internal_processor_initialized(processor_id)
return
@@ -355,9 +353,14 @@ async def __on_internal_task_result(self, processor_id: bytes, task_result: Task
else:
return
- task_result.metadata = profile_result.serialize()
-
- await self._task_manager.on_task_result(task_result)
+ await self._task_manager.on_task_result(
+ TaskResult.new_msg(
+ task_id=task_id,
+ status=task_result.status,
+ metadata=profile_result.serialize(),
+ results=task_result.results,
+ )
+ )
def __processor_ready_to_process_object(self, processor_id: bytes) -> bool:
holder = self._holders_by_processor_id.get(processor_id)
diff --git a/scaler/worker/agent/profiling_manager.py b/scaler/worker/agent/profiling_manager.py
index f15fbaa..06e4939 100644
--- a/scaler/worker/agent/profiling_manager.py
+++ b/scaler/worker/agent/profiling_manager.py
@@ -1,5 +1,5 @@
import dataclasses
-
+import time
from typing import Dict, Optional
import psutil
@@ -15,7 +15,7 @@ class _ProcessProfiler:
current_task_id: Optional[bytes] = None
- start_time: Optional[int] = None
+ start_time: Optional[float] = None
init_memory_rss: Optional[int] = None
peak_memory_rss: Optional[int] = None
@@ -57,7 +57,7 @@ def on_task_end(self, pid: int, task_id: bytes) -> ProfileResult:
raise ValueError(f"process {pid=} is not registered.")
if task_id != process_profiler.current_task_id:
- raise ValueError(f"task {task_id=} is not the current task task_id={process_profiler.current_task_id}.")
+ raise ValueError(f"task {task_id=!r} is not the current task task_id={process_profiler.current_task_id!r}.")
assert process_profiler.start_time is not None
assert process_profiler.init_memory_rss is not None
@@ -83,8 +83,7 @@ async def routine(self):
@staticmethod
def __process_cpu_time(process: psutil.Process) -> float:
- cpu_times = process.cpu_times()
- return cpu_times.user + cpu_times.system
+ return time.monotonic()
@staticmethod
def __process_memory_rss(process: psutil.Process) -> int:
diff --git a/scaler/worker/agent/task_manager.py b/scaler/worker/agent/task_manager.py
index 32a2a58..325ac80 100644
--- a/scaler/worker/agent/task_manager.py
+++ b/scaler/worker/agent/task_manager.py
@@ -1,7 +1,8 @@
from typing import Dict, Optional, Tuple
from scaler.io.async_connector import AsyncConnector
-from scaler.protocol.python.message import Task, TaskCancel, TaskResult, TaskStatus
+from scaler.protocol.python.common import TaskStatus
+from scaler.protocol.python.message import Task, TaskCancel, TaskResult
from scaler.utility.metadata.task_flags import retrieve_task_flags_from_task
from scaler.utility.mixins import Looper
from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue
@@ -41,9 +42,12 @@ async def on_cancel_task(self, task_cancel: TaskCancel):
task = self._queued_task_id_to_task.pop(task_cancel.task_id)
self._queued_task_ids.remove(task_cancel.task_id)
- result = TaskResult(task_cancel.task_id, TaskStatus.Canceled)
if task_cancel.flags.retrieve_task_object:
- result.results = list(task.serialize())
+ result = TaskResult.new_msg(
+ task_cancel.task_id, TaskStatus.Canceled, b"", [task.get_message().to_bytes()]
+ )
+ else:
+ result = TaskResult.new_msg(task_cancel.task_id, TaskStatus.Canceled)
await self._connector_external.send(result)
return
@@ -51,16 +55,17 @@ async def on_cancel_task(self, task_cancel: TaskCancel):
if not task_cancel.flags.force:
return
- cancelled_running_task = self._processor_manager.on_cancel_task(task_cancel.task_id)
- if cancelled_running_task is not None:
- result = TaskResult(task_cancel.task_id, TaskStatus.Canceled)
- if task_cancel.flags.retrieve_task_object:
- result.results = list(cancelled_running_task.serialize())
-
- await self._connector_external.send(result)
+ canceled_running_task = self._processor_manager.on_cancel_task(task_cancel.task_id)
+ if canceled_running_task is not None:
+ payload = [canceled_running_task.get_message().to_bytes()] if task_cancel.flags.retrieve_task_object else []
+ await self._connector_external.send(
+ TaskResult.new_msg(
+ task_id=task_cancel.task_id, status=TaskStatus.Canceled, metadata=b"", results=payload
+ )
+ )
return
- await self._connector_external.send(TaskResult(task_cancel.task_id, TaskStatus.NotFound))
+ await self._connector_external.send(TaskResult.new_msg(task_cancel.task_id, TaskStatus.NotFound))
async def on_task_result(self, result: TaskResult):
if result.task_id in self._queued_task_id_to_task:
diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py
index c0dee7d..4de8d2b 100644
--- a/scaler/worker/worker.py
+++ b/scaler/worker/worker.py
@@ -11,14 +11,13 @@
from scaler.protocol.python.message import (
ClientDisconnect,
DisconnectRequest,
- DisconnectType,
- MessageVariant,
ObjectInstruction,
ObjectResponse,
Task,
TaskCancel,
WorkerHeartbeatEcho,
)
+from scaler.protocol.python.mixins import Message
from scaler.utility.event_loop import create_async_loop_routine, register_event_loop
from scaler.utility.exceptions import ClientShutdownException
from scaler.utility.logging.utility import setup_logger
@@ -125,7 +124,7 @@ def __initialize(self):
self.__register_signal()
self._task = self._loop.create_task(self.__get_loops())
- async def __on_receive_external(self, message: MessageVariant):
+ async def __on_receive_external(self, message: Message):
if isinstance(message, WorkerHeartbeatEcho):
await self._heartbeat_manager.on_heartbeat_echo(message)
return
@@ -147,7 +146,7 @@ async def __on_receive_external(self, message: MessageVariant):
return
if isinstance(message, ClientDisconnect):
- if message.type == DisconnectType.Shutdown:
+ if message.disconnect_type == ClientDisconnect.DisconnectType.Shutdown:
raise ClientShutdownException("received client shutdown, quitting")
logging.error(f"Worker received invalid ClientDisconnect type, ignoring {message=}")
return
@@ -169,11 +168,11 @@ async def __get_loops(self):
except (ClientShutdownException, TimeoutError) as e:
logging.info(f"Worker[{self.pid}]: {str(e)}")
- await self._connector_external.send(DisconnectRequest(self._connector_external.identity))
+ await self._connector_external.send(DisconnectRequest.new_msg(self._connector_external.identity))
self._connector_external.destroy()
- self._processor_manager.destroy("quitted")
- logging.info(f"Worker[{self.pid}]: quitted")
+ self._processor_manager.destroy("quited")
+ logging.info(f"Worker[{self.pid}]: quited")
def __run_forever(self):
self._loop.run_until_complete(self._task)
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 28af916..0000000
--- a/setup.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from setuptools import find_packages, setup
-
-from scaler.about import __version__
-
-with open("requirements.txt", "rt") as f:
- requirements = [i.strip() for i in f.readlines()]
-
-setup(
- name="scaler",
- version=__version__,
- packages=find_packages(exclude=("tests",)),
- install_requires=requirements,
- extras_require={
- "graphblas": ["python-graphblas", "numpy"],
- "uvloop": ["uvloop"],
- "gui": ["nicegui[plotly]"],
- "all": ["python-graphblas", "numpy", "uvloop", "nicegui[plotly]"],
- },
- url="",
- license="",
- author="Citi",
- author_email="opensource@citi.com",
- description="Scaler Distributed Framework",
- entry_points={
- "console_scripts": [
- "scaler_scheduler=scaler.entry_points.scheduler:main",
- "scaler_cluster=scaler.entry_points.cluster:main",
- "scaler_top=scaler.entry_points.top:main",
- "scaler_ui=scaler.entry_points.webui:main",
- ]
- },
-)
diff --git a/tests/test_graph.py b/tests/test_graph.py
index b5a5e07..3d22b41 100644
--- a/tests/test_graph.py
+++ b/tests/test_graph.py
@@ -148,7 +148,7 @@ def func(a):
def test_cull_graph(self):
graph = {
- "a": (lambda *_: None),
+ "a": (lambda *_: None,),
"b": (lambda *_: None, "a"),
"c": (lambda *_: None, "a"),
"d": (lambda *_: None, "b"),
@@ -156,8 +156,8 @@ def test_cull_graph(self):
"f": (lambda *_: None, "c"),
}
- def filter_keys(graph, keys):
- return {key: value for key, value in graph.items() if key in keys}
+ def filter_keys(_graph, keys):
+ return {key: value for key, value in _graph.items() if key in keys}
self.assertEqual(cull_graph(graph, ["d"]), filter_keys(graph, ["a", "b", "d"]))
self.assertEqual(cull_graph(graph, ["e"]), filter_keys(graph, ["a", "b", "c", "e"]))
diff --git a/tests/test_object_usage.py b/tests/test_object_usage.py
index 6878348..ca05b5d 100644
--- a/tests/test_object_usage.py
+++ b/tests/test_object_usage.py
@@ -22,7 +22,7 @@ class TestObjectUsage(unittest.TestCase):
def test_object_usage(self):
setup_logger()
- object_usage = ObjectTracker("sample", sample_ready)
+ object_usage: ObjectTracker[str, Sample] = ObjectTracker("sample", sample_ready)
object_usage.add_object(Sample("a", "value1"))
object_usage.add_object(Sample("b", "value2"))
diff --git a/tests/test_worker_object_tracker.py b/tests/test_worker_object_tracker.py
index d97967b..72e6509 100644
--- a/tests/test_worker_object_tracker.py
+++ b/tests/test_worker_object_tracker.py
@@ -1,14 +1,7 @@
import unittest
-from scaler.protocol.python.message import (
- ObjectContent,
- ObjectInstruction,
- ObjectInstructionType,
- ObjectRequest,
- ObjectRequestType,
- ObjectResponse,
- ObjectResponseType,
-)
+from scaler.protocol.python.common import ObjectContent
+from scaler.protocol.python.message import ObjectInstruction, ObjectRequest, ObjectResponse
from scaler.worker.agent.object_tracker import VanillaObjectTracker
@@ -16,33 +9,45 @@ class TestWorkerObjectTracker(unittest.TestCase):
def test_object_tracker(self) -> None:
tracker = VanillaObjectTracker()
- tracker.on_object_request(b"processor_1", ObjectRequest(ObjectRequestType.Get, (b"object_1", b"object_2")))
+ tracker.on_object_request(
+ b"processor_1", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_1", b"object_2"))
+ )
- tracker.on_object_request(b"processor_2", ObjectRequest(ObjectRequestType.Get, (b"object_1", b"object_2")))
- tracker.on_object_request(b"processor_2", ObjectRequest(ObjectRequestType.Get, (b"object_3",)))
+ tracker.on_object_request(
+ b"processor_2", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_1", b"object_2"))
+ )
+ tracker.on_object_request(
+ b"processor_2", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_3",))
+ )
- tracker.on_object_request(b"processor_3", ObjectRequest(ObjectRequestType.Get, (b"object_4", b"object_5")))
+ tracker.on_object_request(
+ b"processor_3", ObjectRequest.new_msg(ObjectRequest.ObjectRequestType.Get, (b"object_4", b"object_5"))
+ )
response_1 = tracker.on_object_response(
- ObjectResponse(ObjectResponseType.Content, ObjectContent((b"object_1", b"object_2")))
+ ObjectResponse.new_msg(
+ ObjectResponse.ObjectResponseType.Content, ObjectContent.new_msg((b"object_1", b"object_2"))
+ )
)
self.assertSetEqual(set(response_1), {b"processor_1", b"processor_2"})
response_2 = tracker.on_object_response(
- ObjectResponse(ObjectResponseType.Content, ObjectContent((b"object_unknown",)))
+ ObjectResponse.new_msg(
+ ObjectResponse.ObjectResponseType.Content, ObjectContent.new_msg((b"object_unknown",))
+ )
)
self.assertSetEqual(response_2, set())
response_3 = tracker.on_object_response(
- ObjectResponse(ObjectResponseType.Content, ObjectContent((b"object_3",)))
+ ObjectResponse.new_msg(ObjectResponse.ObjectResponseType.Content, ObjectContent.new_msg((b"object_3",)))
)
self.assertSetEqual(response_3, {b"processor_2"})
object_instructions = tracker.on_object_instruction(
- ObjectInstruction(
- ObjectInstructionType.Delete,
+ ObjectInstruction.new_msg(
+ ObjectInstruction.ObjectInstructionType.Delete,
b"client",
- ObjectContent(
+ ObjectContent.new_msg(
(b"object_1", b"object_2", b"object_3"),
(b"name_1", b"name_2", b"name_3"),
(b"content_1", b"content_2", b"content_3"),