diff --git a/t5x/partitioning.py b/t5x/partitioning.py
index 910f666fd..4c37d0cbf 100644
--- a/t5x/partitioning.py
+++ b/t5x/partitioning.py
@@ -17,8 +17,9 @@
 import abc
 import collections
 import dataclasses
+import functools
 import typing
-from typing import Any, Callable, Optional, Sequence, Tuple, Union
+from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union
 
 from absl import logging
 import cached_property
@@ -289,9 +290,13 @@ def get_gpu_mesh(num_partitions: int) -> Mesh:
   return global_mesh
 
 
-def default_mesh(num_partitions: int,
-                 model_parallel_submesh: Optional[HardwareMesh] = None,
-                 backend: Optional[str] = None) -> Mesh:
+def default_mesh(
+    num_partitions: int,
+    model_parallel_submesh: Optional[HardwareMesh] = None,
+    backend: Optional[str] = None,
+    ici_mesh_shape: Optional[HardwareMesh] = None,
+    dcn_mesh_shape: Optional[HardwareMesh] = None,
+) -> Mesh:
   """Attempt to return a default mesh for simple cases.
 
   Args:
@@ -302,15 +307,50 @@ def default_mesh(num_partitions: int,
     backend: get devices from the pinned backend, if specified. This is useful
       for explicitly specifying the devices other than relying on
       jax_platform_name.
+    ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in each
+      slice. The meaning of each mesh axis is defined by mesh_axis_names, so
+      these two params must be the same length. If dcn_mesh_shape is present,
+      the overall mesh is the product of ici_mesh_shape and dcn_mesh_shape. For
+      example, an ici_mesh_shape of [2, 3, 4] with mesh_axis_names ['replica',
+      'data', 'mdl'] indicates 2-way replica parallelism, 3-way data
+      parallelism, and 4-way model parallelism over 24 devices. None, the
+      default, is equivalent to a sequence of ones and means that the model is
+      placed on a single device.
+    dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over
+      multiple slices. The overall mesh is the product of ici_mesh_shape and
+      dcn_mesh_shape, and the meaning of each mesh axis is defined by
+      mesh_axis_names, so these three params must be the same length.
 
   Returns:
-    xmap/pjit 2D Mesh with 'data', 'model' mesh axes.
+    xmap/pjit 2D Mesh with 'data', 'model' mesh axes if single-slice, otherwise
+    3D Mesh with 'replica', 'data', and 'model' mesh axes.
   """
-  last_device = jax.devices(backend)[-1]
+  devices = jax.devices(backend)
+  last_device = devices[-1]
   platform = last_device.platform
   device_kind = last_device.device_kind
   bounds = bounds_from_last_device(last_device)
 
+  if ici_mesh_shape is not None and dcn_mesh_shape is not None:
+    device_mesh = create_hybrid_device_mesh(
+        ici_mesh_shape,
+        dcn_mesh_shape,
+        devices=devices,
+    )
+    multi_slice_global_mesh = Mesh(device_mesh, ['replica', 'data', 'model'])
+    logging.info(
+        'multi_slice_global_mesh axis_names: %s',
+        multi_slice_global_mesh.axis_names,
+    )
+    logging.info(
+        'multi_slice_global_mesh devices: %s', multi_slice_global_mesh.devices
+    )
+    logging.info(
+        'multi_slice_global_mesh devices shape: %s',
+        multi_slice_global_mesh.devices.shape,
+    )
+    return multi_slice_global_mesh
+
   if model_parallel_submesh:
     return get_mesh(model_parallel_submesh, backend=backend)
 
@@ -430,6 +470,65 @@ def get_local_chunk_info(
       chunk_size = size // self.num_chunks[mesh_axis]
       local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size)
 
