Skip to content

Commit

Permalink
[serve] use linear backoff for the proxy ready check (ray-project#39738)
Browse files Browse the repository at this point in the history
There are some core issue caused the proxy to taking a long time to start (longer than the default 5s). Adding logics to do linear backoffs to ensure longer timeouts after more failures and allow the proxy to start eventually
  • Loading branch information
GeneDer authored Oct 12, 2023
1 parent 9075c1c commit 01e786f
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 12 deletions.
44 changes: 33 additions & 11 deletions python/ray/serve/_private/proxy_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
import random
import time
import traceback
from abc import ABC, abstractmethod
from enum import Enum
Expand All @@ -25,7 +24,7 @@
SERVE_PROXY_NAME,
)
from ray.serve._private.proxy import ProxyActor
from ray.serve._private.utils import format_actor_name
from ray.serve._private.utils import Timer, TimerBase, format_actor_name
from ray.serve.config import DeploymentMode, HTTPOptions, gRPCOptions
from ray.serve.schema import ProxyDetails
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
Expand Down Expand Up @@ -318,15 +317,19 @@ def __init__(
actor_name: str,
node_id: str,
node_ip: str,
proxy_restart_count: int = 0,
timer: TimerBase = Timer(),
):
self._actor_proxy_wrapper = actor_proxy_wrapper
self._actor_proxy_wrapper.start_new_ready_check()
self._actor_name = actor_name
self._node_id = node_id
self._status = ProxyStatus.STARTING
self._last_health_check_time: float = time.time()
self._timer = timer
self._last_health_check_time: float = self._timer.time()
self._shutting_down = False
self._consecutive_health_check_failures: int = 0
self._proxy_restart_count = proxy_restart_count
self._last_drain_check_time: float = None

self._actor_details = ProxyDetails(
Expand All @@ -353,6 +356,10 @@ def status(self) -> ProxyStatus:
def actor_details(self) -> ProxyDetails:
return self._actor_details

@property
def proxy_restart_count(self) -> int:
return self._proxy_restart_count

def set_status(self, status: ProxyStatus) -> None:
"""Sets _status and updates _actor_details with the new status."""
self._status = status
Expand Down Expand Up @@ -412,7 +419,7 @@ def _health_check(self):
elif healthy_call_status == ProxyWrapperCallStatus.FINISHED_FAILED:
self.try_update_status(ProxyStatus.UNHEALTHY)
elif (
time.time() - self._last_health_check_time
self._timer.time() - self._last_health_check_time
> PROXY_HEALTH_CHECK_TIMEOUT_S
):
# Health check hasn't returned and the timeout is up, consider it
Expand All @@ -432,8 +439,8 @@ def _health_check(self):
if self._actor_proxy_wrapper.health_check_ongoing:
return
randomized_period_s = PROXY_HEALTH_CHECK_PERIOD_S * random.uniform(0.9, 1.1)
if time.time() - self._last_health_check_time > randomized_period_s:
self._last_health_check_time = time.time()
if self._timer.time() - self._last_health_check_time > randomized_period_s:
self._last_health_check_time = self._timer.time()
self._actor_proxy_wrapper.start_new_health_check()

def _drain_check(self):
Expand All @@ -447,8 +454,11 @@ def _drain_check(self):
self.set_status(ProxyStatus.DRAINED)
except Exception as e:
logger.warning(f"Drain check for proxy {self._actor_name} failed: {e}.")
elif time.time() - self._last_drain_check_time > PROXY_DRAIN_CHECK_PERIOD_S:
self._last_drain_check_time = time.time()
elif (
self._timer.time() - self._last_drain_check_time
> PROXY_DRAIN_CHECK_PERIOD_S
):
self._last_drain_check_time = self._timer.time()
self._actor_proxy_wrapper.start_new_drained_check()

def update(self, draining: bool = False):
Expand Down Expand Up @@ -486,7 +496,10 @@ def update(self, draining: bool = False):
):
return

