Skip to content

Commit

Permalink
Ensure that adaptive only stops once (dask#8807)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Jul 31, 2024
1 parent c44ad22 commit 564f28b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
2 changes: 0 additions & 2 deletions distributed/deploy/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def __init__(

self.target_duration = parse_timedelta(target_duration)

logger.info("Adaptive scaling started: minimum=%s maximum=%s", minimum, maximum)

super().__init__(
minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval
)
Expand Down
55 changes: 42 additions & 13 deletions distributed/deploy/adaptive_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict, deque
from collections.abc import Iterable
from datetime import timedelta
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Literal, cast

import tlz as toolz
from tornado.ioloop import IOLoop
Expand All @@ -17,12 +17,21 @@
from distributed.metrics import time

if TYPE_CHECKING:
from distributed.scheduler import WorkerState
from typing_extensions import TypeAlias

from distributed.scheduler import WorkerState

logger = logging.getLogger(__name__)


AdaptiveStateState: TypeAlias = Literal[
"starting",
"running",
"stopped",
"inactive",
]


class AdaptiveCore:
"""
The core logic for adaptive deployments, with none of the cluster details
Expand Down Expand Up @@ -89,6 +98,8 @@ class AdaptiveCore:
observed: set[WorkerState]
close_counts: defaultdict[WorkerState, int]
_adapting: bool
#: Whether this adaptive strategy is periodically adapting
_state: AdaptiveStateState
log: deque[tuple[float, dict]]

def __init__(
Expand All @@ -107,12 +118,6 @@ def __init__(
self.interval = parse_timedelta(interval, "seconds")
self.periodic_callback = None

def f():
try:
self.periodic_callback.start()
except AttributeError:
pass

if self.interval:
import weakref

Expand All @@ -124,8 +129,10 @@ async def _adapt():
await core.adapt()

self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000)
self.loop.add_callback(f)

self._state = "starting"
self.loop.add_callback(self._start)
else:
self._state = "inactive"
try:
self.plan = set()
self.requested = set()
Expand All @@ -140,12 +147,34 @@ async def _adapt():
maxlen=dask.config.get("distributed.admin.low-level-log-length")
)

def _start(self) -> None:
if self._state != "starting":
return

assert self.periodic_callback is not None
self.periodic_callback.start()
self._state = "running"
logger.info(
"Adaptive scaling started: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

def stop(self) -> None:
logger.info("Adaptive stop")
if self._state in ("inactive", "stopped"):
return

if self.periodic_callback:
if self._state == "running":
assert self.periodic_callback is not None
self.periodic_callback.stop()
self.periodic_callback = None
logger.info(
"Adaptive scaling stopped: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

self.periodic_callback = None
self._state = "stopped"

async def target(self) -> int:
"""The target number of workers that should exist"""
Expand Down
19 changes: 16 additions & 3 deletions distributed/deploy/tests/test_adaptive_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def safe_target(self):
raise OSError()

with captured_logger("distributed.deploy.adaptive_core") as log:
adapt = BadAdaptive(minimum=1, maximum=4)
await adapt.adapt()
adapt = BadAdaptive(minimum=1, maximum=4, interval="10ms")
while adapt._state != "stopped":
await asyncio.sleep(0.01)
text = log.getvalue()
assert "Adaptive stopping due to error" in text
assert "Adaptive stop" in text
assert "Adaptive scaling stopped" in text
assert not adapt._adapting
assert not adapt.periodic_callback

Expand Down Expand Up @@ -147,6 +148,18 @@ async def scale_down(self, workers=None):
adapt.stop()


@gen_test()
async def test_adaptive_logs_stopping_once():
with captured_logger("distributed.deploy.adaptive_core") as log:
adapt = MyAdaptive(interval="100ms")
while not adapt.periodic_callback.is_running():
await asyncio.sleep(0.01)
adapt.stop()
adapt.stop()
lines = log.getvalue().splitlines()
assert sum("Adaptive scaling stopped" in line for line in lines) == 1


@gen_test()
async def test_adapt_stop_del():
adapt = MyAdaptive(interval="100ms")
Expand Down

0 comments on commit 564f28b

Please sign in to comment.