+    replica_id = self.get_replica_id(sharded_mesh_axes)
+
+    return LocalChunkInfo(tuple(local_slice), replica_id)
+
+  def get_shard_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int:
+    """Given mesh axes used for sharding, computes current host's shard id.
+
+    To give an example, let's say there are two axes globally: replica, data,
+    and model, the mesh axes for sharding is ('replica', 'data'), which means we
+    are going to partition an array along 'replica' and 'data' axes.
+    The shard_id is to show the index of the current local host along the
+    sharding axes (in this example, it's 'replica' and 'data' axes).
+
+    More concretely, let's say we have 4 local hosts, and we use 'replica' and
+    'data' axes for data parallel (2 hosts along the replica axis, and 2 host
+    along the data axis). The host located in ('replica': 0, 'data': 0), we
+    should assign data shard-0 to it. For host ('replica': 0, 'data': 1), we
+    assign shard-1. For host ('replica': 1, 'data': 0), we assign shard-2.
+    For host ('replica': 1, 'data': 1), we assign shard-3.
+
+    Note: the host location along 'replica' and 'data' axes, e.g.,
+    ('replica': 0, 'data': 0) is named chunk_id and stored in
+    self._local_chunker.chunk_ids[axis].
+
+    Args:
+      sharded_mesh_axes: the mesh axes for sharding.
+
+    Returns:
+      the index of the current local host along the sharding axes.
+    """
+    if isinstance(sharded_mesh_axes, str):
+      sharded_mesh_axes = (sharded_mesh_axes,)
+
+    shard_id = 0
+    for mesh_axis in sharded_mesh_axes:
+      chunk_id = self.chunk_ids[mesh_axis]
+      shard_id = shard_id * self.num_chunks[mesh_axis] + chunk_id
+
+    return shard_id
+
+  def get_replica_id(self, sharded_mesh_axes: str | Set[Optional[str]]) -> int:
+    """Given mesh axes used for sharding, computes current host's replica id.
+
+    To give an example, let's say there are two axes globally: data, and model,
+    the mesh axes for sharding is ('data', ), which means we are going to
+    partition an array along 'data' axis and replicate it along 'model' axis.
+    The replica_id is to show the index of the current local host along the
+    'model' axis.
+
+    Args:
+      sharded_mesh_axes: the mesh axes for sharding.
+
+    Returns:
+      the index of the current local host along the non-sharding axes (i.e.,
+      replicating axes).
+    """
+    if isinstance(sharded_mesh_axes, str):
+      sharded_mesh_axes = (sharded_mesh_axes,)
+
     replicated_mesh_axes = [
         mesh_axis for mesh_axis in self.mesh_axes
         if mesh_axis not in sharded_mesh_axes
@@ -439,7 +538,7 @@ def get_local_chunk_info(
       chunk_id = self.chunk_ids[mesh_axis]
       replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id
 
-    return LocalChunkInfo(tuple(local_slice), replica_id)
+    return replica_id
 
 
 def standard_logical_axis_rules(
@@ -550,11 +649,15 @@ class DataLayout:
 class BasePartitioner(metaclass=abc.ABCMeta):
   """Interface for partitioning computations across hardware devices."""
 
-  def __init__(self,
-               num_partitions: Optional[int] = None,
-               model_parallel_submesh: Optional[HardwareMesh] = None,
-               params_on_devices: bool = True,
-               backend: Optional[str] = None):
+  def __init__(
+      self,
+      num_partitions: Optional[int] = None,
+      model_parallel_submesh: Optional[HardwareMesh] = None,
+      params_on_devices: bool = True,
+      backend: Optional[str] = None,
+      ici_mesh_shape: Optional[HardwareMesh] = None,
+      dcn_mesh_shape: Optional[HardwareMesh] = None,
+  ):
     """Configures the partitioner.
 
     Args:
@@ -573,6 +676,19 @@ def __init__(self,
       backend: get devices from the pinned backend, if specified. This is useful
         for explicitly specifying the devices other than relying on
         jax_platform_name.
+      ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in
+        each slice. The meaning of each mesh axis is defined by mesh_axis_names,
+        so these two params must be the same length. If dcn_mesh_shape is
+        present, the overall mesh is the product of ici_mesh_shape and
+        dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with
+        mesh_axis_names ['replica', 'data', 'mdl'] indicates 2-way replica
+        parallelism, 3-way data parallelism, and 4-way model parallelism over 24
+        devices. None, the default, is equivalent to a sequence of ones and
+        means that the model is placed on a single device.
+      dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over
+        multiple slices. The overall mesh is the product of ici_mesh_shape and
+        dcn_mesh_shape, and the meaning of each mesh axis is defined by
+        mesh_axis_names, so these three params must be the same length.
     """
 
     if not num_partitions and not model_parallel_submesh:
@@ -601,8 +717,13 @@ def __init__(self,
     self._num_partitions = num_partitions
     self._model_parallel_submesh = model_parallel_submesh
     self._params_on_devices = params_on_devices
-    self._data_axis = 'data'
+    if ici_mesh_shape is None or dcn_mesh_shape is None:
+      self._data_axis = 'data'
+    else:
+      self._data_axis = ('replica', 'data')
     self._backend = backend
+    self._ici_mesh_shape = ici_mesh_shape
+    self._dcn_mesh_shape = dcn_mesh_shape
 
   @property
   def mesh(self) -> Mesh:
@@ -612,9 +733,63 @@ def mesh(self) -> Mesh:
   def data_partition_spec(self) -> PartitionSpec:
     return PartitionSpec(self._data_axis)
 
-  def get_data_layout(self,
-                      batch_size: Optional[int] = None,
-                      host_index: Optional[int] = None) -> DataLayout:
+  @property
+  def data_mesh_size(self) -> int:
+    """Data mesh size.
+
+    Data mesh size is defined as the number of global devices involved to
+    carry out data parallel. Let's say we have a global mesh: ('replica': 2,
+    'data': 4, 'model': 2), and axes 'replica' and 'data' are responsible for
+    the data parallel, that means we have 2*4 = 8 devices involved - i.e., data
+    mesh size is 8.
+
+    Returns:
+      the id of the shard for the axes being replicated among the devices used
+      to shard the sharded_mesh_axes.
+    """
+    data_submesh_sizes = (
+        [self.mesh.shape[self._data_axis]]
+        if isinstance(self._data_axis, str)
+        else [self.mesh.shape[axis] for axis in self._data_axis]
+    )
+    data_mesh_size = functools.reduce(lambda x, y: x * y, data_submesh_sizes)
+    return data_mesh_size
+
+  @property
+  def data_shards(self) -> int:
+    """Number of data shards.
+
+    Let's say we are dealing with 2 slices of df4x2 TPUs. In data pipeline
+    we need prepare / send one data shard to each local host. This means, we
+    need 4 shards since we have 4 local hosts. How to infer the number of hosts
+    from mesh information? In this case, we have a global mesh: ('replica': 2,
+    'data': 8, 'model': 2). Each local host (i.e., df2x2) has this local mesh:
+    ('replica': 1, 'data': 4, 'model': 2). By dividing global mesh with local
+    mesh, we can get the count of hosts.
+
+    Returns:
+      Number of data shards. Each shard will be sent to one local host.
+    """
+    data_chunks = (
+        [self._local_chunker.num_chunks[self._data_axis]]
+        if isinstance(self._data_axis, str)
+        else [self._local_chunker.num_chunks[axis] for axis in self._data_axis]
+    )
+    data_shards = functools.reduce(lambda x, y: x * y, data_chunks)
+    return data_shards
+
+  @property
+  def data_shard_id(self) -> int:
+    """Data shard id for the current host.
+
+    Returns:
+      Index of data shard that will be sent to the current local host.
+    """
+    return self._local_chunker.get_shard_id(self._data_axis)
+
+  def get_data_layout(
+      self, batch_size: Optional[int] = None, host_index: Optional[int] = None
+  ) -> DataLayout:
     """Returns filled `DataLayout` based on the partitioned model layout.
 
     Args:
@@ -637,24 +812,26 @@ def get_data_layout(self,
           shard_id=0,
           num_shards=1,
           is_first_host_in_replica_set=(jax.process_index() == 0))
-    mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
-    batch_size = batch_size or mesh_size
-    if batch_size % mesh_size:
+
+    batch_size = batch_size or self.data_mesh_size
+    if batch_size % self.data_mesh_size:
       raise ValueError(
           f'Batch size ({batch_size}) must be divisible by corresponding '
-          f'mesh size ({mesh_size}).')
-    num_shards = self._local_chunker.num_chunks[self._data_axis]
-    if batch_size % num_shards:
+          f'data mesh size ({self.data_mesh_size}).'
+      )
+
+    if batch_size % self.data_shards:
       raise ValueError(
           f'Batch size ({batch_size}) must be divisible by number of '
-          f'replicas ({num_shards}).')
-    replica_id = self._local_chunker.get_local_chunk_info(
-        (batch_size,), [self._data_axis]).replica_id
+          f'data shards ({self.data_shards}).'
+      )
+    replica_id = self._local_chunker.get_replica_id(self._data_axis)
     return DataLayout(
         batch_size=int(batch_size),
-        shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
-        num_shards=int(num_shards),
-        is_first_host_in_replica_set=(replica_id == 0))
+        shard_id=int(self.data_shard_id),
+        num_shards=int(self.data_shards),
+        is_first_host_in_replica_set=(replica_id == 0),
+    )
 
   def get_local_chunk_info(
       self, global_shape: Tuple[int, ...],
@@ -790,8 +967,13 @@ def _local_chunker(self) -> LocalChunker:
 
   @cached_property
   def mesh(self) -> Mesh:
-    return default_mesh(self._num_partitions, self._model_parallel_submesh,
-                        self._backend)
+    return default_mesh(
+        self._num_partitions,
+        self._model_parallel_submesh,
+        self._backend,
+        self._ici_mesh_shape,
+        self._dcn_mesh_shape,
+    )
 
   def partition(
       self,
@@ -825,6 +1007,8 @@ def __init__(
       model_parallel_submesh: Optional[HardwareMesh] = None,
       params_on_devices: bool = True,
       backend: Optional[str] = None,
+      ici_mesh_shape: Optional[HardwareMesh] = None,
+      dcn_mesh_shape: Optional[HardwareMesh] = None,
       logical_axis_rules: Optional[LogicalAxisRules] = None,
   ):
     """PjitPartitioner constructor.
@@ -856,6 +1040,19 @@ def __init__(
       backend: get devices from the pinned backend, if specified. This is useful
         for explicitly specifying the devices other than relying on
         jax_platform_name.
+      ici_mesh_shape: Shape of the logical mesh used for SPMD parallelism in
+        each slice. The meaning of each mesh axis is defined by mesh_axis_names,
+        so these two params must be the same length. If dcn_mesh_shape is
+        present, the overall mesh is the product of ici_mesh_shape and
+        dcn_mesh_shape. For example, an ici_mesh_shape of [2, 3, 4] with
+        mesh_axis_names ['replica', 'data', 'mdl'] indicates 2-way replica
+        parallelism, 3-way data parallelism, and 4-way model parallelism over 24
+        devices. None, the default, is equivalent to a sequence of ones and
+        means that the model is placed on a single device.
+      dcn_mesh_shape: Shape of the logical mesh used for SPMD parallelism over
+        multiple slices. The overall mesh is the product of ici_mesh_shape and
+        dcn_mesh_shape, and the meaning of each mesh axis is defined by
+        mesh_axis_names, so these three params must be the same length.
       logical_axis_rules: a priority-ordered sequence of KV tuples that maps
         logical axis names to either `None` (not sharded), 'model' (to shard
         across the model-parallel submesh), or 'data' (to shard across the
@@ -865,12 +1062,21 @@ def __init__(
         num_partitions=num_partitions,
         model_parallel_submesh=model_parallel_submesh,
         params_on_devices=params_on_devices,
-        backend=backend)
+        backend=backend,
+        ici_mesh_shape=ici_mesh_shape,
+        dcn_mesh_shape=dcn_mesh_shape,
+    )
     if logical_axis_rules is None:
       logical_axis_rules = standard_logical_axis_rules()
+    if ici_mesh_shape is not None and dcn_mesh_shape is not None:
+      # Split batch over new replica axis.
+      logical_axis_rules = (
+          (k, ('replica', 'data') if k == 'batch' else v)
+          for k, v in logical_axis_rules
+      )
     self._logical_axis_rules = tuple(logical_axis_rules)
     (self._data_axis,) = flax_partitioning.logical_to_mesh_axes(
-        ['batch'], logical_axis_rules
+        ['batch'], self._logical_axis_rules
     )
 
   def partition(