Skip to content

Commit

Permalink
ENH: Add a check whether scheduling is stuck (#695)
Browse files Browse the repository at this point in the history
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
Flying-Tom and mergify[bot] authored Sep 21, 2023
1 parent fe7caba commit bc0c60c
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 39 deletions.
9 changes: 9 additions & 0 deletions python/xorbits/_mars/deploy/oscar/base_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ scheduling:
# Max number of concurrent speculative run for a subtask.
max_concurrent_run: 3
subtask_cancel_timeout: 5
stage_monitor:
enable_check: false
refresh_time: 3
prepare_data_timeout: 300
request_quota_timeout: 300
acquire_slot_timeout: 300
execution_timeout: null
release_slot_timeout: 300
finish_timeout: 300
metrics:
backend: console
# If backend is prometheus, then we can add prometheus config as follows:
Expand Down
109 changes: 81 additions & 28 deletions python/xorbits/_mars/services/scheduling/worker/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,79 @@


class StageMonitorActor(mo.Actor):
def __init__(self):
def __init__(
self,
monitoring_config: Dict = {},
):
self._records = dict()

def report_stage(self, keys: Tuple[str, str], stage: SubtaskStage):
if keys not in self._records:
self._records[keys] = {
"history": [],
}
if stage == SubtaskStage.FINISH:
self._records.pop(keys)
return
self._records[keys]["history"].append((time.time(), stage))
self._enable_check = monitoring_config.get("enable_check", False)
self._refresh_time = monitoring_config.get("refresh_time", 3)
self._kill_timeout = {
SubtaskStage.PREPARE_DATA: monitoring_config.get("prepare_data_timeout"),
SubtaskStage.REQUEST_QUOTA: monitoring_config.get("request_quota_timeout"),
SubtaskStage.ACQUIRE_SLOT: monitoring_config.get("acquire_slot_timeout"),
SubtaskStage.EXECUTE: monitoring_config.get("execution_timeout"),
SubtaskStage.RELEASE_SLOT: monitoring_config.get("release_slot_timeout"),
SubtaskStage.FINISH: monitoring_config.get("finish_timeout"),
}
self._check_task = None

async def __post_create__(self):
await super().__post_create__()
if self._enable_check:
self._check_task = self.ref().check_subtasks.tell_delay(
delay=self._refresh_time
)

async def get_stale_tasks(self, status: SubtaskStage, timeout: int = 5):
async def __pre_destroy__(self):
if self._enable_check:
self._check_task.cancel()
await super().__pre_destroy__()

async def check_subtasks(self):
stale_tasks = await self.get_all_stale_tasks()
for task_key, stage in stale_tasks:
session_id, subtask_id = task_key
try:
logger.warning(
"Subtask[session_id: %s, subtask_id: %s] is timeout at stage %s",
session_id,
subtask_id,
stage,
)
except Exception as e:
logger.error(e)

self._check_task = self.ref().check_subtasks.tell_delay(
delay=self._refresh_time
)

async def get_all_stale_tasks(self):
cur_timestamp = time.time()
stale_tasks_keys = []
stale_tasks = []
for k, v in self._records.items():
pre_timestamp, cur_stage = v["history"][-1][0], v["history"][-1][1]
if (
cur_timestamp - v["history"][-1][0] >= timeout
and v["history"][-1][1] == status
self._kill_timeout[cur_stage] is not None
and cur_timestamp - pre_timestamp >= self._kill_timeout[cur_stage]
):
stale_tasks_keys.append(k)
return stale_tasks_keys
stale_tasks.append((k, cur_stage))
return stale_tasks

async def register_subtask(self, subtask: Subtask, supervisor_address: str):
keys = (subtask.session_id, subtask.subtask_id)
self._records[keys] = {
"subtask": subtask,
"history": [],
"supervisor_address": supervisor_address,
}

async def report_stage(self, keys: Tuple[str, str], stage: SubtaskStage):
if stage == SubtaskStage.FINISH:
self._records.pop(keys)
return
self._records[keys]["history"].append((time.time(), stage))

async def get_records(self):
return self._records
Expand Down Expand Up @@ -404,6 +454,9 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str):
)
try:
logger.debug("Preparing data for subtask %s", subtask.subtask_id)
await self._stat_monitor_ref.report_stage(
(subtask.session_id, subtask.subtask_id), SubtaskStage.PREPARE_DATA
)
with Timer() as timer:
prepare_data_task = asyncio.create_task(
_retry_run(
Expand All @@ -414,9 +467,7 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str):
band_name,
)
)
await self._stat_monitor_ref.report_stage(
(subtask.session_id, subtask.subtask_id), SubtaskStage.PREPARE_DATA
)