ready_check_timeout = PROXY_READY_CHECK_TIMEOUT_S
# Doing a linear backoff for the ready check timeout.
ready_check_timeout = (
self.proxy_restart_count + 1
) * PROXY_READY_CHECK_TIMEOUT_S
if self._status == ProxyStatus.STARTING:
try:
ready_call_status = self._actor_proxy_wrapper.is_ready()
Expand All @@ -503,7 +516,10 @@ def update(self, draining: bool = False):
"Unexpected actor death when checking readiness of "
f"proxy on node {self._node_id}:\n{traceback.format_exc()}"
)
elif time.time() - self._last_health_check_time > ready_check_timeout:
elif (
self._timer.time() - self._last_health_check_time
> ready_check_timeout
):
# Ready check hasn't returned and the timeout is up, consider it
# failed.
self.set_status(ProxyStatus.UNHEALTHY)
Expand Down Expand Up @@ -533,7 +549,7 @@ def update(self, draining: bool = False):
self._actor_proxy_wrapper.update_draining(draining=True)
assert self._actor_proxy_wrapper.is_draining is False
assert self._last_drain_check_time is None
self._last_drain_check_time = time.time()
self._last_drain_check_time = self._timer.time()

if (self._status == ProxyStatus.DRAINING) and not draining:
logger.info(f"Stop draining the proxy actor on node {self._node_id}")
Expand Down Expand Up @@ -577,6 +593,7 @@ def __init__(
grpc_options: Optional[gRPCOptions] = None,
proxy_actor_class: Type[ProxyActor] = ProxyActor,
actor_proxy_wrapper_class: Type[ProxyWrapper] = ActorProxyWrapper,
timer: TimerBase = Timer(),
):
self._controller_name = controller_name
if config is not None:
Expand All @@ -585,9 +602,11 @@ def __init__(
self._config = HTTPOptions()
self._grpc_options = grpc_options or gRPCOptions()
self._proxy_states: Dict[NodeId, ProxyState] = dict()
self._proxy_restart_counts: Dict[NodeId, int] = dict()
self._head_node_id: str = head_node_id
self._proxy_actor_class = proxy_actor_class
self._actor_proxy_wrapper_class = actor_proxy_wrapper_class
self._timer = timer

self._cluster_node_info_cache = cluster_node_info_cache

Expand Down Expand Up @@ -750,6 +769,8 @@ def _start_proxies_if_needed(self, target_nodes) -> None:
actor_name=name,
node_id=node_id,
node_ip=node_ip_address,
proxy_restart_count=self._proxy_restart_counts.get(node_id, 0),
timer=self._timer,
)

def _stop_proxies_if_needed(self) -> bool:
Expand All @@ -775,4 +796,5 @@ def _stop_proxies_if_needed(self) -> bool:

for node_id in to_stop:
proxy_state = self._proxy_states.pop(node_id)
self._proxy_restart_counts[node_id] = proxy_state.proxy_restart_count + 1
proxy_state.shutdown()
13 changes: 13 additions & 0 deletions python/ray/serve/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import threading
import time
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
Expand Down Expand Up @@ -634,3 +635,15 @@ def is_running_in_asyncio_loop() -> bool:
return True
except RuntimeError:
return False


class TimerBase(ABC):
@abstractmethod
def time(self) -> float:
"""Return the current time."""
raise NotImplementedError


class Timer(TimerBase):
def time(self) -> float:
return time.time()
3 changes: 2 additions & 1 deletion python/ray/serve/tests/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from ray.serve._private.constants import SERVE_NAMESPACE
from ray.serve._private.proxy import DRAINED_MESSAGE
from ray.serve._private.usage import ServeUsageTag
from ray.serve._private.utils import TimerBase
from ray.serve.generated import serve_pb2, serve_pb2_grpc

TELEMETRY_ROUTE_PREFIX = "/telemetry"
STORAGE_ACTOR_NAME = "storage"


class MockTimer:
class MockTimer(TimerBase):
def __init__(self, start_time=None):
if start_time is None:
start_time = time.time()
Expand Down
82 changes: 82 additions & 0 deletions python/ray/serve/tests/test_proxy_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
ProxyWrapper,
ProxyWrapperCallStatus,
)
from ray.serve._private.utils import Timer
from ray.serve.config import DeploymentMode, HTTPOptions
from ray.serve.tests.common.utils import MockTimer

HEAD_NODE_ID = "node_id-index-head"

