Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables multi-slice training in T5X #1409

Merged
merged 1 commit into from
Oct 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 238 additions & 32 deletions t5x/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import abc
import collections
import dataclasses
import functools
import typing
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union

from absl import logging
import cached_property
Expand Down Expand Up @@ -289,9 +290,13 @@ def get_gpu_mesh(num_partitions: int) -> Mesh:
return global_mesh


def default_mesh(num_partitions: int,
model_parallel_submesh: Optional[HardwareMesh] = None,
backend: Optional[str] = None) -> Mesh:
def default_mesh(
num_partitions: int,
model_parallel_submesh: Optional[HardwareMesh] = None,
backend: Optional[str] = None,
ici_mesh_shape: Optional[HardwareMesh] = None,
dcn_mesh_shape: Optional[HardwareMesh] = None,
) -> Mesh:
"""Attempt to return a default mesh for simple cases.

Args:
Expand All @@ -302,15 +307,50 @@ def default_mesh(num_partitions: int,
backend: get devices from the pinned backend, if specified. This is useful
for explicitly specifying the devices other than relying on
jax_platform_name.
ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in each
slice. The meaning of each mesh axis is defined by mesh_axis_names, so
these two params must be the same length. If dcn_mesh_shape is present,
the overall mesh is the product of ici_mesh_shape and dcn_mesh_shape. For
example, an ici_mesh_shape of [2, 3, 4] with mesh_axis_names ['replica',
'data', 'model'] indicates 2-way replica parallelism, 3-way data
parallelism, and 4-way model parallelism over 24 devices. None, the
default, is equivalent to a sequence of ones and means that the model is
placed on a single device.
dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over
multiple slices. The overall mesh is the product of ici_mesh_shape and
dcn_mesh_shape, and the meaning of each mesh axis is defined by
mesh_axis_names, so these three params must be the same length.

Returns:
xmap/pjit 2D Mesh with 'data', 'model' mesh axes.
xmap/pjit 2D Mesh with 'data', 'model' mesh axes if single-slice, otherwise
3D Mesh with 'replica', 'data', and 'model' mesh axes.
"""
last_device = jax.devices(backend)[-1]
devices = jax.devices(backend)
last_device = devices[-1]
platform = last_device.platform
device_kind = last_device.device_kind
bounds = bounds_from_last_device(last_device)

if ici_mesh_shape is not None and dcn_mesh_shape is not None:
device_mesh = create_hybrid_device_mesh(
ici_mesh_shape,
dcn_mesh_shape,
devices=devices,
)
multi_slice_global_mesh = Mesh(device_mesh, ['replica', 'data', 'model'])
logging.info(
'multi_slice_global_mesh axis_names: %s',
multi_slice_global_mesh.axis_names,
)
logging.info(
'multi_slice_global_mesh devices: %s', multi_slice_global_mesh.devices
)
logging.info(
'multi_slice_global_mesh devices shape: %s',
multi_slice_global_mesh.devices.shape,
)
return multi_slice_global_mesh

if model_parallel_submesh:
return get_mesh(model_parallel_submesh, backend=backend)

Expand Down Expand Up @@ -430,6 +470,65 @@ def get_local_chunk_info(
chunk_size = size // self.num_chunks[mesh_axis]
local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size)

replica_id = self.get_replica_id(sharded_mesh_axes)

return LocalChunkInfo(tuple(local_slice), replica_id)

def get_shard_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int:
"""Given mesh axes used for sharding, computes current host's shard id.

To give an example, let's say there are two axes globally: replica, data,
and model, the mesh axes for sharding is ('replica', 'data'), which means we
are going to partition an array along 'replica' and 'data' axes.
The shard_id is to show the index of the current local host along the
sharding axes (in this example, it's 'replica' and 'data' axes).

More concretely, let's say we have 4 local hosts, and we use 'replica' and
'data' axes for data parallel (2 hosts along the replica axis, and 2 host
along the data axis). The host located in ('replica': 0, 'data': 0), we
should assign data shard-0 to it. For host ('replica': 0, 'data': 1), we
assign shard-1. For host ('replica': 1, 'data': 0), we assign shard-2.
For host ('replica': 1, 'data': 1), we assign shard-3.

Note: the host location along 'replica' and 'data' axes, e.g.,
('replica': 0, 'data': 0) is named chunk_id and stored in
self._local_chunker.chunk_ids[axis].

Args:
sharded_mesh_axes: the mesh axes for sharding.

Returns:
the index of the current local host along the sharding axes.
"""
if isinstance(sharded_mesh_axes, str):
sharded_mesh_axes = (sharded_mesh_axes,)

shard_id = 0
for mesh_axis in sharded_mesh_axes:
chunk_id = self.chunk_ids[mesh_axis]
shard_id = shard_id * self.num_chunks[mesh_axis] + chunk_id

return shard_id

