Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor RPC connection in server-client mode #83

Merged
merged 10 commits into from
Sep 12, 2023
7 changes: 5 additions & 2 deletions graphlearn_torch/python/distributed/dist_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def init_client(num_servers: int, num_clients: int, client_rank: int,
master_addr: str, master_port: int, num_rpc_threads: int = 4,
client_group_name: Optional[str] = None):
client_group_name: Optional[str] = None, is_dynamic: bool = False):
r""" Initialize the current process as a client and establish connections
with all other servers and clients. Note that this method should be called
only in the server-client distribution mode.
Expand All @@ -44,11 +44,14 @@ def init_client(num_servers: int, num_clients: int, client_rank: int,
client_group_name (str): A unique name of the client group that current
process belongs to. If set to ``None``, a default name will be used.
(Default: ``None``).
is_dynamic (bool): Whether the world size is dynamic. (Default: ``False``).
"""
if client_group_name:
client_group_name = client_group_name.replace('-', '_')
_set_client_context(num_servers, num_clients, client_rank, client_group_name)
# Note that a client RPC agent will never remote requests, thus set the
# number of rpc threads to ``1`` is enough.
init_rpc(master_addr, master_port, num_rpc_threads=num_rpc_threads)
init_rpc(master_addr, master_port, num_rpc_threads=num_rpc_threads, is_dynamic=is_dynamic)


def shutdown_client():
Expand Down
42 changes: 36 additions & 6 deletions graphlearn_torch/python/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

from enum import Enum
from typing import Optional
from typing import Optional, List


class DistRole(Enum):
Expand Down Expand Up @@ -101,19 +101,24 @@ def num_clients(self) -> int:
def worker_name(self) -> str:
r""" Get worker name of the current process of this context.
"""
return f"{self.group_name}-{self.rank}"
return f"{self.group_name}_{self.rank}"


_dist_context: DistContext = None
r""" Distributed context on the current process.
"""

_clients_to_servers: dict = None
r""" A dict mapping from client rank to server ranks. int -> List[int]"""

def get_context() -> DistContext:
r""" Get distributed context info of the current process.
"""
return _dist_context

def get_clients_to_servers() -> dict:
r""" Get client to servers mapping.
"""
return _clients_to_servers

def _set_worker_context(world_size: int, rank: int,
group_name: Optional[str] = None):
Expand All @@ -132,11 +137,11 @@ def _set_worker_context(world_size: int, rank: int,
)


