From bc0c60c5989567bd6748c56a393da0b03a4e9c28 Mon Sep 17 00:00:00 2001 From: Tom <56171752+Flying-Tom@users.noreply.github.com> Date: Thu, 21 Sep 2023 09:59:25 +0800 Subject: [PATCH] ENH: Add a check whether scheduling is stuck (#695) Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- .../_mars/deploy/oscar/base_config.yml | 9 ++ .../services/scheduling/worker/execution.py | 109 +++++++++++++----- .../services/scheduling/worker/service.py | 1 + .../scheduling/worker/tests/test_execution.py | 66 +++++++++-- 4 files changed, 146 insertions(+), 39 deletions(-) diff --git a/python/xorbits/_mars/deploy/oscar/base_config.yml b/python/xorbits/_mars/deploy/oscar/base_config.yml index 99754f1bf..51305f766 100644 --- a/python/xorbits/_mars/deploy/oscar/base_config.yml +++ b/python/xorbits/_mars/deploy/oscar/base_config.yml @@ -56,6 +56,15 @@ scheduling: # Max number of concurrent speculative run for a subtask. max_concurrent_run: 3 subtask_cancel_timeout: 5 + stage_monitor: + enable_check: false + refresh_time: 3 + prepare_data_timeout: 300 + request_quota_timeout: 300 + acquire_slot_timeout: 300 + execution_timeout: null + release_slot_timeout: 300 + finish_timeout: 300 metrics: backend: console # If backend is prometheus, then we can add prometheus config as follows: diff --git a/python/xorbits/_mars/services/scheduling/worker/execution.py b/python/xorbits/_mars/services/scheduling/worker/execution.py index 05b8177ba..b90da08fc 100644 --- a/python/xorbits/_mars/services/scheduling/worker/execution.py +++ b/python/xorbits/_mars/services/scheduling/worker/execution.py @@ -47,29 +47,79 @@ class StageMonitorActor(mo.Actor): - def __init__(self): + def __init__( + self, + monitoring_config: Dict = {}, + ): self._records = dict() - def report_stage(self, keys: Tuple[str, str], stage: SubtaskStage): - if keys not in self._records: - self._records[keys] = { - "history": [], - } - if stage == SubtaskStage.FINISH: - self._records.pop(keys) - return - self._records[keys]["history"].append((time.time(), stage)) + self._enable_check = monitoring_config.get("enable_check", False) + self._refresh_time = monitoring_config.get("refresh_time", 3) + self._kill_timeout = { + SubtaskStage.PREPARE_DATA: monitoring_config.get("prepare_data_timeout"), + SubtaskStage.REQUEST_QUOTA: monitoring_config.get("request_quota_timeout"), + SubtaskStage.ACQUIRE_SLOT: monitoring_config.get("acquire_slot_timeout"), + SubtaskStage.EXECUTE: monitoring_config.get("execution_timeout"), + SubtaskStage.RELEASE_SLOT: monitoring_config.get("release_slot_timeout"), + SubtaskStage.FINISH: monitoring_config.get("finish_timeout"), + } + self._check_task = None + + async def __post_create__(self): + await super().__post_create__() + if self._enable_check: + self._check_task = self.ref().check_subtasks.tell_delay( + delay=self._refresh_time + ) - async def get_stale_tasks(self, status: SubtaskStage, timeout: int = 5): + async def __pre_destroy__(self): + if self._enable_check: + self._check_task.cancel() + await super().__pre_destroy__() + + async def check_subtasks(self): + stale_tasks = await self.get_all_stale_tasks() + for task_key, stage in stale_tasks: + session_id, subtask_id = task_key + try: + logger.warning( + "Subtask[session_id: %s, subtask_id: %s] is timeout at stage %s", + session_id, + subtask_id, + stage, + ) + except Exception as e: + logger.error(e) + + self._check_task = self.ref().check_subtasks.tell_delay( + delay=self._refresh_time + ) + + async def get_all_stale_tasks(self): cur_timestamp = time.time() - stale_tasks_keys = [] + stale_tasks = [] for k, v in self._records.items(): + pre_timestamp, cur_stage = v["history"][-1][0], v["history"][-1][1] if ( - cur_timestamp - v["history"][-1][0] >= timeout - and v["history"][-1][1] == status + self._kill_timeout[cur_stage] is not None + and cur_timestamp - pre_timestamp >= self._kill_timeout[cur_stage] ): - stale_tasks_keys.append(k) - return stale_tasks_keys + stale_tasks.append((k, cur_stage)) + return stale_tasks + + async def register_subtask(self, subtask: Subtask, supervisor_address: str): + keys = (subtask.session_id, subtask.subtask_id) + self._records[keys] = { + "subtask": subtask, + "history": [], + "supervisor_address": supervisor_address, + } + + async def report_stage(self, keys: Tuple[str, str], stage: SubtaskStage): + if stage == SubtaskStage.FINISH: + self._records.pop(keys) + return + self._records[keys]["history"].append((time.time(), stage)) async def get_records(self): return self._records @@ -404,6 +454,9 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str): ) try: logger.debug("Preparing data for subtask %s", subtask.subtask_id) + await self._stat_monitor_ref.report_stage( + (subtask.session_id, subtask.subtask_id), SubtaskStage.PREPARE_DATA + ) with Timer() as timer: prepare_data_task = asyncio.create_task( _retry_run( @@ -414,9 +467,7 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str): band_name, ) ) - await self._stat_monitor_ref.report_stage( - (subtask.session_id, subtask.subtask_id), SubtaskStage.PREPARE_DATA - ) + await asyncio.wait_for( prepare_data_task, timeout=self._data_prepare_timeout ) @@ -446,9 +497,6 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str): except: # noqa: E722 # pylint: disable=bare-except _fill_subtask_result_with_exception(subtask, subtask_info) finally: - await self._stat_monitor_ref.report_stage( - (subtask.session_id, subtask.subtask_id), SubtaskStage.RELEASE_SLOT - ) # make sure new slot usages are uploaded in time try: slot_manager_ref = await self._get_slot_manager_ref(band_name) @@ -487,13 +535,14 @@ async def _run_subtask_once(): subtask_info.slot_id = slot_id self._check_cancelling(subtask_info) + await self._stat_monitor_ref.report_stage( + (subtask.session_id, subtask.subtask_id), SubtaskStage.EXECUTE + ) subtask_info.result.status = SubtaskStatus.running aiotask = asyncio.create_task( subtask_api.run_subtask_in_slot(band_name, slot_id, subtask) ) - await self._stat_monitor_ref.report_stage( - (subtask.session_id, subtask.subtask_id), SubtaskStage.EXECUTE - ) + return await asyncio.shield(aiotask) except asyncio.CancelledError as ex: try: @@ -554,6 +603,10 @@ async def _run_subtask_once(): await slot_manager_ref.release_free_slot( slot_id, (subtask.session_id, subtask.subtask_id) ) + await self._stat_monitor_ref.report_stage( + (subtask.session_id, subtask.subtask_id), + SubtaskStage.RELEASE_SLOT, + ) logger.debug( "Released slot %d for subtask %s", slot_id, subtask.subtask_id ) @@ -593,9 +646,9 @@ async def run_subtask( logger.debug( "Start to schedule subtask %s on %s.", subtask.subtask_id, self.address ) - await self._stat_monitor_ref.report_stage( - (subtask.session_id, subtask.subtask_id), "subtask start" - ) + + await self._stat_monitor_ref.register_subtask(subtask, supervisor_address) + self._submitted_subtask_count.record(1, {"band": self.address}) with mo.debug.no_message_trace(): task = asyncio.create_task( diff --git a/python/xorbits/_mars/services/scheduling/worker/service.py b/python/xorbits/_mars/services/scheduling/worker/service.py index 1049bc69b..0d12fab05 100644 --- a/python/xorbits/_mars/services/scheduling/worker/service.py +++ b/python/xorbits/_mars/services/scheduling/worker/service.py @@ -64,6 +64,7 @@ async def start(self): await mo.create_actor( StageMonitorActor, + monitoring_config=scheduling_config.get("stage_monitor", {}), uid=StageMonitorActor.default_uid(), address=address, ) diff --git a/python/xorbits/_mars/services/scheduling/worker/tests/test_execution.py b/python/xorbits/_mars/services/scheduling/worker/tests/test_execution.py index 26742cae0..e174598c7 100644 --- a/python/xorbits/_mars/services/scheduling/worker/tests/test_execution.py +++ b/python/xorbits/_mars/services/scheduling/worker/tests/test_execution.py @@ -50,7 +50,7 @@ from ....session import MockSessionAPI from ....storage import MockStorageAPI from ....storage.handler import StorageHandlerActor -from ....subtask import MockSubtaskAPI, Subtask, SubtaskStage, SubtaskStatus +from ....subtask import MockSubtaskAPI, Subtask, SubtaskStatus from ....task.supervisor.manager import TaskManagerActor from ....task.task_info_collector import TaskInfoCollectorActor from ...supervisor import GlobalResourceManagerActor @@ -163,7 +163,7 @@ def collect_task_info_enabled(self): @pytest.fixture async def actor_pool(request): - n_slots, enable_kill = request.param + n_slots, enable_kill, enable_stage_check = request.param pool = await create_actor_pool( "127.0.0.1", labels=[None] + ["numa-0"] * n_slots, n_process=n_slots ) @@ -189,6 +189,12 @@ async def actor_pool(request): # create monitor actor monitor_ref = await mo.create_actor( StageMonitorActor, + monitoring_config={ + "enable_check": True, + "execution_timeout": 5, + } + if enable_stage_check + else {}, uid=StageMonitorActor.default_uid(), address=pool.external_address, ) @@ -253,7 +259,7 @@ async def actor_pool(request): @pytest.mark.asyncio -@pytest.mark.parametrize("actor_pool", [(1, True)], indirect=True) +@pytest.mark.parametrize("actor_pool", [(1, True, False)], indirect=True) async def test_execute_tensor(actor_pool): pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool @@ -334,7 +340,7 @@ async def test_execute_tensor(actor_pool): @pytest.mark.asyncio @pytest.mark.parametrize( "actor_pool,cancel_phase", - [((1, True), phase) for phase in _cancel_phases], + [((1, True, False), phase) for phase in _cancel_phases], indirect=["actor_pool"], ) async def test_execute_with_cancel(actor_pool, cancel_phase): @@ -438,7 +444,7 @@ def delay_fun(delay, _inp1): @pytest.mark.asyncio -@pytest.mark.parametrize("actor_pool", [(1, True)], indirect=True) +@pytest.mark.parametrize("actor_pool", [(1, True, False)], indirect=True) async def test_execute_with_pure_deps(actor_pool): pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool @@ -519,7 +525,7 @@ def test_estimate_size(): @pytest.mark.asyncio -@pytest.mark.parametrize("actor_pool", [(1, False)], indirect=True) +@pytest.mark.parametrize("actor_pool", [(1, False, False)], indirect=True) async def test_cancel_without_kill(actor_pool): pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool executed_file = os.path.join( @@ -625,8 +631,8 @@ def test_fetch_data_from_both_cpu_and_gpu(data_type, chunked, setup_gpu): @pytest.mark.asyncio -@pytest.mark.parametrize("actor_pool", [(1, True)], indirect=True) -async def test_status_monitor_actor(actor_pool): +@pytest.mark.parametrize("actor_pool", [(1, True, False)], indirect=True) +async def test_stage_monitor_actor(actor_pool): pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool subtask_id = f"test_subtask_{uuid.uuid4()}" subtask = Subtask( @@ -642,10 +648,48 @@ async def test_status_monitor_actor(actor_pool): await asyncio.wait_for( execution_ref.run_subtask(subtask, "numa-0", pool.external_address), timeout=30 ) - for stage in SubtaskStage: - stale_tasks = await monitor_ref.get_stale_tasks(stage) - assert len(stale_tasks) == 0 + + stale_tasks = await monitor_ref.get_all_stale_tasks() + assert len(stale_tasks) == 0 # task has been finished records = await monitor_ref.get_records() assert len(records) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("actor_pool", [(1, True, True)], indirect=True) +async def test_terminate_stale_tasks(actor_pool, caplog): + pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool + + def delay_fun(delay): + time.sleep(delay) + return delay + + remote_result = RemoteFunction( + function=delay_fun, function_args=[10], function_kwargs={} + ).new_chunk([]) + chunk_graph = ChunkGraph([remote_result]) + chunk_graph.add_node(remote_result) + + subtask = Subtask( + f"test_subtask_{uuid.uuid4()}", + session_id=session_id, + task_id=f"test_task_{uuid.uuid4()}", + chunk_graph=chunk_graph, + ) + + with Timer() as timer: + aiotask = asyncio.create_task( + execution_ref.run_subtask(subtask, "numa-0", pool.external_address) + ) + + r = await asyncio.wait_for(aiotask, timeout=20) + assert r.status == SubtaskStatus.succeeded + + assert 5 < timer.duration < 20 + + import re + + match = re.search(r"Subtask\[.*?\].*stage.*", caplog.text) + assert match is not None