await asyncio.wait_for(
prepare_data_task, timeout=self._data_prepare_timeout
)
Expand Down Expand Up @@ -446,9 +497,6 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str):
except: # noqa: E722 # pylint: disable=bare-except
_fill_subtask_result_with_exception(subtask, subtask_info)
finally:
await self._stat_monitor_ref.report_stage(
(subtask.session_id, subtask.subtask_id), SubtaskStage.RELEASE_SLOT
)
# make sure new slot usages are uploaded in time
try:
slot_manager_ref = await self._get_slot_manager_ref(band_name)
Expand Down Expand Up @@ -487,13 +535,14 @@ async def _run_subtask_once():
subtask_info.slot_id = slot_id
self._check_cancelling(subtask_info)

await self._stat_monitor_ref.report_stage(
(subtask.session_id, subtask.subtask_id), SubtaskStage.EXECUTE
)
subtask_info.result.status = SubtaskStatus.running
aiotask = asyncio.create_task(
subtask_api.run_subtask_in_slot(band_name, slot_id, subtask)
)
await self._stat_monitor_ref.report_stage(
(subtask.session_id, subtask.subtask_id), SubtaskStage.EXECUTE
)

return await asyncio.shield(aiotask)
except asyncio.CancelledError as ex:
try:
Expand Down Expand Up @@ -554,6 +603,10 @@ async def _run_subtask_once():
await slot_manager_ref.release_free_slot(
slot_id, (subtask.session_id, subtask.subtask_id)
)
await self._stat_monitor_ref.report_stage(
(subtask.session_id, subtask.subtask_id),
SubtaskStage.RELEASE_SLOT,
)
logger.debug(
"Released slot %d for subtask %s", slot_id, subtask.subtask_id
)
Expand Down Expand Up @@ -593,9 +646,9 @@ async def run_subtask(
logger.debug(
"Start to schedule subtask %s on %s.", subtask.subtask_id, self.address
)
await self._stat_monitor_ref.report_stage(
(subtask.session_id, subtask.subtask_id), "subtask start"
)

await self._stat_monitor_ref.register_subtask(subtask, supervisor_address)

self._submitted_subtask_count.record(1, {"band": self.address})
with mo.debug.no_message_trace():
task = asyncio.create_task(
Expand Down
1 change: 1 addition & 0 deletions python/xorbits/_mars/services/scheduling/worker/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def start(self):

await mo.create_actor(
StageMonitorActor,
monitoring_config=scheduling_config.get("stage_monitor", {}),
uid=StageMonitorActor.default_uid(),
address=address,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ....session import MockSessionAPI
from ....storage import MockStorageAPI
from ....storage.handler import StorageHandlerActor
from ....subtask import MockSubtaskAPI, Subtask, SubtaskStage, SubtaskStatus
from ....subtask import MockSubtaskAPI, Subtask, SubtaskStatus
from ....task.supervisor.manager import TaskManagerActor
from ....task.task_info_collector import TaskInfoCollectorActor
from ...supervisor import GlobalResourceManagerActor
Expand Down Expand Up @@ -163,7 +163,7 @@ def collect_task_info_enabled(self):

@pytest.fixture
async def actor_pool(request):
n_slots, enable_kill = request.param
n_slots, enable_kill, enable_stage_check = request.param
pool = await create_actor_pool(
"127.0.0.1", labels=[None] + ["numa-0"] * n_slots, n_process=n_slots
)
Expand All @@ -189,6 +189,12 @@ async def actor_pool(request):
# create monitor actor
monitor_ref = await mo.create_actor(
StageMonitorActor,
monitoring_config={
"enable_check": True,
"execution_timeout": 5,
}
if enable_stage_check
else {},
uid=StageMonitorActor.default_uid(),
address=pool.external_address,
)
Expand Down Expand Up @@ -253,7 +259,7 @@ async def actor_pool(request):


@pytest.mark.asyncio
@pytest.mark.parametrize("actor_pool", [(1, True)], indirect=True)
@pytest.mark.parametrize("actor_pool", [(1, True, False)], indirect=True)
async def test_execute_tensor(actor_pool):
pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool

Expand Down Expand Up @@ -334,7 +340,7 @@ async def test_execute_tensor(actor_pool):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"actor_pool,cancel_phase",
[((1, True), phase) for phase in _cancel_phases],
[((1, True, False), phase) for phase in _cancel_phases],
indirect=["actor_pool"],
)
async def test_execute_with_cancel(actor_pool, cancel_phase):
Expand Down Expand Up @@ -438,7 +444,7 @@ def delay_fun(delay, _inp1):


@pytest.mark.asyncio
@pytest.mark.parametrize("actor_pool", [(1, True)], indirect=True)
@pytest.mark.parametrize("actor_pool", [(1, True, False)], indirect=True)
async def test_execute_with_pure_deps(actor_pool):
pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool

Expand Down Expand Up @@ -519,7 +525,7 @@ def test_estimate_size():


@pytest.mark.asyncio
@pytest.mark.parametrize("actor_pool", [(1, False)], indirect=True)
@pytest.mark.parametrize("actor_pool", [(1, False, False)], indirect=True)
async def test_cancel_without_kill(actor_pool):
pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool
executed_file = os.path.join(
Expand Down Expand Up @@ -625,8 +631,8 @@ def test_fetch_data_from_both_cpu_and_gpu(data_type, chunked, setup_gpu):


@pytest.mark.asyncio
@pytest.mark.parametrize("actor_pool", [(1, True)], indirect=True)
async def test_status_monitor_actor(actor_pool):
@pytest.mark.parametrize("actor_pool", [(1, True, False)], indirect=True)
async def test_stage_monitor_actor(actor_pool):
pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool
subtask_id = f"test_subtask_{uuid.uuid4()}"
subtask = Subtask(
Expand All @@ -642,10 +648,48 @@ async def test_status_monitor_actor(actor_pool):
await asyncio.wait_for(
execution_ref.run_subtask(subtask, "numa-0", pool.external_address), timeout=30
)
for stage in SubtaskStage:
stale_tasks = await monitor_ref.get_stale_tasks(stage)
assert len(stale_tasks) == 0

stale_tasks = await monitor_ref.get_all_stale_tasks()
assert len(stale_tasks) == 0

# task has been finished
records = await monitor_ref.get_records()
assert len(records) == 0


@pytest.mark.asyncio
@pytest.mark.parametrize("actor_pool", [(1, True, True)], indirect=True)
async def test_terminate_stale_tasks(actor_pool, caplog):
pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref = actor_pool

def delay_fun(delay):
time.sleep(delay)
return delay

remote_result = RemoteFunction(
function=delay_fun, function_args=[10], function_kwargs={}
).new_chunk([])
chunk_graph = ChunkGraph([remote_result])
chunk_graph.add_node(remote_result)

subtask = Subtask(
f"test_subtask_{uuid.uuid4()}",
session_id=session_id,
task_id=f"test_task_{uuid.uuid4()}",
chunk_graph=chunk_graph,
)

with Timer() as timer:
aiotask = asyncio.create_task(
execution_ref.run_subtask(subtask, "numa-0", pool.external_address)
)

r = await asyncio.wait_for(aiotask, timeout=20)
assert r.status == SubtaskStatus.succeeded

assert 5 < timer.duration < 20

import re

match = re.search(r"Subtask\[.*?\].*stage.*", caplog.text)
assert match is not None

0 comments on commit bc0c60c

Please sign in to comment.