def _set_server_context(num_servers: int, num_clients: int, server_rank: int,
server_group_name: Optional[str] = None):
def _set_server_context(num_servers: int, server_rank: int,
server_group_name: Optional[str] = None, num_clients: int = 0):
r""" Set distributed context info as a server on the current process.
"""
assert num_servers > 0 and num_clients > 0
assert num_servers > 0
global _dist_context
_dist_context = DistContext(
role=DistRole.SERVER,
Expand Down Expand Up @@ -164,6 +169,31 @@ def _set_client_context(num_servers: int, num_clients: int, client_rank: int,
global_world_size=num_servers+num_clients,
global_rank=num_servers+client_rank
)
assign_server_by_order()

def assign_server_by_order():
r"""Assign servers to each client in turn.
e.g. 2 clients and 4 servers, then the assignment is: {0: [0, 1], 1: [2, 3]},
5 clients and 2 servers, then the assignment is: {0: [0], 1: [1], 2: [0], 3: [1], 4: [0]}."""
ctx = get_context()
assert ctx is not None and ctx.is_client()
client_num, server_num = ctx.world_size, ctx.global_world_size - ctx.world_size
global _clients_to_servers
_clients_to_servers = {}
cur_server = 0
for i in range(client_num):
if i not in _clients_to_servers:
_clients_to_servers[i] = []
for j in range(server_num // client_num):
_clients_to_servers[i].append(cur_server)
cur_server = (cur_server + 1) % server_num
if i < server_num % client_num:
_clients_to_servers[i].append(cur_server)
cur_server = (cur_server + 1) % server_num
if len(_clients_to_servers[i]) == 0:
_clients_to_servers[i].append(cur_server)
cur_server = (cur_server + 1) % server_num
return _clients_to_servers[ctx.rank]


def init_worker_group(world_size: int, rank: int,
Expand Down
15 changes: 9 additions & 6 deletions graphlearn_torch/python/distributed/dist_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..utils import assign_device

from .dist_context import DistContext
from .dist_context import DistContext, assign_server_by_order


class _BasicDistSamplingWorkerOptions(object):
Expand Down Expand Up @@ -208,8 +208,9 @@ class RemoteDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
produced by those remote sampling workers and consumed by the current process.

Args:
server_rank (int): The rank of server to launch sampling workers. If set
to ``None``, it will be automatically assigned. (default: ``None``).
server_rank (int or List[int], optional): The rank of server to launch
sampling workers, can be multiple. If set to ``None``, it will be
automatically assigned. (default: ``None``).
num_workers (int): How many workers to launch on the remote server for
distributed neighbor sampling of the current process. (default: ``1``).
worker_devices (torch.device or List[torch.device], optional): List of
Expand All @@ -231,7 +232,7 @@ class RemoteDistSamplingWorkerOptions(_BasicDistSamplingWorkerOptions):
the client side. (default: ``4``).
"""
def __init__(self,
server_rank: Optional[int] = None,
server_rank: Optional[Union[int, List[int]]] = None,
num_workers: int = 1,
worker_devices: Optional[List[torch.device]] = None,
worker_concurrency: int = 4,
Expand All @@ -244,8 +245,10 @@ def __init__(self,
worker_key: str = None):
super().__init__(num_workers, worker_devices, worker_concurrency,
master_addr, master_port, num_rpc_threads, rpc_timeout)

self.server_rank = server_rank
if server_rank is not None:
Jia-zb marked this conversation as resolved.
Show resolved Hide resolved
self.server_rank = server_rank
else:
self.server_rank = assign_server_by_order()
self.buffer_capacity = self.num_workers * self.worker_concurrency
if buffer_size is None:
self.buffer_size = f'{self.num_workers * 64}MB'
Expand Down
16 changes: 10 additions & 6 deletions graphlearn_torch/python/distributed/dist_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,16 @@ def get_server() -> DistServer:
return _dist_server


def init_server(num_servers: int, num_clients: int, server_rank: int,
dataset: DistDataset, master_addr: str, master_port: int,
def init_server(num_servers: int, server_rank: int, dataset: DistDataset,
master_addr: str, master_port: int, num_clients: int = 0,
num_rpc_threads: int = 16, request_timeout: int = 180,
server_group_name: Optional[str] = None,):
server_group_name: Optional[str] = None, is_dynamic: bool = False):
r""" Initialize the current process as a server and establish connections
with all other servers and clients. Note that this method should be called
only in the server-client distribution mode.

Args:
num_servers (int): Number of processes participating in the server group.
num_clients (int): Number of processes participating in the client group.
server_rank (int): Rank of the current process withing the server group (it
should be a number between 0 and ``num_servers``-1).
dataset (DistDataset): The ``DistDataset`` object of a partition of graph
Expand All @@ -198,18 +197,23 @@ def init_server(num_servers: int, num_clients: int, server_rank: int,
master_port (int): The master TCP port for RPC connection between all
servers and clients, the value of this parameter should be same for all
servers and clients.
num_clients (int): Number of processes participating in the client group.
if ``is_dynamic`` is ``True``, this parameter will be ignored.
num_rpc_threads (int): The number of RPC worker threads used for the
current server to respond remote requests. (Default: ``16``).
request_timeout (int): The max timeout seconds for remote requests,
otherwise an exception will be raised. (Default: ``16``).
server_group_name (str): A unique name of the server group that current
process belongs to. If set to ``None``, a default name will be used.
(Default: ``None``).
is_dynamic (bool): Whether the world size is dynamic. (Default: ``False``).
"""
_set_server_context(num_servers, num_clients, server_rank, server_group_name)
if server_group_name:
server_group_name = server_group_name.replace('-', '_')
_set_server_context(num_servers, server_rank, server_group_name, num_clients)
global _dist_server
_dist_server = DistServer(dataset=dataset)
init_rpc(master_addr, master_port, num_rpc_threads, request_timeout)
init_rpc(master_addr, master_port, num_rpc_threads, request_timeout, is_dynamic=is_dynamic)


def wait_and_shutdown_server():
Expand Down
71 changes: 64 additions & 7 deletions graphlearn_torch/python/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from .dist_context import DistRole, get_context

SERVER_INIT_CHECK_INTERVAL = 3.0
MAX_RETYR_TIMES = 60

_rpc_init_lock = threading.RLock()

Expand All @@ -37,8 +39,9 @@
r""" Dict from role type to all rpc worker names in this role group.
"""

_rpc_current_group_worker_names: Set[str] = set()
r""" Set of rpc worker names in the current role group.
_rpc_current_group_worker_names: Set[str] = None
r""" Set of rpc worker names in the current role group. Used in all_gather
in a role group.
"""

_rpc_master_addr: str = None
Expand Down Expand Up @@ -134,7 +137,7 @@ def _role_based_broadcast_to_followers(sequence_id, objects_map):
def all_gather(obj, timeout=None):
r""" Gathers objects only from the current role group in a list. This
function blocks until all workers in the current role group have received
the gathered results. The implementation of this methid is refer to
the gathered results. The implementation of this method is refer to
``torch.distributed.rpc.api._all_gather``.
"""
assert (
Expand Down Expand Up @@ -237,7 +240,8 @@ def global_barrier(timeout=None):
def init_rpc(master_addr: str,
master_port: int,
num_rpc_threads: int = 16,
rpc_timeout: float = 180):
rpc_timeout: float = 180,
is_dynamic: bool = False):
r""" Initialize rpc on the current process.
"""
with _rpc_init_lock:
Expand All @@ -257,19 +261,65 @@ def init_rpc(master_addr: str,
rpc_timeout=rpc_timeout,
init_method=f'tcp://{master_addr}:{master_port}'
)

rpc.init_rpc(
name=ctx.worker_name,
rank=ctx.global_rank,
world_size=ctx.global_world_size,
world_size=None if is_dynamic else ctx.global_world_size,
rpc_backend_options=options
)

global _rpc_inited
_rpc_inited = True

global _rpc_current_group_worker_names
global _rpc_worker_names
_rpc_worker_names = {}

if is_dynamic:
_rpc_worker_names[DistRole.SERVER] = []
_rpc_worker_names[DistRole.CLIENT] = []

if ctx.is_server():
# ensure all servers is inited
for server_rank in range(ctx.world_size):
if server_rank == ctx.rank:
_rpc_worker_names[DistRole.SERVER].append(ctx.group_name + '_' + str(server_rank))
continue
times = 0
is_avail = False
while not is_avail:
try:
is_avail = rpc_global_request_by_rank(server_rank, rpc.is_available)
except:
time.sleep(SERVER_INIT_CHECK_INTERVAL)
logging.info(f"RETRY {times}: server {ctx.rank} waits server {server_rank}...")
times += 1
if times >= MAX_RETYR_TIMES:
raise RuntimeError(f"TIMEOUT: server {ctx.rank} waits server {server_rank} timeout. "
f"Check if server {server_rank} is ready.")
_rpc_worker_names[DistRole.SERVER].append(ctx.group_name + '_' + str(server_rank))
_rpc_current_group_worker_names = set(_rpc_worker_names[DistRole.SERVER])
return
if ctx.is_client():
for server_rank in range(ctx.global_rank - ctx.rank):
times = 0
is_avail = False
while not is_avail:
try:
is_avail = rpc_global_request_by_rank(server_rank, rpc.is_available)
except:
time.sleep(SERVER_INIT_CHECK_INTERVAL)
logging.info(f"RETRY {times}: client {ctx.rank} waits server {server_rank}...")
times += 1
if times >= MAX_RETYR_TIMES:
raise RuntimeError(f"TIMEOUT: client {ctx.rank} waits server {server_rank} timeout. "
f"Check if server {server_rank} is ready.")
server_name = rpc_global_request_by_rank(server_rank, rpc.get_worker_info).name
_rpc_worker_names[DistRole.SERVER].append(server_name)
_rpc_current_group_worker_names = set([ctx.group_name + '_' + str(client_rank) for client_rank in range(ctx.world_size)])
return

gathered_results = global_all_gather(
obj=(ctx.role, ctx.world_size, ctx.rank), timeout=rpc_timeout
)
Expand All @@ -287,7 +337,6 @@ def init_rpc(master_addr: str,
worker_list[role_rank] = worker_name
_rpc_worker_names[role] = worker_list

global _rpc_current_group_worker_names
_rpc_current_group_worker_names = set(_rpc_worker_names[ctx.role])

global_barrier(timeout=rpc_timeout)
Expand Down Expand Up @@ -469,3 +518,11 @@ def rpc_global_request(target_role: DistRole, role_rank: int,
"""
fut = rpc_global_request_async(target_role, role_rank, func, args, kwargs)
return fut.wait()

@_require_initialized
def rpc_global_request_by_rank(global_rank: int, func, args=None, kwargs=None):
r""" Perform a rpc request synchronously to other rpc worker by rank
and return the results.
"""
fut = rpc.rpc_async(global_rank, func, args, kwargs)
return fut.wait()
Loading
Loading