diff --git a/scaler/about.py b/scaler/about.py index 0a40090..880e579 100644 --- a/scaler/about.py +++ b/scaler/about.py @@ -1 +1 @@ -__version__ = "1.8.12" +__version__ = "1.8.13" diff --git a/scaler/client/client.py b/scaler/client/client.py index e5c9899..ea15d8f 100644 --- a/scaler/client/client.py +++ b/scaler/client/client.py @@ -524,7 +524,7 @@ def __get_task_flags(self) -> TaskFlags: parent_task_priority = self.__get_parent_task_priority() if parent_task_priority is not None: - task_priority = parent_task_priority - 1 + task_priority = parent_task_priority + 1 else: task_priority = 0 diff --git a/scaler/io/config.py b/scaler/io/config.py index 7e22715..09ccb0d 100644 --- a/scaler/io/config.py +++ b/scaler/io/config.py @@ -34,7 +34,7 @@ # if didn't receive heartbeat for following seconds, then scheduler will treat client as dead and cancel remaining # tasks for this client -DEFAULT_CLIENT_TIMEOUT_SECONDS = 600 +DEFAULT_CLIENT_TIMEOUT_SECONDS = 60 # number of seconds for load balance, if value is 0 means disable load balance DEFAULT_LOAD_BALANCE_SECONDS = 1 diff --git a/scaler/utility/queues/async_priority_queue.py b/scaler/utility/queues/async_priority_queue.py index 6f58ed4..123dfba 100644 --- a/scaler/utility/queues/async_priority_queue.py +++ b/scaler/utility/queues/async_priority_queue.py @@ -1,4 +1,5 @@ import heapq +import sys from asyncio import Queue from typing import Dict, List, Tuple, Union @@ -59,7 +60,7 @@ def __to_lowest_priority(cls, original_priority: PriorityType) -> PriorityType: if isinstance(original_priority, tuple): return tuple(cls.__to_lowest_priority(value) for value in original_priority) else: - return -1 + return -sys.maxsize - 1 @classmethod def __to_lower_priority(cls, original_priority: PriorityType) -> PriorityType: diff --git a/scaler/worker/agent/heartbeat_manager.py b/scaler/worker/agent/heartbeat_manager.py index 98b51be..b85274c 100644 --- a/scaler/worker/agent/heartbeat_manager.py +++ b/scaler/worker/agent/heartbeat_manager.py @@ -14,7 +14,6 @@ class VanillaHeartbeatManager(Looper, HeartbeatManager): def __init__(self): self._agent_process = psutil.Process() - self._worker_process: Optional[psutil.Process] = None self._connector_external: Optional[AsyncConnector] = None self._worker_task_manager: Optional[TaskManager] = None @@ -36,9 +35,6 @@ def register( self._timeout_manager = timeout_manager self._processor_manager = processor_manager - def set_processor_pid(self, process_id: int): - self._worker_process = psutil.Process(process_id) - async def on_heartbeat_echo(self, heartbeat: WorkerHeartbeatEcho): if self._start_timestamp_ns == 0: # not handling echo if we didn't send out heartbeat @@ -49,17 +45,21 @@ async def on_heartbeat_echo(self, heartbeat: WorkerHeartbeatEcho): self._timeout_manager.update_last_seen_time() async def routine(self): - if self._worker_process is None: + processors = self._processor_manager.processors() + + if len(processors) == 0: return if self._start_timestamp_ns != 0: # already sent heartbeat, expecting heartbeat echo, so not sending return - if self._worker_process.status() in {psutil.STATUS_ZOMBIE, psutil.STATUS_DEAD}: - await self._processor_manager.on_failing_task(self._worker_process.status()) + for processor_holder in processors: + status = processor_holder.process().status() + if status in {psutil.STATUS_ZOMBIE, psutil.STATUS_DEAD}: + await self._processor_manager.on_failing_processor(processor_holder.processor_id(), status) - processors = self._processor_manager.processors() + processors = self._processor_manager.processors() # refreshes for removed dead and zombie processors num_suspended_processors = self._processor_manager.num_suspended_processors() await self._connector_external.send( @@ -68,7 +68,7 @@ async def routine(self): psutil.virtual_memory().available, self._worker_task_manager.get_queued_size() - num_suspended_processors, self._latency_us, - self._processor_manager.task_lock(), + self._processor_manager.can_accept_task(), [self.__get_processor_status_from_holder(processor) for processor in processors], ) ) @@ -77,10 +77,17 @@ async def routine(self): @staticmethod def __get_processor_status_from_holder(processor: ProcessorHolder) -> ProcessorStatus: process = processor.process() + + try: + resource = Resource.new_msg(int(process.cpu_percent() * 10), process.memory_info().rss) + except psutil.ZombieProcess: + # Assumes dead processes do not use any resources + resource = Resource.new_msg(0, 0) + return ProcessorStatus.new_msg( processor.pid(), processor.initialized(), processor.task() is not None, processor.suspended(), - Resource.new_msg(int(process.cpu_percent() * 10), process.memory_info().rss), + resource, ) diff --git a/scaler/worker/agent/mixins.py b/scaler/worker/agent/mixins.py index 6ad3f53..8cec6de 100644 --- a/scaler/worker/agent/mixins.py +++ b/scaler/worker/agent/mixins.py @@ -15,10 +15,6 @@ class HeartbeatManager(metaclass=abc.ABCMeta): - @abc.abstractmethod - def set_processor_pid(self, process_id: int): - raise NotImplementedError() - @abc.abstractmethod async def on_heartbeat_echo(self, heartbeat: WorkerHeartbeatEcho): raise NotImplementedError() @@ -58,31 +54,31 @@ def on_object_response(self, request: ObjectResponse): raise NotImplementedError() @abc.abstractmethod - async def acquire_task_active_lock(self): + def can_accept_task(self) -> bool: raise NotImplementedError() @abc.abstractmethod - async def on_task(self, task: Task) -> bool: + async def wait_until_can_accept_task(self): raise NotImplementedError() @abc.abstractmethod - def on_cancel_task(self, task_id: bytes) -> Optional[Task]: + async def on_task(self, task: Task) -> bool: raise NotImplementedError() @abc.abstractmethod - async def on_failing_task(self, error: str): + async def on_cancel_task(self, task_id: bytes) -> Optional[Task]: raise NotImplementedError() @abc.abstractmethod - def on_suspend_task(self, task_id: bytes) -> bool: + async def on_failing_processor(self, processor_id: bytes, process_status: str): raise NotImplementedError() @abc.abstractmethod - def on_resume_task(self, task_id: bytes) -> bool: + async def on_suspend_task(self, task_id: bytes) -> bool: raise NotImplementedError() @abc.abstractmethod - def restart_current_processor(self, reason: str): + def on_resume_task(self, task_id: bytes) -> bool: raise NotImplementedError() @abc.abstractmethod @@ -109,10 +105,6 @@ def processors(self) -> List[ProcessorHolder]: def num_suspended_processors(self) -> int: raise NotImplementedError() - @abc.abstractmethod - def task_lock(self) -> bool: - raise NotImplementedError() - class ProfilingManager(metaclass=abc.ABCMeta): @abc.abstractmethod diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py index b331bee..ccbaa23 100644 --- a/scaler/worker/agent/processor/processor.py +++ b/scaler/worker/agent/processor/processor.py @@ -41,6 +41,7 @@ def __init__( event_loop: str, address: ZMQConfig, resume_event: Optional[EventType], + resumed_event: Optional[EventType], garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, logging_paths: Tuple[str, ...], @@ -52,6 +53,7 @@ def __init__( self._address = address self._resume_event = resume_event + self._resumed_event = resumed_event self._garbage_collect_interval_seconds = garbage_collect_interval_seconds self._trim_memory_threshold_bytes = trim_memory_threshold_bytes @@ -108,8 +110,14 @@ def __interrupt(self, *args): def __suspend(self, *args): assert self._resume_event is not None + assert self._resumed_event is not None + self._resume_event.wait() # stops any computation in the main thread until the event is triggered + # Ensures the processor agent knows we stopped waiting on `_resume_event`, as to avoid re-entrant wait on the + # event. + self._resumed_event.set() + def __run_forever(self): try: self._connector.send(ProcessorInitialized.new_msg()) diff --git a/scaler/worker/agent/processor_holder.py b/scaler/worker/agent/processor_holder.py index 48ea1c2..bacd8d9 100644 --- a/scaler/worker/agent/processor_holder.py +++ b/scaler/worker/agent/processor_holder.py @@ -1,10 +1,9 @@ -import asyncio import logging import os import signal -from multiprocessing import Event from typing import Optional, Tuple +import multiprocessing import psutil from scaler.io.config import DEFAULT_PROCESSOR_KILL_DELAY_SECONDS @@ -26,19 +25,22 @@ def __init__( ): self._processor_id: Optional[bytes] = None self._task: Optional[Task] = None - self._initialized = asyncio.Event() self._suspended = False self._hard_suspend = hard_suspend if hard_suspend: self._resume_event = None + self._resumed_event = None else: - self._resume_event = Event() + context = multiprocessing.get_context("spawn") + self._resume_event = context.Event() + self._resumed_event = context.Event() self._processor = Processor( event_loop=event_loop, address=address, resume_event=self._resume_event, + resumed_event=self._resumed_event, garbage_collect_interval_seconds=garbage_collect_interval_seconds, trim_memory_threshold_bytes=trim_memory_threshold_bytes, logging_paths=logging_paths, @@ -59,14 +61,10 @@ def processor_id(self) -> bytes: return self._processor_id def initialized(self) -> bool: - return self._initialized.is_set() + return self._processor_id is not None - def wait_initialized(self): - return self._initialized.wait() - - def set_initialized(self, processor_id: bytes): + def initialize(self, processor_id: bytes): self._processor_id = processor_id - self._initialized.set() def task(self) -> Optional[Task]: return self._task @@ -81,6 +79,7 @@ def suspend(self): assert self._processor is not None assert self._task is not None assert self._suspended is False + assert self.initialized() if self._hard_suspend: self.__send_signal("SIGSTOP") @@ -92,7 +91,9 @@ def suspend(self): # See https://github.com/Citi/scaler/issues/14 assert self._resume_event is not None + assert self._resumed_event is not None self._resume_event.clear() + self._resumed_event.clear() self.__send_signal(SUSPEND_SIGNAL) @@ -106,8 +107,14 @@ def resume(self): self.__send_signal("SIGCONT") else: assert self._resume_event is not None + assert self._resumed_event is not None + self._resume_event.set() + # Waits until the processor resumes processing. This avoids any future call to `suspend()` while the + # processor hasn't returned from the `_resumed_event.wait()` call yet (causes a re-entrant error on Linux). + self._resumed_event.wait() + self._suspended = False def kill(self): diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index 10e9fc8..26c559d 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -65,7 +65,7 @@ def __init__( self._suspended_holders_by_task_id: Dict[bytes, ProcessorHolder] = {} self._holders_by_processor_id: Dict[bytes, ProcessorHolder] = {} - self._task_active_lock: asyncio.Lock = asyncio.Lock() + self._can_accept_task_lock: asyncio.Lock = asyncio.Lock() self._binder_internal: AsyncBinder = AsyncBinder( context=context, name="processor_manager", address=self._address, identity=None @@ -86,8 +86,9 @@ def register( self._object_tracker = object_tracker self._connector_external = connector_external - def initialize(self): + async def initialize(self): # setup_logger() + await self._can_accept_task_lock.acquire() # prevents processor to accept task until initialized self.__start_new_processor() async def routine(self): @@ -105,29 +106,39 @@ async def on_object_response(self, response: ObjectResponse): for process_id in processors_ids: await self._binder_internal.send(process_id, response) - async def acquire_task_active_lock(self): - await self._task_active_lock.acquire() + def can_accept_task(self) -> bool: + return self._can_accept_task_lock.locked() + + async def wait_until_can_accept_task(self): + """ + Makes sure a processor is ready to start processing a new or suspended task. + + Must be called before any call to `on_task()` or `on_task_resume()`. + """ + + await self._can_accept_task_lock.acquire() async def on_task(self, task: Task) -> bool: - assert self._current_holder is not None - assert self.current_task() is None + assert self._can_accept_task_lock.locked() + assert self.initialized() + + holder = self._current_holder - await self._current_holder.wait_initialized() + assert holder.task() is None + holder.set_task(task) - self._current_holder.set_task(task) + self._profiling_manager.on_task_start(holder.pid(), task.task_id) - await self._binder_internal.send(self._current_holder.processor_id(), task) - self._profiling_manager.on_task_start(self._current_holder.pid(), task.task_id) + await self._binder_internal.send(holder.processor_id(), task) return True - def on_cancel_task(self, task_id: bytes) -> Optional[Task]: + async def on_cancel_task(self, task_id: bytes) -> Optional[Task]: assert self._current_holder is not None if self.current_task_id() == task_id: current_task = self.current_task() - self._task_active_lock.release() - self.restart_current_processor(f"cancel task_id={task_id.hex()}") + self.__restart_current_processor(f"cancel task_id={task_id.hex()}") return current_task if task_id in self._suspended_holders_by_task_id: @@ -138,17 +149,30 @@ def on_cancel_task(self, task_id: bytes) -> Optional[Task]: return None - async def on_failing_task(self, process_status: str): + async def on_failing_processor(self, processor_id: bytes, process_status: str): assert self._current_holder is not None - task = self.current_task() + holder = self._holders_by_processor_id.get(processor_id) + + if holder is None: + return + + task = holder.task() + if task is not None: + profile_result = self.__end_task(holder) # profiling the task should happen before killing the processor + else: + profile_result = None + + reason = f"process died {process_status=}" + if holder == self._current_holder: + self.__restart_current_processor(reason) + else: + self.__kill_processor(reason, holder) if task is not None: source = task.source task_id = task.task_id - profile_result = self.__end_task(self._current_holder) - result_object_bytes = chunk_to_list_of_bytes(serialize_failure(ProcessorDiedError(f"{process_status=}"))) result_object_id = generate_object_id(source, uuid.uuid4().bytes) @@ -164,29 +188,27 @@ async def on_failing_task(self, process_status: str): TaskResult.new_msg(task_id, TaskStatus.Failed, profile_result.serialize(), [result_object_id]) ) - self.restart_current_processor(f"process died {process_status=}") - - def on_suspend_task(self, task_id: bytes) -> bool: + async def on_suspend_task(self, task_id: bytes) -> bool: assert self._current_holder is not None + holder = self._current_holder - current_task = self.current_task() + current_task = holder.task() if current_task is None or current_task.task_id != task_id: return False - self._current_holder.suspend() - self._suspended_holders_by_task_id[task_id] = self._current_holder + holder.suspend() + self._suspended_holders_by_task_id[task_id] = holder - logging.info(f"Worker[{os.getpid()}]: suspend Processor[{self._current_holder.pid()}]") + logging.info(f"Worker[{os.getpid()}]: suspend Processor[{holder.pid()}]") self.__start_new_processor() - self._task_active_lock.release() - return True def on_resume_task(self, task_id: bytes) -> bool: - assert self._current_holder is not None + assert self._can_accept_task_lock.locked() + assert self.initialized() if self.current_task() is not None: return False @@ -201,19 +223,10 @@ def on_resume_task(self, task_id: bytes) -> bool: self._current_holder = suspended_holder suspended_holder.resume() - self._heartbeat.set_processor_pid(suspended_holder.pid()) - logging.info(f"Worker[{os.getpid()}]: resume Processor[{self._current_holder.pid()}]") return True - def restart_current_processor(self, reason: str): - assert self._current_holder is not None - - self.__kill_processor(reason, self._current_holder) - - self.__start_new_processor() - def destroy(self, reason: str): self.__kill_all_processors(reason) self._binder_internal.destroy() @@ -240,9 +253,6 @@ def processors(self) -> List[ProcessorHolder]: def num_suspended_processors(self) -> int: return len(self._suspended_holders_by_task_id) - def task_lock(self) -> bool: - return self._task_active_lock.locked() - def __start_new_processor(self): self._current_holder = ProcessorHolder( self._event_loop, @@ -255,16 +265,13 @@ def __start_new_processor(self): ) processor_pid = self._current_holder.pid() - assert processor_pid is not None - self._heartbeat.set_processor_pid(processor_pid) self._profiling_manager.on_process_start(processor_pid) logging.info(f"Worker[{os.getpid()}]: start Processor[{processor_pid}]") def __kill_processor(self, reason: str, holder: ProcessorHolder): processor_pid = holder.pid() - assert processor_pid is not None self._profiling_manager.on_process_end(processor_pid) @@ -276,6 +283,12 @@ def __kill_processor(self, reason: str, holder: ProcessorHolder): logging.info(f"Worker[{os.getpid()}]: stop Processor[{processor_pid}], reason: {reason}") + def __restart_current_processor(self, reason: str): + assert self._current_holder is not None + + self.__kill_processor(reason, self._current_holder) + self.__start_new_processor() + def __kill_all_processors(self, reason: str): if self._current_holder is not None: self.__kill_processor(reason, self._current_holder) @@ -291,9 +304,6 @@ def __end_task(self, processor_holder: ProcessorHolder) -> ProfileResult: profile_result = self._profiling_manager.on_task_end(processor_holder.pid(), processor_holder.task().task_id) processor_holder.set_task(None) - if self._current_holder == processor_holder: - self._task_active_lock.release() - return profile_result async def __on_receive_internal(self, processor_id: bytes, message: Message): @@ -322,7 +332,9 @@ async def __on_internal_processor_initialized(self, processor_id: bytes): return self._holders_by_processor_id[processor_id] = self._current_holder - self._current_holder.set_initialized(processor_id) + self._current_holder.initialize(processor_id) + + self._can_accept_task_lock.release() async def __on_internal_object_request(self, processor_id: bytes, request: ObjectRequest): if not self.__processor_ready_to_process_object(processor_id): @@ -345,6 +357,8 @@ async def __on_internal_task_result(self, processor_id: bytes, task_result: Task assert self._current_holder.processor_id() == processor_id profile_result = self.__end_task(self._current_holder) + + release_task_lock = True elif task_id in self._suspended_holders_by_task_id: # Receiving a task result from a suspended processor is possible as the message might have been queued while # we were suspending the process. @@ -355,6 +369,8 @@ async def __on_internal_task_result(self, processor_id: bytes, task_result: Task profile_result = self.__end_task(holder) self.__kill_processor("task finished in suspended processor", holder) + + release_task_lock = False else: return @@ -367,6 +383,10 @@ async def __on_internal_task_result(self, processor_id: bytes, task_result: Task ) ) + # task lock must be released after calling `TaskManager.on_task_result()` + if release_task_lock: + self._can_accept_task_lock.release() + def __processor_ready_to_process_object(self, processor_id: bytes) -> bool: holder = self._holders_by_processor_id.get(processor_id) diff --git a/scaler/worker/agent/profiling_manager.py b/scaler/worker/agent/profiling_manager.py index 5bd7365..dc6502b 100644 --- a/scaler/worker/agent/profiling_manager.py +++ b/scaler/worker/agent/profiling_manager.py @@ -1,4 +1,5 @@ import dataclasses +import logging import time from typing import Dict, Optional @@ -68,7 +69,13 @@ def on_task_end(self, pid: int, task_id: bytes) -> ProfileResult: process = process_profiler.process time_delta = self.__process_time() - process_profiler.start_time - cpu_time_delta = self.__process_cpu_time(process) - process_profiler.start_cpu_time + + try: + cpu_time_delta = self.__process_cpu_time(process) - process_profiler.start_cpu_time + except psutil.ZombieProcess: + logging.warning(f"profiling zombie process: {pid=}") + cpu_time_delta = 0 + memory_delta = process_profiler.peak_memory_rss - process_profiler.init_memory_rss process_profiler.current_task_id = None @@ -80,9 +87,12 @@ def on_task_end(self, pid: int, task_id: bytes) -> ProfileResult: async def routine(self): for process_profiler in self._process_profiler_by_pid.values(): if process_profiler.current_task_id is not None: - process_profiler.peak_memory_rss = max( - process_profiler.peak_memory_rss, self.__process_memory_rss(process_profiler.process) - ) + try: + process_profiler.peak_memory_rss = max( + process_profiler.peak_memory_rss, self.__process_memory_rss(process_profiler.process) + ) + except psutil.ZombieProcess: + logging.warning(f"profiling zombie process: pid={process_profiler.process.pid}") @staticmethod def __process_time(): diff --git a/scaler/worker/agent/task_manager.py b/scaler/worker/agent/task_manager.py index 325ac80..8bb380c 100644 --- a/scaler/worker/agent/task_manager.py +++ b/scaler/worker/agent/task_manager.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Set from scaler.io.async_connector import AsyncConnector from scaler.protocol.python.common import TaskStatus @@ -19,9 +19,19 @@ def __init__(self, task_timeout_seconds: int): self._queued_task_id_to_task: Dict[bytes, Task] = dict() # Queued tasks are sorted first by task's priorities, then suspended tasks are prioritized over non yet started - # tasks. + # tasks. Finally the sorted queue ensure we execute the oldest tasks first. + # + # For example, if we receive these tasks in this order: + # 1. Task(priority=0) [suspended] + # 2. Task(priority=3) [suspended] + # 3. Task(priority=3) + # 4. Task(priority=0) + # + # We want to execute the tasks in this order: 2-3-1-4. self._queued_task_ids = AsyncSortedPriorityQueue() + self._processing_task_ids: Set[bytes] = set() # Tasks associated with a processor, including suspended tasks + self._connector_external: Optional[AsyncConnector] = None self._processor_manager: Optional[ProcessorManager] = None @@ -30,42 +40,36 @@ def register(self, connector: AsyncConnector, processor_manager: ProcessorManage self._processor_manager = processor_manager async def on_task_new(self, task: Task): - task_priority = self.__get_task_priority(task) - - self._queued_task_id_to_task[task.task_id] = task - await self._queued_task_ids.put(((task_priority, _QUEUED_TASKS_PRIORITY), task.task_id)) + self.__enqueue_task(task, is_suspended=False) - await self.__suspend_if_priority_is_lower(task_priority) + await self.__suspend_if_priority_is_higher(task) async def on_cancel_task(self, task_cancel: TaskCancel): - if task_cancel.task_id in self._queued_task_id_to_task: - task = self._queued_task_id_to_task.pop(task_cancel.task_id) - self._queued_task_ids.remove(task_cancel.task_id) + task_id = task_cancel.task_id - if task_cancel.flags.retrieve_task_object: - result = TaskResult.new_msg( - task_cancel.task_id, TaskStatus.Canceled, b"", [task.get_message().to_bytes()] - ) - else: - result = TaskResult.new_msg(task_cancel.task_id, TaskStatus.Canceled) + task_not_found = task_id not in self._processing_task_ids and task_id not in self._queued_task_id_to_task + task_is_processing = task_id in self._processing_task_ids + if task_not_found or (task_is_processing and not task_cancel.flags.force): + result = TaskResult.new_msg(task_id, TaskStatus.NotFound) await self._connector_external.send(result) return - if not task_cancel.flags.force: - return + # A suspended task will be both processing AND queued - canceled_running_task = self._processor_manager.on_cancel_task(task_cancel.task_id) - if canceled_running_task is not None: - payload = [canceled_running_task.get_message().to_bytes()] if task_cancel.flags.retrieve_task_object else [] - await self._connector_external.send( - TaskResult.new_msg( - task_id=task_cancel.task_id, status=TaskStatus.Canceled, metadata=b"", results=payload - ) - ) - return + if task_cancel.task_id in self._queued_task_id_to_task: + canceled_task = self._queued_task_id_to_task.pop(task_cancel.task_id) + self._queued_task_ids.remove(task_cancel.task_id) + + if task_is_processing: + self._processing_task_ids.remove(task_cancel.task_id) + canceled_task = await self._processor_manager.on_cancel_task(task_cancel.task_id) - await self._connector_external.send(TaskResult.new_msg(task_cancel.task_id, TaskStatus.NotFound)) + assert canceled_task is not None + + payload = [canceled_task.get_message().to_bytes()] if task_cancel.flags.retrieve_task_object else [] + result = TaskResult.new_msg(task_id=task_id, status=TaskStatus.Canceled, metadata=b"", results=payload) + await self._connector_external.send(result) async def on_task_result(self, result: TaskResult): if result.task_id in self._queued_task_id_to_task: @@ -73,6 +77,8 @@ async def on_task_result(self, result: TaskResult): self._queued_task_id_to_task.pop(result.task_id) self._queued_task_ids.remove(result.task_id) + self._processing_task_ids.remove(result.task_id) + await self._connector_external.send(result) async def routine(self): @@ -82,36 +88,51 @@ def get_queued_size(self): return self._queued_task_ids.qsize() async def __processing_task(self): - await self._processor_manager.acquire_task_active_lock() + await self._processor_manager.wait_until_can_accept_task() - priority, task_id = await self._queued_task_ids.get() + _, task_id = await self._queued_task_ids.get() task = self._queued_task_id_to_task.pop(task_id) - if not self.__is_suspended_task(priority): + if task_id not in self._processing_task_ids: + self._processing_task_ids.add(task_id) await self._processor_manager.on_task(task) else: self._processor_manager.on_resume_task(task_id) - async def __suspend_if_priority_is_lower(self, new_task_priority: int): + async def __suspend_if_priority_is_higher(self, new_task: Task): current_task = self._processor_manager.current_task() if current_task is None: return + new_task_priority = self.__get_task_priority(new_task) current_task_priority = self.__get_task_priority(current_task) - if new_task_priority >= current_task_priority: + if new_task_priority <= current_task_priority: return - self._processor_manager.on_suspend_task(current_task.task_id) + self.__enqueue_task(current_task, is_suspended=True) - await self._queued_task_ids.put(((current_task_priority, _SUSPENDED_TASKS_PRIORITY), current_task.task_id)) - self._queued_task_id_to_task[current_task.task_id] = current_task + await self._processor_manager.on_suspend_task(current_task.task_id) + + def __enqueue_task(self, task: Task, is_suspended: bool): + task_priority = self.__get_task_priority(task) + + # Higher-priority tasks have an higher priority value. But as the queue is sorted by increasing order, we negate + # the inserted value so they will be at the head of the queue. + if is_suspended: + queue_priority = (-task_priority, _SUSPENDED_TASKS_PRIORITY) + else: + queue_priority = (-task_priority, _QUEUED_TASKS_PRIORITY) + + self._queued_task_ids.put_nowait((queue_priority, task.task_id)) + self._queued_task_id_to_task[task.task_id] = task @staticmethod def __get_task_priority(task: Task) -> int: - return retrieve_task_flags_from_task(task).priority + priority = retrieve_task_flags_from_task(task).priority - @staticmethod - def __is_suspended_task(priority: Tuple[int, int]) -> bool: - return priority[1] == _SUSPENDED_TASKS_PRIORITY + if priority < 0: + raise ValueError(f"invalid task priority, must be positive or zero, got {priority}") + + return priority diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index c594536..590c17a 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -123,7 +123,6 @@ def __initialize(self): object_tracker=self._object_tracker, connector_external=self._connector_external, ) - self._processor_manager.initialize() self._loop = asyncio.get_event_loop() self.__register_signal() @@ -159,6 +158,8 @@ async def __on_receive_external(self, message: Message): raise TypeError(f"Unknown {message=}") async def __get_loops(self): + await self._processor_manager.initialize() + try: await asyncio.gather( create_async_loop_routine(self._connector_external.routine, 0), diff --git a/tests/test_async_priority_queue.py b/tests/test_async_priority_queue.py index 7431b2e..c5f7d07 100644 --- a/tests/test_async_priority_queue.py +++ b/tests/test_async_priority_queue.py @@ -17,16 +17,18 @@ async def async_test(): await queue.put((5, 5)) await queue.put((2, 2)) await queue.put((6, 6)) + await queue.put((-3, 0)) # supports negative priorities await queue.put((1, 1)) await queue.put((4, 4)) await queue.put((3, 3)) queue.remove(2) queue.remove(3) - self.assertEqual(queue.qsize(), 4) + self.assertEqual(queue.qsize(), 5) queue.decrease_priority(4) # (4, 4) becomes (3, 4) + self.assertEqual(await queue.get(), (-3, 0)) self.assertEqual(await queue.get(), (1, 1)) self.assertEqual(await queue.get(), (3, 4)) self.assertEqual(await queue.get(), (5, 5)) diff --git a/tests/test_async_sorted_priority_queue.py b/tests/test_async_sorted_priority_queue.py index d124f21..f5e64e9 100644 --- a/tests/test_async_sorted_priority_queue.py +++ b/tests/test_async_sorted_priority_queue.py @@ -21,11 +21,13 @@ async def async_test(): await queue.put([1, 1]) await queue.put([3, 6]) await queue.put([2, 4]) + await queue.put([-3, 0]) # supports negative priorities await queue.put([1, 2]) queue.remove(4) - self.assertEqual(queue.qsize(), 5) + self.assertEqual(queue.qsize(), 6) + self.assertEqual(await queue.get(), [-3, 0]) self.assertEqual(await queue.get(), [1, 1]) self.assertEqual(await queue.get(), [1, 2]) self.assertEqual(await queue.get(), [2, 3]) diff --git a/tests/test_future.py b/tests/test_future.py index 011f115..789218b 100644 --- a/tests/test_future.py +++ b/tests/test_future.py @@ -51,7 +51,7 @@ def test_state(self): self.assertTrue(fut.running()) self.assertFalse(fut.done()) - time.sleep(1.5) + fut.result() self.assertFalse(fut.running()) self.assertTrue(fut.done()) diff --git a/tests/test_nested_task.py b/tests/test_nested_task.py index e675027..f2831aa 100644 --- a/tests/test_nested_task.py +++ b/tests/test_nested_task.py @@ -29,7 +29,6 @@ def test_recursive_task(self) -> None: result = client.submit(factorial, client, 10).result() self.assertEqual(result, 3_628_800) - @unittest.skip("this test occasionally never finishes") def test_multiple_recursive_task(self) -> None: with Client(self.address) as client: result = client.submit(fibonacci, client, 8).result()