Expand Down Expand Up @@ -102,6 +104,7 @@ def _create_proxy_state_manager(
head_node_id: str = HEAD_NODE_ID,
cluster_node_info_cache=MockClusterNodeInfoCache(),
actor_proxy_wrapper_class=FakeProxyWrapper,
timer=Timer(),
) -> (ProxyStateManager, ClusterNodeInfoCache):
return (
ProxyStateManager(
Expand All @@ -110,6 +113,7 @@ def _create_proxy_state_manager(
head_node_id=head_node_id,
cluster_node_info_cache=cluster_node_info_cache,
actor_proxy_wrapper_class=actor_proxy_wrapper_class,
timer=timer,
),
cluster_node_info_cache,
)
Expand All @@ -119,13 +123,15 @@ def _create_proxy_state(
actor_proxy_wrapper_class=FakeProxyWrapper,
status: ProxyStatus = ProxyStatus.STARTING,
node_id: str = "mock_node_id",
timer=Timer(),
**kwargs,
) -> ProxyState:
state = ProxyState(
actor_proxy_wrapper=actor_proxy_wrapper_class(),
actor_name="alice",
node_id=node_id,
node_ip="mock_node_ip",
timer=timer,
)
state.set_status(status=status)
return state
Expand Down Expand Up @@ -824,6 +830,82 @@ def check_is_ready_for_shutdown():
wait_for_condition(check_is_ready_for_shutdown)


@patch("ray.serve._private.proxy_state.PROXY_READY_CHECK_TIMEOUT_S", 0.1)
@pytest.mark.parametrize("number_of_worker_nodes", [1])
def test_proxy_starting_timeout_longer_than_env(number_of_worker_nodes, all_nodes):
"""Test update method on ProxyStateManager when the proxy state is STARTING and
when the ready call takes longer than PROXY_READY_CHECK_TIMEOUT_S.
The proxy state started with STARTING. After update is called, ready calls takes
some time to finish. The proxy state manager will restart the proxy state after
PROXY_READY_CHECK_TIMEOUT_S. After the next period of check_health call,
the proxy state manager will check on backoff timeout, not immediately
restarting the proxy states, and eventually set the proxy state to HEALTHY.
"""
fake_time = MockTimer()
proxy_state_manager, cluster_node_info_cache = _create_proxy_state_manager(
http_options=HTTPOptions(location=DeploymentMode.EveryNode),
timer=fake_time,
)
cluster_node_info_cache.alive_nodes = all_nodes

node_ids = {node[0] for node in all_nodes}

# Run update to create proxy states.
proxy_state_manager.update(proxy_nodes=node_ids)
old_proxy_states = {
node_id: state for node_id, state in proxy_state_manager._proxy_states.items()
}

# Ensure 2 proxies are created, one for the head node and another for the worker.
assert len(proxy_state_manager._proxy_states) == len(node_ids)

# Ensure the proxy state statuses before update are STARTING. Also, setting the
# ready call status to be pending to simulate the call never respond.
def check_proxy_state_starting(_proxy_state_manager: ProxyStateManager):
for proxy_state in _proxy_state_manager._proxy_states.values():
assert proxy_state.status == ProxyStatus.STARTING
proxy_state._actor_proxy_wrapper.ready = ProxyWrapperCallStatus.PENDING

# Trigger update and ensure proxy states are restarted due to time advanced
# longer than PROXY_READY_CHECK_TIMEOUT_S of 0.1s.
check_proxy_state_starting(_proxy_state_manager=proxy_state_manager)
fake_time.advance(0.11)
proxy_state_manager.update(proxy_nodes=node_ids)
assert all(
[
proxy_state_manager._proxy_states[node_id] != old_proxy_states[node_id]
for node_id in node_ids
]
)

# Trigger another update with the same advance time and this time the
# proxy states should not be restarted again with the backoff timeout.
old_proxy_states = {
node_id: state for node_id, state in proxy_state_manager._proxy_states.items()
}
check_proxy_state_starting(_proxy_state_manager=proxy_state_manager)
fake_time.advance(0.11)
proxy_state_manager.update(proxy_nodes=node_ids)
assert all(
[
proxy_state_manager._proxy_states[node_id] == old_proxy_states[node_id]
for node_id in node_ids
]
)

# Ensure the proxy states turns healthy after the ready call is unblocked.
for proxy_state in proxy_state_manager._proxy_states.values():
proxy_state._actor_proxy_wrapper.ready = ProxyWrapperCallStatus.FINISHED_SUCCEED
proxy_state_manager.update(proxy_nodes=node_ids)
assert all(
[
proxy_state_manager._proxy_states[node_id].status == ProxyStatus.HEALTHY
for node_id in node_ids
]
)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 01e786f

Please sign in to comment.