Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Shared LongPollClient for Routers #48807

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import random
from asyncio import sleep
from asyncio.events import AbstractEventLoop
from collections import defaultdict
from collections.abc import Mapping
Expand Down Expand Up @@ -78,8 +79,8 @@ def __init__(
host_actor,
key_listeners: Dict[KeyType, UpdateStateCallable],
call_in_event_loop: AbstractEventLoop,
only_once: bool = False,
) -> None:
assert len(key_listeners) > 0
# We used to allow this to be optional, but due to Ray Client issue
# we now enforce all long poll client to post callback to event loop
# See https://github.com/ray-project/ray/issues/20971
Expand All @@ -88,6 +89,8 @@ def __init__(
self.host_actor = host_actor
self.key_listeners = key_listeners
self.event_loop = call_in_event_loop
self.only_once = only_once

self.snapshot_ids: Dict[KeyType, int] = {
# The initial snapshot id for each key is < 0,
# but real snapshot keys in the long poll host are always >= 0,
Expand All @@ -99,6 +102,23 @@ def __init__(

self._poll_next()

def add_key_listeners(
self, key_listeners: Dict[KeyType, UpdateStateCallable]
) -> None:
"""Add more key listeners to the client.
The new listeners will only be included in the *next* long poll request;
the current request will continue with the existing listeners.

If a key is already in the client, the new listener will replace the old one,
but the snapshot ID will be preserved, so the new listener will only be called
on the *next* update to that key.
"""
# Only initialize snapshot ids for *new* keys.
self.snapshot_ids.update(
{key: -1 for key in key_listeners.keys() if key not in self.key_listeners}
)
self.key_listeners.update(key_listeners)

def _on_callback_completed(self, trigger_at: int):
"""Called after a single callback is completed.

Expand All @@ -108,7 +128,7 @@ def _on_callback_completed(self, trigger_at: int):
way to serialize the callback invocations between object versions.
"""
self._callbacks_processed_count += 1
if self._callbacks_processed_count == trigger_at:
if not self.only_once and self._callbacks_processed_count == trigger_at:
self._poll_next()

def _poll_next(self):
Expand Down Expand Up @@ -162,6 +182,8 @@ def _process_update(self, updates: Dict[str, UpdatedObject]):
f"{list(updates.keys())}.",
extra={"log_to_stderr": False},
)
if not updates: # no updates, no callbacks to run, just poll again
self._schedule_to_event_loop(self._poll_next)
for key, update in updates.items():
self.snapshot_ids[key] = update.snapshot_id
callback = self.key_listeners[key]
Expand Down Expand Up @@ -246,10 +268,20 @@ async def listen_for_change(
) -> Union[LongPollState, Dict[KeyType, UpdatedObject]]:
"""Listen for changed objects.

This method will returns a dictionary of updated objects. It returns
This method will return a dictionary of updated objects. It returns
immediately if the snapshot_ids are outdated, otherwise it will block
until there's an update.
"""
# If there are no keys to listen for,
# just wait for a short time to provide backpressure,
# then return an empty update.
if not keys_to_snapshot_ids:
await sleep(1)

updated_objects = {}
self._count_send(updated_objects)
return updated_objects

# If there are any keys with outdated snapshot ids,
# return their updated values immediately.
updated_objects = {}
Expand Down
86 changes: 85 additions & 1 deletion python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import threading
import time
import uuid
import weakref
from abc import ABC, abstractmethod
from asyncio import AbstractEventLoop
from collections import defaultdict
from collections.abc import MutableMapping
from contextlib import contextmanager
from functools import partial
from functools import lru_cache, partial
from typing import Any, Coroutine, DefaultDict, Dict, List, Optional, Tuple, Union

import ray
Expand Down Expand Up @@ -398,6 +401,16 @@ def __init__(
),
)

# The Router needs to stay informed about changes to the target deployment's
# running replicas and deployment config. We do this via the long poll system.
# However, for efficiency, we don't want to create a LongPollClient for every
# DeploymentHandle, so we use a shared LongPollClient that all Routers
# register themselves with. But first, the router needs to get a fast initial
# update so that it can start serving requests, which we do with a
# LongPollClient that is told to run only once. This client gets the
# first update quickly, and then future updates are handled
# by the SharedRouterLongPollClient.

self.long_poll_client = LongPollClient(
controller_handle,
{
Expand All @@ -411,7 +424,13 @@ def __init__(
): self.update_deployment_config,
},
call_in_event_loop=self._event_loop,
only_once=True,
)

shared = SharedRouterLongPollClient.get_or_create(
controller_handle, self._event_loop
)
shared.register(self)

def running_replicas_populated(self) -> bool:
return self._running_replicas_populated
Expand Down Expand Up @@ -684,3 +703,68 @@ def shutdown(self):
asyncio.run_coroutine_threadsafe(
self._asyncio_router.shutdown(), loop=self._asyncio_loop
).result()


class SharedRouterLongPollClient:
def __init__(self, controller_handle: ActorHandle, event_loop: AbstractEventLoop):
self.controller_handler = controller_handle

# We use a WeakSet to store the Routers so that we don't prevent them
# from being garbage-collected.
self.routers: MutableMapping[
DeploymentID, weakref.WeakSet[Router]
] = defaultdict(weakref.WeakSet)

# Creating the LongPollClient implicitly starts it
self.long_poll_client = LongPollClient(
controller_handle,
key_listeners={},
call_in_event_loop=event_loop,
)

@classmethod
@lru_cache(maxsize=None)
def get_or_create(
cls, controller_handle: ActorHandle, event_loop: AbstractEventLoop
) -> "SharedRouterLongPollClient":
shared = cls(controller_handle=controller_handle, event_loop=event_loop)
logger.info(f"Started {shared}.")
return shared

def update_running_replicas(
self, running_replicas: List[RunningReplicaInfo], deployment_id: DeploymentID
) -> None:
for router in self.routers[deployment_id]:
router.update_running_replicas(running_replicas)

def update_deployment_config(
self, deployment_config: DeploymentConfig, deployment_id: DeploymentID
) -> None:
for router in self.routers[deployment_id]:
router.update_deployment_config(deployment_config)

def register(self, router: Router) -> None:
self.routers[router.deployment_id].add(router)

# Remove the entries for any deployment ids that no longer have any routers.
# The WeakSets will automatically lose track of Routers that get GC'd,
# but the outer dict will keep the key around, so we need to clean up manually.
# Note the list(...) to avoid mutating self.routers while iterating over it.
for deployment_id, routers in list(self.routers.items()):
if not routers:
self.routers.pop(deployment_id)

# Register the new listeners on the long poll client.
# Some of these listeners may already exist, but it's safe to add them again.
key_listeners = {
(LongPollNamespace.RUNNING_REPLICAS, deployment_id): partial(
self.update_running_replicas, deployment_id=deployment_id
)
for deployment_id in self.routers.keys()
} | {
(LongPollNamespace.DEPLOYMENT_CONFIG, deployment_id): partial(
self.update_deployment_config, deployment_id=deployment_id
)
for deployment_id in self.routers.keys()
}
self.long_poll_client.add_key_listeners(key_listeners)