diff --git a/python/xoscar/backends/communication/ucx.py b/python/xoscar/backends/communication/ucx.py index 09a17c4e..cc66f7fb 100644 --- a/python/xoscar/backends/communication/ucx.py +++ b/python/xoscar/backends/communication/ucx.py @@ -312,8 +312,7 @@ async def send_buffers(self, buffers: list, meta: Optional[_MessageBase] = None) for buf in meta_buffers: await self.ucp_endpoint.send(buf) for buffer in buffers: - if buffer.nbytes if hasattr(buffer, "nbytes") else len(buffer) > 0: - await self.ucp_endpoint.send(buffer) + await self.ucp_endpoint.send(buffer) except ucp.exceptions.UCXBaseException: # pragma: no cover self.abort() raise ChannelClosed("While writing, the connection was closed") diff --git a/python/xoscar/collective/__init__.py b/python/xoscar/collective/__init__.py index 37f6558d..c1c9f82e 100644 --- a/python/xoscar/collective/__init__.py +++ b/python/xoscar/collective/__init__.py @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .core import ( + RankActor, + allgather, + allreduce, + alltoall, + broadcast, + gather, + init_process_group, + new_group, + reduce, + reduce_scatter, + scatter, +) diff --git a/python/xoscar/collective/common.py b/python/xoscar/collective/common.py new file mode 100644 index 00000000..77ce1409 --- /dev/null +++ b/python/xoscar/collective/common.py @@ -0,0 +1,72 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import IntEnum +from typing import Dict, Type + +import numpy as np + +from . import xoscar_pygloo as xp + +ReduceOpMappingGloo: Dict["CollectiveReduceOp", "xp.ReduceOp"] = {} +AllReduceAlgorithmMappingGloo: Dict["AllReduceAlgorithm", "xp.AllreduceAlgorithm"] = {} + + +def _register_reduce_op(reduce_op): + for op_type in reduce_op: + ReduceOpMappingGloo[op_type] = xp.ReduceOp(op_type) + return reduce_op + + +def _register_allreduce_algo(algorithms): + for algo in algorithms: + AllReduceAlgorithmMappingGloo[algo] = xp.AllreduceAlgorithm(algo) + return algorithms + + +@_register_reduce_op +class CollectiveReduceOp(IntEnum): + SUM = 0 + PRODUCT = 1 + MIN = 2 + MAX = 3 + BAND = 4 + BOR = 5 + BXOR = 6 + UNUSED = 7 + + +@_register_allreduce_algo +class AllReduceAlgorithm(IntEnum): + UNSPECIFIED = 0 + RING = 1 + BCUBE = 2 + + +TypeMappingGloo: Dict[Type[np.dtype], "xp.GlooDataType_t"] = { + np.int8: xp.GlooDataType_t.glooInt8, + np.uint8: xp.GlooDataType_t.glooUint8, + np.int32: xp.GlooDataType_t.glooInt32, + np.uint32: xp.GlooDataType_t.glooUint32, + np.int64: xp.GlooDataType_t.glooInt64, + np.uint64: xp.GlooDataType_t.glooUint64, + np.float16: xp.GlooDataType_t.glooFloat16, + np.float32: xp.GlooDataType_t.glooFloat32, + np.float64: xp.GlooDataType_t.glooFloat64, +} + +# Some static variables +INVOKE_ERROR_MESSAGE = "Collective-related functions must be called in a process that is involved in collection communication." +RANK_ADDRESS_ENV_KEY = "COLLECTIVE_RANK_ADDRESS" +RENDEZVOUS_MASTER_IP_ENV_KEY = "COLLECTIVE_MASTER_IP" +RENDEZVOUS_MASTER_PORT_ENV_KEY = "COLLECTIVE_MASTER_PORT" diff --git a/python/xoscar/collective/core.py b/python/xoscar/collective/core.py new file mode 100644 index 00000000..2d180778 --- /dev/null +++ b/python/xoscar/collective/core.py @@ -0,0 +1,323 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from .. import Actor, actor_ref +from ..context import get_context +from .common import ( + INVOKE_ERROR_MESSAGE, + RANK_ADDRESS_ENV_KEY, + AllReduceAlgorithm, + CollectiveReduceOp, +) +from .process_group import ProcessGroup, ProcessGroupGloo +from .utils import get_rank_address_via_env + + +class RankActor(Actor): + def __init__( + self, + rank: int, + world: int, + backend: str = "gloo", + pg_options: Optional[ProcessGroup.Options] = None, + *args, + **kwargs, + ): + self._rank = rank + self._world = world + self._backend = backend + self.name_to_pg: Dict[str, Dict[str, "ProcessGroup"]] = defaultdict(dict) + self._pg_options = pg_options + + @classmethod + def default_uid(cls): + return "RankActor" + + async def __post_create__(self): + os.environ[RANK_ADDRESS_ENV_KEY] = self.address + _ip = self._get_ip() + if self._backend == "gloo": + pg = ProcessGroupGloo( + _ip, + self._rank, + self._world, + group_name="default", + pg_options=self._pg_options, + ) + self.name_to_pg["gloo"]["default"] = pg + else: + raise NotImplementedError("Not impl other backends for now!") + + def process_group(self, pg_name: str) -> ProcessGroup: + return self.name_to_pg[self._backend][pg_name] + + def rank(self) -> int: + return self._rank + + def world(self) -> int: + return self._world + + def backend(self) -> str: + return self._backend + + def _get_ip(self) -> str: + return self.address.split(":")[0] + + def _process_group_name(self, ranks: List[int]) -> str: + return hashlib.sha1( + bytes(self._backend + "_".join(map(str, ranks)), "utf-8") + ).hexdigest() + + def new_group( + self, ranks: List[int], pg_options: Optional[ProcessGroup.Options] = None + ) -> Optional[str]: + assert ( + len(ranks) <= self._world + ), "``ranks`` in new_group cannot be larger than the world." + assert all( + [self._world > rank >= 0 for rank in ranks] + ), "rank in ``ranks`` is illegal." + assert len({rank for rank in ranks}) == len( + ranks + ), "there can be no duplicate ranks in the ``ranks``." + if self._rank not in ranks: + return None + if len(ranks) == self._world: + return "default" + global_ranks = sorted(ranks) + group_rank = global_ranks.index(self._rank) + group_world = len(global_ranks) + group_name = self._process_group_name(global_ranks) + if group_name in self.name_to_pg[self._backend]: + return group_name + _ip = self._get_ip() + pg = ProcessGroupGloo( + _ip, group_rank, group_world, group_name=group_name, pg_options=pg_options + ) + self.name_to_pg[self._backend][group_name] = pg + return group_name + + def reduce( + self, + send_data: Any, + recv_data: Any, + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + root: Optional[int] = 0, + tag: Optional[int] = 0, + pg_name: str = "default", + ): + self.name_to_pg[self._backend][pg_name].reduce( + send_data, recv_data, op=op, root=root, tag=tag + ) + + def allreduce( + self, + send_data: Any, + recv_data: Any, + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + algorithm: AllReduceAlgorithm = AllReduceAlgorithm.RING, + tag: Optional[int] = 0, + pg_name: str = "default", + ): + self.name_to_pg[self._backend][pg_name].allreduce( + send_data, recv_data, op=op, algorithm=algorithm, tag=tag + ) + + def gather( + self, + send_data: Any, + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + pg_name: str = "default", + ): + self.name_to_pg[self._backend][pg_name].gather( + send_data, recv_data, root=root, tag=tag + ) + + def allgather( + self, + send_data: Any, + recv_data: Any, + tag: Optional[int] = 0, + pg_name: str = "default", + ): + self.name_to_pg[self._backend][pg_name].allgather(send_data, recv_data, tag=tag) + + def scatter( + self, + send_data: List[Any], + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + pg_name: str = "default", + ): + self.name_to_pg[self._backend][pg_name].scatter( + send_data, recv_data, root=root, tag=tag + ) + + def reduce_scatter( + self, + send_data: Any, + recv_data: Any, + recv_elems: List[int], + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + pg_name: str = "default", + ): # pragma: no cover + self.name_to_pg[self._backend][pg_name].reduce_scatter( + send_data, recv_data, recv_elems, op + ) + + def alltoall( + self, + send_data: Any, + recv_data: Any, + tag: Optional[int] = 0, + pg_name: str = "default", + ): + self.name_to_pg[self._backend][pg_name].alltoall(send_data, recv_data, tag=tag) + + def broadcast( + self, + send_data: Any, + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + pg_name: str = "default", + ): + self.name_to_pg[self._backend][pg_name].broadcast( + send_data, recv_data, root, tag=tag + ) + + +async def init_process_group( + rank: int, world_size: int, backend: str = "gloo", address: Optional[str] = None +): + address = address or os.environ.get(RANK_ADDRESS_ENV_KEY, None) + if address is None: + raise RuntimeError( + "Cannot decide which process to involve in the collective communication." + ) + ctx = get_context() + await ctx.create_actor( + RankActor, rank, world_size, backend=backend, address=address, uid="RankActor" + ) + + +async def new_group( + ranks: List[int], pg_options: Optional[ProcessGroup.Options] = None +): + address = os.environ.get(RANK_ADDRESS_ENV_KEY, None) + if address is None: + raise RuntimeError(INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + return await ref.new_group(ranks, pg_options) + + +async def reduce( + send_data: Any, + recv_data: Any, + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + root: Optional[int] = 0, + tag: Optional[int] = 0, + group_name: str = "default", +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + await ref.reduce( + send_data, recv_data, op=op, root=root, tag=tag, pg_name=group_name + ) + + +async def allreduce( + send_data: Any, + recv_data: Any, + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + algorithm: AllReduceAlgorithm = AllReduceAlgorithm.RING, + tag: Optional[int] = 0, + group_name: str = "default", +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid="RankActor") + await ref.allreduce( + send_data, recv_data, op=op, algorithm=algorithm, tag=tag, pg_name=group_name + ) + + +async def gather( + send_data: Any, + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + group_name: str = "default", +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + await ref.gather(send_data, recv_data, root=root, tag=tag, pg_name=group_name) + + +async def allgather( + send_data: Any, recv_data: Any, tag: Optional[int] = 0, group_name: str = "default" +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + await ref.allgather(send_data, recv_data, tag=tag, pg_name=group_name) + + +async def scatter( + send_data: List[Any], + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + group_name: str = "default", +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + await ref.scatter(send_data, recv_data, root=root, tag=tag, pg_name=group_name) + + +async def reduce_scatter( + send_data: Any, + recv_data: Any, + recv_elems: List[int], + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + group_name: str = "default", +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + await ref.reduce_scatter(send_data, recv_data, recv_elems, op, pg_name=group_name) + + +async def alltoall( + send_data: Any, recv_data: Any, tag: Optional[int] = 0, group_name: str = "default" +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + await ref.alltoall(send_data, recv_data, tag=tag, pg_name=group_name) + + +async def broadcast( + send_data: Any, + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + group_name: str = "default", +): + address = get_rank_address_via_env(RANK_ADDRESS_ENV_KEY, INVOKE_ERROR_MESSAGE) + ref = await actor_ref(address=address, uid=f"RankActor") + await ref.broadcast(send_data, recv_data, root, tag, pg_name=group_name) diff --git a/python/xoscar/collective/process_group.py b/python/xoscar/collective/process_group.py new file mode 100644 index 00000000..c208b030 --- /dev/null +++ b/python/xoscar/collective/process_group.py @@ -0,0 +1,323 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from ..utils import is_linux +from . import xoscar_pygloo as xp +from .common import ( + RENDEZVOUS_MASTER_IP_ENV_KEY, + RENDEZVOUS_MASTER_PORT_ENV_KEY, + AllReduceAlgorithm, + AllReduceAlgorithmMappingGloo, + CollectiveReduceOp, + ReduceOpMappingGloo, + TypeMappingGloo, +) +from .utils import convert_data_to_np_array + + +class _World: + def __init__(self): + self._store = None + self._device = None + + @property + def store(self): + return self._store + + @property + def device(self): + return self._device + + @store.setter # type: ignore + def store(self, store): + self._store = store + + @device.setter # type: ignore + def device(self, device): + self._device = device + + +_world = _World() + + +class ProcessGroup(ABC): + class Options: + master_ip: Optional[str] = None + master_port: Optional[int] = None + + def __init__( + self, + rank: int, + world_size: int, + group_name: Optional[str] = None, + pg_options: Optional[Options] = None, + ): + self._rank = rank + self._world_size = world_size + self._group_name = group_name + self._option = pg_options + + @property + def rank(self): + return self._rank + + @property + def world_size(self): + return self._world_size + + @property + def name(self): + return self._group_name + + @property + def options(self): + return self._option + + @abstractmethod + def allreduce(self, *args, **kwargs): + """All reduce function""" + + @abstractmethod + def reduce(self, *args, **kwargs): + """Reduce function""" + + @abstractmethod + def allgather(self, *args, **kwargs): + """All gather function""" + + @abstractmethod + def gather(self, *args, **kwargs): + """Gather function""" + + @abstractmethod + def scatter(self, *args, **kwargs): + """Scatter function""" + + @abstractmethod + def reduce_scatter(self, *args, **kwargs): + """Reduce scatter function""" + + @abstractmethod + def alltoall(self, *args, **kwargs): + """All to all function""" + + @abstractmethod + def broadcast(self, *args, **kwargs): + """Broadcast function""" + + +class ProcessGroupGloo(ProcessGroup): + def __init__( + self, + ip: str, + rank: int, + world_size: int, + group_name: Optional[str] = None, + pg_options: Optional[ProcessGroup.Options] = None, + ): + super().__init__(rank, world_size, group_name, pg_options) + if _world.store is None: + master_ip = ( + pg_options.master_ip + if pg_options is not None + else os.environ.get(RENDEZVOUS_MASTER_IP_ENV_KEY, None) + ) + master_port = ( + pg_options.master_port + if pg_options is not None + else os.environ.get(RENDEZVOUS_MASTER_PORT_ENV_KEY, None) + ) + if master_ip is None or master_port is None: + raise ValueError("Cannot find master ip or port for rendezvous") + + opt = xp.rendezvous.TCPStoreOptions() + opt.port = int(master_port) + opt.numWorkers = world_size + opt.isServer = rank == 0 + + store = xp.rendezvous.TCPStore(master_ip, opt) + if not is_linux(): + attr = xp.transport.uv.attr(ip) # type: ignore + dev = xp.transport.uv.CreateDevice(attr) # type: ignore + else: + attr = xp.transport.tcp.attr(ip) + dev = xp.transport.tcp.CreateDevice(attr) # type: ignore + _world.store = store # type: ignore + _world.device = dev # type: ignore + else: + store = _world.store + dev = _world.device + + prefix_store = xp.rendezvous.PrefixStore(group_name or str(world_size), store) # type: ignore + context = xp.rendezvous.Context(rank, world_size) + context.connectFullMesh(prefix_store, dev) + self._context = context + + def reduce( + self, + send_data: Any, + recv_data: Any, + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + root: Optional[int] = 0, + tag: Optional[int] = 0, + ): + send_buf = convert_data_to_np_array(send_data) + recv_buf = convert_data_to_np_array(recv_data) + size = send_buf.size + dtype = send_buf.dtype + sendptr = send_buf.ctypes.data + recvptr = recv_buf.ctypes.data + gloo_type = TypeMappingGloo[dtype.type] + xp.reduce( + self._context, + sendptr, + recvptr, + size, + gloo_type, + ReduceOpMappingGloo[op], + root, + tag, + ) + + def allreduce( + self, + send_data: Any, + recv_data: Any, + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + algorithm: AllReduceAlgorithm = AllReduceAlgorithm.RING, + tag: Optional[int] = 0, + ): + send_buf = convert_data_to_np_array(send_data) + recv_buf = convert_data_to_np_array(recv_data) + size = send_buf.size + dtype = send_buf.dtype + sendptr = send_buf.ctypes.data + recvptr = recv_buf.ctypes.data + gloo_type = TypeMappingGloo[dtype.type] + xp.allreduce( + self._context, + sendptr, + recvptr, + size, + gloo_type, + ReduceOpMappingGloo[op], + AllReduceAlgorithmMappingGloo[algorithm], + tag, # type: ignore + ) + + def gather( + self, + send_data: Any, + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + ): + send_buf = convert_data_to_np_array(send_data) + recv_buf = convert_data_to_np_array(recv_data) + size = send_buf.size + dtype = send_buf.dtype + sendptr = send_buf.ctypes.data + recvptr = recv_buf.ctypes.data + gloo_type = TypeMappingGloo[dtype.type] + xp.gather(self._context, sendptr, recvptr, size, gloo_type, root, tag) + + def allgather(self, send_data: Any, recv_data: Any, tag: Optional[int] = 0): + send_buf = convert_data_to_np_array(send_data) + recv_buf = convert_data_to_np_array(recv_data) + size = send_buf.size + dtype = send_buf.dtype + sendptr = send_buf.ctypes.data + recvptr = recv_buf.ctypes.data + gloo_type = TypeMappingGloo[dtype.type] + xp.allgather(self._context, sendptr, recvptr, size, gloo_type, tag) + + def scatter( + self, + send_data: List[Any], + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + ): + send_bufs = [convert_data_to_np_array(d) for d in send_data] + recv_buf = convert_data_to_np_array(recv_data) + size = sum([d.size for d in send_bufs]) + dtype = recv_buf.dtype + sendptrs = [d.ctypes.data for d in send_bufs] + recvptr = recv_buf.ctypes.data + gloo_type = TypeMappingGloo[dtype.type] + xp.scatter(self._context, sendptrs, recvptr, size, gloo_type, root, tag) # type: ignore + + def reduce_scatter( + self, + send_data: Any, + recv_data: Any, + recv_elems: List[int], + op: CollectiveReduceOp = CollectiveReduceOp.SUM, + ): # pragma: no cover + send_buf = convert_data_to_np_array(send_data) + recv_buf = convert_data_to_np_array(recv_data) + sendptr = send_buf.ctypes.data + recvptr = recv_buf.ctypes.data + size = send_buf.size + dtype = send_buf.dtype + gloo_type = TypeMappingGloo[dtype.type] + xp.reduce_scatter( + self._context, + sendptr, + recvptr, + size, + recv_elems, + gloo_type, + ReduceOpMappingGloo[op], + ) + + def alltoall(self, send_data: Any, recv_data: Any, tag: Optional[int] = 0): + send_buf = convert_data_to_np_array(send_data) + recv_buf = convert_data_to_np_array(recv_data) + size = send_buf.size + dtype = send_buf.dtype + sendptr = send_buf.ctypes.data + recvptr = recv_buf.ctypes.data + gloo_type = TypeMappingGloo[dtype.type] + xp.all_to_all(self._context, sendptr, recvptr, size, gloo_type, tag) + + def broadcast( + self, + send_data: Any, + recv_data: Any, + root: Optional[int] = 0, + tag: Optional[int] = 0, + ): + if send_data is not None: + send_buf = convert_data_to_np_array(send_data) + sendptr = send_buf.ctypes.data + else: + sendptr = None + recv_buf = convert_data_to_np_array(recv_data) + size = recv_buf.size + dtype = recv_buf.dtype + recvptr = recv_buf.ctypes.data + gloo_type = TypeMappingGloo[dtype.type] + xp.broadcast( + self._context, + recvptr if sendptr is None else sendptr, + recvptr, + size, + gloo_type, + root, + tag, + ) diff --git a/python/xoscar/collective/tests/test_core.py b/python/xoscar/collective/tests/test_core.py new file mode 100644 index 00000000..ec3c8adf --- /dev/null +++ b/python/xoscar/collective/tests/test_core.py @@ -0,0 +1,262 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os + +import numpy as np +import pytest + +from ... import Actor, ActorRefType, actor_ref, create_actor_pool, get_pool_config +from ...context import get_context +from ...tests.core import require_unix +from ...utils import is_linux +from ..common import ( + RANK_ADDRESS_ENV_KEY, + RENDEZVOUS_MASTER_IP_ENV_KEY, + RENDEZVOUS_MASTER_PORT_ENV_KEY, +) +from ..core import ( + RankActor, + allgather, + allreduce, + alltoall, + broadcast, + gather, + init_process_group, + new_group, + reduce, + reduce_scatter, + scatter, +) +from ..process_group import ProcessGroup + + +class WorkerActor(Actor): + def __init__(self, rank, world, *args, **kwargs): + self._rank = rank + self._world = world + + async def init_process_group(self): + os.environ[RANK_ADDRESS_ENV_KEY] = self.address + return await init_process_group(self._rank, self._world) + + async def init_process_group_without_env(self): + with pytest.raises(RuntimeError): + await init_process_group(self._rank, self._world) + + async def test_params(self): + rank_ref: ActorRefType[RankActor] = await actor_ref( + address=self.address, uid="RankActor" + ) + uid = rank_ref.uid + assert uid == bytes(RankActor.default_uid(), "utf-8") + + rank = await rank_ref.rank() + assert rank == self._rank + + world = await rank_ref.world() + assert world == self._world + + backend = await rank_ref.backend() + assert backend == "gloo" + + pg: ProcessGroup = await rank_ref.process_group("default") + assert pg is not None + + assert pg.rank == self._rank + assert pg.name == "default" + assert pg.world_size == self._world + assert pg.options is None + + async def test_reduce(self): + sendbuf = np.array([1, 2, 3, 4], dtype=np.int32) + recvbuf = np.zeros((4,), dtype=np.int32) + _group = [0, 1, 2] + group = await new_group(_group) + root = 1 + if group is not None: + await reduce(sendbuf, recvbuf, group_name=group, root=root) + + if self._rank == _group[root]: + np.testing.assert_array_equal(recvbuf, sendbuf * 3) + else: + np.testing.assert_array_equal(recvbuf, np.zeros_like(sendbuf)) + + async def test_allreduce(self): + sendbuf = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.int32) + recvbuf = np.zeros_like(sendbuf) + _group = [0, 1] + group = await new_group(_group) + if group is not None: + await allreduce(sendbuf, recvbuf, group_name=group) + if self._rank in _group: + np.testing.assert_array_equal(recvbuf, sendbuf * len(_group)) + else: + np.testing.assert_array_equal(recvbuf, np.zeros_like(sendbuf)) + + async def test_gather(self): + sendbuf = np.array([self._rank], dtype=np.int32) + recvbuf = np.zeros((2,), dtype=np.int32) + root = 0 + _group = [1, 2] + group = await new_group(_group) + if group is not None: + await gather(sendbuf, recvbuf, group_name=group, root=root) + + if self._rank == _group[root]: + np.testing.assert_array_equal(recvbuf, np.array(_group, dtype=np.int32)) + else: + np.testing.assert_array_equal(recvbuf, np.zeros_like(recvbuf)) + + async def test_allgather(self): + sendbuf = np.array([self._rank], dtype=np.int32) + recvbuf = np.zeros((3,), dtype=np.int32) + _group = [0, 1, 2] + group = await new_group(_group) + if group is not None: + await allgather(sendbuf, recvbuf, group_name=group) + if self._rank in _group: + np.testing.assert_array_equal(recvbuf, np.array(_group, dtype=np.int32)) + else: + np.testing.assert_array_equal(recvbuf, np.zeros_like(recvbuf)) + + async def test_scatter(self): + _group = [1, 2] + root = 0 + if self._rank == _group[root]: + sendbuf1 = np.array([10, 11], dtype=np.int32) + sendbuf2 = np.array([12, 13], dtype=np.int32) + else: + sendbuf1 = np.zeros((2,), dtype=np.int32) + sendbuf2 = np.zeros((2,), dtype=np.int32) + recvbuf = np.zeros((2,), dtype=np.int32) + send_list = [sendbuf1, sendbuf2] + group = await new_group(_group) + if group is not None: + await scatter(send_list, recvbuf, group_name=group, root=root) + + if self._rank == _group[0]: + np.testing.assert_array_equal(recvbuf, np.array([10, 11], dtype=np.int32)) + elif self._rank == _group[1]: + np.testing.assert_array_equal(recvbuf, np.array([12, 13], dtype=np.int32)) + else: + np.testing.assert_array_equal(recvbuf, np.zeros_like(recvbuf)) + + async def test_reduce_scatter(self): + data = [self._rank, self._rank + 1, self._rank + 2] + sendbuf = np.array(data, dtype=np.int32) + recvbuf = np.zeros((1,), dtype=np.int32) + recv_elems = [1, 1, 1] + group = await new_group([0, 1, 2]) + if group is not None: + await reduce_scatter(sendbuf, recvbuf, recv_elems, group_name=group) + np.testing.assert_array_equal(recvbuf, np.array([sum(data)], dtype=np.int32)) + + async def test_alltoall(self): + sendbuf = np.zeros((3,), dtype=np.float32) + self._rank + recvbuf = np.zeros(sendbuf.shape, dtype=np.float32) + group = await new_group([0, 1, 2]) + if group is not None: + await alltoall(sendbuf, recvbuf, group_name=group) + np.testing.assert_array_equal(recvbuf, np.array([0, 1, 2], dtype=np.float32)) + + async def test_broadcast(self): + root = 1 + _group = [0, 1, 2] + sendbuf = np.zeros((2, 3), dtype=np.int64) + if self._rank == _group[root]: + sendbuf = sendbuf + self._rank + recvbuf = np.zeros_like(sendbuf, dtype=np.int64) + group = await new_group(_group) + if group is not None: + await broadcast(sendbuf, recvbuf, root=root, group_name=group) + np.testing.assert_array_equal(recvbuf, np.zeros_like(recvbuf) + _group[root]) + + async def test_collective_np(self): + await self.test_params() + await self.test_reduce() + await self.test_allreduce() + await self.test_gather() + await self.test_allgather() + await self.test_scatter() + # reduce_scatter has problem on non-linux os since uv has issue in gloo + if is_linux(): + await self.test_reduce_scatter() + await self.test_alltoall() + await self.test_broadcast() + + +@pytest.mark.asyncio +@require_unix +async def test_collective(): + pool = await create_actor_pool( + "127.0.0.1", + n_process=3, + envs=[ + { + RENDEZVOUS_MASTER_IP_ENV_KEY: "127.0.0.1", + RENDEZVOUS_MASTER_PORT_ENV_KEY: "25001", + } + ] + * 3, + ) + main_addr = pool.external_address + config = (await get_pool_config(pool.external_address)).as_dict() + all_addrs = list(config["mapping"].keys()) + all_addrs.remove(main_addr) + + async with pool: + ctx = get_context() + r0 = await ctx.create_actor(WorkerActor, 0, 3, address=all_addrs[0]) + r1 = await ctx.create_actor(WorkerActor, 1, 3, address=all_addrs[1]) + r2 = await ctx.create_actor(WorkerActor, 2, 3, address=all_addrs[2]) + t0 = r0.init_process_group() + t1 = r1.init_process_group() + t2 = r2.init_process_group() + await asyncio.gather(*[t0, t1, t2]) + + t0 = r0.test_collective_np() + t1 = r1.test_collective_np() + t2 = r2.test_collective_np() + await asyncio.gather(*[t0, t1, t2]) + + +@pytest.mark.asyncio +@require_unix +async def test_collective_without_env(): + pool = await create_actor_pool( + "127.0.0.1", + n_process=3, + ) + main_addr = pool.external_address + config = (await get_pool_config(pool.external_address)).as_dict() + all_addrs = list(config["mapping"].keys()) + all_addrs.remove(main_addr) + + async with pool: + ctx = get_context() + r0 = await ctx.create_actor(WorkerActor, 0, 3, address=all_addrs[0]) + r1 = await ctx.create_actor(WorkerActor, 1, 3, address=all_addrs[1]) + r2 = await ctx.create_actor(WorkerActor, 2, 3, address=all_addrs[2]) + t0 = r0.init_process_group_without_env() + t1 = r1.init_process_group_without_env() + t2 = r2.init_process_group_without_env() + await asyncio.gather(*[t0, t1, t2]) + + t0 = r0.init_process_group() + t1 = r1.init_process_group() + t2 = r2.init_process_group() + with pytest.raises(ValueError): + await asyncio.gather(*[t0, t1, t2]) diff --git a/python/xoscar/collective/utils.py b/python/xoscar/collective/utils.py new file mode 100644 index 00000000..646feb8c --- /dev/null +++ b/python/xoscar/collective/utils.py @@ -0,0 +1,30 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import numpy as np + + +def convert_data_to_np_array(data): + if isinstance(data, np.ndarray): + return data + else: + return np.frombuffer(data, dtype="u1") + + +def get_rank_address_via_env(env_key: str, err_message: str) -> str: + address = os.environ.get(env_key, None) + if address is None: + raise RuntimeError(err_message) + return address