diff --git a/t5x/partitioning.py b/t5x/partitioning.py index 910f666fd..4c37d0cbf 100644 --- a/t5x/partitioning.py +++ b/t5x/partitioning.py @@ -17,8 +17,9 @@ import abc import collections import dataclasses +import functools import typing -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union from absl import logging import cached_property @@ -289,9 +290,13 @@ def get_gpu_mesh(num_partitions: int) -> Mesh: return global_mesh -def default_mesh(num_partitions: int, - model_parallel_submesh: Optional[HardwareMesh] = None, - backend: Optional[str] = None) -> Mesh: +def default_mesh( + num_partitions: int, + model_parallel_submesh: Optional[HardwareMesh] = None, + backend: Optional[str] = None, + ici_mesh_shape: Optional[HardwareMesh] = None, + dcn_mesh_shape: Optional[HardwareMesh] = None, +) -> Mesh: """Attempt to return a default mesh for simple cases. Args: @@ -302,15 +307,50 @@ def default_mesh(num_partitions: int, backend: get devices from the pinned backend, if specified. This is useful for explicitly specifying the devices other than relying on jax_platform_name. + ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in each + slice. The meaning of each mesh axis is defined by mesh_axis_names, so + these two params must be the same length. If dcn_mesh_shape is present, + the overall mesh is the product of ici_mesh_shape and dcn_mesh_shape. For + example, an ici_mesh_shape of [2, 3, 4] with mesh_axis_names ['replica', + 'data', 'mdl'] indicates 2-way replica parallelism, 3-way data + parallelism, and 4-way model parallelism over 24 devices. None, the + default, is equivalent to a sequence of ones and means that the model is + placed on a single device. + dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over + multiple slices. The overall mesh is the product of ici_mesh_shape and + dcn_mesh_shape, and the meaning of each mesh axis is defined by + mesh_axis_names, so these three params must be the same length. Returns: - xmap/pjit 2D Mesh with 'data', 'model' mesh axes. + xmap/pjit 2D Mesh with 'data', 'model' mesh axes if single-slice, otherwise + 3D Mesh with 'replica', 'data', and 'model' mesh axes. """ - last_device = jax.devices(backend)[-1] + devices = jax.devices(backend) + last_device = devices[-1] platform = last_device.platform device_kind = last_device.device_kind bounds = bounds_from_last_device(last_device) + if ici_mesh_shape is not None and dcn_mesh_shape is not None: + device_mesh = create_hybrid_device_mesh( + ici_mesh_shape, + dcn_mesh_shape, + devices=devices, + ) + multi_slice_global_mesh = Mesh(device_mesh, ['replica', 'data', 'model']) + logging.info( + 'multi_slice_global_mesh axis_names: %s', + multi_slice_global_mesh.axis_names, + ) + logging.info( + 'multi_slice_global_mesh devices: %s', multi_slice_global_mesh.devices + ) + logging.info( + 'multi_slice_global_mesh devices shape: %s', + multi_slice_global_mesh.devices.shape, + ) + return multi_slice_global_mesh + if model_parallel_submesh: return get_mesh(model_parallel_submesh, backend=backend) @@ -430,6 +470,65 @@ def get_local_chunk_info( chunk_size = size // self.num_chunks[mesh_axis] local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size) + replica_id = self.get_replica_id(sharded_mesh_axes) + + return LocalChunkInfo(tuple(local_slice), replica_id) + + def get_shard_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int: + """Given mesh axes used for sharding, computes current host's shard id. + + To give an example, let's say there are two axes globally: replica, data, + and model, the mesh axes for sharding is ('replica', 'data'), which means we + are going to partition an array along 'replica' and 'data' axes. + The shard_id is to show the index of the current local host along the + sharding axes (in this example, it's 'replica' and 'data' axes). + + More concretely, let's say we have 4 local hosts, and we use 'replica' and + 'data' axes for data parallel (2 hosts along the replica axis, and 2 host + along the data axis). The host located in ('replica': 0, 'data': 0), we + should assign data shard-0 to it. For host ('replica': 0, 'data': 1), we + assign shard-1. For host ('replica': 1, 'data': 0), we assign shard-2. + For host ('replica': 1, 'data': 1), we assign shard-3. + + Note: the host location along 'replica' and 'data' axes, e.g., + ('replica': 0, 'data': 0) is named chunk_id and stored in + self._local_chunker.chunk_ids[axis]. + + Args: + sharded_mesh_axes: the mesh axes for sharding. + + Returns: + the index of the current local host along the sharding axes. + """ + if isinstance(sharded_mesh_axes, str): + sharded_mesh_axes = (sharded_mesh_axes,) + + shard_id = 0 + for mesh_axis in sharded_mesh_axes: + chunk_id = self.chunk_ids[mesh_axis] + shard_id = shard_id * self.num_chunks[mesh_axis] + chunk_id + + return shard_id + + def get_replica_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int: + """Given mesh axes used for sharding, computes current host's replica id. + + To give an example, let's say there are two axes globally: data, and model, + the mesh axes for sharding is ('data', ), which means we are going to + partition an array along 'data' axis and replicate it along 'model' axis. + The replica_id is to show the index of the current local host along the + 'model' axis. + + Args: + sharded_mesh_axes: the mesh axes for sharding. + + Returns: + the index of the current local host along the non-sharding axes (i.e., + replicating axes). + """ + if isinstance(sharded_mesh_axes, str): + sharded_mesh_axes = (sharded_mesh_axes,) + replicated_mesh_axes = [ mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes @@ -439,7 +538,7 @@ def get_local_chunk_info( chunk_id = self.chunk_ids[mesh_axis] replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id - return LocalChunkInfo(tuple(local_slice), replica_id) + return replica_id def standard_logical_axis_rules( @@ -550,11 +649,15 @@ class DataLayout: class BasePartitioner(metaclass=abc.ABCMeta): """Interface for partitioning computations across hardware devices.""" - def __init__(self, - num_partitions: Optional[int] = None, - model_parallel_submesh: Optional[HardwareMesh] = None, - params_on_devices: bool = True, - backend: Optional[str] = None): + def __init__( + self, + num_partitions: Optional[int] = None, + model_parallel_submesh: Optional[HardwareMesh] = None, + params_on_devices: bool = True, + backend: Optional[str] = None, + ici_mesh_shape: Optional[HardwareMesh] = None, + dcn_mesh_shape: Optional[HardwareMesh] = None, + ): """Configures the partitioner. Args: @@ -573,6 +676,19 @@ def __init__(self, backend: get devices from the pinned backend, if specified. This is useful for explicitly specifying the devices other than relying on jax_platform_name. + ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in + each slice. The meaning of each mesh axis is defined by mesh_axis_names, + so these two params must be the same length. If dcn_mesh_shape is + present, the overall mesh is the product of ici_mesh_shape and + dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with + mesh_axis_names ['replica', 'data', 'mdl'] indicates 2-way replica + parallelism, 3-way data parallelism, and 4-way model parallelism over 24 + devices. None, the default, is equivalent to a sequence of ones and + means that the model is placed on a single device. + dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over + multiple slices. The overall mesh is the product of ici_mesh_shape and + dcn_mesh_shape, and the meaning of each mesh axis is defined by + mesh_axis_names, so these three params must be the same length. """ if not num_partitions and not model_parallel_submesh: @@ -601,8 +717,13 @@ def __init__(self, self._num_partitions = num_partitions self._model_parallel_submesh = model_parallel_submesh self._params_on_devices = params_on_devices - self._data_axis = 'data' + if ici_mesh_shape is None or dcn_mesh_shape is None: + self._data_axis = 'data' + else: + self._data_axis = ('replica', 'data') self._backend = backend + self._ici_mesh_shape = ici_mesh_shape + self._dcn_mesh_shape = dcn_mesh_shape @property def mesh(self) -> Mesh: @@ -612,9 +733,63 @@ def mesh(self) -> Mesh: def data_partition_spec(self) -> PartitionSpec: return PartitionSpec(self._data_axis) - def get_data_layout(self, - batch_size: Optional[int] = None, - host_index: Optional[int] = None) -> DataLayout: + @property + def data_mesh_size(self) -> int: + """Data mesh size. + + Data mesh size is defined as the number of global devices involved to + carry out data parallel. Let's say we have a global mesh: ('replica': 2, + 'data': 4, 'model': 2), and axes 'replica' and 'data' are responsible for + the data parallel, that means we have 2*4 = 8 devices involved - i.e., data + mesh size is 8. + + Returns: + the id of the shard for the axes being replicated among the devices used + to shard the sharded_mesh_axes. + """ + data_submesh_sizes = ( + [self.mesh.shape[self._data_axis]] + if isinstance(self._data_axis, str) + else [self.mesh.shape[axis] for axis in self._data_axis] + ) + data_mesh_size = functools.reduce(lambda x, y: x * y, data_submesh_sizes) + return data_mesh_size + + @property + def data_shards(self) -> int: + """Number of data shards. + + Let's say we are dealing with 2 slices of df4x2 TPUs. In data pipeline + we need prepare / send one data shard to each local host. This means, we + need 4 shards since we have 4 local hosts. How to infer the number of hosts + from mesh information? In this case, we have a global mesh: ('replica': 2, + 'data': 8, 'model': 2). Each local host (i.e., df2x2) has this local mesh: + ('replica': 1, 'data': 4, 'model': 2). By dividing global mesh with local + mesh, we can get the count of hosts. + + Returns: + Number of data shards. Each shard will be sent to one local host. + """ + data_chunks = ( + [self._local_chunker.num_chunks[self._data_axis]] + if isinstance(self._data_axis, str) + else [self._local_chunker.num_chunks[axis] for axis in self._data_axis] + ) + data_shards = functools.reduce(lambda x, y: x * y, data_chunks) + return data_shards + + @property + def data_shard_id(self) -> int: + """Data shard id for the current host. + + Returns: + Index of data shard that will be sent to the current local host. + """ + return self._local_chunker.get_shard_id(self._data_axis) + + def get_data_layout( + self, batch_size: Optional[int] = None, host_index: Optional[int] = None + ) -> DataLayout: """Returns filled `DataLayout` based on the partitioned model layout. Args: @@ -637,24 +812,26 @@ def get_data_layout(self, shard_id=0, num_shards=1, is_first_host_in_replica_set=(jax.process_index() == 0)) - mesh_size = self._local_chunker.global_mesh.shape[self._data_axis] - batch_size = batch_size or mesh_size - if batch_size % mesh_size: + + batch_size = batch_size or self.data_mesh_size + if batch_size % self.data_mesh_size: raise ValueError( f'Batch size ({batch_size}) must be divisible by corresponding ' - f'mesh size ({mesh_size}).') - num_shards = self._local_chunker.num_chunks[self._data_axis] - if batch_size % num_shards: + f'data mesh size ({self.data_mesh_size}).' + ) + + if batch_size % self.data_shards: raise ValueError( f'Batch size ({batch_size}) must be divisible by number of ' - f'replicas ({num_shards}).') - replica_id = self._local_chunker.get_local_chunk_info( - (batch_size,), [self._data_axis]).replica_id + f'data shards ({self.data_shards}).' + ) + replica_id = self._local_chunker.get_replica_id(self._data_axis) return DataLayout( batch_size=int(batch_size), - shard_id=int(self._local_chunker.chunk_ids[self._data_axis]), - num_shards=int(num_shards), - is_first_host_in_replica_set=(replica_id == 0)) + shard_id=int(self.data_shard_id), + num_shards=int(self.data_shards), + is_first_host_in_replica_set=(replica_id == 0), + ) def get_local_chunk_info( self, global_shape: Tuple[int, ...], @@ -790,8 +967,13 @@ def _local_chunker(self) -> LocalChunker: @cached_property def mesh(self) -> Mesh: - return default_mesh(self._num_partitions, self._model_parallel_submesh, - self._backend) + return default_mesh( + self._num_partitions, + self._model_parallel_submesh, + self._backend, + self._ici_mesh_shape, + self._dcn_mesh_shape, + ) def partition( self, @@ -825,6 +1007,8 @@ def __init__( model_parallel_submesh: Optional[HardwareMesh] = None, params_on_devices: bool = True, backend: Optional[str] = None, + ici_mesh_shape: Optional[HardwareMesh] = None, + dcn_mesh_shape: Optional[HardwareMesh] = None, logical_axis_rules: Optional[LogicalAxisRules] = None, ): """PjitPartitioner constructor. @@ -856,6 +1040,19 @@ def __init__( backend: get devices from the pinned backend, if specified. This is useful for explicitly specifying the devices other than relying on jax_platform_name. + ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in + each slice. The meaning of each mesh axis is defined by mesh_axis_names, + so these two params must be the same length. If dcn_mesh_shape is + present, the overall mesh is the product of ici_mesh_shape and + dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with + mesh_axis_names ['replica', 'data', 'mdl'] indicates 2-way replica + parallelism, 3-way data parallelism, and 4-way model parallelism over 24 + devices. None, the default, is equivalent to a sequence of ones and + means that the model is placed on a single device. + dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over + multiple slices. The overall mesh is the product of ici_mesh_shape and + dcn_mesh_shape, and the meaning of each mesh axis is defined by + mesh_axis_names, so these three params must be the same length. logical_axis_rules: a priority-ordered sequence of KV tuples that maps logical axis names to either `None` (not sharded), 'model' (to shard across the model-parallel submesh), or 'data' (to shard across the @@ -865,12 +1062,21 @@ def __init__( num_partitions=num_partitions, model_parallel_submesh=model_parallel_submesh, params_on_devices=params_on_devices, - backend=backend) + backend=backend, + ici_mesh_shape=ici_mesh_shape, + dcn_mesh_shape=dcn_mesh_shape, + ) if logical_axis_rules is None: logical_axis_rules = standard_logical_axis_rules() + if ici_mesh_shape is not None and dcn_mesh_shape is not None: + # Split batch over new replica axis. + logical_axis_rules = ( + (k, ('replica', 'data') if k == 'batch' else v) + for k, v in logical_axis_rules + ) self._logical_axis_rules = tuple(logical_axis_rules) (self._data_axis,) = flax_partitioning.logical_to_mesh_axes( - ['batch'], logical_axis_rules + ['batch'], self._logical_axis_rules ) def partition(