def get_replica_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int:
"""Given mesh axes used for sharding, computes current host's replica id.

To give an example, let's say there are two axes globally: data, and model,
the mesh axes for sharding is ('data', ), which means we are going to
partition an array along 'data' axis and replicate it along 'model' axis.
The replica_id is to show the index of the current local host along the
'model' axis.

Args:
sharded_mesh_axes: the mesh axes for sharding.

Returns:
the index of the current local host along the non-sharding axes (i.e.,
replicating axes).
"""
if isinstance(sharded_mesh_axes, str):
sharded_mesh_axes = (sharded_mesh_axes,)

replicated_mesh_axes = [
mesh_axis for mesh_axis in self.mesh_axes
if mesh_axis not in sharded_mesh_axes
Expand All @@ -439,7 +538,7 @@ def get_local_chunk_info(
chunk_id = self.chunk_ids[mesh_axis]
replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id

return LocalChunkInfo(tuple(local_slice), replica_id)
return replica_id


def standard_logical_axis_rules(
Expand Down Expand Up @@ -550,11 +649,15 @@ class DataLayout:
class BasePartitioner(metaclass=abc.ABCMeta):
"""Interface for partitioning computations across hardware devices."""

def __init__(self,
num_partitions: Optional[int] = None,
model_parallel_submesh: Optional[HardwareMesh] = None,
params_on_devices: bool = True,
backend: Optional[str] = None):
def __init__(
self,
num_partitions: Optional[int] = None,
model_parallel_submesh: Optional[HardwareMesh] = None,
params_on_devices: bool = True,
backend: Optional[str] = None,
ici_mesh_shape: Optional[HardwareMesh] = None,
dcn_mesh_shape: Optional[HardwareMesh] = None,
):
"""Configures the partitioner.

Args:
Expand All @@ -573,6 +676,19 @@ def __init__(self,
backend: get devices from the pinned backend, if specified. This is useful
for explicitly specifying the devices other than relying on
jax_platform_name.
ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in
each slice. The meaning of each mesh axis is defined by mesh_axis_names,
so these two params must be the same length. If dcn_mesh_shape is
present, the overall mesh is the product of ici_mesh_shape and
dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with
mesh_axis_names ['replica', 'data', 'mdl'] indicates 2-way replica
parallelism, 3-way data parallelism, and 4-way model parallelism over 24
devices. None, the default, is equivalent to a sequence of ones and
means that the model is placed on a single device.
dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over
multiple slices. The overall mesh is the product of ici_mesh_shape and
dcn_mesh_shape, and the meaning of each mesh axis is defined by
mesh_axis_names, so these three params must be the same length.
"""

if not num_partitions and not model_parallel_submesh:
Expand Down Expand Up @@ -601,8 +717,13 @@ def __init__(self,
self._num_partitions = num_partitions
self._model_parallel_submesh = model_parallel_submesh
self._params_on_devices = params_on_devices
self._data_axis = 'data'
if ici_mesh_shape is None or dcn_mesh_shape is None:
self._data_axis = 'data'
else:
self._data_axis = ('replica', 'data')
self._backend = backend
self._ici_mesh_shape = ici_mesh_shape
self._dcn_mesh_shape = dcn_mesh_shape

@property
def mesh(self) -> Mesh:
Expand All @@ -612,9 +733,63 @@ def mesh(self) -> Mesh:
def data_partition_spec(self) -> PartitionSpec:
return PartitionSpec(self._data_axis)

def get_data_layout(self,
batch_size: Optional[int] = None,
host_index: Optional[int] = None) -> DataLayout:
@property
def data_mesh_size(self) -> int:
"""Data mesh size.

Data mesh size is defined as the number of global devices involved to
carry out data parallel. Let's say we have a global mesh: ('replica': 2,
'data': 4, 'model': 2), and axes 'replica' and 'data' are responsible for
the data parallel, that means we have 2*4 = 8 devices involved - i.e., data
mesh size is 8.

Returns:
the id of the shard for the axes being replicated among the devices used
to shard the sharded_mesh_axes.
"""
data_submesh_sizes = (
[self.mesh.shape[self._data_axis]]
if isinstance(self._data_axis, str)
else [self.mesh.shape[axis] for axis in self._data_axis]
)
data_mesh_size = functools.reduce(lambda x, y: x * y, data_submesh_sizes)
return data_mesh_size

@property
def data_shards(self) -> int:
"""Number of data shards.

Let's say we are dealing with 2 slices of df4x2 TPUs. In data pipeline
we need prepare / send one data shard to each local host. This means, we
need 4 shards since we have 4 local hosts. How to infer the number of hosts
from mesh information? In this case, we have a global mesh: ('replica': 2,
'data': 8, 'model': 2). Each local host (i.e., df2x2) has this local mesh:
('replica': 1, 'data': 4, 'model': 2). By dividing global mesh with local
mesh, we can get the count of hosts.

Returns:
Number of data shards. Each shard will be sent to one local host.
"""
data_chunks = (
[self._local_chunker.num_chunks[self._data_axis]]
if isinstance(self._data_axis, str)
else [self._local_chunker.num_chunks[axis] for axis in self._data_axis]
)
data_shards = functools.reduce(lambda x, y: x * y, data_chunks)
return data_shards

@property
def data_shard_id(self) -> int:
"""Data shard id for the current host.

Returns:
Index of data shard that will be sent to the current local host.
"""
return self._local_chunker.get_shard_id(self._data_axis)

def get_data_layout(
self, batch_size: Optional[int] = None, host_index: Optional[int] = None
) -> DataLayout:
"""Returns filled `DataLayout` based on the partitioned model layout.

Args:
Expand All @@ -637,24 +812,26 @@ def get_data_layout(self,
shard_id=0,
num_shards=1,
is_first_host_in_replica_set=(jax.process_index() == 0))
mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
batch_size = batch_size or mesh_size
if batch_size % mesh_size:

batch_size = batch_size or self.data_mesh_size
if batch_size % self.data_mesh_size:
raise ValueError(
f'Batch size ({batch_size}) must be divisible by corresponding '
f'mesh size ({mesh_size}).')
num_shards = self._local_chunker.num_chunks[self._data_axis]
if batch_size % num_shards:
f'data mesh size ({self.data_mesh_size}).'
)

if batch_size % self.data_shards:
raise ValueError(
f'Batch size ({batch_size}) must be divisible by number of '
f'replicas ({num_shards}).')
replica_id = self._local_chunker.get_local_chunk_info(
(batch_size,), [self._data_axis]).replica_id
f'data shards ({self.data_shards}).'
)
replica_id = self._local_chunker.get_replica_id(self._data_axis)
return DataLayout(
batch_size=int(batch_size),
shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
num_shards=int(num_shards),
is_first_host_in_replica_set=(replica_id == 0))
shard_id=int(self.data_shard_id),
num_shards=int(self.data_shards),
is_first_host_in_replica_set=(replica_id == 0),
)

def get_local_chunk_info(
self, global_shape: Tuple[int, ...],
Expand Down Expand Up @@ -790,8 +967,13 @@ def _local_chunker(self) -> LocalChunker:

@cached_property
def mesh(self) -> Mesh:
return default_mesh(self._num_partitions, self._model_parallel_submesh,
self._backend)
return default_mesh(
self._num_partitions,
self._model_parallel_submesh,
self._backend,
self._ici_mesh_shape,
self._dcn_mesh_shape,
)

def partition(
self,
Expand Down Expand Up @@ -825,6 +1007,8 @@ def __init__(
model_parallel_submesh: Optional[HardwareMesh] = None,
params_on_devices: bool = True,
backend: Optional[str] = None,
ici_mesh_shape: Optional[HardwareMesh] = None,
dcn_mesh_shape: Optional[HardwareMesh] = None,
logical_axis_rules: Optional[LogicalAxisRules] = None,
):
"""PjitPartitioner constructor.
Expand Down Expand Up @@ -856,6 +1040,19 @@ def __init__(
backend: get devices from the pinned backend, if specified. This is useful
for explicitly specifying the devices other than relying on
jax_platform_name.
ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in
each slice. The meaning of each mesh axis is defined by mesh_axis_names,
so these two params must be the same length. If dcn_mesh_shape is
present, the overall mesh is the product of ici_mesh_shape and
dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with
mesh_axis_names ['replica', 'data', 'model'] indicates 2-way replica
parallelism, 3-way data parallelism, and 4-way model parallelism over 24
devices. None, the default, is equivalent to a sequence of ones and
means that the model is placed on a single device.
dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over
multiple slices. The overall mesh is the product of ici_mesh_shape and
dcn_mesh_shape, and the meaning of each mesh axis is defined by
mesh_axis_names, so these three params must be the same length.
logical_axis_rules: a priority-ordered sequence of KV tuples that maps
logical axis names to either `None` (not sharded), 'model' (to shard
across the model-parallel submesh), or 'data' (to shard across the
Expand All @@ -865,12 +1062,21 @@ def __init__(
num_partitions=num_partitions,
model_parallel_submesh=model_parallel_submesh,
params_on_devices=params_on_devices,
backend=backend)
backend=backend,
ici_mesh_shape=ici_mesh_shape,
dcn_mesh_shape=dcn_mesh_shape,
)
if logical_axis_rules is None:
logical_axis_rules = standard_logical_axis_rules()
if ici_mesh_shape is not None and dcn_mesh_shape is not None:
# Split batch over new replica axis.
logical_axis_rules = (
(k, ('replica', 'data') if k == 'batch' else v)
for k, v in logical_axis_rules
)
self._logical_axis_rules = tuple(logical_axis_rules)
(self._data_axis,) = flax_partitioning.logical_to_mesh_axes(
['batch'], logical_axis_rules
['batch'], self._logical_axis_rules
)

def partition(
Expand Down
Loading